From 8dc47fcc1f3848e44c75514c364d284cd10077fe Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 21 Dec 2023 16:10:51 +0800 Subject: [PATCH 1/6] enh: connector implements autocommit --- ws/tmq/consumer.go | 162 +++++++++++++++++++++++----------------- ws/tmq/consumer_test.go | 101 +++++++++++++++++++++++++ 2 files changed, 195 insertions(+), 68 deletions(-) diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index df1889b..428570f 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "fmt" + "strconv" "sync" "sync/atomic" "time" @@ -21,26 +22,26 @@ import ( ) type Consumer struct { - client *client.Client - requestID uint64 - err error - latestMessageID uint64 - listLock sync.RWMutex - sendChanList *list.List - messageTimeout time.Duration - url string - user string - password string - groupID string - clientID string - offsetRest string - autoCommit string - autoCommitIntervalMS string - snapshotEnable string - withTableName string - closeOnce sync.Once - closeChan chan struct{} - topics []string + client *client.Client + requestID uint64 + err error + listLock sync.RWMutex + sendChanList *list.List + messageTimeout time.Duration + autoCommit bool + autoCommitInterval time.Duration + nextAutoCommitTime time.Time + url string + user string + password string + groupID string + clientID string + offsetRest string + snapshotEnable string + withTableName string + closeOnce sync.Once + closeChan chan struct{} + topics []string } type IndexedChan struct { @@ -63,37 +64,51 @@ func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { if err != nil { return nil, err } + autoCommit := true + if config.AutoCommit == "false" { + autoCommit = false + } + autoCommitInterval := time.Second * 5 + if config.AutoCommitIntervalMS != "" { + interval, err := strconv.ParseUint(config.AutoCommitIntervalMS, 10, 64) + if err != nil { + return nil, err + } + autoCommitInterval = time.Millisecond * time.Duration(interval) + } + ws, _, err := common.DefaultDialer.Dial(config.Url, nil) if err != nil { return nil, err } wsClient := client.NewClient(ws, config.ChanLength) - tmq := &Consumer{ - client: wsClient, - requestID: 0, - sendChanList: list.New(), - messageTimeout: config.MessageTimeout, - url: config.Url, - user: config.User, - password: config.Password, - groupID: config.GroupID, - clientID: config.ClientID, - offsetRest: config.OffsetRest, - autoCommit: config.AutoCommit, - autoCommitIntervalMS: config.AutoCommitIntervalMS, - snapshotEnable: config.SnapshotEnable, - withTableName: config.WithTableName, - closeChan: make(chan struct{}), + + consumer := &Consumer{ + client: wsClient, + requestID: 0, + sendChanList: list.New(), + messageTimeout: config.MessageTimeout, + url: config.Url, + user: config.User, + password: config.Password, + groupID: config.GroupID, + clientID: config.ClientID, + offsetRest: config.OffsetRest, + autoCommit: autoCommit, + autoCommitInterval: autoCommitInterval, + snapshotEnable: config.SnapshotEnable, + withTableName: config.WithTableName, + closeChan: make(chan struct{}), } if config.WriteWait > 0 { wsClient.WriteWait = config.WriteWait } - wsClient.BinaryMessageHandler = tmq.handleBinaryMessage - wsClient.TextMessageHandler = tmq.handleTextMessage - wsClient.ErrorHandler = tmq.handleError + wsClient.BinaryMessageHandler = consumer.handleBinaryMessage + wsClient.TextMessageHandler = consumer.handleTextMessage + wsClient.ErrorHandler = consumer.handleError go wsClient.WritePump() go wsClient.ReadPump() - return tmq, nil + return consumer, nil } func configMapToConfig(m *tmq.ConfigMap) (*config, error) { @@ -294,7 +309,6 @@ const ( TMQCommitOffset = "commit_offset" TMQCommitted = "committed" TMQPosition = "position" - TMQListTopics = "list_topics" ) var ClosedErr = errors.New("connection closed") @@ -338,17 +352,16 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err } reqID := c.generateReqID() req := &SubscribeReq{ - ReqID: reqID, - User: c.user, - Password: c.password, - GroupID: c.groupID, - ClientID: c.clientID, - OffsetRest: c.offsetRest, - Topics: topics, - AutoCommit: c.autoCommit, - AutoCommitIntervalMS: c.autoCommitIntervalMS, - SnapshotEnable: c.snapshotEnable, - WithTableName: c.withTableName, + ReqID: reqID, + User: c.user, + Password: c.password, + GroupID: c.groupID, + ClientID: c.clientID, + OffsetRest: c.offsetRest, + Topics: topics, + AutoCommit: "false", + SnapshotEnable: c.snapshotEnable, + WithTableName: c.withTableName, } args, err := client.JsonI.Marshal(req) if err != nil { @@ -384,7 +397,17 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err // Poll messages func (c *Consumer) Poll(timeoutMs int) tmq.Event { if c.err != nil { - panic(c.err) + return tmq.NewTMQErrorWithErr(c.err) + } + if c.autoCommit { + if c.nextAutoCommitTime.IsZero() { + c.nextAutoCommitTime = time.Now().Add(c.autoCommitInterval) + } else { + if time.Now().After(c.nextAutoCommitTime) { + c.doCommit() + c.nextAutoCommitTime = time.Now().Add(c.autoCommitInterval) + } + } } reqID := c.generateReqID() req := &PollReq{ @@ -417,7 +440,6 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event { if resp.Code != 0 { panic(taosErrors.NewError(resp.Code, resp.Message)) } - c.latestMessageID = resp.MessageID if resp.HaveMessage { switch resp.MessageType { case common.TMQ_RES_DATA: @@ -600,21 +622,29 @@ func (c *Consumer) fetch(messageID uint64) ([]*tmq.Data, error) { } func (c *Consumer) Commit() ([]tmq.TopicPartition, error) { - return c.doCommit(c.latestMessageID) + err := c.doCommit() + if err != nil { + return nil, err + } + partitions, err := c.Assignment() + if err != nil { + return nil, err + } + return c.Committed(partitions, 0) } -func (c *Consumer) doCommit(messageID uint64) ([]tmq.TopicPartition, error) { +func (c *Consumer) doCommit() error { if c.err != nil { - return nil, c.err + return c.err } reqID := c.generateReqID() req := &CommitReq{ ReqID: reqID, - MessageID: messageID, + MessageID: 0, } args, err := client.JsonI.Marshal(req) if err != nil { - return nil, err + return err } action := &client.WSAction{ Action: TMQCommit, @@ -624,25 +654,21 @@ func (c *Consumer) doCommit(messageID uint64) ([]tmq.TopicPartition, error) { err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { c.client.PutEnvelope(envelope) - return nil, err + return err } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return nil, err + return err } var resp CommitResp err = client.JsonI.Unmarshal(respBytes, &resp) if err != nil { - return nil, err + return err } if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } - partitions, err := c.Assignment() - if err != nil { - return nil, err + return taosErrors.NewError(resp.Code, resp.Message) } - return c.Committed(partitions, 0) + return nil } func (c *Consumer) Unsubscribe() error { diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index a7bf95e..2da15b8 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -350,3 +350,104 @@ func TestSeek(t *testing.T) { assert.Equal(t, "test_ws_tmq_seek_topic", *partitions[0].Topic) assert.GreaterOrEqual(t, partitions[0].Offset, messageOffset) } + +func prepareAutocommitEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_autocommit_topic", + "drop database if exists test_ws_tmq_autocommit", + "create database test_ws_tmq_autocommit vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_autocommit_topic as database test_ws_tmq_autocommit", + "create table test_ws_tmq_autocommit.t1(ts timestamp,v int)", + "insert into test_ws_tmq_autocommit.t1 values (now,1)", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanAutocommitEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_autocommit_topic", + "drop database if exists test_ws_tmq_autocommit", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func TestAutoCommit(t *testing.T) { + err := prepareAutocommitEnv() + if err != nil { + t.Error(err) + return + } + defer cleanAutocommitEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "experimental.snapshot.enable": "false", + "msg.with.table.name": "true", + }) + assert.NoError(t, err) + if err != nil { + t.Error(err) + return + } + defer func() { + consumer.Unsubscribe() + consumer.Close() + }() + topic := []string{"test_ws_tmq_autocommit_topic"} + err = consumer.SubscribeTopics(topic, nil) + if err != nil { + t.Error(err) + return + } + partitions, err := consumer.Assignment() + assert.NoError(t, err) + assert.Equal(t, 1, len(partitions)) + assert.Equal(t, "test_ws_tmq_autocommit_topic", *partitions[0].Topic) + assert.Equal(t, tmq.Offset(0), partitions[0].Offset) + + offset, err := consumer.Committed(partitions, 0) + assert.NoError(t, err) + assert.Equal(t, 1, len(offset)) + assert.Equal(t, tmq.OffsetInvalid, offset[0].Offset) + + //poll + messageOffset := tmq.Offset(0) + haveMessage := false + for i := 0; i < 5; i++ { + event := consumer.Poll(500) + if event != nil { + haveMessage = true + messageOffset = event.(*tmq.DataMessage).Offset() + } + } + assert.True(t, haveMessage) + + offset, err = consumer.Committed(partitions, 0) + assert.NoError(t, err) + assert.Equal(t, 1, len(offset)) + assert.GreaterOrEqual(t, offset[0].Offset, messageOffset) +} From 8974e359ca27cf83b41e10d90eecf93a06b31f02 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 28 Dec 2023 16:02:36 +0800 Subject: [PATCH 2/6] enh: enable compression --- README-CN.md | 1 + README.md | 1 + af/tmq/consumer_test.go | 17 ++++++++--------- examples/tmq/main.go | 19 +++++++++---------- taosWS/connection.go | 6 ++++-- taosWS/driver_test.go | 2 +- taosWS/dsn.go | 6 ++++++ taosWS/dsn_test.go | 9 +++++++++ ws/schemaless/config.go | 23 +++++++++++++++-------- ws/schemaless/schemaless.go | 5 ++++- ws/schemaless/schemaless_test.go | 1 + ws/stmt/config.go | 23 ++++++++++++++--------- ws/stmt/connector.go | 5 ++++- ws/stmt/stmt_test.go | 1 + ws/tmq/config.go | 10 ++++++++++ ws/tmq/consumer.go | 13 ++++++++++++- ws/tmq/consumer_test.go | 27 +++++++++++++-------------- 17 files changed, 113 insertions(+), 56 deletions(-) diff --git a/README-CN.md b/README-CN.md index e8b92da..6f22c5d 100644 --- a/README-CN.md +++ b/README-CN.md @@ -484,6 +484,7 @@ DSN 格式为: - `writeTimeout` 通过 websocket 发送数据的超时时间。 - `readTimeout` 通过 websocket 接收响应数据的超时时间。 +- `enableCompression` 是否压缩传输数据,默认为 `false` 不发送压缩数据。 ## 通过 websocket 使用 tmq diff --git a/README.md b/README.md index 11da9bd..a4125c5 100644 --- a/README.md +++ b/README.md @@ -485,6 +485,7 @@ Parameters: - `writeTimeout` The timeout to send data via websocket. - `readTimeout` The timeout to receive response data via websocket. +- `enableCompression` Whether to compress the transmitted data, the default is `false` and no compressed data is sent. ## Using tmq over websocket diff --git a/af/tmq/consumer_test.go b/af/tmq/consumer_test.go index 4dd024c..cd3430e 100644 --- a/af/tmq/consumer_test.go +++ b/af/tmq/consumer_test.go @@ -68,15 +68,14 @@ func TestTmq(t *testing.T) { assert.NoError(t, err) consumer, err := NewConsumer(&tmq.ConfigMap{ - "group.id": "test", - "auto.offset.reset": "earliest", - "td.connect.ip": "127.0.0.1", - "td.connect.user": "root", - "td.connect.pass": "taosdata", - "td.connect.port": "6030", - "client.id": "test_tmq_c", - "enable.auto.commit": "false", - //"experimental.snapshot.enable": "true", + "group.id": "test", + "auto.offset.reset": "earliest", + "td.connect.ip": "127.0.0.1", + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "td.connect.port": "6030", + "client.id": "test_tmq_c", + "enable.auto.commit": "false", "msg.with.table.name": "true", }) if err != nil { diff --git a/examples/tmq/main.go b/examples/tmq/main.go index eb110ce..01e2010 100644 --- a/examples/tmq/main.go +++ b/examples/tmq/main.go @@ -27,16 +27,15 @@ func main() { panic(err) } consumer, err := tmq.NewConsumer(&tmqcommon.ConfigMap{ - "group.id": "test", - "auto.offset.reset": "earliest", - "td.connect.ip": "127.0.0.1", - "td.connect.user": "root", - "td.connect.pass": "taosdata", - "td.connect.port": "6030", - "client.id": "test_tmq_client", - "enable.auto.commit": "false", - "experimental.snapshot.enable": "true", - "msg.with.table.name": "true", + "group.id": "test", + "auto.offset.reset": "earliest", + "td.connect.ip": "127.0.0.1", + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "td.connect.port": "6030", + "client.id": "test_tmq_client", + "enable.auto.commit": "false", + "msg.with.table.name": "true", }) if err != nil { panic(err) diff --git a/taosWS/connection.go b/taosWS/connection.go index 71697e6..df339bc 100644 --- a/taosWS/connection.go +++ b/taosWS/connection.go @@ -57,11 +57,13 @@ func newTaosConn(cfg *config) (*taosConn, error) { endpointUrl.RawQuery = fmt.Sprintf("token=%s", cfg.token) } endpoint := endpointUrl.String() - ws, _, err := common.DefaultDialer.Dial(endpoint, nil) + dialer := common.DefaultDialer + dialer.EnableCompression = cfg.enableCompression + ws, _, err := dialer.Dial(endpoint, nil) if err != nil { return nil, err } - ws.SetReadLimit(common.BufferSize4M) + ws.EnableWriteCompression(cfg.enableCompression) ws.SetReadDeadline(time.Now().Add(common.DefaultPongWait)) ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(common.DefaultPongWait)) diff --git a/taosWS/driver_test.go b/taosWS/driver_test.go index 9e12d7a..56b70e8 100644 --- a/taosWS/driver_test.go +++ b/taosWS/driver_test.go @@ -34,7 +34,7 @@ var ( port = 6041 dbName = "test_taos_ws" dataSourceName = fmt.Sprintf("%s:%s@ws(%s:%d)/", user, password, host, port) - dataSourceNameWithCompression = fmt.Sprintf("%s:%s@ws(%s:%d)/?disableCompression=false", user, password, host, port) + dataSourceNameWithCompression = fmt.Sprintf("%s:%s@ws(%s:%d)/?enableCompression=true", user, password, host, port) ) type DBTest struct { diff --git a/taosWS/dsn.go b/taosWS/dsn.go index ca5ff22..aa8e4dd 100644 --- a/taosWS/dsn.go +++ b/taosWS/dsn.go @@ -29,6 +29,7 @@ type config struct { params map[string]string // Connection parameters interpolateParams bool // Interpolate placeholders into query string token string // cloud platform token + enableCompression bool // Enable write compression readTimeout time.Duration // read message timeout writeTimeout time.Duration // write message timeout } @@ -143,6 +144,11 @@ func parseDSNParams(cfg *config, params string) (err error) { } case "token": cfg.token = value + case "enableCompression": + cfg.enableCompression, err = strconv.ParseBool(value) + if err != nil { + return &errors.TaosError{Code: 0xffff, ErrStr: "invalid enableCompression value: " + value} + } case "readTimeout": cfg.readTimeout, err = time.ParseDuration(value) if err != nil { diff --git a/taosWS/dsn_test.go b/taosWS/dsn_test.go index edfd013..7d7c7d7 100644 --- a/taosWS/dsn_test.go +++ b/taosWS/dsn_test.go @@ -25,6 +25,15 @@ func TestParseDsn(t *testing.T) { {dsn: "user:passwd@wss(:0)/?interpolateParams=false&test=1", want: &config{user: "user", passwd: "passwd", net: "wss", params: map[string]string{"test": "1"}}}, {dsn: "user:passwd@wss(:0)/?interpolateParams=false&token=token", want: &config{user: "user", passwd: "passwd", net: "wss", token: "token"}}, {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m", want: &config{user: "user", passwd: "passwd", net: "wss", readTimeout: 10 * time.Minute, writeTimeout: 8 * time.Second, interpolateParams: true}}, + {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m&enableCompression=true", want: &config{ + user: "user", + passwd: "passwd", + net: "wss", + readTimeout: 10 * time.Minute, + writeTimeout: 8 * time.Second, + interpolateParams: true, + enableCompression: true, + }}, } for _, tc := range tests { t.Run(tc.dsn, func(t *testing.T) { diff --git a/ws/schemaless/config.go b/ws/schemaless/config.go index 58f65b0..d62eb3b 100644 --- a/ws/schemaless/config.go +++ b/ws/schemaless/config.go @@ -10,14 +10,15 @@ const ( ) type Config struct { - url string - chanLength uint - user string - password string - db string - readTimeout time.Duration - writeTimeout time.Duration - errorHandler func(error) + url string + chanLength uint + user string + password string + db string + readTimeout time.Duration + writeTimeout time.Duration + errorHandler func(error) + enableCompression bool } func NewConfig(url string, chanLength uint, opts ...func(*Config)) *Config { @@ -64,3 +65,9 @@ func SetErrorHandler(errorHandler func(error)) func(*Config) { c.errorHandler = errorHandler } } + +func SetEnableCompression(enableCompression bool) func(*Config) { + return func(c *Config) { + c.enableCompression = enableCompression + } +} diff --git a/ws/schemaless/schemaless.go b/ws/schemaless/schemaless.go index db44e09..c6dcf6b 100644 --- a/ws/schemaless/schemaless.go +++ b/ws/schemaless/schemaless.go @@ -47,7 +47,10 @@ func NewSchemaless(config *Config) (*Schemaless, error) { if len(wsUrl.Path) == 0 || wsUrl.Path != "/rest/schemaless" { wsUrl.Path = "/rest/schemaless" } - ws, _, err := common.DefaultDialer.Dial(wsUrl.String(), nil) + dialer := common.DefaultDialer + dialer.EnableCompression = config.enableCompression + ws, _, err := dialer.Dial(wsUrl.String(), nil) + ws.EnableWriteCompression(config.enableCompression) if err != nil { return nil, fmt.Errorf("dial ws error: %s", err) } diff --git a/ws/schemaless/schemaless_test.go b/ws/schemaless/schemaless_test.go index bfa3199..14dbb0d 100644 --- a/ws/schemaless/schemaless_test.go +++ b/ws/schemaless/schemaless_test.go @@ -68,6 +68,7 @@ func TestSchemaless_Insert(t *testing.T) { SetWriteTimeout(10*time.Second), SetUser("root"), SetPassword("taosdata"), + SetEnableCompression(true), SetErrorHandler(func(err error) { t.Fatal(err) }), diff --git a/ws/stmt/config.go b/ws/stmt/config.go index 332ac55..7eab614 100644 --- a/ws/stmt/config.go +++ b/ws/stmt/config.go @@ -6,15 +6,16 @@ import ( ) type Config struct { - Url string - ChanLength uint - MessageTimeout time.Duration - WriteWait time.Duration - ErrorHandler func(connector *Connector, err error) - CloseHandler func() - User string - Password string - DB string + Url string + ChanLength uint + MessageTimeout time.Duration + WriteWait time.Duration + ErrorHandler func(connector *Connector, err error) + CloseHandler func() + User string + Password string + DB string + EnableCompression bool } func NewConfig(url string, chanLength uint) *Config { @@ -60,3 +61,7 @@ func (c *Config) SetErrorHandler(f func(connector *Connector, err error)) { func (c *Config) SetCloseHandler(f func()) { c.CloseHandler = f } + +func (c *Config) SetEnableCompression(enableCompression bool) { + c.EnableCompression = enableCompression +} diff --git a/ws/stmt/connector.go b/ws/stmt/connector.go index 01ba361..48830c2 100644 --- a/ws/stmt/connector.go +++ b/ws/stmt/connector.go @@ -44,10 +44,13 @@ func NewConnector(config *Config) (*Connector, error) { if config.WriteWait > 0 { writeTimeout = config.WriteWait } - ws, _, err := common.DefaultDialer.Dial(config.Url, nil) + dialer := common.DefaultDialer + dialer.EnableCompression = config.EnableCompression + ws, _, err := dialer.Dial(config.Url, nil) if err != nil { return nil, err } + ws.EnableWriteCompression(config.EnableCompression) defer func() { if connector == nil { ws.Close() diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go index 4cc1631..f8f8f82 100644 --- a/ws/stmt/stmt_test.go +++ b/ws/stmt/stmt_test.go @@ -164,6 +164,7 @@ func TestStmt(t *testing.T) { config.SetConnectDB("test_ws_stmt") config.SetMessageTimeout(common.DefaultMessageTimeout) config.SetWriteWait(common.DefaultWriteWait) + config.SetEnableCompression(true) config.SetErrorHandler(func(connector *Connector, err error) { t.Log(err) }) diff --git a/ws/tmq/config.go b/ws/tmq/config.go index 88ed25b..36ec47e 100644 --- a/ws/tmq/config.go +++ b/ws/tmq/config.go @@ -22,6 +22,7 @@ type config struct { AutoCommitIntervalMS string SnapshotEnable string WithTableName string + EnableCompression bool } func newConfig(url string, chanLength uint) *config { @@ -138,3 +139,12 @@ func (c *config) setWithTableName(withTableName tmq.ConfigValue) error { } return nil } + +func (c *config) setEnableCompression(enableCompression tmq.ConfigValue) error { + var ok bool + c.EnableCompression, ok = enableCompression.(bool) + if !ok { + return fmt.Errorf("ws.message.enableCompression requires bool got %T", enableCompression) + } + return nil +} diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index 428570f..26da8e9 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -77,10 +77,13 @@ func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { autoCommitInterval = time.Millisecond * time.Duration(interval) } - ws, _, err := common.DefaultDialer.Dial(config.Url, nil) + dialer := common.DefaultDialer + dialer.EnableCompression = config.EnableCompression + ws, _, err := dialer.Dial(config.Url, nil) if err != nil { return nil, err } + ws.EnableWriteCompression(config.EnableCompression) wsClient := client.NewClient(ws, config.ChanLength) consumer := &Consumer{ @@ -168,6 +171,10 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { if err != nil { return nil, err } + enableCompression, err := m.Get("ws.message.enableCompression", false) + if err != nil { + return nil, err + } config := newConfig(url.(string), chanLen.(uint)) err = config.setMessageTimeout(messageTimeout.(time.Duration)) if err != nil { @@ -213,6 +220,10 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { if err != nil { return nil, err } + err = config.setEnableCompression(enableCompression) + if err != nil { + return nil, err + } return config, nil } diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index 2da15b8..ab43da2 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -276,8 +276,8 @@ func TestSeek(t *testing.T) { "client.id": "test_consumer", "auto.offset.reset": "earliest", "enable.auto.commit": "false", - "experimental.snapshot.enable": "false", "msg.with.table.name": "true", + "ws.message.enableCompression": true, }) if err != nil { t.Error(err) @@ -394,19 +394,18 @@ func TestAutoCommit(t *testing.T) { } defer cleanAutocommitEnv() consumer, err := NewConsumer(&tmq.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", - "ws.message.channelLen": uint(0), - "ws.message.timeout": common.DefaultMessageTimeout, - "ws.message.writeWait": common.DefaultWriteWait, - "td.connect.user": "root", - "td.connect.pass": "taosdata", - "group.id": "test", - "client.id": "test_consumer", - "auto.offset.reset": "earliest", - "enable.auto.commit": "true", - "auto.commit.interval.ms": "1000", - "experimental.snapshot.enable": "false", - "msg.with.table.name": "true", + "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "msg.with.table.name": "true", }) assert.NoError(t, err) if err != nil { From eaa0e81ff946c242c0be399124ec5548b41a6956 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Wed, 3 Jan 2024 16:09:42 +0800 Subject: [PATCH 3/6] enh: support websocket stmt query --- common/param/column.go | 4 + common/param/param.go | 9 + common/serializer/block.go | 4 +- taosSql/statement.go | 6 +- taosWS/connection.go | 311 +++++- taosWS/proto.go | 125 ++- taosWS/rows.go | 4 + taosWS/statement.go | 517 +++++++++ taosWS/statement_test.go | 2159 ++++++++++++++++++++++++++++++++++++ 9 files changed, 3130 insertions(+), 9 deletions(-) create mode 100644 taosWS/statement.go create mode 100644 taosWS/statement_test.go diff --git a/common/param/column.go b/common/param/column.go index de5dedc..1542f70 100644 --- a/common/param/column.go +++ b/common/param/column.go @@ -16,6 +16,10 @@ func NewColumnType(size int) *ColumnType { return &ColumnType{size: size, value: make([]*types.ColumnType, size)} } +func NewColumnTypeWithValue(value []*types.ColumnType) *ColumnType { + return &ColumnType{size: len(value), value: value, column: len(value)} +} + func (c *ColumnType) AddBool() *ColumnType { if c.column >= c.size { return c diff --git a/common/param/param.go b/common/param/param.go index a14854b..cce09d0 100644 --- a/common/param/param.go +++ b/common/param/param.go @@ -20,6 +20,15 @@ func NewParam(size int) *Param { } } +func NewParamsWithRowValue(value []driver.Value) []*Param { + params := make([]*Param, len(value)) + for i, d := range value { + params[i] = NewParam(1) + params[i].AddValue(d) + } + return params +} + func (p *Param) SetBool(offset int, value bool) { if offset >= p.size { return diff --git a/common/serializer/block.go b/common/serializer/block.go index 03a53f2..50d7a6e 100644 --- a/common/serializer/block.go +++ b/common/serializer/block.go @@ -37,7 +37,7 @@ func BMSetNull(c byte, n int) byte { return c + (1 << (7 - BitPos(n))) } -var ColumnNumerNotMatch = errors.New("number of columns does not match") +var ColumnNumberNotMatch = errors.New("number of columns does not match") var DataTypeWrong = errors.New("wrong data type") func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte, error) { @@ -48,7 +48,7 @@ func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte return nil, err } if len(colTypes) != columns { - return nil, ColumnNumerNotMatch + return nil, ColumnNumberNotMatch } var block []byte //version int32 diff --git a/taosSql/statement.go b/taosSql/statement.go index e103a0e..9513e37 100644 --- a/taosSql/statement.go +++ b/taosSql/statement.go @@ -138,11 +138,11 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { case reflect.Bool: v.Value = types.TaosBool(rv.Bool()) case reflect.Float32, reflect.Float64: - v.Value = types.TaosBool(rv.Float() == 1) + v.Value = types.TaosBool(rv.Float() > 0) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v.Value = types.TaosBool(rv.Int() == 1) + v.Value = types.TaosBool(rv.Int() > 0) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v.Value = types.TaosBool(rv.Uint() == 1) + v.Value = types.TaosBool(rv.Uint() > 0) case reflect.String: vv, err := strconv.ParseBool(rv.String()) if err != nil { diff --git a/taosWS/connection.go b/taosWS/connection.go index df339bc..8228661 100644 --- a/taosWS/connection.go +++ b/taosWS/connection.go @@ -15,6 +15,7 @@ import ( "github.com/gorilla/websocket" jsoniter "github.com/json-iterator/go" "github.com/taosdata/driver-go/v3/common" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" taosErrors "github.com/taosdata/driver-go/v3/errors" ) @@ -26,6 +27,14 @@ const ( WSFetch = "fetch" WSFetchBlock = "fetch_block" WSFreeResult = "free_result" + + STMTInit = "init" + STMTPrepare = "prepare" + STMTAddBatch = "add_batch" + STMTExec = "exec" + STMTClose = "close" + STMTGetColFields = "get_col_fields" + STMTUseResult = "use_result" ) var ( @@ -51,7 +60,7 @@ func newTaosConn(cfg *config) (*taosConn, error) { endpointUrl := &url.URL{ Scheme: cfg.net, Host: fmt.Sprintf("%s:%d", cfg.addr, cfg.port), - Path: "/rest/ws", + Path: "/ws", } if cfg.token != "" { endpointUrl.RawQuery = fmt.Sprintf("token=%s", cfg.token) @@ -101,9 +110,297 @@ func (tc *taosConn) Close() (err error) { } func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { - return nil, &taosErrors.TaosError{Code: 0xffff, ErrStr: "websocket does not support stmt"} + stmtID, err := tc.stmtInit() + if err != nil { + return nil, err + } + isInsert, err := tc.stmtPrepare(stmtID, query) + if err != nil { + tc.stmtClose(stmtID) + return nil, err + } + stmt := &Stmt{ + conn: tc, + stmtID: stmtID, + isInsert: isInsert, + pSql: query, + } + return stmt, nil +} + +func (tc *taosConn) stmtInit() (uint64, error) { + reqID := tc.generateReqID() + req := &StmtInitReq{ + ReqID: reqID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return 0, err + } + action := &WSAction{ + Action: STMTInit, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return 0, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return 0, err + } + var resp StmtInitResp + err = tc.readTo(&resp) + if err != nil { + return 0, err + } + if resp.Code != 0 { + return 0, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.StmtID, nil +} + +func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) { + reqID := tc.generateReqID() + req := &StmtPrepareRequest{ + ReqID: reqID, + StmtID: stmtID, + SQL: sql, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return false, err + } + action := &WSAction{ + Action: STMTPrepare, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return false, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return false, err + } + var resp StmtPrepareResponse + err = tc.readTo(&resp) + if err != nil { + return false, err + } + if resp.Code != 0 { + return false, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.IsInsert, nil +} + +func (tc *taosConn) stmtClose(stmtID uint64) error { + reqID := tc.generateReqID() + req := &StmtCloseRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: STMTClose, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return err + } + return nil +} + +func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, error) { + reqID := tc.generateReqID() + req := &StmtGetColFieldsRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return nil, err + } + action := &WSAction{ + Action: STMTGetColFields, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return nil, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return nil, err + } + var resp StmtGetColFieldsResponse + err = tc.readTo(&resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.Fields, nil } +func (tc *taosConn) stmtBindParam(stmtID uint64, block []byte) error { + reqID := tc.generateReqID() + tc.buf.Reset() + WriteUint64(tc.buf, reqID) + WriteUint64(tc.buf, stmtID) + WriteUint64(tc.buf, BindMessage) + tc.buf.Write(block) + err := tc.writeBinary(tc.buf.Bytes()) + if err != nil { + return err + } + var resp StmtBindResponse + err = tc.readTo(&resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func WriteUint64(buffer *bytes.Buffer, v uint64) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) + buffer.WriteByte(byte(v >> 16)) + buffer.WriteByte(byte(v >> 24)) + buffer.WriteByte(byte(v >> 32)) + buffer.WriteByte(byte(v >> 40)) + buffer.WriteByte(byte(v >> 48)) + buffer.WriteByte(byte(v >> 56)) +} + +func (tc *taosConn) stmtAddBatch(stmtID uint64) error { + reqID := tc.generateReqID() + req := &StmtAddBatchRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: STMTAddBatch, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return err + } + var resp StmtAddBatchResponse + err = tc.readTo(&resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (tc *taosConn) stmtExec(stmtID uint64) (int, error) { + reqID := tc.generateReqID() + req := &StmtExecRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return 0, err + } + action := &WSAction{ + Action: STMTExec, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return 0, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return 0, err + } + var resp StmtExecResponse + err = tc.readTo(&resp) + if err != nil { + return 0, err + } + if resp.Code != 0 { + return 0, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.Affected, nil +} + +func (tc *taosConn) stmtUseResult(stmtID uint64) (*rows, error) { + reqID := tc.generateReqID() + req := &StmtUseResultRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return nil, err + } + action := &WSAction{ + Action: STMTUseResult, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return nil, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return nil, err + } + var resp StmtUseResultResponse + err = tc.readTo(&resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + rs := &rows{ + buf: &bytes.Buffer{}, + conn: tc, + resultID: resp.ResultID, + fieldsCount: resp.FieldsCount, + fieldsNames: resp.FieldsNames, + fieldsTypes: resp.FieldsTypes, + fieldsLengths: resp.FieldsLengths, + precision: resp.Precision, + isStmt: true, + } + return rs, nil +} func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { return tc.execCtx(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) } @@ -263,8 +560,16 @@ func (tc *taosConn) connect() error { } func (tc *taosConn) writeText(data []byte) error { + return tc.write(data, websocket.TextMessage) +} + +func (tc *taosConn) writeBinary(data []byte) error { + return tc.write(data, websocket.BinaryMessage) +} + +func (tc *taosConn) write(data []byte, messageType int) error { tc.client.SetWriteDeadline(time.Now().Add(tc.writeTimeout)) - err := tc.client.WriteMessage(websocket.TextMessage, data) + err := tc.client.WriteMessage(messageType, data) if err != nil { return NewBadConnErrorWithCtx(err, string(data)) } diff --git a/taosWS/proto.go b/taosWS/proto.go index fd2fb39..2731eec 100644 --- a/taosWS/proto.go +++ b/taosWS/proto.go @@ -1,6 +1,10 @@ package taosWS -import "encoding/json" +import ( + "encoding/json" + + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" +) type WSConnectReq struct { ReqID uint64 `json:"req_id"` @@ -69,3 +73,122 @@ type WSAction struct { Action string `json:"action"` Args json.RawMessage `json:"args"` } + +type StmtPrepareRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + SQL string `json:"sql"` +} + +type StmtPrepareResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + IsInsert bool `json:"is_insert"` +} + +type StmtInitReq struct { + ReqID uint64 `json:"req_id"` +} + +type StmtInitResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} +type StmtCloseRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtCloseResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id,omitempty"` +} + +type StmtGetColFieldsRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtGetColFieldsResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Fields []*stmtCommon.StmtField `json:"fields"` +} + +const ( + BindMessage = 2 +) + +type StmtBindResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtAddBatchRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtAddBatchResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtExecRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtExecResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Affected int `json:"affected"` +} + +type StmtUseResultRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtUseResultResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + ResultID uint64 `json:"result_id"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes []uint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} diff --git a/taosWS/rows.go b/taosWS/rows.go index b462f4e..9b8f0be 100644 --- a/taosWS/rows.go +++ b/taosWS/rows.go @@ -27,6 +27,7 @@ type rows struct { fieldsTypes []uint8 fieldsLengths []int64 precision int + isStmt bool } func (rs *rows) Columns() []string { @@ -158,6 +159,9 @@ func (rs *rows) fetchBlock() error { } func (rs *rows) freeResult() error { + if rs.isStmt { + return nil + } tc := rs.conn reqID := tc.generateReqID() req := &WSFreeResultReq{ diff --git a/taosWS/statement.go b/taosWS/statement.go new file mode 100644 index 0000000..d313820 --- /dev/null +++ b/taosWS/statement.go @@ -0,0 +1,517 @@ +package taosWS + +import ( + "bytes" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/param" + "github.com/taosdata/driver-go/v3/common/serializer" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" + "github.com/taosdata/driver-go/v3/types" +) + +type Stmt struct { + stmtID uint64 + conn *taosConn + buffer bytes.Buffer + pSql string + isInsert bool + cols []*stmtCommon.StmtField + colTypes *param.ColumnType + queryColTypes []*types.ColumnType +} + +func (stmt *Stmt) Close() error { + err := stmt.conn.stmtClose(stmt.stmtID) + stmt.buffer.Reset() + stmt.conn = nil + return err +} + +func (stmt *Stmt) NumInput() int { + if stmt.colTypes != nil { + return len(stmt.cols) + } + return -1 +} + +func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.conn == nil { + return nil, driver.ErrBadConn + } + if len(args) != len(stmt.cols) { + return nil, fmt.Errorf("stmt exec error: wrong number of parameters") + } + block, err := serializer.SerializeRawBlock(param.NewParamsWithRowValue(args), stmt.colTypes) + if err != nil { + return nil, err + } + err = stmt.conn.stmtBindParam(stmt.stmtID, block) + if err != nil { + return nil, err + } + err = stmt.conn.stmtAddBatch(stmt.stmtID) + if err != nil { + return nil, err + } + affected, err := stmt.conn.stmtExec(stmt.stmtID) + if err != nil { + return nil, err + } + return driver.RowsAffected(affected), nil +} + +func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { + if stmt.conn == nil { + return nil, driver.ErrBadConn + } + block, err := serializer.SerializeRawBlock(param.NewParamsWithRowValue(args), param.NewColumnTypeWithValue(stmt.queryColTypes)) + if err != nil { + return nil, err + } + err = stmt.conn.stmtBindParam(stmt.stmtID, block) + if err != nil { + return nil, err + } + err = stmt.conn.stmtAddBatch(stmt.stmtID) + if err != nil { + return nil, err + } + _, err = stmt.conn.stmtExec(stmt.stmtID) + if err != nil { + return nil, err + } + return stmt.conn.stmtUseResult(stmt.stmtID) +} + +func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { + if stmt.isInsert { + if stmt.cols == nil { + cols, err := stmt.conn.stmtGetColFields(stmt.stmtID) + if err != nil { + return err + } + colTypes := make([]*types.ColumnType, len(cols)) + for i, col := range cols { + t, err := col.GetType() + if err != nil { + return err + } + colTypes[i] = t + } + stmt.cols = cols + stmt.colTypes = param.NewColumnTypeWithValue(colTypes) + } + if v.Ordinal > len(stmt.cols) { + return nil + } + if v.Value == nil { + return nil + } + switch stmt.cols[v.Ordinal-1].FieldType { + case common.TSDB_DATA_TYPE_NULL: + v.Value = nil + case common.TSDB_DATA_TYPE_BOOL: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + v.Value = types.TaosBool(rv.Bool()) + case reflect.Float32, reflect.Float64: + v.Value = types.TaosBool(rv.Float() > 0) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBool(rv.Int() > 0) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosBool(rv.Uint() > 0) + case reflect.String: + vv, err := strconv.ParseBool(rv.String()) + if err != nil { + return err + } + v.Value = types.TaosBool(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bool", v) + } + case common.TSDB_DATA_TYPE_TINYINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosTinyint(1) + } else { + v.Value = types.TaosTinyint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosTinyint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosTinyint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosTinyint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 8) + if err != nil { + return err + } + v.Value = types.TaosTinyint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to tinyint", v) + } + case common.TSDB_DATA_TYPE_SMALLINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosSmallint(1) + } else { + v.Value = types.TaosSmallint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosSmallint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosSmallint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosSmallint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 16) + if err != nil { + return err + } + v.Value = types.TaosSmallint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to smallint", v) + } + case common.TSDB_DATA_TYPE_INT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosInt(1) + } else { + v.Value = types.TaosInt(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosInt(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosInt(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosInt(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 32) + if err != nil { + return err + } + v.Value = types.TaosInt(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to int", v) + } + case common.TSDB_DATA_TYPE_BIGINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosBigint(1) + } else { + v.Value = types.TaosBigint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosBigint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBigint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosBigint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 64) + if err != nil { + return err + } + v.Value = types.TaosBigint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bigint", v) + } + case common.TSDB_DATA_TYPE_FLOAT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosFloat(1) + } else { + v.Value = types.TaosFloat(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosFloat(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosFloat(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosFloat(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseFloat(rv.String(), 32) + if err != nil { + return err + } + v.Value = types.TaosFloat(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to float", v) + } + case common.TSDB_DATA_TYPE_DOUBLE: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosDouble(1) + } else { + v.Value = types.TaosDouble(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosDouble(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosDouble(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosDouble(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseFloat(rv.String(), 64) + if err != nil { + return err + } + v.Value = types.TaosDouble(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to double", v) + } + case common.TSDB_DATA_TYPE_BINARY: + switch v.Value.(type) { + case string: + v.Value = types.TaosBinary(v.Value.(string)) + case []byte: + v.Value = types.TaosBinary(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to binary", v) + } + case common.TSDB_DATA_TYPE_VARBINARY: + switch v.Value.(type) { + case string: + v.Value = types.TaosVarBinary(v.Value.(string)) + case []byte: + v.Value = types.TaosVarBinary(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to varbinary", v) + } + + case common.TSDB_DATA_TYPE_GEOMETRY: + switch v.Value.(type) { + case string: + v.Value = types.TaosGeometry(v.Value.(string)) + case []byte: + v.Value = types.TaosGeometry(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to geometry", v) + } + + case common.TSDB_DATA_TYPE_TIMESTAMP: + t, is := v.Value.(time.Time) + if is { + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + return nil + } + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Float32, reflect.Float64: + t := common.TimestampConvertToTime(int64(rv.Float()), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + t := common.TimestampConvertToTime(rv.Int(), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + t := common.TimestampConvertToTime(int64(rv.Uint()), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.String: + t, err := time.Parse(time.RFC3339Nano, rv.String()) + if err != nil { + return err + } + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to timestamp", v) + } + case common.TSDB_DATA_TYPE_NCHAR: + switch v.Value.(type) { + case string: + v.Value = types.TaosNchar(v.Value.(string)) + case []byte: + v.Value = types.TaosNchar(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to nchar", v) + } + case common.TSDB_DATA_TYPE_UTINYINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUTinyint(1) + } else { + v.Value = types.TaosUTinyint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUTinyint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUTinyint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUTinyint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 8) + if err != nil { + return err + } + v.Value = types.TaosUTinyint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to tinyint unsigned", v) + } + case common.TSDB_DATA_TYPE_USMALLINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUSmallint(1) + } else { + v.Value = types.TaosUSmallint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUSmallint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUSmallint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUSmallint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 16) + if err != nil { + return err + } + v.Value = types.TaosUSmallint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to smallint unsigned", v) + } + case common.TSDB_DATA_TYPE_UINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUInt(1) + } else { + v.Value = types.TaosUInt(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUInt(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUInt(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUInt(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 32) + if err != nil { + return err + } + v.Value = types.TaosUInt(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to int unsigned", v) + } + case common.TSDB_DATA_TYPE_UBIGINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUBigint(1) + } else { + v.Value = types.TaosUBigint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUBigint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUBigint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUBigint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 64) + if err != nil { + return err + } + v.Value = types.TaosUBigint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bigint unsigned", v) + } + } + return nil + } else { + if v.Value == nil { + return errors.New("CheckNamedValue: value is nil") + } + if v.Ordinal == 1 { + stmt.queryColTypes = nil + } + if len(stmt.queryColTypes) < v.Ordinal { + tmp := stmt.queryColTypes + stmt.queryColTypes = make([]*types.ColumnType, v.Ordinal) + copy(stmt.queryColTypes, tmp) + } + t, is := v.Value.(time.Time) + if is { + v.Value = types.TaosBinary(t.Format(time.RFC3339Nano)) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBinaryType} + return nil + } + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + v.Value = types.TaosBool(rv.Bool()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBoolType} + case reflect.Float32, reflect.Float64: + v.Value = types.TaosDouble(rv.Float()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosDoubleType} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBigint(rv.Int()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBigintType} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUBigint(rv.Uint()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosUBigintType} + case reflect.String: + strVal := rv.String() + v.Value = types.TaosBinary(strVal) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{ + Type: types.TaosBinaryType, + MaxLen: len(strVal), + } + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + bsVal := rv.Bytes() + v.Value = types.TaosBinary(bsVal) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{ + Type: types.TaosBinaryType, + MaxLen: len(bsVal), + } + } else { + return fmt.Errorf("CheckNamedValue: can not convert query value %v", v) + } + default: + return fmt.Errorf("CheckNamedValue: can not convert query value %v", v) + } + return nil + } +} diff --git a/taosWS/statement_test.go b/taosWS/statement_test.go new file mode 100644 index 0000000..1ab008c --- /dev/null +++ b/taosWS/statement_test.go @@ -0,0 +1,2159 @@ +package taosWS + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStmtExec(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer func() { + t.Log("start3") + db.Close() + t.Log("done3") + }() + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws") + if err != nil { + t.Error(err) + return + } + t.Log("done2") + }() + _, err = db.Exec("create database if not exists test_stmt_driver_ws") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table if not exists test_stmt_driver_ws.ct(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare("insert into test_stmt_driver_ws.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + + if err != nil { + t.Error(err) + return + } + result, err := stmt.Exec(time.Now(), 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(1), affected) + t.Log("done") +} + +func TestStmtQuery(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + defer func() { + db.Exec("drop database if exists test_stmt_driver_ws_q") + }() + _, err = db.Exec("create database if not exists test_stmt_driver_ws_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table if not exists test_stmt_driver_ws_q.ct(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare("insert into test_stmt_driver_ws_q.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if err != nil { + t.Error(err) + return + } + now := time.Now() + result, err := stmt.Exec(now, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, int64(1), affected) + stmt.Close() + stmt, err = db.Prepare("select * from test_stmt_driver_ws_q.ct where ts = ?") + if err != nil { + t.Error(err) + return + } + rows, err := stmt.Query(now) + if err != nil { + t.Error(err) + return + } + columns, err := rows.Columns() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, []string{"ts", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13"}, columns) + count := 0 + for rows.Next() { + count += 1 + var ( + ts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + ) + err = rows.Scan(&ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13) + assert.NoError(t, err) + assert.Equal(t, now.UnixNano()/1e6, ts.UnixNano()/1e6) + assert.Equal(t, true, c1) + assert.Equal(t, int8(2), c2) + assert.Equal(t, int16(3), c3) + assert.Equal(t, int32(4), c4) + assert.Equal(t, int64(5), c5) + assert.Equal(t, uint8(6), c6) + assert.Equal(t, uint16(7), c7) + assert.Equal(t, uint32(8), c8) + assert.Equal(t, uint64(9), c9) + assert.Equal(t, float32(10), c10) + assert.Equal(t, float64(11), c11) + assert.Equal(t, "binary", c12) + assert.Equal(t, "nchar", c13) + } + assert.Equal(t, 1, count) +} + +func TestStmtConvertExec(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + }() + _, err = db.Exec("create database test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("use test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + now := time.Now().Format(time.RFC3339Nano) + tests := []struct { + name string + tbType string + pos string + bind []interface{} + expectValue interface{} + expectError bool + }{ + //bool + { + name: "bool_null", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "bool_err", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, []int{123}}, + expectValue: nil, + expectError: true, + }, + { + name: "bool_bool_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: true, + }, + { + name: "bool_bool_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: false, + }, + { + name: "bool_float_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: true, + }, + { + name: "bool_float_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, float32(0)}, + expectValue: false, + }, + { + name: "bool_int_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, int32(1)}, + expectValue: true, + }, + { + name: "bool_int_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, int32(0)}, + expectValue: false, + }, + { + name: "bool_uint_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, uint32(1)}, + expectValue: true, + }, + { + name: "bool_uint_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, uint32(0)}, + expectValue: false, + }, + { + name: "bool_string_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, "true"}, + expectValue: true, + }, + { + name: "bool_string_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, "false"}, + expectValue: false, + }, + //tiny int + { + name: "tiny_nil", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "tiny_err", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "tiny_bool_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int8(1), + }, + { + name: "tiny_bool_0", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int8(0), + }, + { + name: "tiny_float_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int8(1), + }, + { + name: "tiny_int_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int8(1), + }, + { + name: "tiny_uint_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int8(1), + }, + { + name: "tiny_string_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int8(1), + }, + // small int + { + name: "small_nil", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "small_err", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "small_bool_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int16(1), + }, + { + name: "small_bool_0", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int16(0), + }, + { + name: "small_float_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int16(1), + }, + { + name: "small_int_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int16(1), + }, + { + name: "small_uint_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int16(1), + }, + { + name: "small_string_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int16(1), + }, + // int + { + name: "int_nil", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "int_err", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "int_bool_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int32(1), + }, + { + name: "int_bool_0", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int32(0), + }, + { + name: "int_float_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int32(1), + }, + { + name: "int_int_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int32(1), + }, + { + name: "int_uint_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int32(1), + }, + { + name: "int_string_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int32(1), + }, + // big int + { + name: "big_nil", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "big_err", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "big_bool_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int64(1), + }, + { + name: "big_bool_0", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int64(0), + }, + { + name: "big_float_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int64(1), + }, + { + name: "big_int_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int64(1), + }, + { + name: "big_uint_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int64(1), + }, + { + name: "big_string_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int64(1), + }, + // float + { + name: "float_nil", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "float_err", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "float_bool_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: float32(1), + }, + { + name: "float_bool_0", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: float32(0), + }, + { + name: "float_float_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: float32(1), + }, + { + name: "float_int_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: float32(1), + }, + { + name: "float_uint_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: float32(1), + }, + { + name: "float_string_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: float32(1), + }, + //double + { + name: "double_nil", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "double_err", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "double_bool_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: float64(1), + }, + { + name: "double_bool_0", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: float64(0), + }, + { + name: "double_double_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: float64(1), + }, + { + name: "double_int_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: float64(1), + }, + { + name: "double_uint_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: float64(1), + }, + { + name: "double_string_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: float64(1), + }, + + //tiny int unsigned + { + name: "utiny_nil", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "utiny_err", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "utiny_bool_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint8(1), + }, + { + name: "utiny_bool_0", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint8(0), + }, + { + name: "utiny_float_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_int_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_uint_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_string_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint8(1), + }, + // small int unsigned + { + name: "usmall_nil", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "usmall_err", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "usmall_bool_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint16(1), + }, + { + name: "usmall_bool_0", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint16(0), + }, + { + name: "usmall_float_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_int_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_uint_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_string_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint16(1), + }, + // int unsigned + { + name: "uint_nil", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "uint_err", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "uint_bool_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint32(1), + }, + { + name: "uint_bool_0", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint32(0), + }, + { + name: "uint_float_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint32(1), + }, + { + name: "uint_int_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint32(1), + }, + { + name: "uint_uint_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint32(1), + }, + { + name: "uint_string_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint32(1), + }, + // big int unsigned + { + name: "ubig_nil", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "ubig_err", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "ubig_bool_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint64(1), + }, + { + name: "ubig_bool_0", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint64(0), + }, + { + name: "ubig_float_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_int_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_uint_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_string_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint64(1), + }, + //binary + { + name: "binary_nil", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "binary_err", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "binary_string_chinese", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, "中文"}, + expectValue: "中文", + }, + { + name: "binary_bytes_chinese", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, []byte("中文")}, + expectValue: "中文", + }, + //nchar + { + name: "nchar_nil", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "nchar_err", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "binary_string_chinese", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, "中文"}, + expectValue: "中文", + }, + { + name: "binary_bytes_chinese", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, []byte("中文")}, + expectValue: "中文", + }, + // timestamp + { + name: "ts_nil", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "ts_err", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "ts_time_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, time.Unix(0, 1e6)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_float_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_int_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_uint_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_string_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, "1970-01-01T00:00:00.001Z"}, + expectValue: time.Unix(0, 1e6), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tbName := fmt.Sprintf("test_%s", tt.name) + tbType := tt.tbType + drop := fmt.Sprintf("drop table if exists %s", tbName) + create := fmt.Sprintf("create table if not exists %s(%s)", tbName, tbType) + pos := tt.pos + sql := fmt.Sprintf("insert into %s values(%s)", tbName, pos) + var err error + if _, err = db.Exec(drop); err != nil { + t.Error(err) + return + } + if _, err = db.Exec(create); err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare(sql) + if err != nil { + t.Error(err) + return + } + result, err := stmt.Exec(tt.bind...) + if tt.expectError { + assert.NotNil(t, err) + stmt.Close() + return + } + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, int64(1), affected) + rows, err := db.Query(fmt.Sprintf("select v from %s", tbName)) + if err != nil { + t.Error(err) + return + } + var data []driver.Value + tts, err := rows.ColumnTypes() + if err != nil { + t.Error(err) + return + } + typesL := make([]reflect.Type, 1) + for i, tp := range tts { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + typesL[i] = st + } + for rows.Next() { + values := make([]interface{}, 1) + for i := range values { + values[i] = reflect.New(typesL[i]).Interface() + } + err = rows.Scan(values...) + if err != nil { + t.Error(err) + return + } + v, err := values[0].(driver.Valuer).Value() + if err != nil { + t.Error(err) + } + data = append(data, v) + } + if len(data) != 1 { + t.Errorf("expect %d got %d", 1, len(data)) + return + } + if data[0] != tt.expectValue { + t.Errorf("expect %v got %v", tt.expectValue, data[0]) + return + } + }) + } +} + +func TestStmtConvertQuery(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + }() + _, err = db.Exec("create database test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("use test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table t0 (ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + now := time.Now() + after1s := now.Add(time.Second) + _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',true,2,3,4,5,6,7,8,9,10,11,'binary','nchar')", now.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", after1s.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + tests := []struct { + name string + field string + where string + bind interface{} + expectNoValue bool + expectValue driver.Value + expectError bool + }{ + //ts + { + name: "ts", + field: "ts", + where: "ts = ?", + bind: now, + expectValue: time.Unix(now.Unix(), int64((now.Nanosecond()/1e6)*1e6)).Local(), + }, + + //bool + { + name: "bool_true", + field: "c1", + where: "c1 = ?", + bind: true, + expectValue: true, + }, + { + name: "bool_false", + field: "c1", + where: "c1 = ?", + bind: false, + expectNoValue: true, + }, + { + name: "tinyint_int8", + field: "c2", + where: "c2 = ?", + bind: int8(2), + expectValue: int8(2), + }, + { + name: "tinyint_iny16", + field: "c2", + where: "c2 = ?", + bind: int16(2), + expectValue: int8(2), + }, + { + name: "tinyint_int32", + field: "c2", + where: "c2 = ?", + bind: int32(2), + expectValue: int8(2), + }, + { + name: "tinyint_int64", + field: "c2", + where: "c2 = ?", + bind: int64(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint8", + field: "c2", + where: "c2 = ?", + bind: uint8(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint16", + field: "c2", + where: "c2 = ?", + bind: uint16(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint32", + field: "c2", + where: "c2 = ?", + bind: uint32(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint64", + field: "c2", + where: "c2 = ?", + bind: uint64(2), + expectValue: int8(2), + }, + { + name: "tinyint_float32", + field: "c2", + where: "c2 = ?", + bind: float32(2), + expectValue: int8(2), + }, + { + name: "tinyint_float64", + field: "c2", + where: "c2 = ?", + bind: float64(2), + expectValue: int8(2), + }, + { + name: "tinyint_int", + field: "c2", + where: "c2 = ?", + bind: int(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint", + field: "c2", + where: "c2 = ?", + bind: uint(2), + expectValue: int8(2), + }, + + // smallint + { + name: "smallint_int8", + field: "c3", + where: "c3 = ?", + bind: int8(3), + expectValue: int16(3), + }, + { + name: "smallint_iny16", + field: "c3", + where: "c3 = ?", + bind: int16(3), + expectValue: int16(3), + }, + { + name: "smallint_int32", + field: "c3", + where: "c3 = ?", + bind: int32(3), + expectValue: int16(3), + }, + { + name: "smallint_int64", + field: "c3", + where: "c3 = ?", + bind: int64(3), + expectValue: int16(3), + }, + { + name: "smallint_uint8", + field: "c3", + where: "c3 = ?", + bind: uint8(3), + expectValue: int16(3), + }, + { + name: "smallint_uint16", + field: "c3", + where: "c3 = ?", + bind: uint16(3), + expectValue: int16(3), + }, + { + name: "smallint_uint32", + field: "c3", + where: "c3 = ?", + bind: uint32(3), + expectValue: int16(3), + }, + { + name: "smallint_uint64", + field: "c3", + where: "c3 = ?", + bind: uint64(3), + expectValue: int16(3), + }, + { + name: "smallint_float32", + field: "c3", + where: "c3 = ?", + bind: float32(3), + expectValue: int16(3), + }, + { + name: "smallint_float64", + field: "c3", + where: "c3 = ?", + bind: float64(3), + expectValue: int16(3), + }, + { + name: "smallint_int", + field: "c3", + where: "c3 = ?", + bind: int(3), + expectValue: int16(3), + }, + { + name: "smallint_uint", + field: "c3", + where: "c3 = ?", + bind: uint(3), + expectValue: int16(3), + }, + + //int + { + name: "int_int8", + field: "c4", + where: "c4 = ?", + bind: int8(4), + expectValue: int32(4), + }, + { + name: "int_iny16", + field: "c4", + where: "c4 = ?", + bind: int16(4), + expectValue: int32(4), + }, + { + name: "int_int32", + field: "c4", + where: "c4 = ?", + bind: int32(4), + expectValue: int32(4), + }, + { + name: "int_int64", + field: "c4", + where: "c4 = ?", + bind: int64(4), + expectValue: int32(4), + }, + { + name: "int_uint8", + field: "c4", + where: "c4 = ?", + bind: uint8(4), + expectValue: int32(4), + }, + { + name: "int_uint16", + field: "c4", + where: "c4 = ?", + bind: uint16(4), + expectValue: int32(4), + }, + { + name: "int_uint32", + field: "c4", + where: "c4 = ?", + bind: uint32(4), + expectValue: int32(4), + }, + { + name: "int_uint64", + field: "c4", + where: "c4 = ?", + bind: uint64(4), + expectValue: int32(4), + }, + { + name: "int_float32", + field: "c4", + where: "c4 = ?", + bind: float32(4), + expectValue: int32(4), + }, + { + name: "int_float64", + field: "c4", + where: "c4 = ?", + bind: float64(4), + expectValue: int32(4), + }, + { + name: "int_int", + field: "c4", + where: "c4 = ?", + bind: int(4), + expectValue: int32(4), + }, + { + name: "int_uint", + field: "c4", + where: "c4 = ?", + bind: uint(4), + expectValue: int32(4), + }, + + //bigint + { + name: "bigint_int8", + field: "c5", + where: "c5 = ?", + bind: int8(5), + expectValue: int64(5), + }, + { + name: "bigint_iny16", + field: "c5", + where: "c5 = ?", + bind: int16(5), + expectValue: int64(5), + }, + { + name: "bigint_int32", + field: "c5", + where: "c5 = ?", + bind: int32(5), + expectValue: int64(5), + }, + { + name: "bigint_int64", + field: "c5", + where: "c5 = ?", + bind: int64(5), + expectValue: int64(5), + }, + { + name: "bigint_uint8", + field: "c5", + where: "c5 = ?", + bind: uint8(5), + expectValue: int64(5), + }, + { + name: "bigint_uint16", + field: "c5", + where: "c5 = ?", + bind: uint16(5), + expectValue: int64(5), + }, + { + name: "bigint_uint32", + field: "c5", + where: "c5 = ?", + bind: uint32(5), + expectValue: int64(5), + }, + { + name: "bigint_uint64", + field: "c5", + where: "c5 = ?", + bind: uint64(5), + expectValue: int64(5), + }, + { + name: "bigint_float32", + field: "c5", + where: "c5 = ?", + bind: float32(5), + expectValue: int64(5), + }, + { + name: "bigint_float64", + field: "c5", + where: "c5 = ?", + bind: float64(5), + expectValue: int64(5), + }, + { + name: "bigint_int", + field: "c5", + where: "c5 = ?", + bind: int(5), + expectValue: int64(5), + }, + { + name: "bigint_uint", + field: "c5", + where: "c5 = ?", + bind: uint(5), + expectValue: int64(5), + }, + + //utinyint + { + name: "utinyint_int8", + field: "c6", + where: "c6 = ?", + bind: int8(6), + expectValue: uint8(6), + }, + { + name: "utinyint_iny16", + field: "c6", + where: "c6 = ?", + bind: int16(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int32", + field: "c6", + where: "c6 = ?", + bind: int32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int64", + field: "c6", + where: "c6 = ?", + bind: int64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint8", + field: "c6", + where: "c6 = ?", + bind: uint8(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint16", + field: "c6", + where: "c6 = ?", + bind: uint16(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint32", + field: "c6", + where: "c6 = ?", + bind: uint32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint64", + field: "c6", + where: "c6 = ?", + bind: uint64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_float32", + field: "c6", + where: "c6 = ?", + bind: float32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_float64", + field: "c6", + where: "c6 = ?", + bind: float64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int", + field: "c6", + where: "c6 = ?", + bind: int(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint", + field: "c6", + where: "c6 = ?", + bind: uint(6), + expectValue: uint8(6), + }, + + //usmallint + { + name: "usmallint_int8", + field: "c7", + where: "c7 = ?", + bind: int8(7), + expectValue: uint16(7), + }, + { + name: "usmallint_iny16", + field: "c7", + where: "c7 = ?", + bind: int16(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int32", + field: "c7", + where: "c7 = ?", + bind: int32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int64", + field: "c7", + where: "c7 = ?", + bind: int64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint8", + field: "c7", + where: "c7 = ?", + bind: uint8(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint16", + field: "c7", + where: "c7 = ?", + bind: uint16(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint32", + field: "c7", + where: "c7 = ?", + bind: uint32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint64", + field: "c7", + where: "c7 = ?", + bind: uint64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_float32", + field: "c7", + where: "c7 = ?", + bind: float32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_float64", + field: "c7", + where: "c7 = ?", + bind: float64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int", + field: "c7", + where: "c7 = ?", + bind: int(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint", + field: "c7", + where: "c7 = ?", + bind: uint(7), + expectValue: uint16(7), + }, + + //uint + { + name: "uint_int8", + field: "c8", + where: "c8 = ?", + bind: int8(8), + expectValue: uint32(8), + }, + { + name: "uint_iny16", + field: "c8", + where: "c8 = ?", + bind: int16(8), + expectValue: uint32(8), + }, + { + name: "uint_int32", + field: "c8", + where: "c8 = ?", + bind: int32(8), + expectValue: uint32(8), + }, + { + name: "uint_int64", + field: "c8", + where: "c8 = ?", + bind: int64(8), + expectValue: uint32(8), + }, + { + name: "uint_uint8", + field: "c8", + where: "c8 = ?", + bind: uint8(8), + expectValue: uint32(8), + }, + { + name: "uint_uint16", + field: "c8", + where: "c8 = ?", + bind: uint16(8), + expectValue: uint32(8), + }, + { + name: "uint_uint32", + field: "c8", + where: "c8 = ?", + bind: uint32(8), + expectValue: uint32(8), + }, + { + name: "uint_uint64", + field: "c8", + where: "c8 = ?", + bind: uint64(8), + expectValue: uint32(8), + }, + { + name: "uint_float32", + field: "c8", + where: "c8 = ?", + bind: float32(8), + expectValue: uint32(8), + }, + { + name: "uint_float64", + field: "c8", + where: "c8 = ?", + bind: float64(8), + expectValue: uint32(8), + }, + { + name: "uint_int", + field: "c8", + where: "c8 = ?", + bind: int(8), + expectValue: uint32(8), + }, + { + name: "uint_uint", + field: "c8", + where: "c8 = ?", + bind: uint(8), + expectValue: uint32(8), + }, + + //ubigint + { + name: "ubigint_int8", + field: "c9", + where: "c9 = ?", + bind: int8(9), + expectValue: uint64(9), + }, + { + name: "ubigint_iny16", + field: "c9", + where: "c9 = ?", + bind: int16(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int32", + field: "c9", + where: "c9 = ?", + bind: int32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int64", + field: "c9", + where: "c9 = ?", + bind: int64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint8", + field: "c9", + where: "c9 = ?", + bind: uint8(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint16", + field: "c9", + where: "c9 = ?", + bind: uint16(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint32", + field: "c9", + where: "c9 = ?", + bind: uint32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint64", + field: "c9", + where: "c9 = ?", + bind: uint64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_float32", + field: "c9", + where: "c9 = ?", + bind: float32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_float64", + field: "c9", + where: "c9 = ?", + bind: float64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int", + field: "c9", + where: "c9 = ?", + bind: int(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint", + field: "c9", + where: "c9 = ?", + bind: uint(9), + expectValue: uint64(9), + }, + + //float + { + name: "float_int8", + field: "c10", + where: "c10 = ?", + bind: int8(10), + expectValue: float32(10), + }, + { + name: "float_iny16", + field: "c10", + where: "c10 = ?", + bind: int16(10), + expectValue: float32(10), + }, + { + name: "float_int32", + field: "c10", + where: "c10 = ?", + bind: int32(10), + expectValue: float32(10), + }, + { + name: "float_int64", + field: "c10", + where: "c10 = ?", + bind: int64(10), + expectValue: float32(10), + }, + { + name: "float_uint8", + field: "c10", + where: "c10 = ?", + bind: uint8(10), + expectValue: float32(10), + }, + { + name: "float_uint16", + field: "c10", + where: "c10 = ?", + bind: uint16(10), + expectValue: float32(10), + }, + { + name: "float_uint32", + field: "c10", + where: "c10 = ?", + bind: uint32(10), + expectValue: float32(10), + }, + { + name: "float_uint64", + field: "c10", + where: "c10 = ?", + bind: uint64(10), + expectValue: float32(10), + }, + { + name: "float_float32", + field: "c10", + where: "c10 = ?", + bind: float32(10), + expectValue: float32(10), + }, + { + name: "float_float64", + field: "c10", + where: "c10 = ?", + bind: float64(10), + expectValue: float32(10), + }, + { + name: "float_int", + field: "c10", + where: "c10 = ?", + bind: int(10), + expectValue: float32(10), + }, + { + name: "float_uint", + field: "c10", + where: "c10 = ?", + bind: uint(10), + expectValue: float32(10), + }, + + //double + { + name: "double_int8", + field: "c11", + where: "c11 = ?", + bind: int8(11), + expectValue: float64(11), + }, + { + name: "double_iny16", + field: "c11", + where: "c11 = ?", + bind: int16(11), + expectValue: float64(11), + }, + { + name: "double_int32", + field: "c11", + where: "c11 = ?", + bind: int32(11), + expectValue: float64(11), + }, + { + name: "double_int64", + field: "c11", + where: "c11 = ?", + bind: int64(11), + expectValue: float64(11), + }, + { + name: "double_uint8", + field: "c11", + where: "c11 = ?", + bind: uint8(11), + expectValue: float64(11), + }, + { + name: "double_uint16", + field: "c11", + where: "c11 = ?", + bind: uint16(11), + expectValue: float64(11), + }, + { + name: "double_uint32", + field: "c11", + where: "c11 = ?", + bind: uint32(11), + expectValue: float64(11), + }, + { + name: "double_uint64", + field: "c11", + where: "c11 = ?", + bind: uint64(11), + expectValue: float64(11), + }, + { + name: "double_float32", + field: "c11", + where: "c11 = ?", + bind: float32(11), + expectValue: float64(11), + }, + { + name: "double_float64", + field: "c11", + where: "c11 = ?", + bind: float64(11), + expectValue: float64(11), + }, + { + name: "double_int", + field: "c11", + where: "c11 = ?", + bind: int(11), + expectValue: float64(11), + }, + { + name: "double_uint", + field: "c11", + where: "c11 = ?", + bind: uint(11), + expectValue: float64(11), + }, + + // binary + { + name: "binary_string", + field: "c12", + where: "c12 = ?", + bind: "binary", + expectValue: "binary", + }, + { + name: "binary_bytes", + field: "c12", + where: "c12 = ?", + bind: []byte("binary"), + expectValue: "binary", + }, + { + name: "binary_string_like", + field: "c12", + where: "c12 like ?", + bind: "bin%", + expectValue: "binary", + }, + + // nchar + { + name: "nchar_string", + field: "c13", + where: "c13 = ?", + bind: "nchar", + expectValue: "nchar", + }, + { + name: "nchar_bytes", + field: "c13", + where: "c13 = ?", + bind: []byte("nchar"), + expectValue: "nchar", + }, + { + name: "nchar_string", + field: "c13", + where: "c13 like ?", + bind: "nch%", + expectValue: "nchar", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql := fmt.Sprintf("select %s from t0 where %s", tt.field, tt.where) + + stmt, err := db.Prepare(sql) + if err != nil { + t.Error(err) + return + } + defer stmt.Close() + rows, err := stmt.Query(tt.bind) + if tt.expectError { + assert.NotNil(t, err) + stmt.Close() + return + } + if err != nil { + t.Error(err) + return + } + tts, err := rows.ColumnTypes() + typesL := make([]reflect.Type, 1) + for i, tp := range tts { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + typesL[i] = st + } + var data []driver.Value + for rows.Next() { + values := make([]interface{}, 1) + for i := range values { + values[i] = reflect.New(typesL[i]).Interface() + } + err = rows.Scan(values...) + if err != nil { + t.Error(err) + return + } + v, err := values[0].(driver.Valuer).Value() + if err != nil { + t.Error(err) + } + data = append(data, v) + } + if tt.expectNoValue { + if len(data) > 0 { + t.Errorf("expect no value got %#v", data) + return + } + return + } + if len(data) != 1 { + t.Errorf("expect %d got %d", 1, len(data)) + return + } + if data[0] != tt.expectValue { + t.Errorf("expect %v got %v", tt.expectValue, data[0]) + return + } + }) + } +} From 2d2a158ec1d5ddf78bce041caf81790f88cbd15a Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 19 Jan 2024 14:22:03 +0800 Subject: [PATCH 4/6] enh: adapts to new websocket interface --- examples/stmtoverws/main.go | 2 +- examples/tmqoverws/main.go | 2 +- ws/schemaless/schemaless.go | 4 +--- ws/schemaless/schemaless_test.go | 2 +- ws/stmt/connector.go | 8 +++++++- ws/stmt/stmt_test.go | 2 +- ws/tmq/consumer.go | 8 +++++++- ws/tmq/consumer_test.go | 6 +++--- 8 files changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/stmtoverws/main.go b/examples/stmtoverws/main.go index a9c0d5a..6e083df 100644 --- a/examples/stmtoverws/main.go +++ b/examples/stmtoverws/main.go @@ -19,7 +19,7 @@ func main() { defer db.Close() prepareEnv(db) - config := stmt.NewConfig("ws://127.0.0.1:6041/rest/stmt", 0) + config := stmt.NewConfig("ws://127.0.0.1:6041", 0) config.SetConnectUser("root") config.SetConnectPass("taosdata") config.SetConnectDB("example_ws_stmt") diff --git a/examples/tmqoverws/main.go b/examples/tmqoverws/main.go index b0fdf91..cac691a 100644 --- a/examples/tmqoverws/main.go +++ b/examples/tmqoverws/main.go @@ -18,7 +18,7 @@ func main() { defer db.Close() prepareEnv(db) consumer, err := tmq.NewConsumer(&tmqcommon.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.url": "ws://127.0.0.1:6041", "ws.message.channelLen": uint(0), "ws.message.timeout": common.DefaultMessageTimeout, "ws.message.writeWait": common.DefaultWriteWait, diff --git a/ws/schemaless/schemaless.go b/ws/schemaless/schemaless.go index c6dcf6b..5494745 100644 --- a/ws/schemaless/schemaless.go +++ b/ws/schemaless/schemaless.go @@ -44,9 +44,7 @@ func NewSchemaless(config *Config) (*Schemaless, error) { if wsUrl.Scheme != "ws" && wsUrl.Scheme != "wss" { return nil, errors.New("config url scheme error") } - if len(wsUrl.Path) == 0 || wsUrl.Path != "/rest/schemaless" { - wsUrl.Path = "/rest/schemaless" - } + wsUrl.Path = "/ws" dialer := common.DefaultDialer dialer.EnableCompression = config.enableCompression ws, _, err := dialer.Dial(wsUrl.String(), nil) diff --git a/ws/schemaless/schemaless_test.go b/ws/schemaless/schemaless_test.go index 14dbb0d..d3754e2 100644 --- a/ws/schemaless/schemaless_test.go +++ b/ws/schemaless/schemaless_test.go @@ -62,7 +62,7 @@ func TestSchemaless_Insert(t *testing.T) { } defer func() { _ = after() }() - s, err := NewSchemaless(NewConfig("ws://localhost:6041/rest/schemaless", 1, + s, err := NewSchemaless(NewConfig("ws://localhost:6041", 1, SetDb("test_schemaless_ws"), SetReadTimeout(10*time.Second), SetWriteTimeout(10*time.Second), diff --git a/ws/stmt/connector.go b/ws/stmt/connector.go index 48830c2..c33e57b 100644 --- a/ws/stmt/connector.go +++ b/ws/stmt/connector.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "net/url" "sync" "sync/atomic" "time" @@ -46,7 +47,12 @@ func NewConnector(config *Config) (*Connector, error) { } dialer := common.DefaultDialer dialer.EnableCompression = config.EnableCompression - ws, _, err := dialer.Dial(config.Url, nil) + u, err := url.Parse(config.Url) + if err != nil { + return nil, err + } + u.Path = "/ws" + ws, _, err := dialer.Dial(u.String(), nil) if err != nil { return nil, err } diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go index f8f8f82..d0ffc69 100644 --- a/ws/stmt/stmt_test.go +++ b/ws/stmt/stmt_test.go @@ -158,7 +158,7 @@ func TestStmt(t *testing.T) { } defer cleanEnv() now := time.Now() - config := NewConfig("ws://127.0.0.1:6041/rest/stmt", 0) + config := NewConfig("ws://127.0.0.1:6041", 0) config.SetConnectUser("root") config.SetConnectPass("taosdata") config.SetConnectDB("test_ws_stmt") diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index 26da8e9..46ea559 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "fmt" + "net/url" "strconv" "sync" "sync/atomic" @@ -79,7 +80,12 @@ func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { dialer := common.DefaultDialer dialer.EnableCompression = config.EnableCompression - ws, _, err := dialer.Dial(config.Url, nil) + u, err := url.Parse(config.Url) + if err != nil { + return nil, err + } + u.Path = "/rest/tmq" + ws, _, err := dialer.Dial(u.String(), nil) if err != nil { return nil, err } diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index ab43da2..4d382eb 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -124,7 +124,7 @@ func TestConsumer(t *testing.T) { } }() consumer, err := NewConsumer(&tmq.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.url": "ws://127.0.0.1:6041", "ws.message.channelLen": uint(0), "ws.message.timeout": common.DefaultMessageTimeout, "ws.message.writeWait": common.DefaultWriteWait, @@ -266,7 +266,7 @@ func TestSeek(t *testing.T) { } defer cleanSeekEnv() consumer, err := NewConsumer(&tmq.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.url": "ws://127.0.0.1:6041", "ws.message.channelLen": uint(0), "ws.message.timeout": common.DefaultMessageTimeout, "ws.message.writeWait": common.DefaultWriteWait, @@ -394,7 +394,7 @@ func TestAutoCommit(t *testing.T) { } defer cleanAutocommitEnv() consumer, err := NewConsumer(&tmq.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.url": "ws://127.0.0.1:6041", "ws.message.channelLen": uint(0), "ws.message.timeout": common.DefaultMessageTimeout, "ws.message.writeWait": common.DefaultWriteWait, From e5a0989f0c3931e49dd5211420825e0c81cfea6d Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Fri, 19 Jan 2024 14:38:27 +0800 Subject: [PATCH 5/6] ci: ci remove sccache --- .github/workflows/go.yml | 6 +----- .github/workflows/push.yml | 8 +------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 058d947..1284c07 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -13,8 +13,6 @@ on: required: true type: string -env: - SCCACHE_GHA_ENABLED: "true" jobs: build: @@ -45,8 +43,6 @@ jobs: cd TDengine echo "commit_id=$(git rev-parse HEAD)" >> $GITHUB_OUTPUT - - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.3 - name: Cache server by pr if: github.event_name == 'pull_request' @@ -78,7 +74,7 @@ jobs: cd TDengine mkdir debug cd debug - cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 make -j 4 - name: package diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 34fb993..35d1d97 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -7,8 +7,6 @@ on: - '3.0' - '3.1' -env: - SCCACHE_GHA_ENABLED: "true" jobs: build: @@ -30,8 +28,6 @@ jobs: cd TDengine echo "commit_id=$(git rev-parse HEAD)" >> $GITHUB_OUTPUT - - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.3 - name: Cache server id: cache-server @@ -44,8 +40,6 @@ jobs: if: steps.cache-server.outputs.cache-hit != 'true' run: sudo apt install -y libgeos-dev - - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.3 - name: install TDengine if: steps.cache-server.outputs.cache-hit != 'true' @@ -53,7 +47,7 @@ jobs: cd TDengine mkdir debug cd debug - cmake .. -DBUILD_JDBC=false -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + cmake .. -DBUILD_JDBC=false -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 make -j 4 - name: package From 16ab09ce63e4ac6d1edbd39494c6d37eeecb3a07 Mon Sep 17 00:00:00 2001 From: t_max <1172915550@qq.com> Date: Thu, 25 Jan 2024 14:59:20 +0800 Subject: [PATCH 6/6] enh: use fetch_raw to reduce the number of calls to get results --- af/tmq/consumer.go | 39 +- af/tmq/consumer_test.go | 101 ++++- common/parser/block.go | 12 + common/parser/block_test.go | 10 +- common/parser/raw.go | 179 ++++++++ common/parser/raw_test.go | 797 ++++++++++++++++++++++++++++++++++++ ws/tmq/consumer.go | 104 ++--- ws/tmq/consumer_test.go | 96 +++++ ws/tmq/proto.go | 5 + 9 files changed, 1244 insertions(+), 99 deletions(-) create mode 100644 common/parser/raw.go create mode 100644 common/parser/raw_test.go diff --git a/af/tmq/consumer.go b/af/tmq/consumer.go index ef2338e..b6330d8 100644 --- a/af/tmq/consumer.go +++ b/af/tmq/consumer.go @@ -13,7 +13,8 @@ import ( ) type Consumer struct { - cConsumer unsafe.Pointer + cConsumer unsafe.Pointer + dataParser *parser.TMQRawDataParser } // NewConsumer Create new TMQ consumer with TMQ config @@ -28,7 +29,8 @@ func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { return nil, err } consumer := &Consumer{ - cConsumer: cConsumer, + cConsumer: cConsumer, + dataParser: parser.NewTMQRawDataParser(), } return consumer, nil } @@ -176,27 +178,22 @@ func (c *Consumer) getMeta(message unsafe.Pointer) (*tmq.Meta, error) { } func (c *Consumer) getData(message unsafe.Pointer) ([]*tmq.Data, error) { + errCode, raw := wrapper.TMQGetRaw(message) + if errCode != taosError.SUCCESS { + errStr := wrapper.TaosErrorStr(message) + err := taosError.NewError(int(errCode), errStr) + return nil, err + } + _, _, rawPtr := wrapper.ParseRawMeta(raw) + blockInfos, err := c.dataParser.Parse(rawPtr) + if err != nil { + return nil, err + } var tmqData []*tmq.Data - for { - blockSize, errCode, block := wrapper.TaosFetchRawBlock(message) - if errCode != int(taosError.SUCCESS) { - errStr := wrapper.TaosErrorStr(message) - err := taosError.NewError(errCode, errStr) - return nil, err - } - if blockSize == 0 { - break - } - tableName := wrapper.TMQGetTableName(message) - fileCount := wrapper.TaosNumFields(message) - rh, err := wrapper.ReadColumn(message, fileCount) - if err != nil { - return nil, err - } - precision := wrapper.TaosResultPrecision(message) + for i := 0; i < len(blockInfos); i++ { tmqData = append(tmqData, &tmq.Data{ - TableName: tableName, - Data: parser.ReadBlock(block, blockSize, rh.ColTypes, precision), + TableName: blockInfos[i].TableName, + Data: parser.ReadBlockSimple(blockInfos[i].RawBlock, blockInfos[i].Precision), }) } return tmqData, nil diff --git a/af/tmq/consumer_test.go b/af/tmq/consumer_test.go index cd3430e..20232e0 100644 --- a/af/tmq/consumer_test.go +++ b/af/tmq/consumer_test.go @@ -180,7 +180,7 @@ func TestSeek(t *testing.T) { } defer func() { - //execWithoutResult(conn, "drop database if exists "+db) + execWithoutResult(conn, "drop database if exists "+db) }() for _, sql := range sqls { err = execWithoutResult(conn, sql) @@ -308,3 +308,102 @@ func execWithoutResult(conn unsafe.Pointer, sql string) error { } return nil } + +func prepareMultiBlockEnv(conn unsafe.Pointer) error { + var err error + steps := []string{ + "drop topic if exists test_tmq_multi_block_topic", + "drop database if exists test_tmq_multi_block", + "create database test_tmq_multi_block vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_tmq_multi_block_topic as database test_tmq_multi_block", + "create table test_tmq_multi_block.t1(ts timestamp,v int)", + "create table test_tmq_multi_block.t2(ts timestamp,v int)", + "create table test_tmq_multi_block.t3(ts timestamp,v int)", + "create table test_tmq_multi_block.t4(ts timestamp,v int)", + "create table test_tmq_multi_block.t5(ts timestamp,v int)", + "create table test_tmq_multi_block.t6(ts timestamp,v int)", + "create table test_tmq_multi_block.t7(ts timestamp,v int)", + "create table test_tmq_multi_block.t8(ts timestamp,v int)", + "create table test_tmq_multi_block.t9(ts timestamp,v int)", + "create table test_tmq_multi_block.t10(ts timestamp,v int)", + "insert into test_tmq_multi_block.t1 values (now,1) test_tmq_multi_block.t2 values (now,2) " + + "test_tmq_multi_block.t3 values (now,3) test_tmq_multi_block.t4 values (now,4)" + + "test_tmq_multi_block.t5 values (now,5) test_tmq_multi_block.t6 values (now,6)" + + "test_tmq_multi_block.t7 values (now,7) test_tmq_multi_block.t8 values (now,8)" + + "test_tmq_multi_block.t9 values (now,9) test_tmq_multi_block.t10 values (now,10)", + } + for _, step := range steps { + err = execWithoutResult(conn, step) + if err != nil { + return err + } + } + return nil +} + +func cleanMultiBlockEnv(conn unsafe.Pointer) error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_tmq_multi_block_topic", + "drop database if exists test_tmq_multi_block", + } + for _, step := range steps { + err = execWithoutResult(conn, step) + if err != nil { + return err + } + } + return nil +} + +func TestMultiBlock(t *testing.T) { + conn, err := wrapper.TaosConnect("", "root", "taosdata", "", 0) + if err != nil { + t.Error(err) + return + } + defer wrapper.TaosClose(conn) + err = prepareMultiBlockEnv(conn) + assert.NoError(t, err) + defer cleanMultiBlockEnv(conn) + consumer, err := NewConsumer(&tmq.ConfigMap{ + "group.id": "test", + "td.connect.ip": "127.0.0.1", + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "td.connect.port": "6030", + "auto.offset.reset": "earliest", + "client.id": "test_tmq_multi_block_topic", + "enable.auto.commit": "false", + "msg.with.table.name": "true", + }) + assert.NoError(t, err) + if err != nil { + t.Error(err) + return + } + defer func() { + consumer.Unsubscribe() + consumer.Close() + }() + topic := []string{"test_tmq_multi_block_topic"} + err = consumer.SubscribeTopics(topic, nil) + if err != nil { + t.Error(err) + return + } + for i := 0; i < 10; i++ { + event := consumer.Poll(500) + if event == nil { + continue + } + switch e := event.(type) { + case *tmq.DataMessage: + data := e.Value().([]*tmq.Data) + assert.Equal(t, "test_tmq_multi_block", e.DBName()) + assert.Equal(t, 10, len(data)) + return + } + } +} diff --git a/common/parser/block.go b/common/parser/block.go index 7228e90..c0374d9 100644 --- a/common/parser/block.go +++ b/common/parser/block.go @@ -267,6 +267,18 @@ func rawConvertJson(pHeader, pStart unsafe.Pointer, row int) driver.Value { return binaryVal[:] } +func ReadBlockSimple(block unsafe.Pointer, precision int) [][]driver.Value { + blockSize := RawBlockGetNumOfRows(block) + colCount := RawBlockGetNumOfCols(block) + colInfo := make([]RawBlockColInfo, colCount) + RawBlockGetColInfo(block, colInfo) + colTypes := make([]uint8, colCount) + for i := int32(0); i < colCount; i++ { + colTypes[i] = uint8(colInfo[i].ColType) + } + return ReadBlock(block, int(blockSize), colTypes, precision) +} + // ReadBlock in-place func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision int) [][]driver.Value { r := make([][]driver.Value, blockSize) diff --git a/common/parser/block_test.go b/common/parser/block_test.go index 5b7232a..42b2d26 100644 --- a/common/parser/block_test.go +++ b/common/parser/block_test.go @@ -663,12 +663,6 @@ func TestParseBlock(t *testing.T) { t.Error(errors.NewError(code, errStr)) return } - fileCount := wrapper.TaosNumFields(res) - rh, err := wrapper.ReadColumn(res, fileCount) - if err != nil { - t.Error(err) - return - } precision := wrapper.TaosResultPrecision(res) var data [][]driver.Value for { @@ -684,7 +678,7 @@ func TestParseBlock(t *testing.T) { break } version := RawBlockGetVersion(block) - assert.Equal(t, int32(1), version) + t.Log(version) length := RawBlockGetLength(block) assert.Equal(t, int32(447), length) rows := RawBlockGetNumOfRows(block) @@ -771,7 +765,7 @@ func TestParseBlock(t *testing.T) { }, infos, ) - d := ReadBlock(block, blockSize, rh.ColTypes, precision) + d := ReadBlockSimple(block, precision) data = append(data, d...) } wrapper.TaosFreeResult(res) diff --git a/common/parser/raw.go b/common/parser/raw.go new file mode 100644 index 0000000..21ede0c --- /dev/null +++ b/common/parser/raw.go @@ -0,0 +1,179 @@ +package parser + +import ( + "fmt" + "unsafe" + + "github.com/taosdata/driver-go/v3/common/pointer" +) + +type TMQRawDataParser struct { + block unsafe.Pointer + offset uintptr +} + +func NewTMQRawDataParser() *TMQRawDataParser { + return &TMQRawDataParser{} +} + +type TMQBlockInfo struct { + RawBlock unsafe.Pointer + Precision int + Schema []*TMQRawDataSchema + TableName string +} + +type TMQRawDataSchema struct { + ColType uint8 + Flag int8 + Bytes int64 + ColID int + Name string +} + +func (p *TMQRawDataParser) getTypeSkip(t int8) (int, error) { + skip := 8 + switch t { + case 1: + case 2, 3: + skip = 16 + default: + return 0, fmt.Errorf("unknown type %d", t) + } + return skip, nil +} + +func (p *TMQRawDataParser) skipHead() error { + t := p.parseInt8() + skip, err := p.getTypeSkip(t) + if err != nil { + return err + } + p.skip(skip) + t = p.parseInt8() + skip, err = p.getTypeSkip(t) + if err != nil { + return err + } + p.skip(skip) + return nil +} + +func (p *TMQRawDataParser) skip(count int) { + p.offset += uintptr(count) +} + +func (p *TMQRawDataParser) parseBlockInfos() []*TMQBlockInfo { + blockNum := p.parseInt32() + blockInfos := make([]*TMQBlockInfo, blockNum) + withTableName := p.parseBool() + withSchema := p.parseBool() + for i := int32(0); i < blockNum; i++ { + blockInfo := &TMQBlockInfo{} + blockTotalLen := p.parseVariableByteInteger() + p.skip(17) + blockInfo.Precision = int(p.parseUint8()) + blockInfo.RawBlock = pointer.AddUintptr(p.block, p.offset) + p.skip(blockTotalLen - 18) + if withSchema { + cols := p.parseZigzagVariableByteInteger() + //version + _ = p.parseZigzagVariableByteInteger() + + blockInfo.Schema = make([]*TMQRawDataSchema, cols) + for j := 0; j < cols; j++ { + blockInfo.Schema[j] = p.parseSchema() + } + } + if withTableName { + blockInfo.TableName = p.parseName() + } + blockInfos[i] = blockInfo + } + return blockInfos +} + +func (p *TMQRawDataParser) parseZigzagVariableByteInteger() int { + return zigzagDecode(p.parseVariableByteInteger()) +} + +func (p *TMQRawDataParser) parseBool() bool { + v := *(*int8)(pointer.AddUintptr(p.block, p.offset)) + p.skip(1) + return v != 0 +} + +func (p *TMQRawDataParser) parseUint8() uint8 { + v := *(*uint8)(pointer.AddUintptr(p.block, p.offset)) + p.skip(1) + return v +} + +func (p *TMQRawDataParser) parseInt8() int8 { + v := *(*int8)(pointer.AddUintptr(p.block, p.offset)) + p.skip(1) + return v +} + +func (p *TMQRawDataParser) parseInt32() int32 { + v := *(*int32)(pointer.AddUintptr(p.block, p.offset)) + p.skip(4) + return v +} + +func (p *TMQRawDataParser) parseSchema() *TMQRawDataSchema { + colType := p.parseUint8() + flag := p.parseInt8() + bytes := int64(p.parseZigzagVariableByteInteger()) + colID := p.parseZigzagVariableByteInteger() + name := p.parseName() + return &TMQRawDataSchema{ + ColType: colType, + Flag: flag, + Bytes: bytes, + ColID: colID, + Name: name, + } +} + +func (p *TMQRawDataParser) parseName() string { + nameLen := p.parseVariableByteInteger() + name := make([]byte, nameLen-1) + for i := 0; i < nameLen-1; i++ { + name[i] = *(*byte)(pointer.AddUintptr(p.block, p.offset+uintptr(i))) + } + p.skip(nameLen) + return string(name) +} + +func (p *TMQRawDataParser) Parse(block unsafe.Pointer) ([]*TMQBlockInfo, error) { + p.reset(block) + err := p.skipHead() + if err != nil { + return nil, err + } + return p.parseBlockInfos(), nil +} + +func (p *TMQRawDataParser) reset(block unsafe.Pointer) { + p.block = block + p.offset = 0 +} + +func (p *TMQRawDataParser) parseVariableByteInteger() int { + multiplier := 1 + value := 0 + for { + encodedByte := p.parseUint8() + value += int(encodedByte&127) * multiplier + if encodedByte&128 == 0 { + break + } + multiplier *= 128 + } + return value +} + +func zigzagDecode(n int) int { + return (n >> 1) ^ (-(n & 1)) +} diff --git a/common/parser/raw_test.go b/common/parser/raw_test.go new file mode 100644 index 0000000..ef6f56c --- /dev/null +++ b/common/parser/raw_test.go @@ -0,0 +1,797 @@ +package parser + +import ( + "database/sql/driver" + "fmt" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x01, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x01, 0x00, 0x00, 0x00, + + 0x01, + 0x01, + + 0xc5, 0x01, + + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x02, + + 0x02, 0x00, 0x00, 0x00, + 0xb3, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x06, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x82, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x5c, 0x00, 0x00, 0x00, + + 0x00, + 0xc0, 0xed, 0x82, 0x05, 0xc3, 0x1b, 0xab, 0x17, + + 0x80, + 0x00, 0x00, 0x00, 0x00, + + 0x80, + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x5a, 0x00, + 0x61, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x61, + + 0x08, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x03, + 0x63, 0x31, 0x00, + + 0x06, + 0x01, + 0x08, + 0x06, + 0x03, + 0x63, 0x32, 0x00, + + 0x08, + 0x01, + 0x84, 0x02, + 0x08, + 0x03, 0x63, 0x33, 0x00, + + 0x05, + 0x63, 0x74, 0x62, 0x30, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 1, len(blockInfos)) + assert.Equal(t, 2, blockInfos[0].Precision) + assert.Equal(t, 4, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 2, + Name: "c1", + }, + { + ColType: 6, + Flag: 1, + Bytes: 4, + ColID: 3, + Name: "c2", + }, + { + ColType: 8, + Flag: 1, + Bytes: 130, + ColID: 4, + Name: "c3", + }, + }, blockInfos[0].Schema) + assert.Equal(t, "ctb0", blockInfos[0].TableName) +} + +func TestParseTwoBlock(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x02, 0x00, 0x00, 0x00, + + 0x00, // withTbName false + 0x01, // withSchema true + + 0x60, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x4e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x0c, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x00, + 0xf8, 0x6b, 0x75, 0x35, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, + 0x63, 0x74, 0x30, + + 0x06, + 0x00, + + 0x09, + 0x00, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x00, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x08, + 0x00, + 0x18, + 0x06, + 0x02, + 0x6e, 0x00, + + 0x60, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x4e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x0c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x00, + 0xf9, 0x6b, 0x75, 0x35, + 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x01, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, + 0x63, 0x74, 0x31, + + 0x06, + 0x00, + + 0x09, + 0x00, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x00, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x08, + 0x00, + 0x18, + 0x06, + 0x02, + 0x6e, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 2, len(blockInfos)) + assert.Equal(t, 0, blockInfos[0].Precision) + assert.Equal(t, 0, blockInfos[1].Precision) + assert.Equal(t, 3, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 0, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 0, + Bytes: 4, + ColID: 2, + Name: "v", + }, + { + ColType: 8, + Flag: 0, + Bytes: 12, + ColID: 3, + Name: "n", + }, + }, blockInfos[0].Schema) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 0, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 0, + Bytes: 4, + ColID: 2, + Name: "v", + }, + { + ColType: 8, + Flag: 0, + Bytes: 12, + ColID: 3, + Name: "n", + }, + }, blockInfos[1].Schema) + assert.Equal(t, "", blockInfos[0].TableName) + assert.Equal(t, "", blockInfos[1].TableName) +} + +func TestParseTenBlock(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + 0x0d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, + 0x01, + 0x01, + + // block1 + 0x4e, + + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x01, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x31, 0x00, + + //block2 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + 0x00, + 0x02, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x32, 0x00, + + //block3 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x03, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x33, 0x00, + + //block4 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x34, 0x00, + + // block5 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x35, 0x00, + + //block6 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x06, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x36, 0x00, + + //block7 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x07, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x37, 0x00, + + //block8 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x08, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x38, 0x00, + + //block9 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x09, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x39, 0x00, + + //block10 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + 0x00, + 0x0a, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + 0x04, + 0x74, 0x31, 0x30, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 10, len(blockInfos)) + for i := 0; i < 10; i++ { + assert.Equal(t, 0, blockInfos[i].Precision) + assert.Equal(t, 2, len(blockInfos[i].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 2, + Name: "v", + }, + }, blockInfos[i].Schema) + assert.Equal(t, fmt.Sprintf("t%d", i+1), blockInfos[i].TableName) + value := ReadBlockSimple(blockInfos[i].RawBlock, blockInfos[i].Precision) + ts := time.Unix(0, 1706081119570000000).Local() + assert.Equal(t, [][]driver.Value{{ts, int32(i + 1)}}, value) + } +} diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index 46ea559..ec8e40e 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -26,6 +26,7 @@ type Consumer struct { client *client.Client requestID uint64 err error + dataParser *parser.TMQRawDataParser listLock sync.RWMutex sendChanList *list.List messageTimeout time.Duration @@ -108,6 +109,7 @@ func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { snapshotEnable: config.SnapshotEnable, withTableName: config.WithTableName, closeChan: make(chan struct{}), + dataParser: parser.NewTMQRawDataParser(), } if config.WriteWait > 0 { wsClient.WriteWait = config.WriteWait @@ -316,8 +318,7 @@ func (c *Consumer) findOutChanByID(index uint64) *list.Element { const ( TMQSubscribe = "subscribe" TMQPoll = "poll" - TMQFetch = "fetch" - TMQFetchBlock = "fetch_block" + TMQFetchRaw = "fetch_raw" TMQFetchJsonMeta = "fetch_json_meta" TMQCommit = "commit" TMQUnsubscribe = "unsubscribe" @@ -566,73 +567,38 @@ func (c *Consumer) fetchJsonMeta(messageID uint64) (*tmq.Meta, error) { } func (c *Consumer) fetch(messageID uint64) ([]*tmq.Data, error) { - var tmqData []*tmq.Data - for { - reqID := c.generateReqID() - req := &FetchReq{ - ReqID: reqID, - MessageID: messageID, - } - args, err := client.JsonI.Marshal(req) - if err != nil { - return nil, err - } - action := &client.WSAction{ - Action: TMQFetch, - Args: args, - } - envelope := c.client.GetEnvelope() - err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) - if err != nil { - c.client.PutEnvelope(envelope) - return nil, err - } - respBytes, err := c.sendText(reqID, envelope) - if err != nil { - return nil, err - } - var resp FetchResp - err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return nil, err - } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } - if resp.Completed { - break - } - // fetch block - { - req := &FetchBlockReq{ - ReqID: reqID, - MessageID: messageID, - } - args, err := client.JsonI.Marshal(req) - if err != nil { - return nil, err - } - action := &client.WSAction{ - Action: TMQFetchBlock, - Args: args, - } - envelope := c.client.GetEnvelope() - err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) - if err != nil { - c.client.PutEnvelope(envelope) - return nil, err - } - respBytes, err := c.sendText(reqID, envelope) - if err != nil { - return nil, err - } - block := respBytes[24:] - p := unsafe.Pointer(&block[0]) - data := parser.ReadBlock(p, resp.Rows, resp.FieldsTypes, resp.Precision) - tmqData = append(tmqData, &tmq.Data{ - TableName: resp.TableName, - Data: data, - }) + reqID := c.generateReqID() + req := &TMQFetchRawMetaReq{ + ReqID: reqID, + MessageID: messageID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQFetchRaw, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + blockInfo, err := c.dataParser.Parse(unsafe.Pointer(&respBytes[38])) + if err != nil { + return nil, err + } + tmqData := make([]*tmq.Data, len(blockInfo)) + for i := 0; i < len(blockInfo); i++ { + tmqData[i] = &tmq.Data{ + TableName: blockInfo[i].TableName, + Data: parser.ReadBlockSimple(blockInfo[i].RawBlock, blockInfo[i].Precision), } } return tmqData, nil diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index 4d382eb..dc394e9 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -450,3 +450,99 @@ func TestAutoCommit(t *testing.T) { assert.Equal(t, 1, len(offset)) assert.GreaterOrEqual(t, offset[0].Offset, messageOffset) } + +func prepareMultiBlockEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_multi_block_topic", + "drop database if exists test_ws_tmq_multi_block", + "create database test_ws_tmq_multi_block vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_multi_block_topic as database test_ws_tmq_multi_block", + "create table test_ws_tmq_multi_block.t1(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t2(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t3(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t4(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t5(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t6(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t7(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t8(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t9(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t10(ts timestamp,v int)", + "insert into test_ws_tmq_multi_block.t1 values (now,1) test_ws_tmq_multi_block.t2 values (now,2) " + + "test_ws_tmq_multi_block.t3 values (now,3) test_ws_tmq_multi_block.t4 values (now,4)" + + "test_ws_tmq_multi_block.t5 values (now,5) test_ws_tmq_multi_block.t6 values (now,6)" + + "test_ws_tmq_multi_block.t7 values (now,7) test_ws_tmq_multi_block.t8 values (now,8)" + + "test_ws_tmq_multi_block.t9 values (now,9) test_ws_tmq_multi_block.t10 values (now,10)", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanMultiBlockEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_multi_block_topic", + "drop database if exists test_ws_tmq_multi_block", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func TestMultiBlock(t *testing.T) { + err := prepareMultiBlockEnv() + assert.NoError(t, err) + defer cleanMultiBlockEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "msg.with.table.name": "true", + }) + assert.NoError(t, err) + if err != nil { + t.Error(err) + return + } + defer func() { + consumer.Unsubscribe() + consumer.Close() + }() + topic := []string{"test_ws_tmq_multi_block_topic"} + err = consumer.SubscribeTopics(topic, nil) + if err != nil { + t.Error(err) + return + } + for i := 0; i < 10; i++ { + event := consumer.Poll(500) + if event == nil { + continue + } + switch e := event.(type) { + case *tmq.DataMessage: + data := e.Value().([]*tmq.Data) + assert.Equal(t, "test_ws_tmq_multi_block", e.DBName()) + assert.Equal(t, 10, len(data)) + return + } + } +} diff --git a/ws/tmq/proto.go b/ws/tmq/proto.go index d9b8c1d..3a17c8b 100644 --- a/ws/tmq/proto.go +++ b/ws/tmq/proto.go @@ -196,3 +196,8 @@ type PositionResp struct { Timing int64 `json:"timing"` Position []int64 `json:"position"` } + +type TMQFetchRawMetaReq struct { + ReqID uint64 `json:"req_id"` + MessageID uint64 `json:"message_id"` +}