Skip to content

Commit

Permalink
Merge pull request #270 from taosdata/fix/td-32478
Browse files Browse the repository at this point in the history
fix: migrate the get-meta-api-call  to the prepare method
  • Loading branch information
YamingPei authored Oct 12, 2024
2 parents 925a549 + 2145c59 commit 9a5d47d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 23 deletions.
7 changes: 4 additions & 3 deletions taos/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,12 @@ def statement2(self, sql=None, option=None):
return None
if option is not None:
option = option.get_impl()
stmt = taos_stmt2_init(self._conn, option)
_stmt2 = taos_stmt2_init(self._conn, option)
stmt2 = TaosStmt2(_stmt2, decode_binary=self.decode_binary)
if sql is not None:
taos_stmt2_prepare(stmt, sql)
stmt2.prepare(sql)

return TaosStmt2(stmt, decode_binary=self.decode_binary)
return stmt2

def load_table_info(self, tables):
# type: (str) -> None
Expand Down
45 changes: 25 additions & 20 deletions taos/statement2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,14 @@ def obtainSchema(statement2):
#statement2.fields = [FieldType.C_TIMESTAMP, FieldType.C_BINARY, FieldType.C_BOOL, FieldType.C_INT]
#return len(statement2.fields) > 0

count, statement2.fields = statement2.get_fields(TAOS_FIELD_COL)
count, statement2.tag_fields = statement2.get_fields(TAOS_FIELD_TAG)
log.debug(f"obtain schema tag fields = {statement2.tag_fields}")
log.debug(f"obtain schema fields = {statement2.fields}")
try:
count, statement2.fields = statement2.get_fields(TAOS_FIELD_COL)
count, statement2.tag_fields = statement2.get_fields(TAOS_FIELD_TAG)
log.debug(f"obtain schema tag fields = {statement2.tag_fields}")
log.debug(f"obtain schema fields = {statement2.fields}")
except Exception as err:
log.debug(f"obtain schema tag/col fields failed, reason: {repr(err)}")
return False

return len(statement2.fields) > 0

Expand Down Expand Up @@ -229,6 +233,7 @@ def __init__(self, stmt2, decode_binary=True):
self.fields = None
self.tag_fields = None
self.types = None
self._is_insert = None
self.valid_field_types = [TAOS_FIELD_COL, TAOS_FIELD_TAG, TAOS_FIELD_QUERY, TAOS_FIELD_TBNAME]

def prepare(self, sql):
Expand All @@ -244,8 +249,16 @@ def prepare(self, sql):
if len(sql) == 0:
raise StatementError("sql is empty.")

self.fields = None
self.tag_fields = None
self._is_insert = None

taos_stmt2_prepare(self._stmt2, sql)

# obtain schema if insert
if self.is_insert():
if obtainSchema(self) is False:
raise StatementError(f"obtain schema failed. sql={sql}")

def bind_param(self, tbnames, tags, datas):
if self._stmt2 is None:
Expand All @@ -254,22 +267,11 @@ def bind_param(self, tbnames, tags, datas):
log.debug(f"bind_param tbnames = {tbnames} \n")
log.debug(f"bind_param tags = {tags} \n")
log.debug(f"bind_param datasTbs = {datas} \n")

# obtain schema if insert
if self.is_insert():
if tbnames is not None:
bindv = createBindV(self, tbnames, None, None)
if bindv == None:
raise StatementError("create stmt2 bindV failed.")
taos_stmt2_bind_param(self._stmt2, bindv.get_address(), -1)

if obtainSchema(self) is False:
raise StatementError(f"obtain schema failed. tbnames={tbnames}")

# check consistent
if checkConsistent(tbnames, tags, datas) == False:
raise StatementError("check consistent failed.")

# check consistent
if checkConsistent(tbnames, tags, datas) == False:
raise StatementError("check consistent failed.")

# bindV
bindv = createBindV(self, tbnames, tags, datas)
if bindv == None:
Expand Down Expand Up @@ -319,7 +321,10 @@ def is_insert(self):
if self._stmt2 is None:
raise StatementError("stmt2 object is null.")

return taos_stmt2_is_insert(self._stmt2)
if self._is_insert is None:
self._is_insert = taos_stmt2_is_insert(self._stmt2)

return self._is_insert

def get_fields(self, field_type):
if self._stmt2 is None:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_stmt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,31 @@ def insert_bind_param_normal_tables(conn, stmt2, dbname):
# check correct
checkResultCorrects(conn, dbname, None, ["ntb2"], [None], datas)

def insert_bind_param_with_table(conn, stmt2, dbname, stbname, ctb):

tbnames = None
tags = [
["grade2", 1]
]

# prepare data
datas = [
# table 1
[
# student
[1601481600000,1601481600004,"2024-09-19 10:00:00", "2024-09-19 10:00:01.123", datetime(2024,9,20,10,11,12,456)],
["Mary", "Tom", "Jack", "Jane", "alex" ],
[0, 1, 1, 0, 1 ],
[98, 80, 60, 100, 99 ]
]
]

stmt2.bind_param(tbnames, tags, datas)
stmt2.execute()

# check correct
checkResultCorrects(conn, dbname, stbname, [ctb], tags, datas)


# insert with single table (performance is lower)
def insert_bind_param_with_tables(conn, stmt2, dbname, stbname):
Expand Down Expand Up @@ -377,6 +402,13 @@ def test_stmt2_insert(conn):

try:
prepare(conn, dbname, stbname)

ctb = 'ctb'
stmt2 = conn.statement2(f"insert into {dbname}.{ctb} using {dbname}.{stbname} tags (?,?) values(?,?,?,?)")
insert_bind_param_with_table(conn, stmt2, dbname, stbname, ctb)
print("insert normal table ........................... ok\n")
stmt2.close()

# prepare
stmt2 = conn.statement2(f"insert into ? using {dbname}.{stbname} tags(?,?) values(?,?,?,?)")
print("insert prepare sql ............................ ok\n")
Expand Down

0 comments on commit 9a5d47d

Please sign in to comment.