diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 0ea2913..8e35cbb 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -12,8 +12,6 @@ on: required: true type: string -env: - SCCACHE_GHA_ENABLED: "true" jobs: build: @@ -44,8 +42,6 @@ jobs: cd TDengine echo "commit_id=$(git rev-parse HEAD)" >> $GITHUB_OUTPUT - - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.3 - name: Cache server by pr if: github.event_name == 'pull_request' @@ -77,7 +73,7 @@ jobs: cd TDengine mkdir debug cd debug - cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + cmake .. -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 make -j 4 - name: package @@ -114,7 +110,7 @@ jobs: needs: build strategy: matrix: - go: [ '1.14', '1.19' ] + go: [ '1.14', 'stable' ] name: Go ${{ matrix.go }} steps: - name: get cache server by pr @@ -137,11 +133,6 @@ jobs: restore-keys: | ${{ runner.os }}-build-${{ inputs.tbBranch }}- - - name: checkout - uses: actions/checkout@v3 - with: - path: 'driver-go' - - name: prepare install run: sudo apt install -y libgeos-dev @@ -150,6 +141,9 @@ jobs: tar -zxvf server.tar.gz cd release && sudo sh install.sh + - name: checkout + uses: actions/checkout@v3 + - name: shell run: | cat >start.sh<> $GITHUB_OUTPUT - - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.3 - name: Cache server id: cache-server @@ -43,8 +39,6 @@ jobs: if: steps.cache-server.outputs.cache-hit != 'true' run: sudo apt install -y libgeos-dev - - name: Run sccache-cache - uses: mozilla-actions/sccache-action@v0.0.3 - name: install TDengine if: steps.cache-server.outputs.cache-hit != 'true' @@ -52,7 +46,7 @@ jobs: cd TDengine mkdir debug cd debug - cmake .. -DBUILD_JDBC=false -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 -DCMAKE_C_COMPILER_LAUNCHER=sccache -DCMAKE_CXX_COMPILER_LAUNCHER=sccache + cmake .. -DBUILD_JDBC=false -DBUILD_TEST=off -DBUILD_HTTP=false -DVERNUMBER=3.9.9.9 make -j 4 - name: package @@ -87,7 +81,7 @@ jobs: needs: build strategy: matrix: - go: [ '1.14', '1.19' ] + go: [ '1.14', 'stable' ] name: Go ${{ matrix.go }} steps: - name: get cache server @@ -99,11 +93,6 @@ jobs: restore-keys: | ${{ runner.os }}-build-${{ github.ref_name }}- - - name: checkout - uses: actions/checkout@v3 - with: - path: 'driver-go' - - name: prepare install run: sudo apt install -y libgeos-dev @@ -112,6 +101,9 @@ jobs: tar -zxvf server.tar.gz cd release && sudo sh install.sh + - name: checkout + uses: actions/checkout@v3 + - name: shell run: | cat >start.sh<= c.size { return c diff --git a/common/param/param.go b/common/param/param.go index a14854b..cce09d0 100644 --- a/common/param/param.go +++ b/common/param/param.go @@ -20,6 +20,15 @@ func NewParam(size int) *Param { } } +func NewParamsWithRowValue(value []driver.Value) []*Param { + params := make([]*Param, len(value)) + for i, d := range value { + params[i] = NewParam(1) + params[i].AddValue(d) + } + return params +} + func (p *Param) SetBool(offset int, value bool) { if offset >= p.size { return diff --git a/common/parser/block.go b/common/parser/block.go index 7228e90..c0374d9 100644 --- a/common/parser/block.go +++ b/common/parser/block.go @@ -267,6 +267,18 @@ func rawConvertJson(pHeader, pStart unsafe.Pointer, row int) driver.Value { return binaryVal[:] } +func ReadBlockSimple(block unsafe.Pointer, precision int) [][]driver.Value { + blockSize := RawBlockGetNumOfRows(block) + colCount := RawBlockGetNumOfCols(block) + colInfo := make([]RawBlockColInfo, colCount) + RawBlockGetColInfo(block, colInfo) + colTypes := make([]uint8, colCount) + for i := int32(0); i < colCount; i++ { + colTypes[i] = uint8(colInfo[i].ColType) + } + return ReadBlock(block, int(blockSize), colTypes, precision) +} + // ReadBlock in-place func ReadBlock(block unsafe.Pointer, blockSize int, colTypes []uint8, precision int) [][]driver.Value { r := make([][]driver.Value, blockSize) diff --git a/common/parser/block_test.go b/common/parser/block_test.go index 5b7232a..42b2d26 100644 --- a/common/parser/block_test.go +++ b/common/parser/block_test.go @@ -663,12 +663,6 @@ func TestParseBlock(t *testing.T) { t.Error(errors.NewError(code, errStr)) return } - fileCount := wrapper.TaosNumFields(res) - rh, err := wrapper.ReadColumn(res, fileCount) - if err != nil { - t.Error(err) - return - } precision := wrapper.TaosResultPrecision(res) var data [][]driver.Value for { @@ -684,7 +678,7 @@ func TestParseBlock(t *testing.T) { break } version := RawBlockGetVersion(block) - assert.Equal(t, int32(1), version) + t.Log(version) length := RawBlockGetLength(block) assert.Equal(t, int32(447), length) rows := RawBlockGetNumOfRows(block) @@ -771,7 +765,7 @@ func TestParseBlock(t *testing.T) { }, infos, ) - d := ReadBlock(block, blockSize, rh.ColTypes, precision) + d := ReadBlockSimple(block, precision) data = append(data, d...) } wrapper.TaosFreeResult(res) diff --git a/common/parser/raw.go b/common/parser/raw.go new file mode 100644 index 0000000..61ad125 --- /dev/null +++ b/common/parser/raw.go @@ -0,0 +1,185 @@ +package parser + +import ( + "fmt" + "unsafe" + + "github.com/taosdata/driver-go/v3/common/pointer" +) + +type TMQRawDataParser struct { + block unsafe.Pointer + offset uintptr +} + +func NewTMQRawDataParser() *TMQRawDataParser { + return &TMQRawDataParser{} +} + +type TMQBlockInfo struct { + RawBlock unsafe.Pointer + Precision int + Schema []*TMQRawDataSchema + TableName string +} + +type TMQRawDataSchema struct { + ColType uint8 + Flag int8 + Bytes int64 + ColID int + Name string +} + +func (p *TMQRawDataParser) getTypeSkip(t int8) (int, error) { + skip := 8 + switch t { + case 1: + case 2, 3: + skip = 16 + default: + return 0, fmt.Errorf("unknown type %d", t) + } + return skip, nil +} + +func (p *TMQRawDataParser) skipHead() error { + v := p.parseInt8() + if v >= 100 { + skip := p.parseInt32() + p.skip(int(skip)) + return nil + } else { + skip, err := p.getTypeSkip(v) + if err != nil { + return err + } + p.skip(skip) + v = p.parseInt8() + skip, err = p.getTypeSkip(v) + if err != nil { + return err + } + p.skip(skip) + return nil + } +} + +func (p *TMQRawDataParser) skip(count int) { + p.offset += uintptr(count) +} + +func (p *TMQRawDataParser) parseBlockInfos() []*TMQBlockInfo { + blockNum := p.parseInt32() + blockInfos := make([]*TMQBlockInfo, blockNum) + withTableName := p.parseBool() + withSchema := p.parseBool() + for i := int32(0); i < blockNum; i++ { + blockInfo := &TMQBlockInfo{} + blockTotalLen := p.parseVariableByteInteger() + p.skip(17) + blockInfo.Precision = int(p.parseUint8()) + blockInfo.RawBlock = pointer.AddUintptr(p.block, p.offset) + p.skip(blockTotalLen - 18) + if withSchema { + cols := p.parseZigzagVariableByteInteger() + //version + _ = p.parseZigzagVariableByteInteger() + + blockInfo.Schema = make([]*TMQRawDataSchema, cols) + for j := 0; j < cols; j++ { + blockInfo.Schema[j] = p.parseSchema() + } + } + if withTableName { + blockInfo.TableName = p.parseName() + } + blockInfos[i] = blockInfo + } + return blockInfos +} + +func (p *TMQRawDataParser) parseZigzagVariableByteInteger() int { + return zigzagDecode(p.parseVariableByteInteger()) +} + +func (p *TMQRawDataParser) parseBool() bool { + v := *(*int8)(pointer.AddUintptr(p.block, p.offset)) + p.skip(1) + return v != 0 +} + +func (p *TMQRawDataParser) parseUint8() uint8 { + v := *(*uint8)(pointer.AddUintptr(p.block, p.offset)) + p.skip(1) + return v +} + +func (p *TMQRawDataParser) parseInt8() int8 { + v := *(*int8)(pointer.AddUintptr(p.block, p.offset)) + p.skip(1) + return v +} + +func (p *TMQRawDataParser) parseInt32() int32 { + v := *(*int32)(pointer.AddUintptr(p.block, p.offset)) + p.skip(4) + return v +} + +func (p *TMQRawDataParser) parseSchema() *TMQRawDataSchema { + colType := p.parseUint8() + flag := p.parseInt8() + bytes := int64(p.parseZigzagVariableByteInteger()) + colID := p.parseZigzagVariableByteInteger() + name := p.parseName() + return &TMQRawDataSchema{ + ColType: colType, + Flag: flag, + Bytes: bytes, + ColID: colID, + Name: name, + } +} + +func (p *TMQRawDataParser) parseName() string { + nameLen := p.parseVariableByteInteger() + name := make([]byte, nameLen-1) + for i := 0; i < nameLen-1; i++ { + name[i] = *(*byte)(pointer.AddUintptr(p.block, p.offset+uintptr(i))) + } + p.skip(nameLen) + return string(name) +} + +func (p *TMQRawDataParser) Parse(block unsafe.Pointer) ([]*TMQBlockInfo, error) { + p.reset(block) + err := p.skipHead() + if err != nil { + return nil, err + } + return p.parseBlockInfos(), nil +} + +func (p *TMQRawDataParser) reset(block unsafe.Pointer) { + p.block = block + p.offset = 0 +} + +func (p *TMQRawDataParser) parseVariableByteInteger() int { + multiplier := 1 + value := 0 + for { + encodedByte := p.parseUint8() + value += int(encodedByte&127) * multiplier + if encodedByte&128 == 0 { + break + } + multiplier *= 128 + } + return value +} + +func zigzagDecode(n int) int { + return (n >> 1) ^ (-(n & 1)) +} diff --git a/common/parser/raw_test.go b/common/parser/raw_test.go new file mode 100644 index 0000000..521a626 --- /dev/null +++ b/common/parser/raw_test.go @@ -0,0 +1,1049 @@ +package parser + +import ( + "database/sql/driver" + "fmt" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +func TestParse(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x01, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x01, 0x00, 0x00, 0x00, + + 0x01, + 0x01, + + 0xc5, 0x01, + + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x02, + + 0x02, 0x00, 0x00, 0x00, + 0xb3, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x06, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x82, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x5c, 0x00, 0x00, 0x00, + + 0x00, + 0xc0, 0xed, 0x82, 0x05, 0xc3, 0x1b, 0xab, 0x17, + + 0x80, + 0x00, 0x00, 0x00, 0x00, + + 0x80, + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x5a, 0x00, + 0x61, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, + 0x34, 0x61, + + 0x08, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x03, + 0x63, 0x31, 0x00, + + 0x06, + 0x01, + 0x08, + 0x06, + 0x03, + 0x63, 0x32, 0x00, + + 0x08, + 0x01, + 0x84, 0x02, + 0x08, + 0x03, 0x63, 0x33, 0x00, + + 0x05, + 0x63, 0x74, 0x62, 0x30, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 1, len(blockInfos)) + assert.Equal(t, 2, blockInfos[0].Precision) + assert.Equal(t, 4, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 2, + Name: "c1", + }, + { + ColType: 6, + Flag: 1, + Bytes: 4, + ColID: 3, + Name: "c2", + }, + { + ColType: 8, + Flag: 1, + Bytes: 130, + ColID: 4, + Name: "c3", + }, + }, blockInfos[0].Schema) + assert.Equal(t, "ctb0", blockInfos[0].TableName) +} + +func TestParseTwoBlock(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x02, 0x00, 0x00, 0x00, + + 0x00, // withTbName false + 0x01, // withSchema true + + 0x60, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x4e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x0c, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x00, + 0xf8, 0x6b, 0x75, 0x35, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x00, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, + 0x63, 0x74, 0x30, + + 0x06, + 0x00, + + 0x09, + 0x00, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x00, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x08, + 0x00, + 0x18, + 0x06, + 0x02, + 0x6e, 0x00, + + 0x60, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x4e, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x0c, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x00, + 0xf9, 0x6b, 0x75, 0x35, + 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x01, 0x00, 0x00, 0x00, + + 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, + 0x63, 0x74, 0x31, + + 0x06, + 0x00, + + 0x09, + 0x00, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x00, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x08, + 0x00, + 0x18, + 0x06, + 0x02, + 0x6e, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 2, len(blockInfos)) + assert.Equal(t, 0, blockInfos[0].Precision) + assert.Equal(t, 0, blockInfos[1].Precision) + assert.Equal(t, 3, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 0, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 0, + Bytes: 4, + ColID: 2, + Name: "v", + }, + { + ColType: 8, + Flag: 0, + Bytes: 12, + ColID: 3, + Name: "n", + }, + }, blockInfos[0].Schema) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 0, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 0, + Bytes: 4, + ColID: 2, + Name: "v", + }, + { + ColType: 8, + Flag: 0, + Bytes: 12, + ColID: 3, + Name: "n", + }, + }, blockInfos[1].Schema) + assert.Equal(t, "", blockInfos[0].TableName) + assert.Equal(t, "", blockInfos[1].TableName) +} + +func TestParseTenBlock(t *testing.T) { + data := []byte{ + 0x01, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + 0x0d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0a, 0x00, 0x00, 0x00, + 0x01, + 0x01, + + // block1 + 0x4e, + + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x01, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x31, 0x00, + + //block2 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + 0x00, + 0x02, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x32, 0x00, + + //block3 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + + 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x03, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x33, 0x00, + + //block4 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x34, 0x00, + + // block5 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x05, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x35, 0x00, + + //block6 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x06, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x36, 0x00, + + //block7 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x07, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x37, 0x00, + + //block8 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x08, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x38, 0x00, + + //block9 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + + 0x00, + 0x09, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + + 0x03, + 0x74, 0x39, 0x00, + + //block10 + 0x4e, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x3c, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x52, 0xed, 0x5b, 0x3a, 0x8d, 0x01, 0x00, 0x00, + 0x00, + 0x0a, 0x00, 0x00, 0x00, + + 0x04, + 0x00, + + 0x09, + 0x01, + 0x10, + 0x02, + 0x03, + 0x74, 0x73, 0x00, + + 0x04, + 0x01, + 0x08, + 0x04, + 0x02, + 0x76, 0x00, + 0x04, + 0x74, 0x31, 0x30, 0x00, + + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 10, len(blockInfos)) + for i := 0; i < 10; i++ { + assert.Equal(t, 0, blockInfos[i].Precision) + assert.Equal(t, 2, len(blockInfos[i].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 2, + Name: "v", + }, + }, blockInfos[i].Schema) + assert.Equal(t, fmt.Sprintf("t%d", i+1), blockInfos[i].TableName) + value := ReadBlockSimple(blockInfos[i].RawBlock, blockInfos[i].Precision) + ts := time.Unix(0, 1706081119570000000).Local() + assert.Equal(t, [][]driver.Value{{ts, int32(i + 1)}}, value) + } +} + +func TestVersion100Block(t *testing.T) { + data := []byte{ + 0x64, //version + 0x12, 0x00, 0x00, 0x00, // skip 18 bytes + 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, //block count 1 + + 0x01, // with table name + 0x01, // with schema + + 0x92, 0x02, // block length 274 + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, + 0x00, + + 0x02, 0x00, 0x00, 0x00, + 0x00, 0x01, 0x00, 0x00, // 256 + 0x01, 0x00, 0x00, 0x00, // rows + 0x0e, 0x00, 0x00, 0x00, // cols + 0x00, 0x00, 0x00, 0x80, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x09, 0x08, 0x00, 0x00, 0x00, + 0x01, 0x01, 0x00, 0x00, 0x00, + 0x02, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x04, 0x00, 0x00, 0x00, + 0x05, 0x08, 0x00, 0x00, 0x00, + 0x0b, 0x01, 0x00, 0x00, 0x00, + 0x0c, 0x02, 0x00, 0x00, 0x00, + 0x0d, 0x04, 0x00, 0x00, 0x00, + 0x0e, 0x08, 0x00, 0x00, 0x00, + 0x06, 0x04, 0x00, 0x00, 0x00, + 0x07, 0x08, 0x00, 0x00, 0x00, + 0x08, 0x16, 0x00, 0x00, 0x00, + 0x0a, 0x52, 0x00, 0x00, 0x00, + + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, + 0x16, 0x00, 0x00, 0x00, + + 0x00, + 0x9e, 0x37, 0x6a, 0x04, 0x8f, 0x01, 0x00, 0x00, + + 0x00, + 0x01, + + 0x00, + 0x02, + + 0x00, + 0x03, 0x00, + + 0x00, + 0x04, 0x00, 0x00, 0x00, + + 0x00, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x00, + 0x06, + + 0x00, + 0x07, 0x00, + + 0x00, + 0x08, 0x00, 0x00, 0x00, + + 0x00, + 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + 0x00, + 0xcf, 0xf7, 0x21, 0x41, + + 0x00, + 0xe5, 0xd0, 0x22, 0xdb, 0xf9, 0x3e, 0x26, 0x40, + + 0x00, 0x00, 0x00, 0x00, + 0x06, 0x00, + 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + 0x00, 0x00, 0x00, 0x00, + 0x14, 0x00, + 0x6e, 0x00, 0x00, 0x00, + 0x63, 0x00, 0x00, 0x00, + 0x68, 0x00, 0x00, 0x00, + 0x61, 0x00, 0x00, 0x00, + 0x72, 0x00, 0x00, 0x00, + + 0x00, // + + 0x1c, // cols 14 + 0x00, // version + + // col meta + 0x09, 0x01, 0x10, 0x02, 0x03, 0x74, 0x73, 0x00, + 0x01, 0x01, 0x02, 0x04, 0x03, 0x63, 0x31, 0x00, + 0x02, 0x01, 0x02, 0x06, 0x03, 0x63, 0x32, 0x00, + 0x03, 0x01, 0x04, 0x08, 0x03, 0x63, 0x33, 0x00, + 0x04, 0x01, 0x08, 0x0a, 0x03, 0x63, 0x34, 0x00, + 0x05, 0x01, 0x10, 0x0c, 0x03, 0x63, 0x35, 0x00, + 0x0b, 0x01, 0x02, 0x0e, 0x03, 0x63, 0x36, 0x00, + 0x0c, 0x01, 0x04, 0x10, 0x03, 0x63, 0x37, 0x00, + 0x0d, 0x01, 0x08, 0x12, 0x03, 0x63, 0x38, 0x00, + 0x0e, 0x01, 0x10, 0x14, 0x03, 0x63, 0x39, 0x00, + 0x06, 0x01, 0x08, 0x16, 0x04, 0x63, 0x31, 0x30, 0x00, + 0x07, 0x01, 0x10, 0x18, 0x04, 0x63, 0x31, 0x31, 0x00, + 0x08, 0x01, 0x2c, 0x1a, 0x04, 0x63, 0x31, 0x32, 0x00, + 0x0a, 0x01, 0xa4, 0x01, 0x1c, 0x04, 0x63, 0x31, 0x33, 0x00, + + 0x06, // table name + 0x74, 0x5f, 0x61, 0x6c, 0x6c, 0x00, + // sleep time + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + } + parser := NewTMQRawDataParser() + blockInfos, err := parser.Parse(unsafe.Pointer(&data[0])) + assert.NoError(t, err) + assert.Equal(t, 1, len(blockInfos)) + assert.Equal(t, 0, blockInfos[0].Precision) + assert.Equal(t, 14, len(blockInfos[0].Schema)) + assert.Equal(t, []*TMQRawDataSchema{ + { + ColType: 9, + Flag: 1, + Bytes: 8, + ColID: 1, + Name: "ts", + }, + { + ColType: 1, + Flag: 1, + Bytes: 1, + ColID: 2, + Name: "c1", + }, + { + ColType: 2, + Flag: 1, + Bytes: 1, + ColID: 3, + Name: "c2", + }, + { + ColType: 3, + Flag: 1, + Bytes: 2, + ColID: 4, + Name: "c3", + }, + { + ColType: 4, + Flag: 1, + Bytes: 4, + ColID: 5, + Name: "c4", + }, + { + ColType: 5, + Flag: 1, + Bytes: 8, + ColID: 6, + Name: "c5", + }, + { + ColType: 11, + Flag: 1, + Bytes: 1, + ColID: 7, + Name: "c6", + }, + { + ColType: 12, + Flag: 1, + Bytes: 2, + ColID: 8, + Name: "c7", + }, + { + ColType: 13, + Flag: 1, + Bytes: 4, + ColID: 9, + Name: "c8", + }, + { + ColType: 14, + Flag: 1, + Bytes: 8, + ColID: 10, + Name: "c9", + }, + { + ColType: 6, + Flag: 1, + Bytes: 4, + ColID: 11, + Name: "c10", + }, + { + ColType: 7, + Flag: 1, + Bytes: 8, + ColID: 12, + Name: "c11", + }, + { + ColType: 8, + Flag: 1, + Bytes: 22, + ColID: 13, + Name: "c12", + }, + { + ColType: 10, + Flag: 1, + Bytes: 82, + ColID: 14, + Name: "c13", + }, + }, blockInfos[0].Schema) + assert.Equal(t, "t_all", blockInfos[0].TableName) + value := ReadBlockSimple(blockInfos[0].RawBlock, blockInfos[0].Precision) + expect := []driver.Value{ + time.Unix(0, 1713766021022000000).Local(), + true, + int8(2), + int16(3), + int32(4), + int64(5), + uint8(6), + uint16(7), + uint32(8), + uint64(9), + float32(10.123), + float64(11.123), + "binary", + "nchar", + } + assert.Equal(t, [][]driver.Value{expect}, value) +} diff --git a/common/serializer/block.go b/common/serializer/block.go index 03a53f2..50d7a6e 100644 --- a/common/serializer/block.go +++ b/common/serializer/block.go @@ -37,7 +37,7 @@ func BMSetNull(c byte, n int) byte { return c + (1 << (7 - BitPos(n))) } -var ColumnNumerNotMatch = errors.New("number of columns does not match") +var ColumnNumberNotMatch = errors.New("number of columns does not match") var DataTypeWrong = errors.New("wrong data type") func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte, error) { @@ -48,7 +48,7 @@ func SerializeRawBlock(params []*param.Param, colType *param.ColumnType) ([]byte return nil, err } if len(colTypes) != columns { - return nil, ColumnNumerNotMatch + return nil, ColumnNumberNotMatch } var block []byte //version int32 diff --git a/examples/stmtoverws/main.go b/examples/stmtoverws/main.go index a9c0d5a..6e083df 100644 --- a/examples/stmtoverws/main.go +++ b/examples/stmtoverws/main.go @@ -19,7 +19,7 @@ func main() { defer db.Close() prepareEnv(db) - config := stmt.NewConfig("ws://127.0.0.1:6041/rest/stmt", 0) + config := stmt.NewConfig("ws://127.0.0.1:6041", 0) config.SetConnectUser("root") config.SetConnectPass("taosdata") config.SetConnectDB("example_ws_stmt") diff --git a/examples/tmq/main.go b/examples/tmq/main.go index eb110ce..01e2010 100644 --- a/examples/tmq/main.go +++ b/examples/tmq/main.go @@ -27,16 +27,15 @@ func main() { panic(err) } consumer, err := tmq.NewConsumer(&tmqcommon.ConfigMap{ - "group.id": "test", - "auto.offset.reset": "earliest", - "td.connect.ip": "127.0.0.1", - "td.connect.user": "root", - "td.connect.pass": "taosdata", - "td.connect.port": "6030", - "client.id": "test_tmq_client", - "enable.auto.commit": "false", - "experimental.snapshot.enable": "true", - "msg.with.table.name": "true", + "group.id": "test", + "auto.offset.reset": "earliest", + "td.connect.ip": "127.0.0.1", + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "td.connect.port": "6030", + "client.id": "test_tmq_client", + "enable.auto.commit": "false", + "msg.with.table.name": "true", }) if err != nil { panic(err) diff --git a/examples/tmqoverws/main.go b/examples/tmqoverws/main.go index b0fdf91..cac691a 100644 --- a/examples/tmqoverws/main.go +++ b/examples/tmqoverws/main.go @@ -18,7 +18,7 @@ func main() { defer db.Close() prepareEnv(db) consumer, err := tmq.NewConsumer(&tmqcommon.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.url": "ws://127.0.0.1:6041", "ws.message.channelLen": uint(0), "ws.message.timeout": common.DefaultMessageTimeout, "ws.message.writeWait": common.DefaultWriteWait, diff --git a/taosRestful/connection.go b/taosRestful/connection.go index 1f4ddd0..e7f21c5 100644 --- a/taosRestful/connection.go +++ b/taosRestful/connection.go @@ -211,6 +211,7 @@ func (tc *taosConn) taosQuery(ctx context.Context, sql string, bufferSize int) ( return nil, fmt.Errorf("server response: %s - %s", resp.Status, string(body)) } respBody := resp.Body + defer ioutil.ReadAll(respBody) if !tc.cfg.disableCompression && EqualFold(resp.Header.Get("Content-Encoding"), "gzip") { respBody, err = gzip.NewReader(resp.Body) if err != nil { @@ -320,6 +321,16 @@ func marshalBody(body io.Reader, bufferSize int) (*common.TDEngineRestfulResp, e row[column] = iter.ReadUint32() case common.TSDB_DATA_TYPE_UBIGINT: row[column] = iter.ReadUint64() + case common.TSDB_DATA_TYPE_VARBINARY, common.TSDB_DATA_TYPE_GEOMETRY: + data := iter.ReadStringAsSlice() + if len(data)%2 != 0 { + iter.ReportError("read varbinary", fmt.Sprintf("invalid length %s", string(data))) + } + value := make([]byte, len(data)/2) + for i := 0; i < len(data); i += 2 { + value[i/2] = hexCharToDigit(data[i])<<4 | hexCharToDigit(data[i+1]) + } + row[column] = value default: row[column] = nil iter.Skip() @@ -366,3 +377,14 @@ func lower(b byte) byte { } return b } + +func hexCharToDigit(char byte) uint8 { + switch { + case char >= '0' && char <= '9': + return char - '0' + case char >= 'a' && char <= 'f': + return char - 'a' + 10 + default: + panic("assertion failed: invalid hex char") + } +} diff --git a/taosRestful/connector_test.go b/taosRestful/connector_test.go index d471b22..f8bf6f7 100644 --- a/taosRestful/connector_test.go +++ b/taosRestful/connector_test.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "math/rand" + "reflect" + "strings" "testing" "time" @@ -11,11 +13,78 @@ import ( "github.com/taosdata/driver-go/v3/types" ) +func generateCreateTableSql(db string, withJson bool) string { + createSql := fmt.Sprintf("create table if not exists %s.alltype(ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20),"+ + "c14 varbinary(100),"+ + "c15 geometry(100)"+ + ")", + db) + if withJson { + createSql += " tags(t json)" + } + return createSql +} + +func generateValues() (value []interface{}, scanValue []interface{}, insertSql string) { + rand.Seed(time.Now().UnixNano()) + v1 := true + v2 := int8(rand.Int()) + v3 := int16(rand.Int()) + v4 := rand.Int31() + v5 := int64(rand.Int31()) + v6 := uint8(rand.Uint32()) + v7 := uint16(rand.Uint32()) + v8 := rand.Uint32() + v9 := uint64(rand.Uint32()) + v10 := rand.Float32() + v11 := rand.Float64() + v12 := "test_binary" + v13 := "test_nchar" + v14 := []byte("test_varbinary") + v15 := []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40} + ts := time.Now().Round(time.Millisecond).UTC() + var ( + cts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + c14 []byte + c15 []byte + ) + return []interface{}{ + ts, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + }, []interface{}{cts, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15}, + fmt.Sprintf(`values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar','test_varbinary','point(100 100)')`, ts.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) +} + // @author: xftan // @date: 2021/12/21 10:59 // @description: test restful query of all type func TestAllTypeQuery(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "restful_test" db, err := sql.Open("taosRestful", dataSourceName) if err != nil { t.Fatal(err) @@ -26,57 +95,25 @@ func TestAllTypeQuery(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists restful_test") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists restful_test") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = true - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists restful_test.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into restful_test.t1 using restful_test.alltype tags('{"a":"b"}') values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from restful_test.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -84,71 +121,27 @@ func TestAllTypeQuery(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + var tt types.RawMessage + dest := make([]interface{}, len(scanValues)+1) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } + dest[len(scanValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - tt types.RawMessage - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } - + err := rows.Scan(dest...) + assert.NoError(t, err) } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) } // @author: xftan // @date: 2022/2/8 12:51 // @description: test query all null value func TestAllTypeQueryNull(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "restful_test_null" db, err := sql.Open("taosRestful", dataSourceName) if err != nil { t.Fatal(err) @@ -159,42 +152,29 @@ func TestAllTypeQueryNull(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists restful_test_null") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists restful_test_null") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - - _, err = db.Exec("create table if not exists restful_test_null.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into restful_test_null.t1 using restful_test_null.alltype tags('null') values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)`, now.Format(time.RFC3339Nano))) + colValues, _, _ := generateValues() + builder := &strings.Builder{} + for i := 1; i < len(colValues); i++ { + builder.WriteString(",null") + } + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') values('%s'%s)`, database, database, colValues[0].(time.Time).Format(time.RFC3339Nano), builder.String())) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from restful_test_null.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -202,72 +182,32 @@ func TestAllTypeQueryNull(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + values := make([]interface{}, len(cTypes)) + values[0] = new(time.Time) + for i := 1; i < len(colValues); i++ { + var v interface{} + values[i] = &v + } + var tt types.RawMessage + values[len(colValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 *bool - c2 *int8 - c3 *int16 - c4 *int32 - c5 *int64 - c6 *uint8 - c7 *uint16 - c8 *uint32 - c9 *uint64 - c10 *float32 - c11 *float64 - c12 *string - c13 *string - tt *types.RawMessage - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Nil(t, c1) - assert.Nil(t, c2) - assert.Nil(t, c3) - assert.Nil(t, c4) - assert.Nil(t, c5) - assert.Nil(t, c6) - assert.Nil(t, c7) - assert.Nil(t, c8) - assert.Nil(t, c9) - assert.Nil(t, c10) - assert.Nil(t, c11) - assert.Nil(t, c12) - assert.Nil(t, c13) - assert.Equal(t, types.RawMessage("null"), *tt) + err := rows.Scan(values...) if err != nil { - t.Fatal(err) } - if ts.IsZero() { - t.Fatal(ts) - } - } + assert.Equal(t, *values[0].(*time.Time), colValues[0].(time.Time)) + for i := 1; i < len(values)-1; i++ { + assert.Nil(t, *values[i].(*interface{})) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), *(values[len(values)-1]).(*types.RawMessage)) } // @author: xftan // @date: 2022/2/10 14:32 // @description: test restful query of all type with compression func TestAllTypeQueryCompression(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "restful_test_compression" db, err := sql.Open("taosRestful", dataSourceNameWithCompression) if err != nil { t.Fatal(err) @@ -278,57 +218,25 @@ func TestAllTypeQueryCompression(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists restful_test") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists restful_test") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = true - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists restful_test.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into restful_test.t1 using restful_test.alltype tags('{"a":"b"}') values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from restful_test.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -336,70 +244,27 @@ func TestAllTypeQueryCompression(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + var tt types.RawMessage + dest := make([]interface{}, len(scanValues)+1) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } + dest[len(scanValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - tt types.RawMessage - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } + err := rows.Scan(dest...) + assert.NoError(t, err) } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) } // @author: xftan // @date: 2022/5/19 15:22 // @description: test restful query of all type without json (httpd) func TestAllTypeQueryWithoutJson(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "restful_test_without_json" db, err := sql.Open("taosRestful", dataSourceName) if err != nil { t.Fatal(err) @@ -410,56 +275,25 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists restful_test_without_json") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists restful_test_without_json") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = false - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists restful_test_without_json.all_type(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")", - ) + _, err = db.Exec(generateCreateTableSql(database, false)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into restful_test_without_json.all_type values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.alltype %s`, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from restful_test_without_json.all_type where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -467,60 +301,16 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + dest := make([]interface{}, len(scanValues)) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } - + err := rows.Scan(dest...) + assert.NoError(t, err) + } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) } } @@ -528,7 +318,7 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { // @date: 2022/5/19 15:22 // @description: test query all null value without json (httpd) func TestAllTypeQueryNullWithoutJson(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "restful_test_without_json_null" db, err := sql.Open("taosRestful", dataSourceName) if err != nil { t.Fatal(err) @@ -539,41 +329,30 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists restful_test_without_json_null") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists restful_test_without_json_null") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - - _, err = db.Exec("create table if not exists restful_test_without_json_null.all_type(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")", - ) + _, err = db.Exec(generateCreateTableSql(database, false)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into restful_test_without_json_null.all_type values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)`, now.Format(time.RFC3339Nano))) + colValues, _, _ := generateValues() + builder := &strings.Builder{} + for i := 1; i < len(colValues); i++ { + builder.WriteString(",null") + } + insertSql := fmt.Sprintf(`insert into %s.alltype values('%s'%s)`, database, colValues[0].(time.Time).Format(time.RFC3339Nano), builder.String()) + _, err = db.Exec(insertSql) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from restful_test_without_json_null.all_type where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -581,60 +360,20 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + values := make([]interface{}, len(cTypes)) + values[0] = new(time.Time) + for i := 1; i < len(colValues); i++ { + var v interface{} + values[i] = &v + } for rows.Next() { - var ( - ts time.Time - c1 *bool - c2 *int8 - c3 *int16 - c4 *int32 - c5 *int64 - c6 *uint8 - c7 *uint16 - c8 *uint32 - c9 *uint64 - c10 *float32 - c11 *float64 - c12 *string - c13 *string - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Nil(t, c1) - assert.Nil(t, c2) - assert.Nil(t, c3) - assert.Nil(t, c4) - assert.Nil(t, c5) - assert.Nil(t, c6) - assert.Nil(t, c7) - assert.Nil(t, c8) - assert.Nil(t, c9) - assert.Nil(t, c10) - assert.Nil(t, c11) - assert.Nil(t, c12) - assert.Nil(t, c13) + err := rows.Scan(values...) if err != nil { - t.Fatal(err) } - if ts.IsZero() { - t.Fatal(ts) - } - + } + assert.Equal(t, *values[0].(*time.Time), colValues[0].(time.Time)) + for i := 1; i < len(values)-1; i++ { + assert.Nil(t, *values[i].(*interface{})) } } diff --git a/taosSql/statement.go b/taosSql/statement.go index e103a0e..9513e37 100644 --- a/taosSql/statement.go +++ b/taosSql/statement.go @@ -138,11 +138,11 @@ func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { case reflect.Bool: v.Value = types.TaosBool(rv.Bool()) case reflect.Float32, reflect.Float64: - v.Value = types.TaosBool(rv.Float() == 1) + v.Value = types.TaosBool(rv.Float() > 0) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v.Value = types.TaosBool(rv.Int() == 1) + v.Value = types.TaosBool(rv.Int() > 0) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v.Value = types.TaosBool(rv.Uint() == 1) + v.Value = types.TaosBool(rv.Uint() > 0) case reflect.String: vv, err := strconv.ParseBool(rv.String()) if err != nil { diff --git a/taosWS/connection.go b/taosWS/connection.go index 71697e6..8228661 100644 --- a/taosWS/connection.go +++ b/taosWS/connection.go @@ -15,6 +15,7 @@ import ( "github.com/gorilla/websocket" jsoniter "github.com/json-iterator/go" "github.com/taosdata/driver-go/v3/common" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" taosErrors "github.com/taosdata/driver-go/v3/errors" ) @@ -26,6 +27,14 @@ const ( WSFetch = "fetch" WSFetchBlock = "fetch_block" WSFreeResult = "free_result" + + STMTInit = "init" + STMTPrepare = "prepare" + STMTAddBatch = "add_batch" + STMTExec = "exec" + STMTClose = "close" + STMTGetColFields = "get_col_fields" + STMTUseResult = "use_result" ) var ( @@ -51,17 +60,19 @@ func newTaosConn(cfg *config) (*taosConn, error) { endpointUrl := &url.URL{ Scheme: cfg.net, Host: fmt.Sprintf("%s:%d", cfg.addr, cfg.port), - Path: "/rest/ws", + Path: "/ws", } if cfg.token != "" { endpointUrl.RawQuery = fmt.Sprintf("token=%s", cfg.token) } endpoint := endpointUrl.String() - ws, _, err := common.DefaultDialer.Dial(endpoint, nil) + dialer := common.DefaultDialer + dialer.EnableCompression = cfg.enableCompression + ws, _, err := dialer.Dial(endpoint, nil) if err != nil { return nil, err } - ws.SetReadLimit(common.BufferSize4M) + ws.EnableWriteCompression(cfg.enableCompression) ws.SetReadDeadline(time.Now().Add(common.DefaultPongWait)) ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(common.DefaultPongWait)) @@ -99,9 +110,297 @@ func (tc *taosConn) Close() (err error) { } func (tc *taosConn) Prepare(query string) (driver.Stmt, error) { - return nil, &taosErrors.TaosError{Code: 0xffff, ErrStr: "websocket does not support stmt"} + stmtID, err := tc.stmtInit() + if err != nil { + return nil, err + } + isInsert, err := tc.stmtPrepare(stmtID, query) + if err != nil { + tc.stmtClose(stmtID) + return nil, err + } + stmt := &Stmt{ + conn: tc, + stmtID: stmtID, + isInsert: isInsert, + pSql: query, + } + return stmt, nil +} + +func (tc *taosConn) stmtInit() (uint64, error) { + reqID := tc.generateReqID() + req := &StmtInitReq{ + ReqID: reqID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return 0, err + } + action := &WSAction{ + Action: STMTInit, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return 0, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return 0, err + } + var resp StmtInitResp + err = tc.readTo(&resp) + if err != nil { + return 0, err + } + if resp.Code != 0 { + return 0, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.StmtID, nil +} + +func (tc *taosConn) stmtPrepare(stmtID uint64, sql string) (bool, error) { + reqID := tc.generateReqID() + req := &StmtPrepareRequest{ + ReqID: reqID, + StmtID: stmtID, + SQL: sql, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return false, err + } + action := &WSAction{ + Action: STMTPrepare, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return false, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return false, err + } + var resp StmtPrepareResponse + err = tc.readTo(&resp) + if err != nil { + return false, err + } + if resp.Code != 0 { + return false, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.IsInsert, nil +} + +func (tc *taosConn) stmtClose(stmtID uint64) error { + reqID := tc.generateReqID() + req := &StmtCloseRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: STMTClose, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return err + } + return nil +} + +func (tc *taosConn) stmtGetColFields(stmtID uint64) ([]*stmtCommon.StmtField, error) { + reqID := tc.generateReqID() + req := &StmtGetColFieldsRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return nil, err + } + action := &WSAction{ + Action: STMTGetColFields, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return nil, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return nil, err + } + var resp StmtGetColFieldsResponse + err = tc.readTo(&resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.Fields, nil } +func (tc *taosConn) stmtBindParam(stmtID uint64, block []byte) error { + reqID := tc.generateReqID() + tc.buf.Reset() + WriteUint64(tc.buf, reqID) + WriteUint64(tc.buf, stmtID) + WriteUint64(tc.buf, BindMessage) + tc.buf.Write(block) + err := tc.writeBinary(tc.buf.Bytes()) + if err != nil { + return err + } + var resp StmtBindResponse + err = tc.readTo(&resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func WriteUint64(buffer *bytes.Buffer, v uint64) { + buffer.WriteByte(byte(v)) + buffer.WriteByte(byte(v >> 8)) + buffer.WriteByte(byte(v >> 16)) + buffer.WriteByte(byte(v >> 24)) + buffer.WriteByte(byte(v >> 32)) + buffer.WriteByte(byte(v >> 40)) + buffer.WriteByte(byte(v >> 48)) + buffer.WriteByte(byte(v >> 56)) +} + +func (tc *taosConn) stmtAddBatch(stmtID uint64) error { + reqID := tc.generateReqID() + req := &StmtAddBatchRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return err + } + action := &WSAction{ + Action: STMTAddBatch, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return err + } + var resp StmtAddBatchResponse + err = tc.readTo(&resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + return nil +} + +func (tc *taosConn) stmtExec(stmtID uint64) (int, error) { + reqID := tc.generateReqID() + req := &StmtExecRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return 0, err + } + action := &WSAction{ + Action: STMTExec, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return 0, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return 0, err + } + var resp StmtExecResponse + err = tc.readTo(&resp) + if err != nil { + return 0, err + } + if resp.Code != 0 { + return 0, taosErrors.NewError(resp.Code, resp.Message) + } + return resp.Affected, nil +} + +func (tc *taosConn) stmtUseResult(stmtID uint64) (*rows, error) { + reqID := tc.generateReqID() + req := &StmtUseResultRequest{ + ReqID: reqID, + StmtID: stmtID, + } + reqArgs, err := json.Marshal(req) + if err != nil { + return nil, err + } + action := &WSAction{ + Action: STMTUseResult, + Args: reqArgs, + } + tc.buf.Reset() + err = jsonI.NewEncoder(tc.buf).Encode(action) + if err != nil { + return nil, err + } + err = tc.writeText(tc.buf.Bytes()) + if err != nil { + return nil, err + } + var resp StmtUseResultResponse + err = tc.readTo(&resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + rs := &rows{ + buf: &bytes.Buffer{}, + conn: tc, + resultID: resp.ResultID, + fieldsCount: resp.FieldsCount, + fieldsNames: resp.FieldsNames, + fieldsTypes: resp.FieldsTypes, + fieldsLengths: resp.FieldsLengths, + precision: resp.Precision, + isStmt: true, + } + return rs, nil +} func (tc *taosConn) Exec(query string, args []driver.Value) (driver.Result, error) { return tc.execCtx(context.Background(), query, common.ValueArgsToNamedValueArgs(args)) } @@ -261,8 +560,16 @@ func (tc *taosConn) connect() error { } func (tc *taosConn) writeText(data []byte) error { + return tc.write(data, websocket.TextMessage) +} + +func (tc *taosConn) writeBinary(data []byte) error { + return tc.write(data, websocket.BinaryMessage) +} + +func (tc *taosConn) write(data []byte, messageType int) error { tc.client.SetWriteDeadline(time.Now().Add(tc.writeTimeout)) - err := tc.client.WriteMessage(websocket.TextMessage, data) + err := tc.client.WriteMessage(messageType, data) if err != nil { return NewBadConnErrorWithCtx(err, string(data)) } diff --git a/taosWS/connector_test.go b/taosWS/connector_test.go index a40a251..d971a1d 100644 --- a/taosWS/connector_test.go +++ b/taosWS/connector_test.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "math/rand" + "reflect" + "strings" "testing" "time" @@ -11,11 +13,78 @@ import ( "github.com/taosdata/driver-go/v3/types" ) +func generateCreateTableSql(db string, withJson bool) string { + createSql := fmt.Sprintf("create table if not exists %s.alltype(ts timestamp,"+ + "c1 bool,"+ + "c2 tinyint,"+ + "c3 smallint,"+ + "c4 int,"+ + "c5 bigint,"+ + "c6 tinyint unsigned,"+ + "c7 smallint unsigned,"+ + "c8 int unsigned,"+ + "c9 bigint unsigned,"+ + "c10 float,"+ + "c11 double,"+ + "c12 binary(20),"+ + "c13 nchar(20),"+ + "c14 varbinary(100),"+ + "c15 geometry(100)"+ + ")", + db) + if withJson { + createSql += " tags(t json)" + } + return createSql +} + +func generateValues() (value []interface{}, scanValue []interface{}, insertSql string) { + rand.Seed(time.Now().UnixNano()) + v1 := true + v2 := int8(rand.Int()) + v3 := int16(rand.Int()) + v4 := rand.Int31() + v5 := int64(rand.Int31()) + v6 := uint8(rand.Uint32()) + v7 := uint16(rand.Uint32()) + v8 := rand.Uint32() + v9 := uint64(rand.Uint32()) + v10 := rand.Float32() + v11 := rand.Float64() + v12 := "test_binary" + v13 := "test_nchar" + v14 := []byte("test_varbinary") + v15 := []byte{0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x59, 0x40} + ts := time.Now().Round(time.Millisecond) + var ( + cts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + c14 []byte + c15 []byte + ) + return []interface{}{ + ts, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + }, []interface{}{cts, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, c14, c15}, + fmt.Sprintf(`values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar','test_varbinary','point(100 100)')`, ts.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) +} + // @author: xftan // @date: 2023/10/13 11:22 // @description: test all type query func TestAllTypeQuery(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "ws_test" db, err := sql.Open("taosWS", dataSourceName) if err != nil { t.Fatal(err) @@ -26,57 +95,25 @@ func TestAllTypeQuery(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists ws_test") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists ws_test") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = true - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists ws_test.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into ws_test.t1 using ws_test.alltype tags('{"a":"b"}') values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from ws_test.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -84,71 +121,27 @@ func TestAllTypeQuery(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + var tt types.RawMessage + dest := make([]interface{}, len(scanValues)+1) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } + dest[len(scanValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - tt types.RawMessage - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } - + err := rows.Scan(dest...) + assert.NoError(t, err) } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) } // @author: xftan // @date: 2023/10/13 11:22 // @description: test null value func TestAllTypeQueryNull(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "ws_test_null" db, err := sql.Open("taosWS", dataSourceName) if err != nil { t.Fatal(err) @@ -159,42 +152,29 @@ func TestAllTypeQueryNull(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists ws_test_null") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists ws_test_null") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - - _, err = db.Exec("create table if not exists ws_test_null.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into ws_test_null.t1 using ws_test_null.alltype tags('null') values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)`, now.Format(time.RFC3339Nano))) + colValues, _, _ := generateValues() + builder := &strings.Builder{} + for i := 1; i < len(colValues); i++ { + builder.WriteString(",null") + } + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') values('%s'%s)`, database, database, colValues[0].(time.Time).Format(time.RFC3339Nano), builder.String())) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from ws_test_null.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -202,72 +182,32 @@ func TestAllTypeQueryNull(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + values := make([]interface{}, len(cTypes)) + values[0] = new(time.Time) + for i := 1; i < len(colValues); i++ { + var v interface{} + values[i] = &v + } + var tt types.RawMessage + values[len(colValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 *bool - c2 *int8 - c3 *int16 - c4 *int32 - c5 *int64 - c6 *uint8 - c7 *uint16 - c8 *uint32 - c9 *uint64 - c10 *float32 - c11 *float64 - c12 *string - c13 *string - tt *string - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Nil(t, c1) - assert.Nil(t, c2) - assert.Nil(t, c3) - assert.Nil(t, c4) - assert.Nil(t, c5) - assert.Nil(t, c6) - assert.Nil(t, c7) - assert.Nil(t, c8) - assert.Nil(t, c9) - assert.Nil(t, c10) - assert.Nil(t, c11) - assert.Nil(t, c12) - assert.Nil(t, c13) - assert.Nil(t, tt) + err := rows.Scan(values...) if err != nil { - t.Fatal(err) } - if ts.IsZero() { - t.Fatal(ts) - } - } + assert.Equal(t, *values[0].(*time.Time), colValues[0].(time.Time)) + for i := 1; i < len(values)-1; i++ { + assert.Nil(t, *values[i].(*interface{})) + } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), *(values[len(values)-1]).(*types.RawMessage)) } // @author: xftan // @date: 2023/10/13 11:24 // @description: test compression func TestAllTypeQueryCompression(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "ws_test_compression" db, err := sql.Open("taosWS", dataSourceNameWithCompression) if err != nil { t.Fatal(err) @@ -278,57 +218,25 @@ func TestAllTypeQueryCompression(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists ws_test") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists ws_test") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = true - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists ws_test.alltype(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")" + - "tags(t json)", - ) + _, err = db.Exec(generateCreateTableSql(database, true)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into ws_test.t1 using ws_test.alltype tags('{"a":"b"}') values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.t1 using %s.alltype tags('{"a":"b"}') %s`, database, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from ws_test.alltype where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -336,70 +244,27 @@ func TestAllTypeQueryCompression(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + var tt types.RawMessage + dest := make([]interface{}, len(scanValues)+1) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } + dest[len(scanValues)] = &tt for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - tt types.RawMessage - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - &tt, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } + err := rows.Scan(dest...) + assert.NoError(t, err) + } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) } + assert.Equal(t, types.RawMessage(`{"a":"b"}`), tt) } // @author: xftan // @date: 2023/10/13 11:24 // @description: test all type query without json func TestAllTypeQueryWithoutJson(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "ws_test_without_json" db, err := sql.Open("taosWS", dataSourceName) if err != nil { t.Fatal(err) @@ -410,56 +275,25 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists ws_test_without_json") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists ws_test_without_json") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - var ( - v1 = false - v2 = int8(rand.Int()) - v3 = int16(rand.Int()) - v4 = rand.Int31() - v5 = int64(rand.Int31()) - v6 = uint8(rand.Uint32()) - v7 = uint16(rand.Uint32()) - v8 = rand.Uint32() - v9 = uint64(rand.Uint32()) - v10 = rand.Float32() - v11 = rand.Float64() - v12 = "test_binary" - v13 = "test_nchar" - ) - - _, err = db.Exec("create table if not exists ws_test_without_json.all_type(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")", - ) + _, err = db.Exec(generateCreateTableSql(database, false)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into ws_test_without_json.all_type values('%s',%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,'test_binary','test_nchar')`, now.Format(time.RFC3339Nano), v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11)) + colValues, scanValues, insertSql := generateValues() + _, err = db.Exec(fmt.Sprintf(`insert into %s.alltype %s`, database, insertSql)) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from ws_test_without_json.all_type where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -467,60 +301,16 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + dest := make([]interface{}, len(scanValues)) + for i := range scanValues { + dest[i] = reflect.ValueOf(&scanValues[i]).Interface() + } for rows.Next() { - var ( - ts time.Time - c1 bool - c2 int8 - c3 int16 - c4 int32 - c5 int64 - c6 uint8 - c7 uint16 - c8 uint32 - c9 uint64 - c10 float32 - c11 float64 - c12 string - c13 string - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Equal(t, v1, c1) - assert.Equal(t, v2, c2) - assert.Equal(t, v3, c3) - assert.Equal(t, v4, c4) - assert.Equal(t, v5, c5) - assert.Equal(t, v6, c6) - assert.Equal(t, v7, c7) - assert.Equal(t, v8, c8) - assert.Equal(t, v9, c9) - assert.Equal(t, v10, c10) - assert.Equal(t, v11, c11) - assert.Equal(t, v12, c12) - assert.Equal(t, v13, c13) - if err != nil { - t.Fatal(err) - } - if ts.IsZero() { - t.Fatal(ts) - } - + err := rows.Scan(dest...) + assert.NoError(t, err) + } + for i, v := range colValues { + assert.Equal(t, v, scanValues[i]) } } @@ -528,7 +318,7 @@ func TestAllTypeQueryWithoutJson(t *testing.T) { // @date: 2023/10/13 11:24 // @description: test all type query with null without json func TestAllTypeQueryNullWithoutJson(t *testing.T) { - rand.Seed(time.Now().UnixNano()) + database := "ws_test_without_json_null" db, err := sql.Open("taosWS", dataSourceName) if err != nil { t.Fatal(err) @@ -539,41 +329,30 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) { t.Fatal(err) } defer func() { - _, err = db.Exec("drop database if exists ws_test_without_json_null") + _, err = db.Exec(fmt.Sprintf("drop database if exists %s", database)) if err != nil { t.Fatal(err) } }() - _, err = db.Exec("create database if not exists ws_test_without_json_null") + _, err = db.Exec(fmt.Sprintf("create database if not exists %s", database)) if err != nil { t.Fatal(err) } - - _, err = db.Exec("create table if not exists ws_test_without_json_null.all_type(ts timestamp," + - "c1 bool," + - "c2 tinyint," + - "c3 smallint," + - "c4 int," + - "c5 bigint," + - "c6 tinyint unsigned," + - "c7 smallint unsigned," + - "c8 int unsigned," + - "c9 bigint unsigned," + - "c10 float," + - "c11 double," + - "c12 binary(20)," + - "c13 nchar(20)" + - ")", - ) + _, err = db.Exec(generateCreateTableSql(database, false)) if err != nil { t.Fatal(err) } - now := time.Now().Round(time.Millisecond) - _, err = db.Exec(fmt.Sprintf(`insert into ws_test_without_json_null.all_type values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)`, now.Format(time.RFC3339Nano))) + colValues, _, _ := generateValues() + builder := &strings.Builder{} + for i := 1; i < len(colValues); i++ { + builder.WriteString(",null") + } + insertSql := fmt.Sprintf(`insert into %s.alltype values('%s'%s)`, database, colValues[0].(time.Time).Format(time.RFC3339Nano), builder.String()) + _, err = db.Exec(insertSql) if err != nil { t.Fatal(err) } - rows, err := db.Query(fmt.Sprintf("select * from ws_test_without_json_null.all_type where ts = '%s'", now.Format(time.RFC3339Nano))) + rows, err := db.Query(fmt.Sprintf("select * from %s.alltype where ts = '%s'", database, colValues[0].(time.Time).Format(time.RFC3339Nano))) assert.NoError(t, err) columns, err := rows.Columns() assert.NoError(t, err) @@ -581,61 +360,21 @@ func TestAllTypeQueryNullWithoutJson(t *testing.T) { cTypes, err := rows.ColumnTypes() assert.NoError(t, err) t.Log(cTypes) + values := make([]interface{}, len(cTypes)) + values[0] = new(time.Time) + for i := 1; i < len(colValues); i++ { + var v interface{} + values[i] = &v + } for rows.Next() { - var ( - ts time.Time - c1 *bool - c2 *int8 - c3 *int16 - c4 *int32 - c5 *int64 - c6 *uint8 - c7 *uint16 - c8 *uint32 - c9 *uint64 - c10 *float32 - c11 *float64 - c12 *string - c13 *string - ) - err := rows.Scan( - &ts, - &c1, - &c2, - &c3, - &c4, - &c5, - &c6, - &c7, - &c8, - &c9, - &c10, - &c11, - &c12, - &c13, - ) - assert.Equal(t, now.UTC(), ts.UTC()) - assert.Nil(t, c1) - assert.Nil(t, c2) - assert.Nil(t, c3) - assert.Nil(t, c4) - assert.Nil(t, c5) - assert.Nil(t, c6) - assert.Nil(t, c7) - assert.Nil(t, c8) - assert.Nil(t, c9) - assert.Nil(t, c10) - assert.Nil(t, c11) - assert.Nil(t, c12) - assert.Nil(t, c13) + err := rows.Scan(values...) if err != nil { - t.Fatal(err) } - if ts.IsZero() { - t.Fatal(ts) - } - + } + assert.Equal(t, *values[0].(*time.Time), colValues[0].(time.Time)) + for i := 1; i < len(values)-1; i++ { + assert.Nil(t, *values[i].(*interface{})) } } diff --git a/taosWS/driver_test.go b/taosWS/driver_test.go index 9e12d7a..56b70e8 100644 --- a/taosWS/driver_test.go +++ b/taosWS/driver_test.go @@ -34,7 +34,7 @@ var ( port = 6041 dbName = "test_taos_ws" dataSourceName = fmt.Sprintf("%s:%s@ws(%s:%d)/", user, password, host, port) - dataSourceNameWithCompression = fmt.Sprintf("%s:%s@ws(%s:%d)/?disableCompression=false", user, password, host, port) + dataSourceNameWithCompression = fmt.Sprintf("%s:%s@ws(%s:%d)/?enableCompression=true", user, password, host, port) ) type DBTest struct { diff --git a/taosWS/dsn.go b/taosWS/dsn.go index ca5ff22..aa8e4dd 100644 --- a/taosWS/dsn.go +++ b/taosWS/dsn.go @@ -29,6 +29,7 @@ type config struct { params map[string]string // Connection parameters interpolateParams bool // Interpolate placeholders into query string token string // cloud platform token + enableCompression bool // Enable write compression readTimeout time.Duration // read message timeout writeTimeout time.Duration // write message timeout } @@ -143,6 +144,11 @@ func parseDSNParams(cfg *config, params string) (err error) { } case "token": cfg.token = value + case "enableCompression": + cfg.enableCompression, err = strconv.ParseBool(value) + if err != nil { + return &errors.TaosError{Code: 0xffff, ErrStr: "invalid enableCompression value: " + value} + } case "readTimeout": cfg.readTimeout, err = time.ParseDuration(value) if err != nil { diff --git a/taosWS/dsn_test.go b/taosWS/dsn_test.go index edfd013..7d7c7d7 100644 --- a/taosWS/dsn_test.go +++ b/taosWS/dsn_test.go @@ -25,6 +25,15 @@ func TestParseDsn(t *testing.T) { {dsn: "user:passwd@wss(:0)/?interpolateParams=false&test=1", want: &config{user: "user", passwd: "passwd", net: "wss", params: map[string]string{"test": "1"}}}, {dsn: "user:passwd@wss(:0)/?interpolateParams=false&token=token", want: &config{user: "user", passwd: "passwd", net: "wss", token: "token"}}, {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m", want: &config{user: "user", passwd: "passwd", net: "wss", readTimeout: 10 * time.Minute, writeTimeout: 8 * time.Second, interpolateParams: true}}, + {dsn: "user:passwd@wss(:0)/?writeTimeout=8s&readTimeout=10m&enableCompression=true", want: &config{ + user: "user", + passwd: "passwd", + net: "wss", + readTimeout: 10 * time.Minute, + writeTimeout: 8 * time.Second, + interpolateParams: true, + enableCompression: true, + }}, } for _, tc := range tests { t.Run(tc.dsn, func(t *testing.T) { diff --git a/taosWS/proto.go b/taosWS/proto.go index fd2fb39..2731eec 100644 --- a/taosWS/proto.go +++ b/taosWS/proto.go @@ -1,6 +1,10 @@ package taosWS -import "encoding/json" +import ( + "encoding/json" + + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" +) type WSConnectReq struct { ReqID uint64 `json:"req_id"` @@ -69,3 +73,122 @@ type WSAction struct { Action string `json:"action"` Args json.RawMessage `json:"args"` } + +type StmtPrepareRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` + SQL string `json:"sql"` +} + +type StmtPrepareResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + IsInsert bool `json:"is_insert"` +} + +type StmtInitReq struct { + ReqID uint64 `json:"req_id"` +} + +type StmtInitResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} +type StmtCloseRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtCloseResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id,omitempty"` +} + +type StmtGetColFieldsRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtGetColFieldsResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Fields []*stmtCommon.StmtField `json:"fields"` +} + +const ( + BindMessage = 2 +) + +type StmtBindResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtAddBatchRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtAddBatchResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtExecRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtExecResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + Affected int `json:"affected"` +} + +type StmtUseResultRequest struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type StmtUseResultResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + ResultID uint64 `json:"result_id"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes []uint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} diff --git a/taosWS/rows.go b/taosWS/rows.go index b462f4e..7677ab7 100644 --- a/taosWS/rows.go +++ b/taosWS/rows.go @@ -27,6 +27,7 @@ type rows struct { fieldsTypes []uint8 fieldsLengths []int64 precision int + isStmt bool } func (rs *rows) Columns() []string { @@ -177,5 +178,5 @@ func (rs *rows) freeResult() error { if err != nil { return err } - return nil + return tc.writeText(rs.buf.Bytes()) } diff --git a/taosWS/statement.go b/taosWS/statement.go new file mode 100644 index 0000000..d313820 --- /dev/null +++ b/taosWS/statement.go @@ -0,0 +1,517 @@ +package taosWS + +import ( + "bytes" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/param" + "github.com/taosdata/driver-go/v3/common/serializer" + stmtCommon "github.com/taosdata/driver-go/v3/common/stmt" + "github.com/taosdata/driver-go/v3/types" +) + +type Stmt struct { + stmtID uint64 + conn *taosConn + buffer bytes.Buffer + pSql string + isInsert bool + cols []*stmtCommon.StmtField + colTypes *param.ColumnType + queryColTypes []*types.ColumnType +} + +func (stmt *Stmt) Close() error { + err := stmt.conn.stmtClose(stmt.stmtID) + stmt.buffer.Reset() + stmt.conn = nil + return err +} + +func (stmt *Stmt) NumInput() int { + if stmt.colTypes != nil { + return len(stmt.cols) + } + return -1 +} + +func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { + if stmt.conn == nil { + return nil, driver.ErrBadConn + } + if len(args) != len(stmt.cols) { + return nil, fmt.Errorf("stmt exec error: wrong number of parameters") + } + block, err := serializer.SerializeRawBlock(param.NewParamsWithRowValue(args), stmt.colTypes) + if err != nil { + return nil, err + } + err = stmt.conn.stmtBindParam(stmt.stmtID, block) + if err != nil { + return nil, err + } + err = stmt.conn.stmtAddBatch(stmt.stmtID) + if err != nil { + return nil, err + } + affected, err := stmt.conn.stmtExec(stmt.stmtID) + if err != nil { + return nil, err + } + return driver.RowsAffected(affected), nil +} + +func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { + if stmt.conn == nil { + return nil, driver.ErrBadConn + } + block, err := serializer.SerializeRawBlock(param.NewParamsWithRowValue(args), param.NewColumnTypeWithValue(stmt.queryColTypes)) + if err != nil { + return nil, err + } + err = stmt.conn.stmtBindParam(stmt.stmtID, block) + if err != nil { + return nil, err + } + err = stmt.conn.stmtAddBatch(stmt.stmtID) + if err != nil { + return nil, err + } + _, err = stmt.conn.stmtExec(stmt.stmtID) + if err != nil { + return nil, err + } + return stmt.conn.stmtUseResult(stmt.stmtID) +} + +func (stmt *Stmt) CheckNamedValue(v *driver.NamedValue) error { + if stmt.isInsert { + if stmt.cols == nil { + cols, err := stmt.conn.stmtGetColFields(stmt.stmtID) + if err != nil { + return err + } + colTypes := make([]*types.ColumnType, len(cols)) + for i, col := range cols { + t, err := col.GetType() + if err != nil { + return err + } + colTypes[i] = t + } + stmt.cols = cols + stmt.colTypes = param.NewColumnTypeWithValue(colTypes) + } + if v.Ordinal > len(stmt.cols) { + return nil + } + if v.Value == nil { + return nil + } + switch stmt.cols[v.Ordinal-1].FieldType { + case common.TSDB_DATA_TYPE_NULL: + v.Value = nil + case common.TSDB_DATA_TYPE_BOOL: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + v.Value = types.TaosBool(rv.Bool()) + case reflect.Float32, reflect.Float64: + v.Value = types.TaosBool(rv.Float() > 0) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBool(rv.Int() > 0) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosBool(rv.Uint() > 0) + case reflect.String: + vv, err := strconv.ParseBool(rv.String()) + if err != nil { + return err + } + v.Value = types.TaosBool(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bool", v) + } + case common.TSDB_DATA_TYPE_TINYINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosTinyint(1) + } else { + v.Value = types.TaosTinyint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosTinyint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosTinyint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosTinyint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 8) + if err != nil { + return err + } + v.Value = types.TaosTinyint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to tinyint", v) + } + case common.TSDB_DATA_TYPE_SMALLINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosSmallint(1) + } else { + v.Value = types.TaosSmallint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosSmallint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosSmallint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosSmallint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 16) + if err != nil { + return err + } + v.Value = types.TaosSmallint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to smallint", v) + } + case common.TSDB_DATA_TYPE_INT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosInt(1) + } else { + v.Value = types.TaosInt(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosInt(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosInt(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosInt(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 32) + if err != nil { + return err + } + v.Value = types.TaosInt(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to int", v) + } + case common.TSDB_DATA_TYPE_BIGINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosBigint(1) + } else { + v.Value = types.TaosBigint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosBigint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBigint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosBigint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseInt(rv.String(), 0, 64) + if err != nil { + return err + } + v.Value = types.TaosBigint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bigint", v) + } + case common.TSDB_DATA_TYPE_FLOAT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosFloat(1) + } else { + v.Value = types.TaosFloat(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosFloat(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosFloat(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosFloat(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseFloat(rv.String(), 32) + if err != nil { + return err + } + v.Value = types.TaosFloat(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to float", v) + } + case common.TSDB_DATA_TYPE_DOUBLE: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosDouble(1) + } else { + v.Value = types.TaosDouble(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosDouble(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosDouble(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosDouble(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseFloat(rv.String(), 64) + if err != nil { + return err + } + v.Value = types.TaosDouble(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to double", v) + } + case common.TSDB_DATA_TYPE_BINARY: + switch v.Value.(type) { + case string: + v.Value = types.TaosBinary(v.Value.(string)) + case []byte: + v.Value = types.TaosBinary(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to binary", v) + } + case common.TSDB_DATA_TYPE_VARBINARY: + switch v.Value.(type) { + case string: + v.Value = types.TaosVarBinary(v.Value.(string)) + case []byte: + v.Value = types.TaosVarBinary(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to varbinary", v) + } + + case common.TSDB_DATA_TYPE_GEOMETRY: + switch v.Value.(type) { + case string: + v.Value = types.TaosGeometry(v.Value.(string)) + case []byte: + v.Value = types.TaosGeometry(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to geometry", v) + } + + case common.TSDB_DATA_TYPE_TIMESTAMP: + t, is := v.Value.(time.Time) + if is { + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + return nil + } + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Float32, reflect.Float64: + t := common.TimestampConvertToTime(int64(rv.Float()), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + t := common.TimestampConvertToTime(rv.Int(), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + t := common.TimestampConvertToTime(int64(rv.Uint()), int(stmt.cols[v.Ordinal-1].Precision)) + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + case reflect.String: + t, err := time.Parse(time.RFC3339Nano, rv.String()) + if err != nil { + return err + } + v.Value = types.TaosTimestamp{ + T: t, + Precision: int(stmt.cols[v.Ordinal-1].Precision), + } + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to timestamp", v) + } + case common.TSDB_DATA_TYPE_NCHAR: + switch v.Value.(type) { + case string: + v.Value = types.TaosNchar(v.Value.(string)) + case []byte: + v.Value = types.TaosNchar(v.Value.([]byte)) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to nchar", v) + } + case common.TSDB_DATA_TYPE_UTINYINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUTinyint(1) + } else { + v.Value = types.TaosUTinyint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUTinyint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUTinyint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUTinyint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 8) + if err != nil { + return err + } + v.Value = types.TaosUTinyint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to tinyint unsigned", v) + } + case common.TSDB_DATA_TYPE_USMALLINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUSmallint(1) + } else { + v.Value = types.TaosUSmallint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUSmallint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUSmallint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUSmallint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 16) + if err != nil { + return err + } + v.Value = types.TaosUSmallint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to smallint unsigned", v) + } + case common.TSDB_DATA_TYPE_UINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUInt(1) + } else { + v.Value = types.TaosUInt(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUInt(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUInt(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUInt(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 32) + if err != nil { + return err + } + v.Value = types.TaosUInt(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to int unsigned", v) + } + case common.TSDB_DATA_TYPE_UBIGINT: + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + if rv.Bool() { + v.Value = types.TaosUBigint(1) + } else { + v.Value = types.TaosUBigint(0) + } + case reflect.Float32, reflect.Float64: + v.Value = types.TaosUBigint(rv.Float()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosUBigint(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUBigint(rv.Uint()) + case reflect.String: + vv, err := strconv.ParseUint(rv.String(), 0, 64) + if err != nil { + return err + } + v.Value = types.TaosUBigint(vv) + default: + return fmt.Errorf("CheckNamedValue:%v can not convert to bigint unsigned", v) + } + } + return nil + } else { + if v.Value == nil { + return errors.New("CheckNamedValue: value is nil") + } + if v.Ordinal == 1 { + stmt.queryColTypes = nil + } + if len(stmt.queryColTypes) < v.Ordinal { + tmp := stmt.queryColTypes + stmt.queryColTypes = make([]*types.ColumnType, v.Ordinal) + copy(stmt.queryColTypes, tmp) + } + t, is := v.Value.(time.Time) + if is { + v.Value = types.TaosBinary(t.Format(time.RFC3339Nano)) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBinaryType} + return nil + } + rv := reflect.ValueOf(v.Value) + switch rv.Kind() { + case reflect.Bool: + v.Value = types.TaosBool(rv.Bool()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBoolType} + case reflect.Float32, reflect.Float64: + v.Value = types.TaosDouble(rv.Float()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosDoubleType} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v.Value = types.TaosBigint(rv.Int()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosBigintType} + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v.Value = types.TaosUBigint(rv.Uint()) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{Type: types.TaosUBigintType} + case reflect.String: + strVal := rv.String() + v.Value = types.TaosBinary(strVal) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{ + Type: types.TaosBinaryType, + MaxLen: len(strVal), + } + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + bsVal := rv.Bytes() + v.Value = types.TaosBinary(bsVal) + stmt.queryColTypes[v.Ordinal-1] = &types.ColumnType{ + Type: types.TaosBinaryType, + MaxLen: len(bsVal), + } + } else { + return fmt.Errorf("CheckNamedValue: can not convert query value %v", v) + } + default: + return fmt.Errorf("CheckNamedValue: can not convert query value %v", v) + } + return nil + } +} diff --git a/taosWS/statement_test.go b/taosWS/statement_test.go new file mode 100644 index 0000000..1ab008c --- /dev/null +++ b/taosWS/statement_test.go @@ -0,0 +1,2159 @@ +package taosWS + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestStmtExec(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer func() { + t.Log("start3") + db.Close() + t.Log("done3") + }() + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws") + if err != nil { + t.Error(err) + return + } + t.Log("done2") + }() + _, err = db.Exec("create database if not exists test_stmt_driver_ws") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table if not exists test_stmt_driver_ws.ct(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare("insert into test_stmt_driver_ws.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + + if err != nil { + t.Error(err) + return + } + result, err := stmt.Exec(time.Now(), 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + assert.NoError(t, err) + assert.Equal(t, int64(1), affected) + t.Log("done") +} + +func TestStmtQuery(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + defer func() { + db.Exec("drop database if exists test_stmt_driver_ws_q") + }() + _, err = db.Exec("create database if not exists test_stmt_driver_ws_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table if not exists test_stmt_driver_ws_q.ct(ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare("insert into test_stmt_driver_ws_q.ct values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if err != nil { + t.Error(err) + return + } + now := time.Now() + result, err := stmt.Exec(now, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, "binary", "nchar") + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, int64(1), affected) + stmt.Close() + stmt, err = db.Prepare("select * from test_stmt_driver_ws_q.ct where ts = ?") + if err != nil { + t.Error(err) + return + } + rows, err := stmt.Query(now) + if err != nil { + t.Error(err) + return + } + columns, err := rows.Columns() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, []string{"ts", "c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10", "c11", "c12", "c13"}, columns) + count := 0 + for rows.Next() { + count += 1 + var ( + ts time.Time + c1 bool + c2 int8 + c3 int16 + c4 int32 + c5 int64 + c6 uint8 + c7 uint16 + c8 uint32 + c9 uint64 + c10 float32 + c11 float64 + c12 string + c13 string + ) + err = rows.Scan(&ts, + &c1, + &c2, + &c3, + &c4, + &c5, + &c6, + &c7, + &c8, + &c9, + &c10, + &c11, + &c12, + &c13) + assert.NoError(t, err) + assert.Equal(t, now.UnixNano()/1e6, ts.UnixNano()/1e6) + assert.Equal(t, true, c1) + assert.Equal(t, int8(2), c2) + assert.Equal(t, int16(3), c3) + assert.Equal(t, int32(4), c4) + assert.Equal(t, int64(5), c5) + assert.Equal(t, uint8(6), c6) + assert.Equal(t, uint16(7), c7) + assert.Equal(t, uint32(8), c8) + assert.Equal(t, uint64(9), c9) + assert.Equal(t, float32(10), c10) + assert.Equal(t, float64(11), c11) + assert.Equal(t, "binary", c12) + assert.Equal(t, "nchar", c13) + } + assert.Equal(t, 1, count) +} + +func TestStmtConvertExec(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + }() + _, err = db.Exec("create database test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("use test_stmt_driver_ws_convert") + if err != nil { + t.Error(err) + return + } + now := time.Now().Format(time.RFC3339Nano) + tests := []struct { + name string + tbType string + pos string + bind []interface{} + expectValue interface{} + expectError bool + }{ + //bool + { + name: "bool_null", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "bool_err", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, []int{123}}, + expectValue: nil, + expectError: true, + }, + { + name: "bool_bool_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: true, + }, + { + name: "bool_bool_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: false, + }, + { + name: "bool_float_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: true, + }, + { + name: "bool_float_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, float32(0)}, + expectValue: false, + }, + { + name: "bool_int_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, int32(1)}, + expectValue: true, + }, + { + name: "bool_int_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, int32(0)}, + expectValue: false, + }, + { + name: "bool_uint_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, uint32(1)}, + expectValue: true, + }, + { + name: "bool_uint_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, uint32(0)}, + expectValue: false, + }, + { + name: "bool_string_true", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, "true"}, + expectValue: true, + }, + { + name: "bool_string_false", + tbType: "ts timestamp,v bool", + pos: "?,?", + bind: []interface{}{now, "false"}, + expectValue: false, + }, + //tiny int + { + name: "tiny_nil", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "tiny_err", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "tiny_bool_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int8(1), + }, + { + name: "tiny_bool_0", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int8(0), + }, + { + name: "tiny_float_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int8(1), + }, + { + name: "tiny_int_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int8(1), + }, + { + name: "tiny_uint_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int8(1), + }, + { + name: "tiny_string_1", + tbType: "ts timestamp,v tinyint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int8(1), + }, + // small int + { + name: "small_nil", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "small_err", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "small_bool_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int16(1), + }, + { + name: "small_bool_0", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int16(0), + }, + { + name: "small_float_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int16(1), + }, + { + name: "small_int_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int16(1), + }, + { + name: "small_uint_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int16(1), + }, + { + name: "small_string_1", + tbType: "ts timestamp,v smallint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int16(1), + }, + // int + { + name: "int_nil", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "int_err", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "int_bool_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int32(1), + }, + { + name: "int_bool_0", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int32(0), + }, + { + name: "int_float_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int32(1), + }, + { + name: "int_int_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int32(1), + }, + { + name: "int_uint_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int32(1), + }, + { + name: "int_string_1", + tbType: "ts timestamp,v int", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int32(1), + }, + // big int + { + name: "big_nil", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "big_err", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "big_bool_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: int64(1), + }, + { + name: "big_bool_0", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: int64(0), + }, + { + name: "big_float_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: int64(1), + }, + { + name: "big_int_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: int64(1), + }, + { + name: "big_uint_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: int64(1), + }, + { + name: "big_string_1", + tbType: "ts timestamp,v bigint", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: int64(1), + }, + // float + { + name: "float_nil", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "float_err", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "float_bool_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: float32(1), + }, + { + name: "float_bool_0", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: float32(0), + }, + { + name: "float_float_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: float32(1), + }, + { + name: "float_int_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: float32(1), + }, + { + name: "float_uint_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: float32(1), + }, + { + name: "float_string_1", + tbType: "ts timestamp,v float", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: float32(1), + }, + //double + { + name: "double_nil", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "double_err", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "double_bool_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: float64(1), + }, + { + name: "double_bool_0", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: float64(0), + }, + { + name: "double_double_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: float64(1), + }, + { + name: "double_int_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: float64(1), + }, + { + name: "double_uint_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: float64(1), + }, + { + name: "double_string_1", + tbType: "ts timestamp,v double", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: float64(1), + }, + + //tiny int unsigned + { + name: "utiny_nil", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "utiny_err", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "utiny_bool_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint8(1), + }, + { + name: "utiny_bool_0", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint8(0), + }, + { + name: "utiny_float_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_int_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_uint_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint8(1), + }, + { + name: "utiny_string_1", + tbType: "ts timestamp,v tinyint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint8(1), + }, + // small int unsigned + { + name: "usmall_nil", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "usmall_err", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "usmall_bool_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint16(1), + }, + { + name: "usmall_bool_0", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint16(0), + }, + { + name: "usmall_float_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_int_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_uint_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint16(1), + }, + { + name: "usmall_string_1", + tbType: "ts timestamp,v smallint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint16(1), + }, + // int unsigned + { + name: "uint_nil", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "uint_err", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "uint_bool_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint32(1), + }, + { + name: "uint_bool_0", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint32(0), + }, + { + name: "uint_float_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint32(1), + }, + { + name: "uint_int_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint32(1), + }, + { + name: "uint_uint_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint32(1), + }, + { + name: "uint_string_1", + tbType: "ts timestamp,v int unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint32(1), + }, + // big int unsigned + { + name: "ubig_nil", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "ubig_err", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "ubig_bool_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, true}, + expectValue: uint64(1), + }, + { + name: "ubig_bool_0", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, false}, + expectValue: uint64(0), + }, + { + name: "ubig_float_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_int_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_uint_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: uint64(1), + }, + { + name: "ubig_string_1", + tbType: "ts timestamp,v bigint unsigned", + pos: "?,?", + bind: []interface{}{now, "1"}, + expectValue: uint64(1), + }, + //binary + { + name: "binary_nil", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "binary_err", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "binary_string_chinese", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, "中文"}, + expectValue: "中文", + }, + { + name: "binary_bytes_chinese", + tbType: "ts timestamp,v binary(24)", + pos: "?,?", + bind: []interface{}{now, []byte("中文")}, + expectValue: "中文", + }, + //nchar + { + name: "nchar_nil", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "nchar_err", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "binary_string_chinese", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, "中文"}, + expectValue: "中文", + }, + { + name: "binary_bytes_chinese", + tbType: "ts timestamp,v nchar(24)", + pos: "?,?", + bind: []interface{}{now, []byte("中文")}, + expectValue: "中文", + }, + // timestamp + { + name: "ts_nil", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, nil}, + expectValue: nil, + }, + { + name: "ts_err", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, []int{1}}, + expectValue: nil, + expectError: true, + }, + { + name: "ts_time_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, time.Unix(0, 1e6)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_float_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, float32(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_int_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, int(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_uint_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, uint(1)}, + expectValue: time.Unix(0, 1e6), + }, + { + name: "ts_string_1", + tbType: "ts timestamp,v timestamp", + pos: "?,?", + bind: []interface{}{now, "1970-01-01T00:00:00.001Z"}, + expectValue: time.Unix(0, 1e6), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tbName := fmt.Sprintf("test_%s", tt.name) + tbType := tt.tbType + drop := fmt.Sprintf("drop table if exists %s", tbName) + create := fmt.Sprintf("create table if not exists %s(%s)", tbName, tbType) + pos := tt.pos + sql := fmt.Sprintf("insert into %s values(%s)", tbName, pos) + var err error + if _, err = db.Exec(drop); err != nil { + t.Error(err) + return + } + if _, err = db.Exec(create); err != nil { + t.Error(err) + return + } + stmt, err := db.Prepare(sql) + if err != nil { + t.Error(err) + return + } + result, err := stmt.Exec(tt.bind...) + if tt.expectError { + assert.NotNil(t, err) + stmt.Close() + return + } + if err != nil { + t.Error(err) + return + } + affected, err := result.RowsAffected() + if err != nil { + t.Error(err) + return + } + assert.Equal(t, int64(1), affected) + rows, err := db.Query(fmt.Sprintf("select v from %s", tbName)) + if err != nil { + t.Error(err) + return + } + var data []driver.Value + tts, err := rows.ColumnTypes() + if err != nil { + t.Error(err) + return + } + typesL := make([]reflect.Type, 1) + for i, tp := range tts { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + typesL[i] = st + } + for rows.Next() { + values := make([]interface{}, 1) + for i := range values { + values[i] = reflect.New(typesL[i]).Interface() + } + err = rows.Scan(values...) + if err != nil { + t.Error(err) + return + } + v, err := values[0].(driver.Valuer).Value() + if err != nil { + t.Error(err) + } + data = append(data, v) + } + if len(data) != 1 { + t.Errorf("expect %d got %d", 1, len(data)) + return + } + if data[0] != tt.expectValue { + t.Errorf("expect %v got %v", tt.expectValue, data[0]) + return + } + }) + } +} + +func TestStmtConvertQuery(t *testing.T) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + t.Error(err) + return + } + defer db.Close() + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + defer func() { + _, err = db.Exec("drop database if exists test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + }() + _, err = db.Exec("create database test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("use test_stmt_driver_ws_convert_q") + if err != nil { + t.Error(err) + return + } + _, err = db.Exec("create table t0 (ts timestamp," + + "c1 bool," + + "c2 tinyint," + + "c3 smallint," + + "c4 int," + + "c5 bigint," + + "c6 tinyint unsigned," + + "c7 smallint unsigned," + + "c8 int unsigned," + + "c9 bigint unsigned," + + "c10 float," + + "c11 double," + + "c12 binary(20)," + + "c13 nchar(20)" + + ")") + if err != nil { + t.Error(err) + return + } + now := time.Now() + after1s := now.Add(time.Second) + _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',true,2,3,4,5,6,7,8,9,10,11,'binary','nchar')", now.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + _, err = db.Exec(fmt.Sprintf("insert into t0 values('%s',null,null,null,null,null,null,null,null,null,null,null,null,null)", after1s.Format(time.RFC3339Nano))) + if err != nil { + t.Error(err) + return + } + tests := []struct { + name string + field string + where string + bind interface{} + expectNoValue bool + expectValue driver.Value + expectError bool + }{ + //ts + { + name: "ts", + field: "ts", + where: "ts = ?", + bind: now, + expectValue: time.Unix(now.Unix(), int64((now.Nanosecond()/1e6)*1e6)).Local(), + }, + + //bool + { + name: "bool_true", + field: "c1", + where: "c1 = ?", + bind: true, + expectValue: true, + }, + { + name: "bool_false", + field: "c1", + where: "c1 = ?", + bind: false, + expectNoValue: true, + }, + { + name: "tinyint_int8", + field: "c2", + where: "c2 = ?", + bind: int8(2), + expectValue: int8(2), + }, + { + name: "tinyint_iny16", + field: "c2", + where: "c2 = ?", + bind: int16(2), + expectValue: int8(2), + }, + { + name: "tinyint_int32", + field: "c2", + where: "c2 = ?", + bind: int32(2), + expectValue: int8(2), + }, + { + name: "tinyint_int64", + field: "c2", + where: "c2 = ?", + bind: int64(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint8", + field: "c2", + where: "c2 = ?", + bind: uint8(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint16", + field: "c2", + where: "c2 = ?", + bind: uint16(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint32", + field: "c2", + where: "c2 = ?", + bind: uint32(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint64", + field: "c2", + where: "c2 = ?", + bind: uint64(2), + expectValue: int8(2), + }, + { + name: "tinyint_float32", + field: "c2", + where: "c2 = ?", + bind: float32(2), + expectValue: int8(2), + }, + { + name: "tinyint_float64", + field: "c2", + where: "c2 = ?", + bind: float64(2), + expectValue: int8(2), + }, + { + name: "tinyint_int", + field: "c2", + where: "c2 = ?", + bind: int(2), + expectValue: int8(2), + }, + { + name: "tinyint_uint", + field: "c2", + where: "c2 = ?", + bind: uint(2), + expectValue: int8(2), + }, + + // smallint + { + name: "smallint_int8", + field: "c3", + where: "c3 = ?", + bind: int8(3), + expectValue: int16(3), + }, + { + name: "smallint_iny16", + field: "c3", + where: "c3 = ?", + bind: int16(3), + expectValue: int16(3), + }, + { + name: "smallint_int32", + field: "c3", + where: "c3 = ?", + bind: int32(3), + expectValue: int16(3), + }, + { + name: "smallint_int64", + field: "c3", + where: "c3 = ?", + bind: int64(3), + expectValue: int16(3), + }, + { + name: "smallint_uint8", + field: "c3", + where: "c3 = ?", + bind: uint8(3), + expectValue: int16(3), + }, + { + name: "smallint_uint16", + field: "c3", + where: "c3 = ?", + bind: uint16(3), + expectValue: int16(3), + }, + { + name: "smallint_uint32", + field: "c3", + where: "c3 = ?", + bind: uint32(3), + expectValue: int16(3), + }, + { + name: "smallint_uint64", + field: "c3", + where: "c3 = ?", + bind: uint64(3), + expectValue: int16(3), + }, + { + name: "smallint_float32", + field: "c3", + where: "c3 = ?", + bind: float32(3), + expectValue: int16(3), + }, + { + name: "smallint_float64", + field: "c3", + where: "c3 = ?", + bind: float64(3), + expectValue: int16(3), + }, + { + name: "smallint_int", + field: "c3", + where: "c3 = ?", + bind: int(3), + expectValue: int16(3), + }, + { + name: "smallint_uint", + field: "c3", + where: "c3 = ?", + bind: uint(3), + expectValue: int16(3), + }, + + //int + { + name: "int_int8", + field: "c4", + where: "c4 = ?", + bind: int8(4), + expectValue: int32(4), + }, + { + name: "int_iny16", + field: "c4", + where: "c4 = ?", + bind: int16(4), + expectValue: int32(4), + }, + { + name: "int_int32", + field: "c4", + where: "c4 = ?", + bind: int32(4), + expectValue: int32(4), + }, + { + name: "int_int64", + field: "c4", + where: "c4 = ?", + bind: int64(4), + expectValue: int32(4), + }, + { + name: "int_uint8", + field: "c4", + where: "c4 = ?", + bind: uint8(4), + expectValue: int32(4), + }, + { + name: "int_uint16", + field: "c4", + where: "c4 = ?", + bind: uint16(4), + expectValue: int32(4), + }, + { + name: "int_uint32", + field: "c4", + where: "c4 = ?", + bind: uint32(4), + expectValue: int32(4), + }, + { + name: "int_uint64", + field: "c4", + where: "c4 = ?", + bind: uint64(4), + expectValue: int32(4), + }, + { + name: "int_float32", + field: "c4", + where: "c4 = ?", + bind: float32(4), + expectValue: int32(4), + }, + { + name: "int_float64", + field: "c4", + where: "c4 = ?", + bind: float64(4), + expectValue: int32(4), + }, + { + name: "int_int", + field: "c4", + where: "c4 = ?", + bind: int(4), + expectValue: int32(4), + }, + { + name: "int_uint", + field: "c4", + where: "c4 = ?", + bind: uint(4), + expectValue: int32(4), + }, + + //bigint + { + name: "bigint_int8", + field: "c5", + where: "c5 = ?", + bind: int8(5), + expectValue: int64(5), + }, + { + name: "bigint_iny16", + field: "c5", + where: "c5 = ?", + bind: int16(5), + expectValue: int64(5), + }, + { + name: "bigint_int32", + field: "c5", + where: "c5 = ?", + bind: int32(5), + expectValue: int64(5), + }, + { + name: "bigint_int64", + field: "c5", + where: "c5 = ?", + bind: int64(5), + expectValue: int64(5), + }, + { + name: "bigint_uint8", + field: "c5", + where: "c5 = ?", + bind: uint8(5), + expectValue: int64(5), + }, + { + name: "bigint_uint16", + field: "c5", + where: "c5 = ?", + bind: uint16(5), + expectValue: int64(5), + }, + { + name: "bigint_uint32", + field: "c5", + where: "c5 = ?", + bind: uint32(5), + expectValue: int64(5), + }, + { + name: "bigint_uint64", + field: "c5", + where: "c5 = ?", + bind: uint64(5), + expectValue: int64(5), + }, + { + name: "bigint_float32", + field: "c5", + where: "c5 = ?", + bind: float32(5), + expectValue: int64(5), + }, + { + name: "bigint_float64", + field: "c5", + where: "c5 = ?", + bind: float64(5), + expectValue: int64(5), + }, + { + name: "bigint_int", + field: "c5", + where: "c5 = ?", + bind: int(5), + expectValue: int64(5), + }, + { + name: "bigint_uint", + field: "c5", + where: "c5 = ?", + bind: uint(5), + expectValue: int64(5), + }, + + //utinyint + { + name: "utinyint_int8", + field: "c6", + where: "c6 = ?", + bind: int8(6), + expectValue: uint8(6), + }, + { + name: "utinyint_iny16", + field: "c6", + where: "c6 = ?", + bind: int16(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int32", + field: "c6", + where: "c6 = ?", + bind: int32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int64", + field: "c6", + where: "c6 = ?", + bind: int64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint8", + field: "c6", + where: "c6 = ?", + bind: uint8(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint16", + field: "c6", + where: "c6 = ?", + bind: uint16(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint32", + field: "c6", + where: "c6 = ?", + bind: uint32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint64", + field: "c6", + where: "c6 = ?", + bind: uint64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_float32", + field: "c6", + where: "c6 = ?", + bind: float32(6), + expectValue: uint8(6), + }, + { + name: "utinyint_float64", + field: "c6", + where: "c6 = ?", + bind: float64(6), + expectValue: uint8(6), + }, + { + name: "utinyint_int", + field: "c6", + where: "c6 = ?", + bind: int(6), + expectValue: uint8(6), + }, + { + name: "utinyint_uint", + field: "c6", + where: "c6 = ?", + bind: uint(6), + expectValue: uint8(6), + }, + + //usmallint + { + name: "usmallint_int8", + field: "c7", + where: "c7 = ?", + bind: int8(7), + expectValue: uint16(7), + }, + { + name: "usmallint_iny16", + field: "c7", + where: "c7 = ?", + bind: int16(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int32", + field: "c7", + where: "c7 = ?", + bind: int32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int64", + field: "c7", + where: "c7 = ?", + bind: int64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint8", + field: "c7", + where: "c7 = ?", + bind: uint8(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint16", + field: "c7", + where: "c7 = ?", + bind: uint16(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint32", + field: "c7", + where: "c7 = ?", + bind: uint32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint64", + field: "c7", + where: "c7 = ?", + bind: uint64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_float32", + field: "c7", + where: "c7 = ?", + bind: float32(7), + expectValue: uint16(7), + }, + { + name: "usmallint_float64", + field: "c7", + where: "c7 = ?", + bind: float64(7), + expectValue: uint16(7), + }, + { + name: "usmallint_int", + field: "c7", + where: "c7 = ?", + bind: int(7), + expectValue: uint16(7), + }, + { + name: "usmallint_uint", + field: "c7", + where: "c7 = ?", + bind: uint(7), + expectValue: uint16(7), + }, + + //uint + { + name: "uint_int8", + field: "c8", + where: "c8 = ?", + bind: int8(8), + expectValue: uint32(8), + }, + { + name: "uint_iny16", + field: "c8", + where: "c8 = ?", + bind: int16(8), + expectValue: uint32(8), + }, + { + name: "uint_int32", + field: "c8", + where: "c8 = ?", + bind: int32(8), + expectValue: uint32(8), + }, + { + name: "uint_int64", + field: "c8", + where: "c8 = ?", + bind: int64(8), + expectValue: uint32(8), + }, + { + name: "uint_uint8", + field: "c8", + where: "c8 = ?", + bind: uint8(8), + expectValue: uint32(8), + }, + { + name: "uint_uint16", + field: "c8", + where: "c8 = ?", + bind: uint16(8), + expectValue: uint32(8), + }, + { + name: "uint_uint32", + field: "c8", + where: "c8 = ?", + bind: uint32(8), + expectValue: uint32(8), + }, + { + name: "uint_uint64", + field: "c8", + where: "c8 = ?", + bind: uint64(8), + expectValue: uint32(8), + }, + { + name: "uint_float32", + field: "c8", + where: "c8 = ?", + bind: float32(8), + expectValue: uint32(8), + }, + { + name: "uint_float64", + field: "c8", + where: "c8 = ?", + bind: float64(8), + expectValue: uint32(8), + }, + { + name: "uint_int", + field: "c8", + where: "c8 = ?", + bind: int(8), + expectValue: uint32(8), + }, + { + name: "uint_uint", + field: "c8", + where: "c8 = ?", + bind: uint(8), + expectValue: uint32(8), + }, + + //ubigint + { + name: "ubigint_int8", + field: "c9", + where: "c9 = ?", + bind: int8(9), + expectValue: uint64(9), + }, + { + name: "ubigint_iny16", + field: "c9", + where: "c9 = ?", + bind: int16(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int32", + field: "c9", + where: "c9 = ?", + bind: int32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int64", + field: "c9", + where: "c9 = ?", + bind: int64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint8", + field: "c9", + where: "c9 = ?", + bind: uint8(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint16", + field: "c9", + where: "c9 = ?", + bind: uint16(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint32", + field: "c9", + where: "c9 = ?", + bind: uint32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint64", + field: "c9", + where: "c9 = ?", + bind: uint64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_float32", + field: "c9", + where: "c9 = ?", + bind: float32(9), + expectValue: uint64(9), + }, + { + name: "ubigint_float64", + field: "c9", + where: "c9 = ?", + bind: float64(9), + expectValue: uint64(9), + }, + { + name: "ubigint_int", + field: "c9", + where: "c9 = ?", + bind: int(9), + expectValue: uint64(9), + }, + { + name: "ubigint_uint", + field: "c9", + where: "c9 = ?", + bind: uint(9), + expectValue: uint64(9), + }, + + //float + { + name: "float_int8", + field: "c10", + where: "c10 = ?", + bind: int8(10), + expectValue: float32(10), + }, + { + name: "float_iny16", + field: "c10", + where: "c10 = ?", + bind: int16(10), + expectValue: float32(10), + }, + { + name: "float_int32", + field: "c10", + where: "c10 = ?", + bind: int32(10), + expectValue: float32(10), + }, + { + name: "float_int64", + field: "c10", + where: "c10 = ?", + bind: int64(10), + expectValue: float32(10), + }, + { + name: "float_uint8", + field: "c10", + where: "c10 = ?", + bind: uint8(10), + expectValue: float32(10), + }, + { + name: "float_uint16", + field: "c10", + where: "c10 = ?", + bind: uint16(10), + expectValue: float32(10), + }, + { + name: "float_uint32", + field: "c10", + where: "c10 = ?", + bind: uint32(10), + expectValue: float32(10), + }, + { + name: "float_uint64", + field: "c10", + where: "c10 = ?", + bind: uint64(10), + expectValue: float32(10), + }, + { + name: "float_float32", + field: "c10", + where: "c10 = ?", + bind: float32(10), + expectValue: float32(10), + }, + { + name: "float_float64", + field: "c10", + where: "c10 = ?", + bind: float64(10), + expectValue: float32(10), + }, + { + name: "float_int", + field: "c10", + where: "c10 = ?", + bind: int(10), + expectValue: float32(10), + }, + { + name: "float_uint", + field: "c10", + where: "c10 = ?", + bind: uint(10), + expectValue: float32(10), + }, + + //double + { + name: "double_int8", + field: "c11", + where: "c11 = ?", + bind: int8(11), + expectValue: float64(11), + }, + { + name: "double_iny16", + field: "c11", + where: "c11 = ?", + bind: int16(11), + expectValue: float64(11), + }, + { + name: "double_int32", + field: "c11", + where: "c11 = ?", + bind: int32(11), + expectValue: float64(11), + }, + { + name: "double_int64", + field: "c11", + where: "c11 = ?", + bind: int64(11), + expectValue: float64(11), + }, + { + name: "double_uint8", + field: "c11", + where: "c11 = ?", + bind: uint8(11), + expectValue: float64(11), + }, + { + name: "double_uint16", + field: "c11", + where: "c11 = ?", + bind: uint16(11), + expectValue: float64(11), + }, + { + name: "double_uint32", + field: "c11", + where: "c11 = ?", + bind: uint32(11), + expectValue: float64(11), + }, + { + name: "double_uint64", + field: "c11", + where: "c11 = ?", + bind: uint64(11), + expectValue: float64(11), + }, + { + name: "double_float32", + field: "c11", + where: "c11 = ?", + bind: float32(11), + expectValue: float64(11), + }, + { + name: "double_float64", + field: "c11", + where: "c11 = ?", + bind: float64(11), + expectValue: float64(11), + }, + { + name: "double_int", + field: "c11", + where: "c11 = ?", + bind: int(11), + expectValue: float64(11), + }, + { + name: "double_uint", + field: "c11", + where: "c11 = ?", + bind: uint(11), + expectValue: float64(11), + }, + + // binary + { + name: "binary_string", + field: "c12", + where: "c12 = ?", + bind: "binary", + expectValue: "binary", + }, + { + name: "binary_bytes", + field: "c12", + where: "c12 = ?", + bind: []byte("binary"), + expectValue: "binary", + }, + { + name: "binary_string_like", + field: "c12", + where: "c12 like ?", + bind: "bin%", + expectValue: "binary", + }, + + // nchar + { + name: "nchar_string", + field: "c13", + where: "c13 = ?", + bind: "nchar", + expectValue: "nchar", + }, + { + name: "nchar_bytes", + field: "c13", + where: "c13 = ?", + bind: []byte("nchar"), + expectValue: "nchar", + }, + { + name: "nchar_string", + field: "c13", + where: "c13 like ?", + bind: "nch%", + expectValue: "nchar", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sql := fmt.Sprintf("select %s from t0 where %s", tt.field, tt.where) + + stmt, err := db.Prepare(sql) + if err != nil { + t.Error(err) + return + } + defer stmt.Close() + rows, err := stmt.Query(tt.bind) + if tt.expectError { + assert.NotNil(t, err) + stmt.Close() + return + } + if err != nil { + t.Error(err) + return + } + tts, err := rows.ColumnTypes() + typesL := make([]reflect.Type, 1) + for i, tp := range tts { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + typesL[i] = st + } + var data []driver.Value + for rows.Next() { + values := make([]interface{}, 1) + for i := range values { + values[i] = reflect.New(typesL[i]).Interface() + } + err = rows.Scan(values...) + if err != nil { + t.Error(err) + return + } + v, err := values[0].(driver.Valuer).Value() + if err != nil { + t.Error(err) + } + data = append(data, v) + } + if tt.expectNoValue { + if len(data) > 0 { + t.Errorf("expect no value got %#v", data) + return + } + return + } + if len(data) != 1 { + t.Errorf("expect %d got %d", 1, len(data)) + return + } + if data[0] != tt.expectValue { + t.Errorf("expect %v got %v", tt.expectValue, data[0]) + return + } + }) + } +} diff --git a/wrapper/tmq_test.go b/wrapper/tmq_test.go index b4992cf..4e201fa 100644 --- a/wrapper/tmq_test.go +++ b/wrapper/tmq_test.go @@ -1074,7 +1074,7 @@ func TestTMQModify(t *testing.T) { h2 := cgo.NewHandle(c2) targetConn, err := TaosConnect("", "root", "taosdata", "tmq_test_db_modify_target", 0) assert.NoError(t, err) - defer TaosFreeResult(targetConn) + defer TaosClose(targetConn) result = TaosQuery(conn, "create table stb (ts timestamp,"+ "c1 bool,"+ "c2 tinyint,"+ @@ -1170,70 +1170,41 @@ func TestTMQModify(t *testing.T) { } d, err := query(targetConn, "describe stb") assert.NoError(t, err) - if len(d[0]) == 4 { - assert.Equal(t, [][]driver.Value{ - {"ts", "TIMESTAMP", int32(8), ""}, - {"c1", "BOOL", int32(1), ""}, - {"c2", "TINYINT", int32(1), ""}, - {"c3", "SMALLINT", int32(2), ""}, - {"c4", "INT", int32(4), ""}, - {"c5", "BIGINT", int32(8), ""}, - {"c6", "TINYINT UNSIGNED", int32(1), ""}, - {"c7", "SMALLINT UNSIGNED", int32(2), ""}, - {"c8", "INT UNSIGNED", int32(4), ""}, - {"c9", "BIGINT UNSIGNED", int32(8), ""}, - {"c10", "FLOAT", int32(4), ""}, - {"c11", "DOUBLE", int32(8), ""}, - {"c12", "VARCHAR", int32(20), ""}, - {"c13", "NCHAR", int32(20), ""}, - {"tts", "TIMESTAMP", int32(8), "TAG"}, - {"tc1", "BOOL", int32(1), "TAG"}, - {"tc2", "TINYINT", int32(1), "TAG"}, - {"tc3", "SMALLINT", int32(2), "TAG"}, - {"tc4", "INT", int32(4), "TAG"}, - {"tc5", "BIGINT", int32(8), "TAG"}, - {"tc6", "TINYINT UNSIGNED", int32(1), "TAG"}, - {"tc7", "SMALLINT UNSIGNED", int32(2), "TAG"}, - {"tc8", "INT UNSIGNED", int32(4), "TAG"}, - {"tc9", "BIGINT UNSIGNED", int32(8), "TAG"}, - {"tc10", "FLOAT", int32(4), "TAG"}, - {"tc11", "DOUBLE", int32(8), "TAG"}, - {"tc12", "VARCHAR", int32(20), "TAG"}, - {"tc13", "NCHAR", int32(20), "TAG"}, - }, d) - } else { - assert.Equal(t, [][]driver.Value{ - {"ts", "TIMESTAMP", int32(8), "", ""}, - {"c1", "BOOL", int32(1), "", ""}, - {"c2", "TINYINT", int32(1), "", ""}, - {"c3", "SMALLINT", int32(2), "", ""}, - {"c4", "INT", int32(4), "", ""}, - {"c5", "BIGINT", int32(8), "", ""}, - {"c6", "TINYINT UNSIGNED", int32(1), "", ""}, - {"c7", "SMALLINT UNSIGNED", int32(2), "", ""}, - {"c8", "INT UNSIGNED", int32(4), "", ""}, - {"c9", "BIGINT UNSIGNED", int32(8), "", ""}, - {"c10", "FLOAT", int32(4), "", ""}, - {"c11", "DOUBLE", int32(8), "", ""}, - {"c12", "VARCHAR", int32(20), "", ""}, - {"c13", "NCHAR", int32(20), "", ""}, - {"tts", "TIMESTAMP", int32(8), "TAG", ""}, - {"tc1", "BOOL", int32(1), "TAG", ""}, - {"tc2", "TINYINT", int32(1), "TAG", ""}, - {"tc3", "SMALLINT", int32(2), "TAG", ""}, - {"tc4", "INT", int32(4), "TAG", ""}, - {"tc5", "BIGINT", int32(8), "TAG", ""}, - {"tc6", "TINYINT UNSIGNED", int32(1), "TAG", ""}, - {"tc7", "SMALLINT UNSIGNED", int32(2), "TAG", ""}, - {"tc8", "INT UNSIGNED", int32(4), "TAG", ""}, - {"tc9", "BIGINT UNSIGNED", int32(8), "TAG", ""}, - {"tc10", "FLOAT", int32(4), "TAG", ""}, - {"tc11", "DOUBLE", int32(8), "TAG", ""}, - {"tc12", "VARCHAR", int32(20), "TAG", ""}, - {"tc13", "NCHAR", int32(20), "TAG", ""}, - }, d) + expect := [][]driver.Value{ + {"ts", "TIMESTAMP", int32(8), ""}, + {"c1", "BOOL", int32(1), ""}, + {"c2", "TINYINT", int32(1), ""}, + {"c3", "SMALLINT", int32(2), ""}, + {"c4", "INT", int32(4), ""}, + {"c5", "BIGINT", int32(8), ""}, + {"c6", "TINYINT UNSIGNED", int32(1), ""}, + {"c7", "SMALLINT UNSIGNED", int32(2), ""}, + {"c8", "INT UNSIGNED", int32(4), ""}, + {"c9", "BIGINT UNSIGNED", int32(8), ""}, + {"c10", "FLOAT", int32(4), ""}, + {"c11", "DOUBLE", int32(8), ""}, + {"c12", "VARCHAR", int32(20), ""}, + {"c13", "NCHAR", int32(20), ""}, + {"tts", "TIMESTAMP", int32(8), "TAG"}, + {"tc1", "BOOL", int32(1), "TAG"}, + {"tc2", "TINYINT", int32(1), "TAG"}, + {"tc3", "SMALLINT", int32(2), "TAG"}, + {"tc4", "INT", int32(4), "TAG"}, + {"tc5", "BIGINT", int32(8), "TAG"}, + {"tc6", "TINYINT UNSIGNED", int32(1), "TAG"}, + {"tc7", "SMALLINT UNSIGNED", int32(2), "TAG"}, + {"tc8", "INT UNSIGNED", int32(4), "TAG"}, + {"tc9", "BIGINT UNSIGNED", int32(8), "TAG"}, + {"tc10", "FLOAT", int32(4), "TAG"}, + {"tc11", "DOUBLE", int32(8), "TAG"}, + {"tc12", "VARCHAR", int32(20), "TAG"}, + {"tc13", "NCHAR", int32(20), "TAG"}, + } + for rowIndex, values := range d { + for i := 0; i < 4; i++ { + assert.Equal(t, expect[rowIndex][i], values[i]) + } } - }) TMQUnsubscribe(tmq) diff --git a/ws/schemaless/config.go b/ws/schemaless/config.go index 58f65b0..d62eb3b 100644 --- a/ws/schemaless/config.go +++ b/ws/schemaless/config.go @@ -10,14 +10,15 @@ const ( ) type Config struct { - url string - chanLength uint - user string - password string - db string - readTimeout time.Duration - writeTimeout time.Duration - errorHandler func(error) + url string + chanLength uint + user string + password string + db string + readTimeout time.Duration + writeTimeout time.Duration + errorHandler func(error) + enableCompression bool } func NewConfig(url string, chanLength uint, opts ...func(*Config)) *Config { @@ -64,3 +65,9 @@ func SetErrorHandler(errorHandler func(error)) func(*Config) { c.errorHandler = errorHandler } } + +func SetEnableCompression(enableCompression bool) func(*Config) { + return func(c *Config) { + c.enableCompression = enableCompression + } +} diff --git a/ws/schemaless/schemaless.go b/ws/schemaless/schemaless.go index db44e09..5494745 100644 --- a/ws/schemaless/schemaless.go +++ b/ws/schemaless/schemaless.go @@ -44,10 +44,11 @@ func NewSchemaless(config *Config) (*Schemaless, error) { if wsUrl.Scheme != "ws" && wsUrl.Scheme != "wss" { return nil, errors.New("config url scheme error") } - if len(wsUrl.Path) == 0 || wsUrl.Path != "/rest/schemaless" { - wsUrl.Path = "/rest/schemaless" - } - ws, _, err := common.DefaultDialer.Dial(wsUrl.String(), nil) + wsUrl.Path = "/ws" + dialer := common.DefaultDialer + dialer.EnableCompression = config.enableCompression + ws, _, err := dialer.Dial(wsUrl.String(), nil) + ws.EnableWriteCompression(config.enableCompression) if err != nil { return nil, fmt.Errorf("dial ws error: %s", err) } diff --git a/ws/schemaless/schemaless_test.go b/ws/schemaless/schemaless_test.go index bfa3199..d3754e2 100644 --- a/ws/schemaless/schemaless_test.go +++ b/ws/schemaless/schemaless_test.go @@ -62,12 +62,13 @@ func TestSchemaless_Insert(t *testing.T) { } defer func() { _ = after() }() - s, err := NewSchemaless(NewConfig("ws://localhost:6041/rest/schemaless", 1, + s, err := NewSchemaless(NewConfig("ws://localhost:6041", 1, SetDb("test_schemaless_ws"), SetReadTimeout(10*time.Second), SetWriteTimeout(10*time.Second), SetUser("root"), SetPassword("taosdata"), + SetEnableCompression(true), SetErrorHandler(func(err error) { t.Fatal(err) }), diff --git a/ws/stmt/config.go b/ws/stmt/config.go index 332ac55..7eab614 100644 --- a/ws/stmt/config.go +++ b/ws/stmt/config.go @@ -6,15 +6,16 @@ import ( ) type Config struct { - Url string - ChanLength uint - MessageTimeout time.Duration - WriteWait time.Duration - ErrorHandler func(connector *Connector, err error) - CloseHandler func() - User string - Password string - DB string + Url string + ChanLength uint + MessageTimeout time.Duration + WriteWait time.Duration + ErrorHandler func(connector *Connector, err error) + CloseHandler func() + User string + Password string + DB string + EnableCompression bool } func NewConfig(url string, chanLength uint) *Config { @@ -60,3 +61,7 @@ func (c *Config) SetErrorHandler(f func(connector *Connector, err error)) { func (c *Config) SetCloseHandler(f func()) { c.CloseHandler = f } + +func (c *Config) SetEnableCompression(enableCompression bool) { + c.EnableCompression = enableCompression +} diff --git a/ws/stmt/connector.go b/ws/stmt/connector.go index 01ba361..0cb2d39 100644 --- a/ws/stmt/connector.go +++ b/ws/stmt/connector.go @@ -3,8 +3,10 @@ package stmt import ( "container/list" "context" + "encoding/binary" "errors" "fmt" + "net/url" "sync" "sync/atomic" "time" @@ -44,10 +46,18 @@ func NewConnector(config *Config) (*Connector, error) { if config.WriteWait > 0 { writeTimeout = config.WriteWait } - ws, _, err := common.DefaultDialer.Dial(config.Url, nil) + dialer := common.DefaultDialer + dialer.EnableCompression = config.EnableCompression + u, err := url.Parse(config.Url) if err != nil { return nil, err } + u.Path = "/ws" + ws, _, err := dialer.Dial(u.String(), nil) + if err != nil { + return nil, err + } + ws.EnableWriteCompression(config.EnableCompression) defer func() { if connector == nil { ws.Close() @@ -121,6 +131,7 @@ func NewConnector(config *Config) (*Connector, error) { } wsClient.TextMessageHandler = connector.handleTextMessage + wsClient.BinaryMessageHandler = connector.handleBinaryMessage wsClient.ErrorHandler = connector.handleError go wsClient.WritePump() go wsClient.ReadPump() @@ -150,6 +161,17 @@ func (c *Connector) handleTextMessage(message []byte) { c.listLock.Unlock() } +func (c *Connector) handleBinaryMessage(message []byte) { + reqID := binary.LittleEndian.Uint64(message[8:16]) + c.listLock.Lock() + element := c.findOutChanByID(reqID) + if element != nil { + element.Value.(*IndexedChan).channel <- message + c.sendChanList.Remove(element) + } + c.listLock.Unlock() +} + type IndexedChan struct { index uint64 channel chan []byte diff --git a/ws/stmt/proto.go b/ws/stmt/proto.go index 2fed0ab..b5dc92d 100644 --- a/ws/stmt/proto.go +++ b/ws/stmt/proto.go @@ -15,6 +15,10 @@ const ( STMTAddBatch = "add_batch" STMTExec = "exec" STMTClose = "close" + STMTUseResult = "use_result" + WSFetch = "fetch" + WSFetchBlock = "fetch_block" + WSFreeResult = "free_result" ) type ConnectReq struct { @@ -134,3 +138,50 @@ type CloseReq struct { ReqID uint64 `json:"req_id"` StmtID uint64 `json:"stmt_id"` } + +type UseResultReq struct { + ReqID uint64 `json:"req_id"` + StmtID uint64 `json:"stmt_id"` +} + +type UseResultResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + StmtID uint64 `json:"stmt_id"` + ResultID uint64 `json:"result_id"` + FieldsCount int `json:"fields_count"` + FieldsNames []string `json:"fields_names"` + FieldsTypes []uint8 `json:"fields_types"` + FieldsLengths []int64 `json:"fields_lengths"` + Precision int `json:"precision"` +} + +type WSFetchReq struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} + +type WSFetchResp struct { + Code int `json:"code"` + Message string `json:"message"` + Action string `json:"action"` + ReqID uint64 `json:"req_id"` + Timing int64 `json:"timing"` + ID uint64 `json:"id"` + Completed bool `json:"completed"` + Lengths []int `json:"lengths"` + Rows int `json:"rows"` +} + +type WSFetchBlockReq struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} + +type WSFreeResultRequest struct { + ReqID uint64 `json:"req_id"` + ID uint64 `json:"id"` +} diff --git a/ws/stmt/rows.go b/ws/stmt/rows.go new file mode 100644 index 0000000..78f6c75 --- /dev/null +++ b/ws/stmt/rows.go @@ -0,0 +1,171 @@ +package stmt + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "io" + "reflect" + "unsafe" + + "github.com/taosdata/driver-go/v3/common" + "github.com/taosdata/driver-go/v3/common/parser" + "github.com/taosdata/driver-go/v3/common/pointer" + taosErrors "github.com/taosdata/driver-go/v3/errors" + "github.com/taosdata/driver-go/v3/ws/client" +) + +type Rows struct { + buf *bytes.Buffer + blockPtr unsafe.Pointer + blockOffset int + blockSize int + resultID uint64 + block []byte + conn *Connector + fieldsCount int + fieldsNames []string + fieldsTypes []uint8 + fieldsLengths []int64 + precision int +} + +func (rs *Rows) Columns() []string { + return rs.fieldsNames +} + +func (rs *Rows) ColumnTypeDatabaseTypeName(i int) string { + return common.TypeNameMap[int(rs.fieldsTypes[i])] +} + +func (rs *Rows) ColumnTypeLength(i int) (length int64, ok bool) { + return rs.fieldsLengths[i], ok +} + +func (rs *Rows) ColumnTypeScanType(i int) reflect.Type { + t, exist := common.ColumnTypeMap[int(rs.fieldsTypes[i])] + if !exist { + return common.UnknownType + } + return t +} + +func (rs *Rows) Close() error { + rs.blockPtr = nil + rs.block = nil + return rs.freeResult() +} + +func (rs *Rows) Next(dest []driver.Value) error { + if rs.blockPtr == nil || rs.blockOffset >= rs.blockSize { + err := rs.taosFetchBlock() + if err != nil { + return err + } + } + if rs.blockSize == 0 { + rs.blockPtr = nil + rs.block = nil + return io.EOF + } + parser.ReadRow(dest, rs.blockPtr, rs.blockSize, rs.blockOffset, rs.fieldsTypes, rs.precision) + rs.blockOffset += 1 + return nil +} + +func (rs *Rows) taosFetchBlock() error { + reqID := rs.conn.generateReqID() + req := &WSFetchReq{ + ReqID: reqID, + ID: rs.resultID, + } + args, err := json.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: WSFetch, + Args: args, + } + rs.buf.Reset() + envelope := rs.conn.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + rs.conn.client.PutEnvelope(envelope) + return err + } + respBytes, err := rs.conn.sendText(reqID, envelope) + if err != nil { + return err + } + var resp WSFetchResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return err + } + if resp.Code != 0 { + return taosErrors.NewError(resp.Code, resp.Message) + } + if resp.Completed { + rs.blockSize = 0 + return nil + } else { + rs.blockSize = resp.Rows + return rs.fetchBlock() + } +} + +func (rs *Rows) fetchBlock() error { + req := &WSFetchBlockReq{ + ReqID: rs.resultID, + ID: rs.resultID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: WSFetchBlock, + Args: args, + } + rs.buf.Reset() + envelope := rs.conn.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + rs.conn.client.PutEnvelope(envelope) + return err + } + respBytes, err := rs.conn.sendText(rs.resultID, envelope) + if err != nil { + return err + } + rs.block = respBytes + rs.blockPtr = pointer.AddUintptr(unsafe.Pointer(&rs.block[0]), 16) + rs.blockOffset = 0 + return nil +} + +func (rs *Rows) freeResult() error { + reqID := rs.conn.generateReqID() + req := &WSFreeResultRequest{ + ReqID: reqID, + ID: rs.resultID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return err + } + action := &client.WSAction{ + Action: WSFreeResult, + Args: args, + } + rs.buf.Reset() + envelope := rs.conn.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + rs.conn.client.PutEnvelope(envelope) + return err + } + rs.conn.sendTextWithoutResp(envelope) + return nil +} diff --git a/ws/stmt/stmt.go b/ws/stmt/stmt.go index e3c4c74..9833647 100644 --- a/ws/stmt/stmt.go +++ b/ws/stmt/stmt.go @@ -1,6 +1,7 @@ package stmt import ( + "bytes" "encoding/binary" "github.com/taosdata/driver-go/v3/common/param" @@ -230,6 +231,50 @@ func (s *Stmt) GetAffectedRows() int { return s.lastAffected } +func (s *Stmt) UseResult() (*Rows, error) { + reqID := s.connector.generateReqID() + req := &UseResultReq{ + ReqID: reqID, + StmtID: s.id, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: STMTUseResult, + Args: args, + } + envelope := s.connector.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + s.connector.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := s.connector.sendText(reqID, envelope) + if err != nil { + return nil, err + } + var resp UseResultResp + err = client.JsonI.Unmarshal(respBytes, &resp) + if err != nil { + return nil, err + } + if resp.Code != 0 { + return nil, taosErrors.NewError(resp.Code, resp.Message) + } + return &Rows{ + buf: &bytes.Buffer{}, + conn: s.connector, + resultID: resp.ResultID, + fieldsCount: resp.FieldsCount, + fieldsNames: resp.FieldsNames, + fieldsTypes: resp.FieldsTypes, + fieldsLengths: resp.FieldsLengths, + precision: resp.Precision, + }, nil +} + func (s *Stmt) Close() error { reqID := s.connector.generateReqID() req := &CloseReq{ diff --git a/ws/stmt/stmt_test.go b/ws/stmt/stmt_test.go index 4cc1631..7cd9496 100644 --- a/ws/stmt/stmt_test.go +++ b/ws/stmt/stmt_test.go @@ -18,12 +18,12 @@ import ( "github.com/taosdata/driver-go/v3/ws/client" ) -func prepareEnv() error { +func prepareEnv(db string) error { var err error steps := []string{ - "drop database if exists test_ws_stmt", - "create database test_ws_stmt", - "create table test_ws_stmt.all_json(ts timestamp," + + "drop database if exists " + db, + "create database " + db, + "create table " + db + ".all_json(ts timestamp," + "c1 bool," + "c2 tinyint," + "c3 smallint," + @@ -39,7 +39,7 @@ func prepareEnv() error { "c13 nchar(20)" + ")" + "tags(t json)", - "create table test_ws_stmt.all_all(" + + "create table " + db + ".all_all(" + "ts timestamp," + "c1 bool," + "c2 tinyint," + @@ -80,11 +80,11 @@ func prepareEnv() error { return nil } -func cleanEnv() error { +func cleanEnv(db string) error { var err error time.Sleep(2 * time.Second) steps := []string{ - "drop database if exists test_ws_stmt", + "drop database if exists " + db, } for _, step := range steps { err = doRequest(step) @@ -151,19 +151,20 @@ func query(payload string) (*common.TDEngineRestfulResp, error) { // @date: 2023/10/13 11:35 // @description: test stmt over websocket func TestStmt(t *testing.T) { - err := prepareEnv() + err := prepareEnv("test_ws_stmt") if err != nil { t.Error(err) return } - defer cleanEnv() + defer cleanEnv("test_ws_stmt") now := time.Now() - config := NewConfig("ws://127.0.0.1:6041/rest/stmt", 0) + config := NewConfig("ws://127.0.0.1:6041", 0) config.SetConnectUser("root") config.SetConnectPass("taosdata") config.SetConnectDB("test_ws_stmt") config.SetMessageTimeout(common.DefaultMessageTimeout) config.SetWriteWait(common.DefaultWriteWait) + config.SetEnableCompression(true) config.SetErrorHandler(func(connector *Connector, err error) { t.Log(err) }) @@ -614,3 +615,405 @@ func marshalBody(body io.Reader, bufferSize int) (*common.TDEngineRestfulResp, e } return &result, nil } + +func TestSTMTQuery(t *testing.T) { + err := prepareEnv("test_ws_stmt_query") + if err != nil { + t.Error(err) + return + } + defer cleanEnv("test_ws_stmt_query") + now := time.Now() + config := NewConfig("ws://127.0.0.1:6041", 0) + config.SetConnectUser("root") + config.SetConnectPass("taosdata") + config.SetConnectDB("test_ws_stmt_query") + config.SetMessageTimeout(common.DefaultMessageTimeout) + config.SetWriteWait(common.DefaultWriteWait) + config.SetEnableCompression(true) + config.SetErrorHandler(func(connector *Connector, err error) { + t.Log(err) + }) + config.SetCloseHandler(func() { + t.Log("stmt websocket closed") + }) + connector, err := NewConnector(config) + if err != nil { + t.Error(err) + return + } + defer connector.Close() + { + stmt, err := connector.Init() + if err != nil { + t.Error(err) + return + } + defer stmt.Close() + err = stmt.Prepare("insert into ? using all_json tags(?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + if err != nil { + t.Error(err) + return + } + err = stmt.SetTableName("tb1") + if err != nil { + t.Error(err) + return + } + err = stmt.SetTags(param.NewParam(1).AddJson([]byte(`{"tb":1}`)), param.NewColumnType(1).AddJson(0)) + if err != nil { + t.Error(err) + return + } + params := []*param.Param{ + param.NewParam(3).AddTimestamp(now, 0).AddTimestamp(now.Add(time.Second), 0).AddTimestamp(now.Add(time.Second*2), 0), + param.NewParam(3).AddBool(true).AddNull().AddBool(true), + param.NewParam(3).AddTinyint(1).AddNull().AddTinyint(1), + param.NewParam(3).AddSmallint(1).AddNull().AddSmallint(1), + param.NewParam(3).AddInt(1).AddNull().AddInt(1), + param.NewParam(3).AddBigint(1).AddNull().AddBigint(1), + param.NewParam(3).AddUTinyint(1).AddNull().AddUTinyint(1), + param.NewParam(3).AddUSmallint(1).AddNull().AddUSmallint(1), + param.NewParam(3).AddUInt(1).AddNull().AddUInt(1), + param.NewParam(3).AddUBigint(1).AddNull().AddUBigint(1), + param.NewParam(3).AddFloat(1).AddNull().AddFloat(1), + param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), + param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), + param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + } + paramTypes := param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0) + err = stmt.BindParam(params, paramTypes) + if err != nil { + t.Error(err) + return + } + err = stmt.AddBatch() + if err != nil { + t.Error(err) + return + } + err = stmt.Exec() + if err != nil { + t.Error(err) + return + } + affected := stmt.GetAffectedRows() + if !assert.Equal(t, 3, affected) { + return + } + err = stmt.Prepare("select * from all_json where ts >=? order by ts") + assert.NoError(t, err) + queryTime := now.Format(time.RFC3339Nano) + params = []*param.Param{param.NewParam(1).AddBinary([]byte(queryTime))} + paramTypes = param.NewColumnType(1).AddBinary(len(queryTime)) + err = stmt.BindParam(params, paramTypes) + assert.NoError(t, err) + err = stmt.AddBatch() + assert.NoError(t, err) + err = stmt.Exec() + assert.NoError(t, err) + rows, err := stmt.UseResult() + assert.NoError(t, err) + columns := rows.Columns() + assert.Equal(t, 15, len(columns)) + expectColumns := []string{ + "ts", + "c1", + "c2", + "c3", + "c4", + "c5", + "c6", + "c7", + "c8", + "c9", + "c10", + "c11", + "c12", + "c13", + "t", + } + for i := 0; i < 14; i++ { + assert.Equal(t, columns[i], expectColumns[i]) + rows.ColumnTypeDatabaseTypeName(i) + rows.ColumnTypeLength(i) + rows.ColumnTypeScanType(i) + } + var result [][]driver.Value + for { + values := make([]driver.Value, 15) + err = rows.Next(values) + if err != nil { + if err == io.EOF { + break + } + assert.NoError(t, err) + } + result = append(result, values) + } + assert.Equal(t, 3, len(result)) + row1 := result[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1]) + assert.Equal(t, int8(1), row1[2]) + assert.Equal(t, int16(1), row1[3]) + assert.Equal(t, int32(1), row1[4]) + assert.Equal(t, int64(1), row1[5]) + assert.Equal(t, uint8(1), row1[6]) + assert.Equal(t, uint16(1), row1[7]) + assert.Equal(t, uint32(1), row1[8]) + assert.Equal(t, uint64(1), row1[9]) + assert.Equal(t, float32(1), row1[10]) + assert.Equal(t, float64(1), row1[11]) + assert.Equal(t, "test_binary", row1[12]) + assert.Equal(t, "test_nchar", row1[13]) + assert.Equal(t, []byte(`{"tb":1}`), row1[14]) + row2 := result[1] + assert.Equal(t, now.Add(time.Second).UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, []byte(`{"tb":1}`), row2[14]) + row3 := result[2] + assert.Equal(t, now.Add(time.Second*2).UnixNano()/1e6, row3[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row3[1]) + assert.Equal(t, int8(1), row3[2]) + assert.Equal(t, int16(1), row3[3]) + assert.Equal(t, int32(1), row3[4]) + assert.Equal(t, int64(1), row3[5]) + assert.Equal(t, uint8(1), row3[6]) + assert.Equal(t, uint16(1), row3[7]) + assert.Equal(t, uint32(1), row3[8]) + assert.Equal(t, uint64(1), row3[9]) + assert.Equal(t, float32(1), row3[10]) + assert.Equal(t, float64(1), row3[11]) + assert.Equal(t, "test_binary", row3[12]) + assert.Equal(t, "test_nchar", row3[13]) + assert.Equal(t, []byte(`{"tb":1}`), row3[14]) + } + { + stmt, err := connector.Init() + if err != nil { + t.Error(err) + return + } + defer stmt.Close() + err = stmt.Prepare("insert into ? using all_all tags(?,?,?,?,?,?,?,?,?,?,?,?,?,?) values(?,?,?,?,?,?,?,?,?,?,?,?,?,?)") + err = stmt.SetTableName("tb1") + if err != nil { + t.Error(err) + return + } + + err = stmt.SetTableName("tb2") + if err != nil { + t.Error(err) + return + } + err = stmt.SetTags( + param.NewParam(14). + AddTimestamp(now, 0). + AddBool(true). + AddTinyint(2). + AddSmallint(2). + AddInt(2). + AddBigint(2). + AddUTinyint(2). + AddUSmallint(2). + AddUInt(2). + AddUBigint(2). + AddFloat(2). + AddDouble(2). + AddBinary([]byte("tb2")). + AddNchar("tb2"), + param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0), + ) + if err != nil { + t.Error(err) + return + } + params := []*param.Param{ + param.NewParam(3).AddTimestamp(now, 0).AddTimestamp(now.Add(time.Second), 0).AddTimestamp(now.Add(time.Second*2), 0), + param.NewParam(3).AddBool(true).AddNull().AddBool(true), + param.NewParam(3).AddTinyint(1).AddNull().AddTinyint(1), + param.NewParam(3).AddSmallint(1).AddNull().AddSmallint(1), + param.NewParam(3).AddInt(1).AddNull().AddInt(1), + param.NewParam(3).AddBigint(1).AddNull().AddBigint(1), + param.NewParam(3).AddUTinyint(1).AddNull().AddUTinyint(1), + param.NewParam(3).AddUSmallint(1).AddNull().AddUSmallint(1), + param.NewParam(3).AddUInt(1).AddNull().AddUInt(1), + param.NewParam(3).AddUBigint(1).AddNull().AddUBigint(1), + param.NewParam(3).AddFloat(1).AddNull().AddFloat(1), + param.NewParam(3).AddDouble(1).AddNull().AddDouble(1), + param.NewParam(3).AddBinary([]byte("test_binary")).AddNull().AddBinary([]byte("test_binary")), + param.NewParam(3).AddNchar("test_nchar").AddNull().AddNchar("test_nchar"), + } + paramTypes := param.NewColumnType(14). + AddTimestamp(). + AddBool(). + AddTinyint(). + AddSmallint(). + AddInt(). + AddBigint(). + AddUTinyint(). + AddUSmallint(). + AddUInt(). + AddUBigint(). + AddFloat(). + AddDouble(). + AddBinary(0). + AddNchar(0) + err = stmt.BindParam(params, paramTypes) + if err != nil { + t.Error(err) + return + } + err = stmt.AddBatch() + if err != nil { + t.Error(err) + return + } + err = stmt.Exec() + if err != nil { + t.Error(err) + return + } + affected := stmt.GetAffectedRows() + if !assert.Equal(t, 3, affected) { + return + } + err = stmt.Prepare("select * from all_all where ts >=? order by ts") + assert.NoError(t, err) + queryTime := now.Format(time.RFC3339Nano) + params = []*param.Param{param.NewParam(1).AddBinary([]byte(queryTime))} + paramTypes = param.NewColumnType(1).AddBinary(len(queryTime)) + err = stmt.BindParam(params, paramTypes) + assert.NoError(t, err) + err = stmt.AddBatch() + assert.NoError(t, err) + err = stmt.Exec() + assert.NoError(t, err) + rows, err := stmt.UseResult() + assert.NoError(t, err) + columns := rows.Columns() + assert.Equal(t, 28, len(columns)) + var result [][]driver.Value + for { + values := make([]driver.Value, 28) + err = rows.Next(values) + if err != nil { + if err == io.EOF { + break + } + assert.NoError(t, err) + } + result = append(result, values) + } + assert.Equal(t, 3, len(result)) + row1 := result[0] + assert.Equal(t, now.UnixNano()/1e6, row1[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[1]) + assert.Equal(t, int8(1), row1[2]) + assert.Equal(t, int16(1), row1[3]) + assert.Equal(t, int32(1), row1[4]) + assert.Equal(t, int64(1), row1[5]) + assert.Equal(t, uint8(1), row1[6]) + assert.Equal(t, uint16(1), row1[7]) + assert.Equal(t, uint32(1), row1[8]) + assert.Equal(t, uint64(1), row1[9]) + assert.Equal(t, float32(1), row1[10]) + assert.Equal(t, float64(1), row1[11]) + assert.Equal(t, "test_binary", row1[12]) + assert.Equal(t, "test_nchar", row1[13]) + assert.Equal(t, now.UnixNano()/1e6, row1[14].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[15]) + assert.Equal(t, int8(2), row1[16]) + assert.Equal(t, int16(2), row1[17]) + assert.Equal(t, int32(2), row1[18]) + assert.Equal(t, int64(2), row1[19]) + assert.Equal(t, uint8(2), row1[20]) + assert.Equal(t, uint16(2), row1[21]) + assert.Equal(t, uint32(2), row1[22]) + assert.Equal(t, uint64(2), row1[23]) + assert.Equal(t, float32(2), row1[24]) + assert.Equal(t, float64(2), row1[25]) + assert.Equal(t, "tb2", row1[26]) + assert.Equal(t, "tb2", row1[27]) + row2 := result[1] + assert.Equal(t, now.Add(time.Second).UnixNano()/1e6, row2[0].(time.Time).UnixNano()/1e6) + for i := 1; i < 14; i++ { + assert.Nil(t, row2[i]) + } + assert.Equal(t, now.UnixNano()/1e6, row1[14].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row1[15]) + assert.Equal(t, int8(2), row1[16]) + assert.Equal(t, int16(2), row1[17]) + assert.Equal(t, int32(2), row1[18]) + assert.Equal(t, int64(2), row1[19]) + assert.Equal(t, uint8(2), row1[20]) + assert.Equal(t, uint16(2), row1[21]) + assert.Equal(t, uint32(2), row1[22]) + assert.Equal(t, uint64(2), row1[23]) + assert.Equal(t, float32(2), row1[24]) + assert.Equal(t, float64(2), row1[25]) + assert.Equal(t, "tb2", row1[26]) + assert.Equal(t, "tb2", row1[27]) + row3 := result[2] + assert.Equal(t, now.Add(time.Second*2).UnixNano()/1e6, row3[0].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row3[1]) + assert.Equal(t, int8(1), row3[2]) + assert.Equal(t, int16(1), row3[3]) + assert.Equal(t, int32(1), row3[4]) + assert.Equal(t, int64(1), row3[5]) + assert.Equal(t, uint8(1), row3[6]) + assert.Equal(t, uint16(1), row3[7]) + assert.Equal(t, uint32(1), row3[8]) + assert.Equal(t, uint64(1), row3[9]) + assert.Equal(t, float32(1), row3[10]) + assert.Equal(t, float64(1), row3[11]) + assert.Equal(t, "test_binary", row3[12]) + assert.Equal(t, "test_nchar", row3[13]) + assert.Equal(t, now.UnixNano()/1e6, row3[14].(time.Time).UnixNano()/1e6) + assert.Equal(t, true, row3[15]) + assert.Equal(t, int8(2), row3[16]) + assert.Equal(t, int16(2), row3[17]) + assert.Equal(t, int32(2), row3[18]) + assert.Equal(t, int64(2), row3[19]) + assert.Equal(t, uint8(2), row3[20]) + assert.Equal(t, uint16(2), row3[21]) + assert.Equal(t, uint32(2), row3[22]) + assert.Equal(t, uint64(2), row3[23]) + assert.Equal(t, float32(2), row3[24]) + assert.Equal(t, float64(2), row3[25]) + assert.Equal(t, "tb2", row3[26]) + assert.Equal(t, "tb2", row3[27]) + } +} diff --git a/ws/tmq/config.go b/ws/tmq/config.go index 88ed25b..36ec47e 100644 --- a/ws/tmq/config.go +++ b/ws/tmq/config.go @@ -22,6 +22,7 @@ type config struct { AutoCommitIntervalMS string SnapshotEnable string WithTableName string + EnableCompression bool } func newConfig(url string, chanLength uint) *config { @@ -138,3 +139,12 @@ func (c *config) setWithTableName(withTableName tmq.ConfigValue) error { } return nil } + +func (c *config) setEnableCompression(enableCompression tmq.ConfigValue) error { + var ok bool + c.EnableCompression, ok = enableCompression.(bool) + if !ok { + return fmt.Errorf("ws.message.enableCompression requires bool got %T", enableCompression) + } + return nil +} diff --git a/ws/tmq/consumer.go b/ws/tmq/consumer.go index df1889b..ec8e40e 100644 --- a/ws/tmq/consumer.go +++ b/ws/tmq/consumer.go @@ -6,6 +6,8 @@ import ( "encoding/binary" "errors" "fmt" + "net/url" + "strconv" "sync" "sync/atomic" "time" @@ -21,26 +23,27 @@ import ( ) type Consumer struct { - client *client.Client - requestID uint64 - err error - latestMessageID uint64 - listLock sync.RWMutex - sendChanList *list.List - messageTimeout time.Duration - url string - user string - password string - groupID string - clientID string - offsetRest string - autoCommit string - autoCommitIntervalMS string - snapshotEnable string - withTableName string - closeOnce sync.Once - closeChan chan struct{} - topics []string + client *client.Client + requestID uint64 + err error + dataParser *parser.TMQRawDataParser + listLock sync.RWMutex + sendChanList *list.List + messageTimeout time.Duration + autoCommit bool + autoCommitInterval time.Duration + nextAutoCommitTime time.Time + url string + user string + password string + groupID string + clientID string + offsetRest string + snapshotEnable string + withTableName string + closeOnce sync.Once + closeChan chan struct{} + topics []string } type IndexedChan struct { @@ -63,37 +66,60 @@ func NewConsumer(conf *tmq.ConfigMap) (*Consumer, error) { if err != nil { return nil, err } - ws, _, err := common.DefaultDialer.Dial(config.Url, nil) + autoCommit := true + if config.AutoCommit == "false" { + autoCommit = false + } + autoCommitInterval := time.Second * 5 + if config.AutoCommitIntervalMS != "" { + interval, err := strconv.ParseUint(config.AutoCommitIntervalMS, 10, 64) + if err != nil { + return nil, err + } + autoCommitInterval = time.Millisecond * time.Duration(interval) + } + + dialer := common.DefaultDialer + dialer.EnableCompression = config.EnableCompression + u, err := url.Parse(config.Url) + if err != nil { + return nil, err + } + u.Path = "/rest/tmq" + ws, _, err := dialer.Dial(u.String(), nil) if err != nil { return nil, err } + ws.EnableWriteCompression(config.EnableCompression) wsClient := client.NewClient(ws, config.ChanLength) - tmq := &Consumer{ - client: wsClient, - requestID: 0, - sendChanList: list.New(), - messageTimeout: config.MessageTimeout, - url: config.Url, - user: config.User, - password: config.Password, - groupID: config.GroupID, - clientID: config.ClientID, - offsetRest: config.OffsetRest, - autoCommit: config.AutoCommit, - autoCommitIntervalMS: config.AutoCommitIntervalMS, - snapshotEnable: config.SnapshotEnable, - withTableName: config.WithTableName, - closeChan: make(chan struct{}), + + consumer := &Consumer{ + client: wsClient, + requestID: 0, + sendChanList: list.New(), + messageTimeout: config.MessageTimeout, + url: config.Url, + user: config.User, + password: config.Password, + groupID: config.GroupID, + clientID: config.ClientID, + offsetRest: config.OffsetRest, + autoCommit: autoCommit, + autoCommitInterval: autoCommitInterval, + snapshotEnable: config.SnapshotEnable, + withTableName: config.WithTableName, + closeChan: make(chan struct{}), + dataParser: parser.NewTMQRawDataParser(), } if config.WriteWait > 0 { wsClient.WriteWait = config.WriteWait } - wsClient.BinaryMessageHandler = tmq.handleBinaryMessage - wsClient.TextMessageHandler = tmq.handleTextMessage - wsClient.ErrorHandler = tmq.handleError + wsClient.BinaryMessageHandler = consumer.handleBinaryMessage + wsClient.TextMessageHandler = consumer.handleTextMessage + wsClient.ErrorHandler = consumer.handleError go wsClient.WritePump() go wsClient.ReadPump() - return tmq, nil + return consumer, nil } func configMapToConfig(m *tmq.ConfigMap) (*config, error) { @@ -153,6 +179,10 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { if err != nil { return nil, err } + enableCompression, err := m.Get("ws.message.enableCompression", false) + if err != nil { + return nil, err + } config := newConfig(url.(string), chanLen.(uint)) err = config.setMessageTimeout(messageTimeout.(time.Duration)) if err != nil { @@ -198,6 +228,10 @@ func configMapToConfig(m *tmq.ConfigMap) (*config, error) { if err != nil { return nil, err } + err = config.setEnableCompression(enableCompression) + if err != nil { + return nil, err + } return config, nil } @@ -284,8 +318,7 @@ func (c *Consumer) findOutChanByID(index uint64) *list.Element { const ( TMQSubscribe = "subscribe" TMQPoll = "poll" - TMQFetch = "fetch" - TMQFetchBlock = "fetch_block" + TMQFetchRaw = "fetch_raw" TMQFetchJsonMeta = "fetch_json_meta" TMQCommit = "commit" TMQUnsubscribe = "unsubscribe" @@ -294,7 +327,6 @@ const ( TMQCommitOffset = "commit_offset" TMQCommitted = "committed" TMQPosition = "position" - TMQListTopics = "list_topics" ) var ClosedErr = errors.New("connection closed") @@ -338,17 +370,16 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err } reqID := c.generateReqID() req := &SubscribeReq{ - ReqID: reqID, - User: c.user, - Password: c.password, - GroupID: c.groupID, - ClientID: c.clientID, - OffsetRest: c.offsetRest, - Topics: topics, - AutoCommit: c.autoCommit, - AutoCommitIntervalMS: c.autoCommitIntervalMS, - SnapshotEnable: c.snapshotEnable, - WithTableName: c.withTableName, + ReqID: reqID, + User: c.user, + Password: c.password, + GroupID: c.groupID, + ClientID: c.clientID, + OffsetRest: c.offsetRest, + Topics: topics, + AutoCommit: "false", + SnapshotEnable: c.snapshotEnable, + WithTableName: c.withTableName, } args, err := client.JsonI.Marshal(req) if err != nil { @@ -384,7 +415,17 @@ func (c *Consumer) SubscribeTopics(topics []string, rebalanceCb RebalanceCb) err // Poll messages func (c *Consumer) Poll(timeoutMs int) tmq.Event { if c.err != nil { - panic(c.err) + return tmq.NewTMQErrorWithErr(c.err) + } + if c.autoCommit { + if c.nextAutoCommitTime.IsZero() { + c.nextAutoCommitTime = time.Now().Add(c.autoCommitInterval) + } else { + if time.Now().After(c.nextAutoCommitTime) { + c.doCommit() + c.nextAutoCommitTime = time.Now().Add(c.autoCommitInterval) + } + } } reqID := c.generateReqID() req := &PollReq{ @@ -417,7 +458,6 @@ func (c *Consumer) Poll(timeoutMs int) tmq.Event { if resp.Code != 0 { panic(taosErrors.NewError(resp.Code, resp.Message)) } - c.latestMessageID = resp.MessageID if resp.HaveMessage { switch resp.MessageType { case common.TMQ_RES_DATA: @@ -527,94 +567,67 @@ func (c *Consumer) fetchJsonMeta(messageID uint64) (*tmq.Meta, error) { } func (c *Consumer) fetch(messageID uint64) ([]*tmq.Data, error) { - var tmqData []*tmq.Data - for { - reqID := c.generateReqID() - req := &FetchReq{ - ReqID: reqID, - MessageID: messageID, - } - args, err := client.JsonI.Marshal(req) - if err != nil { - return nil, err - } - action := &client.WSAction{ - Action: TMQFetch, - Args: args, - } - envelope := c.client.GetEnvelope() - err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) - if err != nil { - c.client.PutEnvelope(envelope) - return nil, err - } - respBytes, err := c.sendText(reqID, envelope) - if err != nil { - return nil, err - } - var resp FetchResp - err = client.JsonI.Unmarshal(respBytes, &resp) - if err != nil { - return nil, err - } - if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } - if resp.Completed { - break - } - // fetch block - { - req := &FetchBlockReq{ - ReqID: reqID, - MessageID: messageID, - } - args, err := client.JsonI.Marshal(req) - if err != nil { - return nil, err - } - action := &client.WSAction{ - Action: TMQFetchBlock, - Args: args, - } - envelope := c.client.GetEnvelope() - err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) - if err != nil { - c.client.PutEnvelope(envelope) - return nil, err - } - respBytes, err := c.sendText(reqID, envelope) - if err != nil { - return nil, err - } - block := respBytes[24:] - p := unsafe.Pointer(&block[0]) - data := parser.ReadBlock(p, resp.Rows, resp.FieldsTypes, resp.Precision) - tmqData = append(tmqData, &tmq.Data{ - TableName: resp.TableName, - Data: data, - }) + reqID := c.generateReqID() + req := &TMQFetchRawMetaReq{ + ReqID: reqID, + MessageID: messageID, + } + args, err := client.JsonI.Marshal(req) + if err != nil { + return nil, err + } + action := &client.WSAction{ + Action: TMQFetchRaw, + Args: args, + } + envelope := c.client.GetEnvelope() + err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) + if err != nil { + c.client.PutEnvelope(envelope) + return nil, err + } + respBytes, err := c.sendText(reqID, envelope) + if err != nil { + return nil, err + } + blockInfo, err := c.dataParser.Parse(unsafe.Pointer(&respBytes[38])) + if err != nil { + return nil, err + } + tmqData := make([]*tmq.Data, len(blockInfo)) + for i := 0; i < len(blockInfo); i++ { + tmqData[i] = &tmq.Data{ + TableName: blockInfo[i].TableName, + Data: parser.ReadBlockSimple(blockInfo[i].RawBlock, blockInfo[i].Precision), } } return tmqData, nil } func (c *Consumer) Commit() ([]tmq.TopicPartition, error) { - return c.doCommit(c.latestMessageID) + err := c.doCommit() + if err != nil { + return nil, err + } + partitions, err := c.Assignment() + if err != nil { + return nil, err + } + return c.Committed(partitions, 0) } -func (c *Consumer) doCommit(messageID uint64) ([]tmq.TopicPartition, error) { +func (c *Consumer) doCommit() error { if c.err != nil { - return nil, c.err + return c.err } reqID := c.generateReqID() req := &CommitReq{ ReqID: reqID, - MessageID: messageID, + MessageID: 0, } args, err := client.JsonI.Marshal(req) if err != nil { - return nil, err + return err } action := &client.WSAction{ Action: TMQCommit, @@ -624,25 +637,21 @@ func (c *Consumer) doCommit(messageID uint64) ([]tmq.TopicPartition, error) { err = client.JsonI.NewEncoder(envelope.Msg).Encode(action) if err != nil { c.client.PutEnvelope(envelope) - return nil, err + return err } respBytes, err := c.sendText(reqID, envelope) if err != nil { - return nil, err + return err } var resp CommitResp err = client.JsonI.Unmarshal(respBytes, &resp) if err != nil { - return nil, err + return err } if resp.Code != 0 { - return nil, taosErrors.NewError(resp.Code, resp.Message) - } - partitions, err := c.Assignment() - if err != nil { - return nil, err + return taosErrors.NewError(resp.Code, resp.Message) } - return c.Committed(partitions, 0) + return nil } func (c *Consumer) Unsubscribe() error { diff --git a/ws/tmq/consumer_test.go b/ws/tmq/consumer_test.go index a7bf95e..dc394e9 100644 --- a/ws/tmq/consumer_test.go +++ b/ws/tmq/consumer_test.go @@ -124,7 +124,7 @@ func TestConsumer(t *testing.T) { } }() consumer, err := NewConsumer(&tmq.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.url": "ws://127.0.0.1:6041", "ws.message.channelLen": uint(0), "ws.message.timeout": common.DefaultMessageTimeout, "ws.message.writeWait": common.DefaultWriteWait, @@ -266,7 +266,7 @@ func TestSeek(t *testing.T) { } defer cleanSeekEnv() consumer, err := NewConsumer(&tmq.ConfigMap{ - "ws.url": "ws://127.0.0.1:6041/rest/tmq", + "ws.url": "ws://127.0.0.1:6041", "ws.message.channelLen": uint(0), "ws.message.timeout": common.DefaultMessageTimeout, "ws.message.writeWait": common.DefaultWriteWait, @@ -276,8 +276,8 @@ func TestSeek(t *testing.T) { "client.id": "test_consumer", "auto.offset.reset": "earliest", "enable.auto.commit": "false", - "experimental.snapshot.enable": "false", "msg.with.table.name": "true", + "ws.message.enableCompression": true, }) if err != nil { t.Error(err) @@ -350,3 +350,199 @@ func TestSeek(t *testing.T) { assert.Equal(t, "test_ws_tmq_seek_topic", *partitions[0].Topic) assert.GreaterOrEqual(t, partitions[0].Offset, messageOffset) } + +func prepareAutocommitEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_autocommit_topic", + "drop database if exists test_ws_tmq_autocommit", + "create database test_ws_tmq_autocommit vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_autocommit_topic as database test_ws_tmq_autocommit", + "create table test_ws_tmq_autocommit.t1(ts timestamp,v int)", + "insert into test_ws_tmq_autocommit.t1 values (now,1)", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanAutocommitEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_autocommit_topic", + "drop database if exists test_ws_tmq_autocommit", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func TestAutoCommit(t *testing.T) { + err := prepareAutocommitEnv() + if err != nil { + t.Error(err) + return + } + defer cleanAutocommitEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "msg.with.table.name": "true", + }) + assert.NoError(t, err) + if err != nil { + t.Error(err) + return + } + defer func() { + consumer.Unsubscribe() + consumer.Close() + }() + topic := []string{"test_ws_tmq_autocommit_topic"} + err = consumer.SubscribeTopics(topic, nil) + if err != nil { + t.Error(err) + return + } + partitions, err := consumer.Assignment() + assert.NoError(t, err) + assert.Equal(t, 1, len(partitions)) + assert.Equal(t, "test_ws_tmq_autocommit_topic", *partitions[0].Topic) + assert.Equal(t, tmq.Offset(0), partitions[0].Offset) + + offset, err := consumer.Committed(partitions, 0) + assert.NoError(t, err) + assert.Equal(t, 1, len(offset)) + assert.Equal(t, tmq.OffsetInvalid, offset[0].Offset) + + //poll + messageOffset := tmq.Offset(0) + haveMessage := false + for i := 0; i < 5; i++ { + event := consumer.Poll(500) + if event != nil { + haveMessage = true + messageOffset = event.(*tmq.DataMessage).Offset() + } + } + assert.True(t, haveMessage) + + offset, err = consumer.Committed(partitions, 0) + assert.NoError(t, err) + assert.Equal(t, 1, len(offset)) + assert.GreaterOrEqual(t, offset[0].Offset, messageOffset) +} + +func prepareMultiBlockEnv() error { + var err error + steps := []string{ + "drop topic if exists test_ws_tmq_multi_block_topic", + "drop database if exists test_ws_tmq_multi_block", + "create database test_ws_tmq_multi_block vgroups 1 WAL_RETENTION_PERIOD 86400", + "create topic test_ws_tmq_multi_block_topic as database test_ws_tmq_multi_block", + "create table test_ws_tmq_multi_block.t1(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t2(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t3(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t4(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t5(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t6(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t7(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t8(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t9(ts timestamp,v int)", + "create table test_ws_tmq_multi_block.t10(ts timestamp,v int)", + "insert into test_ws_tmq_multi_block.t1 values (now,1) test_ws_tmq_multi_block.t2 values (now,2) " + + "test_ws_tmq_multi_block.t3 values (now,3) test_ws_tmq_multi_block.t4 values (now,4)" + + "test_ws_tmq_multi_block.t5 values (now,5) test_ws_tmq_multi_block.t6 values (now,6)" + + "test_ws_tmq_multi_block.t7 values (now,7) test_ws_tmq_multi_block.t8 values (now,8)" + + "test_ws_tmq_multi_block.t9 values (now,9) test_ws_tmq_multi_block.t10 values (now,10)", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func cleanMultiBlockEnv() error { + var err error + time.Sleep(2 * time.Second) + steps := []string{ + "drop topic if exists test_ws_tmq_multi_block_topic", + "drop database if exists test_ws_tmq_multi_block", + } + for _, step := range steps { + err = doRequest(step) + if err != nil { + return err + } + } + return nil +} + +func TestMultiBlock(t *testing.T) { + err := prepareMultiBlockEnv() + assert.NoError(t, err) + defer cleanMultiBlockEnv() + consumer, err := NewConsumer(&tmq.ConfigMap{ + "ws.url": "ws://127.0.0.1:6041", + "ws.message.channelLen": uint(0), + "ws.message.timeout": common.DefaultMessageTimeout, + "ws.message.writeWait": common.DefaultWriteWait, + "td.connect.user": "root", + "td.connect.pass": "taosdata", + "group.id": "test", + "client.id": "test_consumer", + "auto.offset.reset": "earliest", + "enable.auto.commit": "true", + "auto.commit.interval.ms": "1000", + "msg.with.table.name": "true", + }) + assert.NoError(t, err) + if err != nil { + t.Error(err) + return + } + defer func() { + consumer.Unsubscribe() + consumer.Close() + }() + topic := []string{"test_ws_tmq_multi_block_topic"} + err = consumer.SubscribeTopics(topic, nil) + if err != nil { + t.Error(err) + return + } + for i := 0; i < 10; i++ { + event := consumer.Poll(500) + if event == nil { + continue + } + switch e := event.(type) { + case *tmq.DataMessage: + data := e.Value().([]*tmq.Data) + assert.Equal(t, "test_ws_tmq_multi_block", e.DBName()) + assert.Equal(t, 10, len(data)) + return + } + } +} diff --git a/ws/tmq/proto.go b/ws/tmq/proto.go index d9b8c1d..3a17c8b 100644 --- a/ws/tmq/proto.go +++ b/ws/tmq/proto.go @@ -196,3 +196,8 @@ type PositionResp struct { Timing int64 `json:"timing"` Position []int64 `json:"position"` } + +type TMQFetchRawMetaReq struct { + ReqID uint64 `json:"req_id"` + MessageID uint64 `json:"message_id"` +}