Skip to content

Commit

Permalink
Propagate python errors through cython FFI handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Aug 25, 2023
1 parent 7879fb8 commit 3a4167f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
7 changes: 7 additions & 0 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ typedef void* TVMObjectHandle;
*/
TVM_DLL void TVMAPISetLastError(const char* msg);

/*!
* \brief Used for implementing C API function.
* Set last exception before return.
* \param py_object The python exception to be set
*/
TVM_DLL void TVMAPISetLastPythonError(void* py_object);

/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from ..base import get_last_ffi_error
from ..base import raise_last_ffi_error
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
Expand Down Expand Up @@ -113,6 +113,7 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
# We mark the possibly long running function as nogil below.
cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg)
void TVMAPISetLastPythonError(void* py_object) except +
const char *TVMGetLastError()
int TVMFuncGetGlobal(const char* name,
TVMPackedFuncHandle* out)
Expand Down Expand Up @@ -178,7 +179,7 @@ cdef inline int CHECK_CALL(int ret) except -2:
if ret == -2:
return -2
if ret != 0:
raise get_last_ffi_error()
raise_last_ffi_error()
return 0


Expand Down
18 changes: 17 additions & 1 deletion python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ cdef int tvm_callback(TVMValue* args,
pyargs.append(c_make_array(value.v_handle, True, False))
try:
rv = local_pyfunc(*pyargs)
except Exception:
except Exception as err:
msg = traceback.format_exc()
msg = py2cerror(msg)
TVMAPISetLastError(c_str(msg))
TVMAPISetLastPythonError(<void*>err)

return -1
if rv is not None:
if isinstance(rv, tuple):
Expand Down Expand Up @@ -368,3 +370,17 @@ def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object

# Py_INCREF and Py_DECREF are C macros, not function objects.
# Therefore, providing a wrapper function.
cdef void _py_incref_wrapper(void* py_object):
Py_INCREF(<object>py_object)
cdef void _py_decref_wrapper(void* py_object):
Py_DECREF(<object>py_object)

def _init_pythonapi_inc_def_ref():
register_func = TVMBackendRegisterEnvCAPI
register_func(c_str("Py_IncRef"), <void*>_py_incref_wrapper)
register_func(c_str("Py_DecRef"), <void*>_py_decref_wrapper)

_init_pythonapi_inc_def_ref()

0 comments on commit 3a4167f

Please sign in to comment.