diff --git a/apsw/__init__.pyi b/apsw/__init__.pyi index 508844d3..4e733b0a 100644 --- a/apsw/__init__.pyi +++ b/apsw/__init__.pyi @@ -1777,9 +1777,14 @@ class Connection: totalchanges = total_changes ## OLD-NAME - def trace_v2(self, mask: int, callback: Optional[Callable[[dict], None]] = None) -> None: - """Registers a trace callback. The callback is called with a dict of relevant values based - on the code. + def trace_v2(self, mask: int, callback: Optional[Callable[[dict], None]] = None, *, id: Optional[Any] = None) -> None: + """Registers a trace callback. Multiple traces can be active at once + (implemented by APSW). A callback of :class:`None` unregisters a + trace. Registered callbacks are distinguished by their ``id`` - an + equality test is done to match ids. + + The callback is called with a dict of relevant values based on the + code. .. list-table:: :header-rows: 1 @@ -1808,6 +1813,10 @@ class Connection: The counters are reset each time a statement starts execution. + Note that SQLite ignores any errors from the trace callbacks, so + whatever was being traced will still proceed. Exceptions will be + delivered when your Python code resumes. + .. seealso:: * :ref:`Example ` diff --git a/apsw/tests.py b/apsw/tests.py index bfec8e9f..606cdbb0 100644 --- a/apsw/tests.py +++ b/apsw/tests.py @@ -3255,6 +3255,8 @@ def profile(*args): wasrun = [False] def profile(*args): + # should still be run despite there being a pending exception + # from the update hook wasrun[0] = True def uh(*args): @@ -3263,7 +3265,7 @@ def uh(*args): self.db.set_profile(profile) self.db.set_update_hook(uh) self.assertRaises(ZeroDivisionError, c.execute, "insert into foo values(3)") - self.assertEqual(wasrun[0], False) + self.assertEqual(wasrun[0], True) self.db.set_profile(None) self.db.set_update_hook(None) @@ -6285,9 +6287,59 @@ def tracehook(x): 1 / 0 self.db.trace_v2(apsw.SQLITE_TRACE_STMT, tracehook) - self.assertRaisesUnraisable(ZeroDivisionError, self.db.execute, query) + self.assertRaises(ZeroDivisionError, self.db.execute, query) self.assertEqual(0, len(results)) + # added with id parameter + counter = [0] + + def meth(): + while True: + self.db.trace_v2(apsw.SQLITE_TRACE_STMT, meth, id=f"hello{counter[0]}") + counter[0] += 1 + + with contextlib.suppress(MemoryError): + meth() + + self.assertGreater(counter[0], 1000) + + # ensure all unregistered + for i in range(0, counter[0]+1): + self.db.trace_v2(0, None, id = f"hello{i}") + self.db.trace_v2(0, None) # tracehook zero div above + + # should be fine + self.db.execute("select 3").get + + # have exceptions - ensure all called + counter = [0] + for i in range(10): + def meth(*args): + counter[0] += 1 + 1/0 + self.db.trace_v2(apsw.SQLITE_TRACE_ROW, meth, id = meth) + + with contextlib.suppress(ZeroDivisionError): + self.db.execute("select 4").get + + self.assertEqual(counter[0], 10) + + # bad equals for id checking + class bad_equals: + def __eq__(self, other): + 1 / 0 + + self.assertRaises(ZeroDivisionError, self.db.trace_v2, apsw.SQLITE_TRACE_ROW, bad_equals, id=bad_equals()) + + def harmless(*args): + counter[0] = 99 + 1/0 + + self.db.trace_v2(apsw.SQLITE_TRACE_CLOSE, harmless, id="jkhkjh") + + with contextlib.suppress(ZeroDivisionError): + self.db.close() + def testURIFilenames(self): assertRaises = self.assertRaises assertEqual = self.assertEqual diff --git a/doc/changes.rst b/doc/changes.rst index c2e16352..096080ed 100644 --- a/doc/changes.rst +++ b/doc/changes.rst @@ -17,6 +17,9 @@ Added :func:`recursive triggers ` and :func:`optimize ` to :mod:`apsw.bestpractice`. +Multiple callbacks can be present for :meth:`Connection.trace_v2` +(:issue:`502`) + 3.46.1.0 ======== diff --git a/src/apsw.docstrings b/src/apsw.docstrings index 23e9593b..5213d893 100644 --- a/src/apsw.docstrings +++ b/src/apsw.docstrings @@ -2193,9 +2193,14 @@ #define Connection_total_changes_USAGE "Connection.total_changes() -> int" #define Connection_total_changes_OLDDOC Connection_total_changes_USAGE "\n(Old less clear name totalchanges)" -#define Connection_trace_v2_DOC "trace_v2($self,mask,callback=None)\n--\n\nConnection.trace_v2(mask: int, callback: Optional[Callable[[dict], None]] = None) -> None\n\n" \ -"Registers a trace callback. The callback is called with a dict of relevant values based\n" \ -"on the code.\n" \ +#define Connection_trace_v2_DOC "trace_v2($self,mask,callback=None,*,id=None)\n--\n\nConnection.trace_v2(mask: int, callback: Optional[Callable[[dict], None]] = None, *, id: Optional[Any] = None) -> None\n\n" \ +"Registers a trace callback. Multiple traces can be active at once\n" \ +"(implemented by APSW). A callback of :class:`None` unregisters a\n" \ +"trace. Registered callbacks are distinguished by their ``id`` - an\n" \ +"equality test is done to match ids.\n" \ +"\n" \ +"The callback is called with a dict of relevant values based on the\n" \ +"code.\n" \ "\n" \ ".. list-table::\n" \ " :header-rows: 1\n" \ @@ -2224,6 +2229,10 @@ " The counters are reset each time a statement\n" \ " starts execution.\n" \ "\n" \ +"Note that SQLite ignores any errors from the trace callbacks, so\n" \ +"whatever was being traced will still proceed. Exceptions will be\n" \ +"delivered when your Python code resumes.\n" \ +"\n" \ ".. seealso::\n" \ "\n" \ " * :ref:`Example `\n" \ @@ -2232,13 +2241,15 @@ " * `sqlite3_trace_v2 `__\n" \ " * `sqlite3_stmt_status `__\n" -#define Connection_trace_v2_KWNAMES "mask", "callback" -#define Connection_trace_v2_USAGE "Connection.trace_v2(mask: int, callback: Optional[Callable[[dict], None]] = None) -> None" +#define Connection_trace_v2_KWNAMES "mask", "callback", "id" +#define Connection_trace_v2_USAGE "Connection.trace_v2(mask: int, callback: Optional[Callable[[dict], None]] = None, *, id: Optional[Any] = None) -> None" #define Connection_trace_v2_CHECK do { \ assert(__builtin_types_compatible_p(typeof(mask), int)); \ assert(__builtin_types_compatible_p(typeof(callback), PyObject *)); \ assert(callback == NULL); \ + assert(__builtin_types_compatible_p(typeof(id), PyObject *)); \ + assert(id == NULL); \ } while(0) diff --git a/src/connection.c b/src/connection.c index 4288008e..2de32729 100644 --- a/src/connection.c +++ b/src/connection.c @@ -59,6 +59,13 @@ typedef struct PyObject *inversefunc; /* inverse function */ } windowfunctioncontext; +struct tracehook +{ + unsigned mask; + PyObject *callback; + PyObject *id; +}; + /* CONNECTION TYPE */ struct Connection @@ -76,7 +83,6 @@ struct Connection /* registered hooks/handlers (NULL or callable) */ PyObject *busyhandler; PyObject *rollbackhook; - PyObject *profile; PyObject *updatehook; PyObject *commithook; PyObject *walhook; @@ -85,8 +91,10 @@ struct Connection PyObject *collationneeded; PyObject *exectrace; PyObject *rowtrace; - PyObject *tracehook; - int tracemask; + /* Array of tracehook. Entry 0 is reserved for the set_profile + callback. */ + struct tracehook *tracehooks; + unsigned tracehooks_count; /* if we are using one of our VFS since sqlite doesn't reference count them */ PyObject *vfs; @@ -167,7 +175,6 @@ Connection_internal_cleanup(Connection *self) Py_CLEAR(self->cursor_factory); Py_CLEAR(self->busyhandler); Py_CLEAR(self->rollbackhook); - Py_CLEAR(self->profile); Py_CLEAR(self->updatehook); Py_CLEAR(self->commithook); Py_CLEAR(self->walhook); @@ -176,10 +183,17 @@ Connection_internal_cleanup(Connection *self) Py_CLEAR(self->collationneeded); Py_CLEAR(self->exectrace); Py_CLEAR(self->rowtrace); - Py_CLEAR(self->tracehook); Py_CLEAR(self->vfs); Py_CLEAR(self->open_flags); Py_CLEAR(self->open_vfs); + for (unsigned i = 0; i < self->tracehooks_count; i++) + { + Py_CLEAR(self->tracehooks[i].callback); + Py_CLEAR(self->tracehooks[i].id); + } + PyMem_Del(self->tracehooks); + self->tracehooks = 0; + self->tracehooks_count = 0; } static void @@ -255,9 +269,11 @@ Connection_close_internal(Connection *self, int force) apsw_connection_remove(self); - PYSQLITE_VOID_CALL(res = sqlite3_close(self->db)); - + /* This ensures any SQLITE_TRACE_CLOSE callbacks see a closed + database */ + sqlite3 *tmp = self->db; self->db = 0; + PYSQLITE_VOID_CALL(res = sqlite3_close(tmp)); if (res != SQLITE_OK) { @@ -366,7 +382,6 @@ Connection_new(PyTypeObject *type, PyObject *Py_UNUSED(args), PyObject *Py_UNUSE self->stmtcache = 0; self->busyhandler = 0; self->rollbackhook = 0; - self->profile = 0; self->updatehook = 0; self->commithook = 0; self->walhook = 0; @@ -375,15 +390,22 @@ Connection_new(PyTypeObject *type, PyObject *Py_UNUSED(args), PyObject *Py_UNUSE self->collationneeded = 0; self->exectrace = 0; self->rowtrace = 0; - self->tracehook = 0; - self->tracemask = 0; self->vfs = 0; self->savepointlevel = 0; self->open_flags = 0; self->open_vfs = 0; self->weakreflist = 0; + self->tracehooks = PyMem_Malloc(sizeof(struct tracehook) * 1); + self->tracehooks_count = 0; + if (self->tracehooks) + { + self->tracehooks[0].callback = 0; + self->tracehooks[0].id = 0; + self->tracehooks[0].mask = 0; + self->tracehooks_count = 1; + } CALL_TRACK_INIT(xConnect); - if (self->dependents) + if (self->dependents && self->tracehooks) return (PyObject *)self; } @@ -1181,76 +1203,6 @@ Connection_set_rollback_hook(Connection *self, PyObject *const *fast_args, Py_ss Py_RETURN_NONE; } -static int -profilecb(unsigned event, void *context, void *stmt, void *elapsed) -{ - assert(event == SQLITE_TRACE_PROFILE); - PyGILState_STATE gilstate = PyGILState_Ensure(); - - PyObject *retval = NULL; - Connection *self = (Connection *)context; - const char *statement = sqlite3_sql((sqlite3_stmt *)stmt); - sqlite3_uint64 runtime = *(sqlite3_uint64 *)elapsed; - - assert(self); - assert(self->profile); - assert(!Py_IsNone(self->profile)); - - MakeExistingException(); - - if (PyErr_Occurred()) - goto finally; /* abort hook due to outstanding exception */ - PyObject *vargs[] = {NULL, PyUnicode_FromString(statement), PyLong_FromLongLong(runtime)}; - if (vargs[1] && vargs[2]) - retval = PyObject_Vectorcall(self->profile, vargs + 1, 2 | PY_VECTORCALL_ARGUMENTS_OFFSET, NULL); - Py_XDECREF(vargs[1]); - Py_XDECREF(vargs[2]); - -finally: - Py_XDECREF(retval); - PyGILState_Release(gilstate); - - return 0; -} - -/** .. method:: set_profile(callable: Optional[Callable[[str, int], None]]) -> None - - Sets a callable which is invoked at the end of execution of each - statement and passed the statement string and how long it took to - execute. (The execution time is in nanoseconds.) Note that it is - called only on completion. If for example you do a ``SELECT`` and only - read the first result, then you won't reach the end of the statement. - - -* sqlite3_trace_v2 -*/ - -static PyObject * -Connection_set_profile(Connection *self, PyObject *const *fast_args, Py_ssize_t fast_nargs, PyObject *fast_kwnames) -{ - int res; - PyObject *callable; - CHECK_USE(NULL); - CHECK_CLOSED(self, NULL); - - { - Connection_set_profile_CHECK; - ARG_PROLOG(1, Connection_set_profile_KWNAMES); - ARG_MANDATORY ARG_optional_Callable(callable); - ARG_EPILOG(NULL, Connection_set_profile_USAGE, ); - } - - PYSQLITE_CON_CALL(res = sqlite3_trace_v2(self->db, SQLITE_TRACE_PROFILE, callable ? profilecb : NULL, callable ? self : NULL)); - if (res == SQLITE_OK) - { - Py_XDECREF(self->profile); - self->profile = callable ? Py_NewRef(callable) : NULL; - Py_RETURN_NONE; - } - - SET_EXC(res, self->db); - return NULL; -} - static int tracehook_cb(unsigned code, void *vconnection, void *one, void *two) { @@ -1264,8 +1216,7 @@ tracehook_cb(unsigned code, void *vconnection, void *one, void *two) MakeExistingException(); - if (PyErr_Occurred()) - goto finally; + CHAIN_EXC_BEGIN switch (code) { @@ -1274,6 +1225,7 @@ tracehook_cb(unsigned code, void *vconnection, void *one, void *two) #define V(x) sqlite3_stmt_status(stmt, x, 1) stmt = (sqlite3_stmt *)one; + /* reset all the counters */ V(SQLITE_STMTSTATUS_FULLSCAN_STEP); V(SQLITE_STMTSTATUS_SORT); V(SQLITE_STMTSTATUS_AUTOINDEX); @@ -1282,23 +1234,35 @@ tracehook_cb(unsigned code, void *vconnection, void *one, void *two) V(SQLITE_STMTSTATUS_RUN); V(SQLITE_STMTSTATUS_FILTER_MISS); V(SQLITE_STMTSTATUS_FILTER_HIT); - if (connection->tracemask & SQLITE_TRACE_STMT) - param = Py_BuildValue("{s: i, s: s, s: O}", - "code", code, "sql", sqlite3_sql(stmt), "connection", connection); - break; #undef V + for (unsigned i = 1; i < connection->tracehooks_count; i++) + { + /* only calculate this if needed */ + if (connection->tracehooks[i].mask & SQLITE_TRACE_STMT) + { + + param = Py_BuildValue("{s: i, s: s, s: O}", + "code", code, "sql", sqlite3_sql(stmt), "connection", connection); + break; + } + } + break; + case SQLITE_TRACE_ROW: stmt = (sqlite3_stmt *)one; - if (connection->tracemask & SQLITE_TRACE_ROW) - param = Py_BuildValue("{s: i, s: s, s: O}", - "code", code, "sql", sqlite3_sql(stmt), "connection", connection); + param = Py_BuildValue("{s: i, s: s, s: O}", + "code", code, "sql", sqlite3_sql(stmt), "connection", connection); break; case SQLITE_TRACE_CLOSE: - if (connection->tracemask & SQLITE_TRACE_CLOSE) - param = Py_BuildValue("{s: i, s: O}", - "code", code, "connection", connection); + /* Checking the refcount is subtle but important. If the + Connection is being closed because there are no more references to it + then the ref count is zero when the callback fires and adding a + reference ressurects a mostly destroyed object which then hits zero + again and gets destroyed a second time. Too difficult to handle. */ + param = Py_BuildValue("{s: i, s: O}", + "code", code, "connection", Py_REFCNT(connection) ? (PyObject *)connection : Py_None); break; case SQLITE_TRACE_PROFILE: @@ -1308,47 +1272,148 @@ tracehook_cb(unsigned code, void *vconnection, void *one, void *two) stmt = (sqlite3_stmt *)one; nanoseconds = (sqlite3_int64 *)two; - if (connection->tracemask & SQLITE_TRACE_PROFILE) + for (unsigned i = 1; i < connection->tracehooks_count; i++) { - /* only SQLITE_STMTSTATUS_MEMUSED actually needs mutex */ - sqlite3_mutex_enter(sqlite3_db_mutex(connection->db)); - param = Py_BuildValue("{s: i, s: O, s: s, s: L, s: {" K K K K K K K K "s: i}}", - "code", code, "connection", connection, "sql", sqlite3_sql(stmt), - "nanoseconds", *nanoseconds, "stmt_status", - V(SQLITE_STMTSTATUS_FULLSCAN_STEP), - V(SQLITE_STMTSTATUS_SORT), - V(SQLITE_STMTSTATUS_AUTOINDEX), - V(SQLITE_STMTSTATUS_VM_STEP), - V(SQLITE_STMTSTATUS_REPREPARE), - V(SQLITE_STMTSTATUS_RUN), - V(SQLITE_STMTSTATUS_FILTER_MISS), - V(SQLITE_STMTSTATUS_FILTER_HIT), - V(SQLITE_STMTSTATUS_MEMUSED)); - sqlite3_mutex_leave(sqlite3_db_mutex(connection->db)); + /* only calculate this if needed */ + if (connection->tracehooks[i].mask & SQLITE_TRACE_PROFILE) + { + /* only SQLITE_STMTSTATUS_MEMUSED actually needs mutex */ + sqlite3_mutex_enter(sqlite3_db_mutex(connection->db)); + param = Py_BuildValue("{s: i, s: O, s: s, s: L, s: {" K K K K K K K K "s: i}}", + "code", code, "connection", connection, "sql", sqlite3_sql(stmt), + "nanoseconds", *nanoseconds, "stmt_status", + V(SQLITE_STMTSTATUS_FULLSCAN_STEP), + V(SQLITE_STMTSTATUS_SORT), + V(SQLITE_STMTSTATUS_AUTOINDEX), + V(SQLITE_STMTSTATUS_VM_STEP), + V(SQLITE_STMTSTATUS_REPREPARE), + V(SQLITE_STMTSTATUS_RUN), + V(SQLITE_STMTSTATUS_FILTER_MISS), + V(SQLITE_STMTSTATUS_FILTER_HIT), + V(SQLITE_STMTSTATUS_MEMUSED)); + sqlite3_mutex_leave(sqlite3_db_mutex(connection->db)); + break; + } } break; #undef K #undef V } - if (param) + if (PyErr_Occurred()) + goto finally; + + /* handle sqlite3_profile compatibility */ + if (code == SQLITE_TRACE_PROFILE && connection->tracehooks[0].callback) + { + CHAIN_EXC_BEGIN + PyObject *vargs[] = {NULL, PyUnicode_FromString(sqlite3_sql(stmt)), PyLong_FromLongLong(*nanoseconds)}; + if (vargs[1] && vargs[2]) + res = PyObject_Vectorcall(connection->tracehooks[0].callback, vargs + 1, 2 | PY_VECTORCALL_ARGUMENTS_OFFSET, NULL); + Py_XDECREF(vargs[1]); + Py_XDECREF(vargs[2]); + Py_CLEAR(res); + CHAIN_EXC_END; + } + + if (!PyErr_Occurred()) { PyObject *vargs[] = {NULL, param}; - res = PyObject_Vectorcall(connection->tracehook, vargs + 1, 1 | PY_VECTORCALL_ARGUMENTS_OFFSET, NULL); - if (!res) - apsw_write_unraisable(NULL); + for (unsigned i = 1; i < connection->tracehooks_count; i++) + { + if (connection->tracehooks[i].mask & code) + { + CHAIN_EXC_BEGIN + res = PyObject_Vectorcall(connection->tracehooks[i].callback, vargs + 1, 1 | PY_VECTORCALL_ARGUMENTS_OFFSET, NULL); + Py_CLEAR(res); + CHAIN_EXC_END; + } + } } + finally: - Py_XDECREF(res); + Py_CLEAR(res); Py_XDECREF(param); + + CHAIN_EXC_END; + PyGILState_Release(gilstate); return 0; } -/** .. method:: trace_v2(mask: int, callback: Optional[Callable[[dict], None]] = None) -> None +/* does sqlite3_trace_v2 call based on current tracehooks, called + after each change */ +PyObject * +Connection_update_trace_v2(Connection *self) +{ + /* Our callers do CHECK_USE and CHECK_CLOSED */ + unsigned mask = 0; + for (unsigned i = 0; i < self->tracehooks_count; i++) + mask |= self->tracehooks[i].mask; + + /* this ensures counters are reset on a per statement basis */ + if (mask & SQLITE_TRACE_PROFILE) + mask |= SQLITE_TRACE_STMT; + + int res; + + PYSQLITE_CON_CALL(res = sqlite3_trace_v2(self->db, mask, mask ? tracehook_cb : NULL, self)); + + if (res != SQLITE_OK) + { + SET_EXC(res, self->db); + return NULL; + } + Py_RETURN_NONE; +} + +/** .. method:: set_profile(callable: Optional[Callable[[str, int], None]]) -> None + + Sets a callable which is invoked at the end of execution of each + statement and passed the statement string and how long it took to + execute. (The execution time is in nanoseconds.) Note that it is + called only on completion. If for example you do a ``SELECT`` and only + read the first result, then you won't reach the end of the statement. + + -* sqlite3_trace_v2 +*/ - Registers a trace callback. The callback is called with a dict of relevant values based - on the code. +static PyObject * +Connection_set_profile(Connection *self, PyObject *const *fast_args, Py_ssize_t fast_nargs, PyObject *fast_kwnames) +{ + PyObject *callable; + CHECK_USE(NULL); + CHECK_CLOSED(self, NULL); + + { + Connection_set_profile_CHECK; + ARG_PROLOG(1, Connection_set_profile_KWNAMES); + ARG_MANDATORY ARG_optional_Callable(callable); + ARG_EPILOG(NULL, Connection_set_profile_USAGE, ); + } + + Py_CLEAR(self->tracehooks[0].callback); + + if (callable) + { + self->tracehooks[0].mask = SQLITE_TRACE_PROFILE; + self->tracehooks[0].callback = Py_NewRef(callable); + } + else + self->tracehooks[0].mask = 0; + + return Connection_update_trace_v2(self); +} + +/** .. method:: trace_v2(mask: int, callback: Optional[Callable[[dict], None]] = None, *, id: Optional[Any] = None) -> None + + Registers a trace callback. Multiple traces can be active at once + (implemented by APSW). A callback of :class:`None` unregisters a + trace. Registered callbacks are distinguished by their ``id`` - an + equality test is done to match ids. + + The callback is called with a dict of relevant values based on the + code. .. list-table:: :header-rows: 1 @@ -1377,6 +1442,10 @@ tracehook_cb(unsigned code, void *vconnection, void *one, void *two) The counters are reset each time a statement starts execution. + Note that SQLite ignores any errors from the trace callbacks, so + whatever was being traced will still proceed. Exceptions will be + delivered when your Python code resumes. + .. seealso:: * :ref:`Example ` @@ -1386,8 +1455,9 @@ tracehook_cb(unsigned code, void *vconnection, void *one, void *two) static PyObject * Connection_trace_v2(Connection *self, PyObject *const *fast_args, Py_ssize_t fast_nargs, PyObject *fast_kwnames) { - int mask = 0, res = 0; + int mask = 0; PyObject *callback = NULL; + PyObject *id = NULL; CHECK_USE(NULL); CHECK_CLOSED(self, NULL); @@ -1397,6 +1467,7 @@ Connection_trace_v2(Connection *self, PyObject *const *fast_args, Py_ssize_t fas ARG_PROLOG(2, Connection_trace_v2_KWNAMES); ARG_MANDATORY ARG_int(mask); ARG_OPTIONAL ARG_optional_Callable(callback); + ARG_OPTIONAL ARG_pyobject(id); ARG_EPILOG(NULL, Connection_trace_v2_USAGE, ); } @@ -1405,29 +1476,68 @@ Connection_trace_v2(Connection *self, PyObject *const *fast_args, Py_ssize_t fas if (mask == 0 && callback) return PyErr_Format(PyExc_ValueError, "mask selects no events, but callback provided"); - /* SQLite doesn't */ + /* Known values only */ if (mask & ~(SQLITE_TRACE_STMT | SQLITE_TRACE_PROFILE | SQLITE_TRACE_ROW | SQLITE_TRACE_CLOSE)) return PyErr_Format(PyExc_ValueError, "mask includes unknown trace values"); - /* what was actually requested */ - self->tracemask = mask; - - /* if profiling, we always want statement start to reset counters */ - if (mask | SQLITE_TRACE_PROFILE) - mask |= SQLITE_TRACE_STMT; + /* always clear out any matching id */ + for (unsigned i = 1; i < self->tracehooks_count; i++) + { + if (self->tracehooks[i].callback) + { + int eq; + /* handle either side being NULL */ + if ((!id || !self->tracehooks[i].id) && id != self->tracehooks[i].id) + eq = 0; + else + eq = PyObject_RichCompareBool(id, self->tracehooks[i].id, Py_EQ); - Py_CLEAR(self->tracehook); - Py_XINCREF(callback); - self->tracehook = callback; + if (eq == -1) + return NULL; + if (eq) + { + Py_CLEAR(self->tracehooks[i].callback); + Py_CLEAR(self->tracehooks[i].id); + self->tracehooks[i].mask = 0; + } + } + } - PYSQLITE_CON_CALL(res = sqlite3_trace_v2(self->db, mask, tracehook_cb, self)); - if (res != SQLITE_OK) + if (callback) { - SET_EXC(res, self->db); - return NULL; + /* find an empty slot */ + int found = 0; + for (unsigned i = 1; i < self->tracehooks_count; i++) + { + if (self->tracehooks[i].callback == 0) + { + self->tracehooks[i].mask = mask; + self->tracehooks[i].id = id ? Py_NewRef(id) : NULL; + self->tracehooks[i].callback = Py_NewRef(callback); + found = 1; + break; + } + } + if (!found) + { + /* increase tracehooks size - we have an arbitrary limit which + makes it easier to test exhaustion */ + struct tracehook *new_tracehooks = (self->tracehooks_count < 1024) ? PyMem_Realloc(self->tracehooks, sizeof(struct tracehook) * (self->tracehooks_count + 1)) : NULL; + if (!new_tracehooks) + { + /* not bothering to call update_trace - worst case there will + be extra trace calls that are ignored. */ + return PyErr_NoMemory(); + } + self->tracehooks = new_tracehooks; + self->tracehooks[self->tracehooks_count].mask = mask; + self->tracehooks[self->tracehooks_count].id = id ? Py_NewRef(id) : NULL; + self->tracehooks[self->tracehooks_count].callback = Py_NewRef(callback); + self->tracehooks_count++; + } } - Py_RETURN_NONE; + return Connection_update_trace_v2(self); } static int @@ -5235,7 +5345,6 @@ Connection_tp_traverse(Connection *self, visitproc visit, void *arg) { Py_VISIT(self->busyhandler); Py_VISIT(self->rollbackhook); - Py_VISIT(self->profile); Py_VISIT(self->updatehook); Py_VISIT(self->commithook); Py_VISIT(self->walhook); @@ -5244,10 +5353,14 @@ Connection_tp_traverse(Connection *self, visitproc visit, void *arg) Py_VISIT(self->collationneeded); Py_VISIT(self->exectrace); Py_VISIT(self->rowtrace); - Py_VISIT(self->tracehook); Py_VISIT(self->vfs); Py_VISIT(self->dependents); Py_VISIT(self->cursor_factory); + for (unsigned i = 0; i < self->tracehooks_count; i++) + { + Py_VISIT(self->tracehooks[i].callback); + Py_VISIT(self->tracehooks[i].id); + } return 0; } diff --git a/tools/gendocstrings.py b/tools/gendocstrings.py index d0a06991..c1d0f1fe 100644 --- a/tools/gendocstrings.py +++ b/tools/gendocstrings.py @@ -554,7 +554,8 @@ def do_argparse(item): pass elif param["type"] in { "PyObject", "Any", "Optional[type[BaseException]]", "Optional[BaseException]", - "Optional[types.TracebackType]", "Optional[VTModule]", "Optional[SQLiteValue]" + "Optional[types.TracebackType]", "Optional[VTModule]", "Optional[SQLiteValue]", + "Optional[Any]" }: type = "PyObject *" kind = "pyobject"