diff --git a/common/parser/block.go b/common/parser/block.go index c0374d9..705ab7d 100644 --- a/common/parser/block.go +++ b/common/parser/block.go @@ -3,6 +3,7 @@ package parser import ( "database/sql/driver" "math" + "unicode/utf8" "unsafe" "github.com/taosdata/driver-go/v3/common" @@ -108,28 +109,9 @@ type rawConvertFunc func(pStart unsafe.Pointer, row int, arg ...interface{}) dri type rawConvertVarDataFunc func(pHeader, pStart unsafe.Pointer, row int) driver.Value -var rawConvertFuncMap = map[uint8]rawConvertFunc{ - uint8(common.TSDB_DATA_TYPE_BOOL): rawConvertBool, - uint8(common.TSDB_DATA_TYPE_TINYINT): rawConvertTinyint, - uint8(common.TSDB_DATA_TYPE_SMALLINT): rawConvertSmallint, - uint8(common.TSDB_DATA_TYPE_INT): rawConvertInt, - uint8(common.TSDB_DATA_TYPE_BIGINT): rawConvertBigint, - uint8(common.TSDB_DATA_TYPE_UTINYINT): rawConvertUTinyint, - uint8(common.TSDB_DATA_TYPE_USMALLINT): rawConvertUSmallint, - uint8(common.TSDB_DATA_TYPE_UINT): rawConvertUInt, - uint8(common.TSDB_DATA_TYPE_UBIGINT): rawConvertUBigint, - uint8(common.TSDB_DATA_TYPE_FLOAT): rawConvertFloat, - uint8(common.TSDB_DATA_TYPE_DOUBLE): rawConvertDouble, - uint8(common.TSDB_DATA_TYPE_TIMESTAMP): rawConvertTime, -} - -var rawConvertVarDataMap = map[uint8]rawConvertVarDataFunc{ - uint8(common.TSDB_DATA_TYPE_BINARY): rawConvertBinary, - uint8(common.TSDB_DATA_TYPE_NCHAR): rawConvertNchar, - uint8(common.TSDB_DATA_TYPE_JSON): rawConvertJson, - uint8(common.TSDB_DATA_TYPE_VARBINARY): rawConvertVarBinary, - uint8(common.TSDB_DATA_TYPE_GEOMETRY): rawConvertGeometry, -} +var rawConvertFuncSlice = [15]rawConvertFunc{} + +var rawConvertVarDataSlice = [21]rawConvertVarDataFunc{} func ItemIsNull(pHeader unsafe.Pointer, row int) bool { offset := CharOffset(row) @@ -196,20 +178,23 @@ func rawConvertTime(pStart unsafe.Pointer, row int, arg ...interface{}) driver.V } func rawConvertVarBinary(pHeader, pStart unsafe.Pointer, row int) driver.Value { + return rawGetBytes(pHeader, pStart, row) +} + +func rawGetBytes(pHeader, pStart unsafe.Pointer, row int) []byte { offset := *((*int32)(pointer.AddUintptr(pHeader, uintptr(row*4)))) if offset == -1 { return nil } currentRow := pointer.AddUintptr(pStart, uintptr(offset)) clen := *((*uint16)(currentRow)) - currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - - binaryVal := make([]byte, clen) - - for index := uint16(0); index < clen; index++ { - binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) + if clen == 0 { + return make([]byte, 0) } - return binaryVal[:] + currentRow = pointer.AddUintptr(currentRow, 2) + result := make([]byte, clen) + Copy(currentRow, result, 0, int(clen)) + return result } func rawConvertGeometry(pHeader, pStart unsafe.Pointer, row int) driver.Value { @@ -217,20 +202,11 @@ func rawConvertGeometry(pHeader, pStart unsafe.Pointer, row int) driver.Value { } func rawConvertBinary(pHeader, pStart unsafe.Pointer, row int) driver.Value { - offset := *((*int32)(pointer.AddUintptr(pHeader, uintptr(row*4)))) - if offset == -1 { + result := rawGetBytes(pHeader, pStart, row) + if result == nil { return nil } - currentRow := pointer.AddUintptr(pStart, uintptr(offset)) - clen := *((*uint16)(currentRow)) - currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - - binaryVal := make([]byte, clen) - - for index := uint16(0); index < clen; index++ { - binaryVal[index] = *((*byte)(unsafe.Pointer(uintptr(currentRow) + uintptr(index)))) - } - return string(binaryVal[:]) + return *(*string)(unsafe.Pointer(&result)) } func rawConvertNchar(pHeader, pStart unsafe.Pointer, row int) driver.Value { @@ -240,31 +216,22 @@ func rawConvertNchar(pHeader, pStart unsafe.Pointer, row int) driver.Value { } currentRow := pointer.AddUintptr(pStart, uintptr(offset)) clen := *((*uint16)(currentRow)) / 4 + if clen == 0 { + return "" + } currentRow = unsafe.Pointer(uintptr(currentRow) + 2) - - binaryVal := make([]rune, clen) - - for index := uint16(0); index < clen; index++ { - binaryVal[index] = *((*rune)(unsafe.Pointer(uintptr(currentRow) + uintptr(index*4)))) + utf8Bytes := make([]byte, clen*utf8.UTFMax) + index := 0 + utf32Slice := (*[1 << 30]rune)(currentRow)[:clen:clen] + for _, runeValue := range utf32Slice { + index += utf8.EncodeRune(utf8Bytes[index:], runeValue) } - return string(binaryVal) + utf8Bytes = utf8Bytes[:index] + return *(*string)(unsafe.Pointer(&utf8Bytes)) } func rawConvertJson(pHeader, pStart unsafe.Pointer, row int) driver.Value { - offset := *((*int32)(pointer.AddUintptr(pHeader, uintptr(row*4)))) - if offset == -1 { - return nil - } - currentRow := pointer.AddUintptr(pStart, uintptr(offset)) - clen := *((*uint16)(currentRow)) - currentRow = pointer.AddUintptr(currentRow, 2) - - binaryVal := make([]byte, clen) - - for index := uint16(0); index < clen; index++ { - binaryVal[index] = *((*byte)(pointer.AddUintptr(currentRow, uintptr(index)))) - } - return binaryVal[:] + return rawConvertVarBinary(pHeader, pStart, row) } func ReadBlockSimple(block unsafe.Pointer, precision int) [][]driver.Value { @@ -290,7 +257,7 @@ func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision for column := 0; column < colCount; column++ { colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(colTypes[column]) { - convertF := rawConvertVarDataMap[colTypes[column]] + convertF := rawConvertVarDataSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, Int32Size*uintptr(blockSize)) for row := 0; row < blockSize; row++ { if column == 0 { @@ -299,7 +266,7 @@ func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision r[row][column] = convertF(pHeader, pStart, row) } } else { - convertF := rawConvertFuncMap[colTypes[column]] + convertF := rawConvertFuncSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, nullBitMapOffset) for row := 0; row < blockSize; row++ { if column == 0 { @@ -326,11 +293,11 @@ func ReadRow(dest []driver.Value, block unsafe.Pointer, blockSize int, row int, for column := 0; column < colCount; column++ { colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(colTypes[column]) { - convertF := rawConvertVarDataMap[colTypes[column]] + convertF := rawConvertVarDataSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, Int32Size*uintptr(blockSize)) dest[column] = convertF(pHeader, pStart, row) } else { - convertF := rawConvertFuncMap[colTypes[column]] + convertF := rawConvertFuncSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, nullBitMapOffset) if ItemIsNull(pHeader, row) { dest[column] = nil @@ -352,7 +319,7 @@ func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uin for column := 0; column < colCount; column++ { colLength := *((*int32)(pointer.AddUintptr(block, lengthOffset+uintptr(column)*Int32Size))) if IsVarDataType(colTypes[column]) { - convertF := rawConvertVarDataMap[colTypes[column]] + convertF := rawConvertVarDataSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, uintptr(4*blockSize)) for row := 0; row < blockSize; row++ { if column == 0 { @@ -361,7 +328,7 @@ func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uin r[row][column] = convertF(pHeader, pStart, row) } } else { - convertF := rawConvertFuncMap[colTypes[column]] + convertF := rawConvertFuncSlice[colTypes[column]] pStart = pointer.AddUintptr(pHeader, nullBitMapOffset) for row := 0; row < blockSize; row++ { if column == 0 { @@ -381,12 +348,33 @@ func ReadBlockWithTimeFormat(block unsafe.Pointer, blockSize int, colTypes []uin func ItemRawBlock(colType uint8, pHeader, pStart unsafe.Pointer, row int, precision int, timeFormat FormatTimeFunc) driver.Value { if IsVarDataType(colType) { - return rawConvertVarDataMap[colType](pHeader, pStart, row) + return rawConvertVarDataSlice[colType](pHeader, pStart, row) } else { if ItemIsNull(pHeader, row) { return nil } else { - return rawConvertFuncMap[colType](pStart, row, precision, timeFormat) + return rawConvertFuncSlice[colType](pStart, row, precision, timeFormat) } } } + +func init() { + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_BOOL)] = rawConvertBool + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_TINYINT)] = rawConvertTinyint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_SMALLINT)] = rawConvertSmallint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_INT)] = rawConvertInt + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_BIGINT)] = rawConvertBigint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_UTINYINT)] = rawConvertUTinyint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_USMALLINT)] = rawConvertUSmallint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_UINT)] = rawConvertUInt + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_UBIGINT)] = rawConvertUBigint + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_FLOAT)] = rawConvertFloat + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_DOUBLE)] = rawConvertDouble + rawConvertFuncSlice[uint8(common.TSDB_DATA_TYPE_TIMESTAMP)] = rawConvertTime + + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_BINARY)] = rawConvertBinary + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_NCHAR)] = rawConvertNchar + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_JSON)] = rawConvertJson + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_VARBINARY)] = rawConvertVarBinary + rawConvertVarDataSlice[uint8(common.TSDB_DATA_TYPE_GEOMETRY)] = rawConvertGeometry +} diff --git a/common/parser/mem.go b/common/parser/mem.go new file mode 100644 index 0000000..f0d4b00 --- /dev/null +++ b/common/parser/mem.go @@ -0,0 +1,12 @@ +package parser + +import "unsafe" + +//go:noescape +func memmove(to, from unsafe.Pointer, n uintptr) + +//go:linkname memmove runtime.memmove + +func Copy(source unsafe.Pointer, data []byte, index int, length int) { + memmove(unsafe.Pointer(&data[index]), source, uintptr(length)) +} diff --git a/common/parser/mem.s b/common/parser/mem.s new file mode 100644 index 0000000..e69de29 diff --git a/common/parser/mem_test.go b/common/parser/mem_test.go new file mode 100644 index 0000000..d3e244b --- /dev/null +++ b/common/parser/mem_test.go @@ -0,0 +1,20 @@ +package parser + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestCopy(t *testing.T) { + data := []byte("World") + data1 := make([]byte, 10) + data1[0] = 'H' + data1[1] = 'e' + data1[2] = 'l' + data1[3] = 'l' + data1[4] = 'o' + Copy(unsafe.Pointer(&data[0]), data1, 5, 5) + assert.Equal(t, "HelloWorld", string(data1)) +} diff --git a/taosSql/rows.go b/taosSql/rows.go index aaf684d..54b61a8 100644 --- a/taosSql/rows.go +++ b/taosSql/rows.go @@ -19,7 +19,6 @@ type rows struct { block unsafe.Pointer blockOffset int blockSize int - lengthList []int result unsafe.Pointer precision int isStmt bool @@ -107,7 +106,6 @@ func (rs *rows) taosFetchBlock() error { } rs.blockSize = result.N rs.block = wrapper.TaosGetRawBlock(result.Res) - rs.lengthList = wrapper.FetchLengths(rs.result, len(rs.rowsHeader.ColLength)) rs.blockOffset = 0 return nil } diff --git a/taosWS/connection.go b/taosWS/connection.go index 5038ea4..c465815 100644 --- a/taosWS/connection.go +++ b/taosWS/connection.go @@ -9,6 +9,7 @@ import ( "fmt" "net/url" "strings" + "sync" "sync/atomic" "time" @@ -37,6 +38,11 @@ const ( STMTUseResult = "use_result" ) +const ( + BinaryQueryMessage uint64 = 6 + FetchRawBlockMessage uint64 = 7 +) + var ( NotQueryError = errors.New("sql is an update statement not a query statement") ReadTimeoutError = errors.New("read timeout") @@ -46,11 +52,21 @@ type taosConn struct { buf *bytes.Buffer client *websocket.Conn requestID uint64 + writeLock sync.Mutex readTimeout time.Duration writeTimeout time.Duration cfg *config + messageChan chan *message + messageError error endpoint string - closed atomic.Bool // set when conn is closed, + closed uint32 + closeCh chan struct{} +} + +type message struct { + mt int + message []byte + err error } func (tc *taosConn) generateReqID() uint64 { @@ -87,8 +103,12 @@ func newTaosConn(cfg *config) (*taosConn, error) { writeTimeout: cfg.writeTimeout, cfg: cfg, endpoint: endpoint, + closeCh: make(chan struct{}), + messageChan: make(chan *message, 10), } + go tc.ping() + go tc.read() err = tc.connect() if err != nil { tc.Close() @@ -96,15 +116,46 @@ func newTaosConn(cfg *config) (*taosConn, error) { return tc, nil } +func (tc *taosConn) ping() { + ticker := time.NewTicker(common.DefaultPingPeriod) + defer ticker.Stop() + for { + select { + case <-tc.closeCh: + return + case <-ticker.C: + tc.writePing() + } + } +} + +func (tc *taosConn) read() { + for { + mt, msg, err := tc.client.ReadMessage() + tc.messageChan <- &message{ + mt: mt, + message: msg, + err: err, + } + if err != nil { + tc.messageError = NewBadConnError(err) + break + } + if tc.isClosed() { + break + } + } +} + func (tc *taosConn) Begin() (driver.Tx, error) { return nil, &taosErrors.TaosError{Code: 0xffff, ErrStr: "websocket does not support transaction"} } func (tc *taosConn) Close() (err error) { - if tc.closed.Swap(true) { - return nil + if !tc.isClosed() { + atomic.StoreUint32(&tc.closed, 1) + close(tc.closeCh) } - if tc.client != nil { err = tc.client.Close() } @@ -114,8 +165,12 @@ func (tc *taosConn) Close() (err error) { return err } +func (tc *taosConn) isClosed() bool { + return atomic.LoadUint32(&tc.closed) != 0 +} + func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { - if tc.closed.Load() { + if tc.isClosed() { return nil, driver.ErrBadConn } stmtID, err := tc.stmtInit() @@ -297,6 +352,18 @@ func WriteUint64(buffer *bytes.Buffer, v uint64) { buffer.WriteByte(byte(v >> 56)) } +func WriteUint32(buffer *bytes.Buffer, v uint32) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) + buffer.WriteByte(byte(v >> 16)) + buffer.WriteByte(byte(v >> 24)) +} + +func WriteUint16(buffer *bytes.Buffer, v uint16) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) +} + func (tc *taosConn) stmtAddBatch(stmtID uint64) error { reqID := tc.generateReqID() req := &StmtAddBatchRequest{ @@ -417,45 +484,8 @@ func (tc *taosConn) ExecContext(ctx context.Context, query string, args []driver return tc.execCtx(ctx, query, args) } -func (tc *taosConn) execCtx(_ context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - if tc.closed.Load() { - return nil, driver.ErrBadConn - } - if len(args) != 0 { - if !tc.cfg.interpolateParams { - return nil, driver.ErrSkip - } - // try to interpolate the parameters to save extra round trips for preparing and closing a statement - prepared, err := common.InterpolateParams(query, args) - if err != nil { - return nil, err - } - query = prepared - } - reqID := tc.generateReqID() - req := &WSQueryReq{ - ReqID: reqID, - SQL: query, - } - reqArgs, err := json.Marshal(req) - if err != nil { - return nil, err - } - action := &WSAction{ - Action: WSQuery, - Args: reqArgs, - } - tc.buf.Reset() - err = jsonI.NewEncoder(tc.buf).Encode(action) - if err != nil { - return nil, err - } - err = tc.writeText(tc.buf.Bytes()) - if err != nil { - return nil, err - } - var resp WSQueryResp - err = tc.readTo(&resp) +func (tc *taosConn) execCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + resp, err := tc.doQuery(ctx, query, args) if err != nil { return nil, err } @@ -473,8 +503,32 @@ func (tc *taosConn) QueryContext(ctx context.Context, query string, args []drive return tc.queryCtx(ctx, query, args) } -func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - if tc.closed.Load() { +func (tc *taosConn) queryCtx(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + resp, err := tc.doQuery(ctx, query, args) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + if resp.IsUpdate { + return nil, NotQueryError + } + rs := &rows{ + buf: &bytes.Buffer{}, + conn: tc, + resultID: resp.ID, + fieldsCount: resp.FieldsCount, + fieldsNames: resp.FieldsNames, + fieldsTypes: resp.FieldsTypes, + fieldsLengths: resp.FieldsLengths, + precision: resp.Precision, + } + return rs, err +} + +func (tc *taosConn) doQuery(_ context.Context, query string, args []driver.NamedValue) (*WSQueryResp, error) { + if tc.isClosed() { return nil, driver.ErrBadConn } if len(args) != 0 { @@ -489,24 +543,15 @@ func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.Name query = prepared } reqID := tc.generateReqID() - req := &WSQueryReq{ - ReqID: reqID, - SQL: query, - } - reqArgs, err := json.Marshal(req) - if err != nil { - return nil, err - } - action := &WSAction{ - Action: WSQuery, - Args: reqArgs, - } tc.buf.Reset() - err = jsonI.NewEncoder(tc.buf).Encode(action) - if err != nil { - return nil, err - } - err = tc.writeText(tc.buf.Bytes()) + + WriteUint64(tc.buf, reqID) // req id + WriteUint64(tc.buf, 0) // message id + WriteUint64(tc.buf, BinaryQueryMessage) + WriteUint16(tc.buf, 1) // version + WriteUint32(tc.buf, uint32(len(query))) // sql length + tc.buf.WriteString(query) + err := tc.writeBinary(tc.buf.Bytes()) if err != nil { return nil, err } @@ -515,30 +560,14 @@ func (tc *taosConn) queryCtx(_ context.Context, query string, args []driver.Name if err != nil { return nil, err } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } - if resp.IsUpdate { - return nil, NotQueryError - } - rs := &rows{ - buf: &bytes.Buffer{}, - conn: tc, - resultID: resp.ID, - fieldsCount: resp.FieldsCount, - fieldsNames: resp.FieldsNames, - fieldsTypes: resp.FieldsTypes, - fieldsLengths: resp.FieldsLengths, - precision: resp.Precision, - } - return rs, err + return &resp, nil } func (tc *taosConn) Ping(ctx context.Context) (err error) { - if tc.closed.Load() { + if tc.isClosed() { return driver.ErrBadConn } - return nil + return tc.writePing() } func (tc *taosConn) connect() error { @@ -577,14 +606,26 @@ func (tc *taosConn) connect() error { } func (tc *taosConn) writeText(data []byte) error { - return tc.write(data, websocket.TextMessage) + return tc.write(websocket.TextMessage, data) } func (tc *taosConn) writeBinary(data []byte) error { - return tc.write(data, websocket.BinaryMessage) + return tc.write(websocket.BinaryMessage, data) } -func (tc *taosConn) write(data []byte, messageType int) error { +func (tc *taosConn) writePing() error { + return tc.write(websocket.PingMessage, nil) +} + +func (tc *taosConn) write(messageType int, data []byte) error { + tc.writeLock.Lock() + defer tc.writeLock.Unlock() + if tc.isClosed() { + return driver.ErrBadConn + } + if tc.messageError != nil { + return tc.messageError + } tc.client.SetWriteDeadline(time.Now().Add(tc.writeTimeout)) err := tc.client.WriteMessage(messageType, data) if err != nil { @@ -594,63 +635,50 @@ func (tc *taosConn) write(data []byte, messageType int) error { } func (tc *taosConn) readTo(to interface{}) error { - var outErr error - done := make(chan struct{}) - go func() { - defer func() { - close(done) - }() - mt, respBytes, err := tc.client.ReadMessage() - if err != nil { - outErr = NewBadConnError(err) - return - } - if mt != websocket.TextMessage { - outErr = NewBadConnErrorWithCtx(fmt.Errorf("readTo: got wrong message type %d", mt), formatBytes(respBytes)) - return - } - err = jsonI.Unmarshal(respBytes, to) - if err != nil { - outErr = NewBadConnErrorWithCtx(err, string(respBytes)) - return - } - }() - ctx, cancel := context.WithTimeout(context.Background(), tc.readTimeout) - defer cancel() - select { - case <-done: - return outErr - case <-ctx.Done(): - return NewBadConnError(ReadTimeoutError) + mt, respBytes, err := tc.readResponse() + if err != nil { + return err + } + if mt != websocket.TextMessage { + return NewBadConnErrorWithCtx(fmt.Errorf("readTo: got wrong message type %d", mt), formatBytes(respBytes)) } + err = jsonI.Unmarshal(respBytes, to) + if err != nil { + return NewBadConnErrorWithCtx(err, string(respBytes)) + } + return nil } func (tc *taosConn) readBytes() ([]byte, error) { - var respBytes []byte - var outErr error - done := make(chan struct{}) - go func() { - defer func() { - close(done) - }() - mt, message, err := tc.client.ReadMessage() - if err != nil { - outErr = NewBadConnError(err) - return - } - if mt != websocket.BinaryMessage { - outErr = NewBadConnErrorWithCtx(fmt.Errorf("readBytes: got wrong message type %d", mt), string(respBytes)) - return - } - respBytes = message - }() + mt, respBytes, err := tc.readResponse() + if err != nil { + return nil, err + } + if mt != websocket.BinaryMessage { + return nil, NewBadConnErrorWithCtx(fmt.Errorf("readBytes: got wrong message type %d", mt), string(respBytes)) + } + return respBytes, err +} + +func (tc *taosConn) readResponse() (int, []byte, error) { + if tc.isClosed() { + return 0, nil, driver.ErrBadConn + } + if tc.messageError != nil { + return 0, nil, tc.messageError + } ctx, cancel := context.WithTimeout(context.Background(), tc.readTimeout) defer cancel() select { - case <-done: - return respBytes, outErr + case <-tc.closeCh: + return 0, nil, driver.ErrBadConn + case msg := <-tc.messageChan: + if msg.err != nil { + return 0, nil, NewBadConnError(msg.err) + } + return msg.mt, msg.message, nil case <-ctx.Done(): - return nil, NewBadConnError(ReadTimeoutError) + return 0, nil, NewBadConnError(ReadTimeoutError) } } diff --git a/taosWS/rows.go b/taosWS/rows.go index 7677ab7..636f54c 100644 --- a/taosWS/rows.go +++ b/taosWS/rows.go @@ -3,14 +3,15 @@ package taosWS import ( "bytes" "database/sql/driver" + "encoding/binary" "encoding/json" + "fmt" "io" "reflect" "unsafe" "github.com/taosdata/driver-go/v3/common" "github.com/taosdata/driver-go/v3/common/parser" - "github.com/taosdata/driver-go/v3/common/pointer" taosErrors "github.com/taosdata/driver-go/v3/errors" ) @@ -86,75 +87,53 @@ func (rs *rows) Next(dest []driver.Value) error { func (rs *rows) taosFetchBlock() error { reqID := rs.conn.generateReqID() - req := &WSFetchReq{ - ReqID: reqID, - ID: rs.resultID, - } - args, err := json.Marshal(req) - if err != nil { - return err - } - action := &WSAction{ - Action: WSFetch, - Args: args, - } rs.buf.Reset() - - err = jsonI.NewEncoder(rs.buf).Encode(action) + WriteUint64(rs.buf, reqID) // req id + WriteUint64(rs.buf, rs.resultID) // message id + WriteUint64(rs.buf, FetchRawBlockMessage) + WriteUint16(rs.buf, 1) // version + err := rs.conn.writeBinary(rs.buf.Bytes()) if err != nil { return err } - err = rs.conn.writeText(rs.buf.Bytes()) + respBytes, err := rs.conn.readBytes() if err != nil { return err } - var resp WSFetchResp - err = rs.conn.readTo(&resp) - if err != nil { - return err + if len(respBytes) < 51 { + return taosErrors.NewError(0xffff, "invalid fetch raw block response") + } + version := binary.LittleEndian.Uint16(respBytes[16:]) + if version != 1 { + return taosErrors.NewError(0xffff, fmt.Sprintf("unsupported fetch raw block version: %d", version)) + } + code := binary.LittleEndian.Uint32(respBytes[34:]) + msgLen := int(binary.LittleEndian.Uint32(respBytes[38:])) + if len(respBytes) < 51+msgLen { + return taosErrors.NewError(0xffff, "invalid fetch raw block response") } - if resp.Code != 0 { - return taosErrors.NewError(resp.Code, resp.Message) + errMsg := string(respBytes[42 : 42+msgLen]) + if code != 0 { + return taosErrors.NewError(int(code), errMsg) } - if resp.Completed { + completed := respBytes[50+msgLen] == 1 + if completed { rs.blockSize = 0 return nil } else { - rs.blockSize = resp.Rows - return rs.fetchBlock() - } -} - -func (rs *rows) fetchBlock() error { - reqID := rs.conn.generateReqID() - req := &WSFetchBlockReq{ - ReqID: reqID, - ID: rs.resultID, - } - args, err := json.Marshal(req) - if err != nil { - return err - } - action := &WSAction{ - Action: WSFetchBlock, - Args: args, - } - rs.buf.Reset() - err = jsonI.NewEncoder(rs.buf).Encode(action) - if err != nil { - return err - } - err = rs.conn.writeText(rs.buf.Bytes()) - if err != nil { - return err - } - respBytes, err := rs.conn.readBytes() - if err != nil { - return err + if len(respBytes) < 55+msgLen { + return taosErrors.NewError(0xffff, "invalid fetch raw block response") + } + blockLength := binary.LittleEndian.Uint32(respBytes[51+msgLen:]) + if len(respBytes) < 55+msgLen+int(blockLength) { + return taosErrors.NewError(0xffff, "invalid fetch raw block response") + } + rawBlock := respBytes[55+msgLen : 55+msgLen+int(blockLength)] + rs.block = rawBlock + rs.blockPtr = unsafe.Pointer(&rs.block[0]) + rs.blockSize = int(parser.RawBlockGetNumOfRows(rs.blockPtr)) + rs.blockOffset = 0 } - rs.block = respBytes - rs.blockPtr = pointer.AddUintptr(unsafe.Pointer(&rs.block[0]), 16) - rs.blockOffset = 0 return nil } diff --git a/taosWS/statement.go b/taosWS/statement.go index cd1e910..11f1f7e 100644 --- a/taosWS/statement.go +++ b/taosWS/statement.go @@ -28,7 +28,7 @@ type Stmt struct { } func (stmt *Stmt) Close() error { - if stmt.conn == nil || stmt.conn.closed.Load() { + if stmt.conn == nil || stmt.conn.isClosed() || stmt.conn.messageError != nil { return driver.ErrBadConn } err := stmt.conn.stmtClose(stmt.stmtID) @@ -45,10 +45,7 @@ func (stmt *Stmt) NumInput() int { } func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.conn.closed.Load() { - return nil, driver.ErrBadConn - } - if stmt.conn == nil { + if stmt.conn.isClosed() { return nil, driver.ErrBadConn } if len(args) != len(stmt.cols) { @@ -74,10 +71,7 @@ func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { } func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { - if stmt.conn.closed.Load() { - return nil, driver.ErrBadConn - } - if stmt.conn == nil { + if stmt.conn.isClosed() { return nil, driver.ErrBadConn } block, err := serializer.SerializeRawBlock(param.NewParamsWithRowValue(args), param.NewColumnTypeWithValue(stmt.queryColTypes))