Skip to content

Commit

Permalink
gh-217: impl C bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorOrachyov committed Aug 27, 2023
1 parent c71b4b7 commit 4917a24
Show file tree
Hide file tree
Showing 19 changed files with 670 additions and 53 deletions.
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,16 @@ add_library(spla SHARED
src/vector.cpp
# C bindings
include/spla.h
src/binding/c_algorithm.cpp
src/binding/c_array.cpp
src/binding/c_config.hpp
src/binding/c_library.cpp
src/binding/c_matrix.cpp
src/binding/c_object.cpp
src/binding/c_op.cpp
src/binding/c_scalar.cpp
src/binding/c_type.cpp
src/binding/c_vector.cpp
# C++ optional part
${SRC_OPENCL})

Expand Down
128 changes: 125 additions & 3 deletions include/spla.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@

/**
* @file spla.h
* @author Egor Orachev
*
* @brief Spla library C++ API bindings for C language
* @brief Spla library C API bindings
*
* @note Bindings primary intended for exporting to other programming languages,
* such as Python, etc. For a manual usage prefer C++ library API declared in spla.hpp file.
*
* @see Source code: https://github.com/SparseLinearAlgebra/spla
* @see Python Reference API: https://SparseLinearAlgebra.github.io/spla/docs-python/spla
Expand All @@ -58,6 +62,10 @@
extern "C" {
#endif

//////////////////////////////////////////////////////////////////////////////////////

/* General definitions */

typedef enum spla_Status {
SPLA_STATUS_OK = 0,
SPLA_STATUS_ERROR = 1,
Expand All @@ -75,7 +83,7 @@ typedef enum spla_AcceleratorType {
SPLA_ACCELERATOR_TYPE_OPENCL = 1
} spla_AcceleratorType;

#define SPLA_NULL_HND NULL
#define SPLA_NULL NULL

typedef uint32_t spla_uint;

Expand All @@ -94,21 +102,108 @@ typedef struct spla_OpSelect_t* spla_OpSelect;

typedef void(spla_MessageCallback)(spla_Status, const char* message, const char* file, const char* function, int line, void* p_user_data);

//////////////////////////////////////////////////////////////////////////////////////

/* Library configuration and accessors */

SPLA_API void spla_Library_finalize();
SPLA_API spla_Status spla_Library_set_accelerator(spla_AcceleratorType accelerator);
SPLA_API spla_Status spla_Library_set_platform(int index);
SPLA_API spla_Status spla_Library_set_device(int index);
SPLA_API spla_Status spla_Library_set_queues_count(int count);
SPLA_API spla_Status spla_Library_set_message_callback(spla_MessageCallback callback, void* p_user_data);
SPLA_API spla_Status spla_Library_set_default_callback();
SPLA_API spla_Status spla_Library_get_accelerator_info(char* buffer, int length);

//////////////////////////////////////////////////////////////////////////////////////

/* Built-in predefined scalar values types for storage parametrization */

SPLA_API spla_Type spla_Type_int();
SPLA_API spla_Type spla_Type_uint();
SPLA_API spla_Type spla_Type_float();

//////////////////////////////////////////////////////////////////////////////////////

/* Built-in binary element-wise operations */

SPLA_API spla_OpBinary spla_OpBinary_PLUS_INT();
SPLA_API spla_OpBinary spla_OpBinary_PLUS_UINT();
SPLA_API spla_OpBinary spla_OpBinary_PLUS_FLOAT();
SPLA_API spla_OpBinary spla_OpBinary_MINUS_INT();
SPLA_API spla_OpBinary spla_OpBinary_MINUS_UINT();
SPLA_API spla_OpBinary spla_OpBinary_MINUS_FLOAT();
SPLA_API spla_OpBinary spla_OpBinary_MULT_INT();
SPLA_API spla_OpBinary spla_OpBinary_MULT_UINT();
SPLA_API spla_OpBinary spla_OpBinary_MULT_FLOAT();
SPLA_API spla_OpBinary spla_OpBinary_DIV_INT();
SPLA_API spla_OpBinary spla_OpBinary_DIV_UINT();
SPLA_API spla_OpBinary spla_OpBinary_DIV_FLOAT();
SPLA_API spla_OpBinary spla_OpBinary_MINUS_POW2_INT();
SPLA_API spla_OpBinary spla_OpBinary_MINUS_POW2_UINT();
SPLA_API spla_OpBinary spla_OpBinary_MINUS_POW2_FLOAT();
SPLA_API spla_OpBinary spla_OpBinary_FIRST_INT();
SPLA_API spla_OpBinary spla_OpBinary_FIRST_UINT();
SPLA_API spla_OpBinary spla_OpBinary_FIRST_FLOAT();
SPLA_API spla_OpBinary spla_OpBinary_SECOND_INT();
SPLA_API spla_OpBinary spla_OpBinary_SECOND_UINT();
SPLA_API spla_OpBinary spla_OpBinary_SECOND_FLOAT();
SPLA_API spla_OpBinary spla_OpBinary_ONE_INT();
SPLA_API spla_OpBinary spla_OpBinary_ONE_UINT();
SPLA_API spla_OpBinary spla_OpBinary_ONE_FLOAT();
SPLA_API spla_OpBinary spla_OpBinary_MIN_INT();
SPLA_API spla_OpBinary spla_OpBinary_MIN_UINT();
SPLA_API spla_OpBinary spla_OpBinary_MIN_FLOAT();
SPLA_API spla_OpBinary spla_OpBinary_MAX_INT();
SPLA_API spla_OpBinary spla_OpBinary_MAX_UINT();
SPLA_API spla_OpBinary spla_OpBinary_MAX_FLOAT();
SPLA_API spla_OpBinary spla_OpBinary_BOR_INT();
SPLA_API spla_OpBinary spla_OpBinary_BOR_UINT();
SPLA_API spla_OpBinary spla_OpBinary_BAND_INT();
SPLA_API spla_OpBinary spla_OpBinary_BAND_UINT();
SPLA_API spla_OpBinary spla_OpBinary_BXOR_INT();
SPLA_API spla_OpBinary spla_OpBinary_BXOR_UINT();

//////////////////////////////////////////////////////////////////////////////////////

/* Built-in selection operations */

SPLA_API spla_OpSelect spla_OpSelect_EQZERO_INT();
SPLA_API spla_OpSelect spla_OpSelect_EQZERO_UINT();
SPLA_API spla_OpSelect spla_OpSelect_EQZERO_FLOAT();
SPLA_API spla_OpSelect spla_OpSelect_NQZERO_INT();
SPLA_API spla_OpSelect spla_OpSelect_NQZERO_UINT();
SPLA_API spla_OpSelect spla_OpSelect_NQZERO_FLOAT();
SPLA_API spla_OpSelect spla_OpSelect_GTZERO_INT();
SPLA_API spla_OpSelect spla_OpSelect_GTZERO_UINT();
SPLA_API spla_OpSelect spla_OpSelect_GTZERO_FLOAT();
SPLA_API spla_OpSelect spla_OpSelect_GEZERO_INT();
SPLA_API spla_OpSelect spla_OpSelect_GEZERO_UINT();
SPLA_API spla_OpSelect spla_OpSelect_GEZERO_FLOAT();
SPLA_API spla_OpSelect spla_OpSelect_LTZERO_INT();
SPLA_API spla_OpSelect spla_OpSelect_LTZERO_UINT();
SPLA_API spla_OpSelect spla_OpSelect_LTZERO_FLOAT();
SPLA_API spla_OpSelect spla_OpSelect_LEZERO_INT();
SPLA_API spla_OpSelect spla_OpSelect_LEZERO_UINT();
SPLA_API spla_OpSelect spla_OpSelect_LEZERO_FLOAT();
SPLA_API spla_OpSelect spla_OpSelect_ALWAYS_INT();
SPLA_API spla_OpSelect spla_OpSelect_ALWAYS_UINT();
SPLA_API spla_OpSelect spla_OpSelect_ALWAYS_FLOAT();
SPLA_API spla_OpSelect spla_OpSelect_NEVER_INT();
SPLA_API spla_OpSelect spla_OpSelect_NEVER_UINT();
SPLA_API spla_OpSelect spla_OpSelect_NEVER_FLOAT();

//////////////////////////////////////////////////////////////////////////////////////

/* General base Object type methods */

SPLA_API spla_Status spla_Object_ref(spla_Object object);
SPLA_API spla_Status spla_Object_unref(spla_Object object);

//////////////////////////////////////////////////////////////////////////////////////

/* Scala container creation and manipulation */

SPLA_API spla_Status spla_Scalar_make(spla_Scalar* scalar, spla_Type type);
SPLA_API spla_Status spla_Scalar_set_int(spla_Scalar s, int value);
SPLA_API spla_Status spla_Scalar_set_uint(spla_Scalar s, unsigned int value);
Expand All @@ -117,6 +212,10 @@ SPLA_API spla_Status spla_Scalar_get_int(spla_Scalar s, int* value);
SPLA_API spla_Status spla_Scalar_get_uint(spla_Scalar s, unsigned int* value);
SPLA_API spla_Status spla_Scalar_get_float(spla_Scalar s, float* value);

//////////////////////////////////////////////////////////////////////////////////////

/* Array container creation and manipulation */

SPLA_API spla_Status spla_Array_make(spla_Array* v, spla_uint n_values, spla_Type type);
SPLA_API spla_Status spla_Array_set_int(spla_Array a, spla_uint i, int value);
SPLA_API spla_Status spla_Array_set_uint(spla_Array a, spla_uint i, unsigned int value);
Expand All @@ -126,6 +225,10 @@ SPLA_API spla_Status spla_Array_get_uint(spla_Array a, spla_uint i, unsigned int
SPLA_API spla_Status spla_Array_get_float(spla_Array a, spla_uint i, float* value);
SPLA_API spla_Status spla_Array_clear(spla_Array a);

//////////////////////////////////////////////////////////////////////////////////////

/* Vector container creation and manipulation */

SPLA_API spla_Status spla_Vector_make(spla_Vector* v, spla_uint n_rows, spla_Type type);
SPLA_API spla_Status spla_Vector_set_fill_value(spla_Vector v, spla_Scalar value);
SPLA_API spla_Status spla_Vector_set_reduce(spla_Vector v, spla_OpBinary reduce);
Expand All @@ -135,9 +238,15 @@ SPLA_API spla_Status spla_Vector_set_float(spla_Vector v, spla_uint row_id, floa
SPLA_API spla_Status spla_Vector_get_int(spla_Vector v, spla_uint row_id, int* value);
SPLA_API spla_Status spla_Vector_get_uint(spla_Vector v, spla_uint row_id, unsigned int* value);
SPLA_API spla_Status spla_Vector_get_float(spla_Vector v, spla_uint row_id, float* value);
SPLA_API spla_Status spla_Vector_build(spla_Vector v, spla_Array keys, spla_Array values);
SPLA_API spla_Status spla_Vector_read(spla_Vector v, spla_Array keys, spla_Array values);
SPLA_API spla_Status spla_Vector_clear(spla_Vector v);

SPLA_API spla_Status spla_Matrix_make(spla_Matrix* M, spla_uint n_rows, spla_Type type);
//////////////////////////////////////////////////////////////////////////////////////

/* Matrix container creation and manipulation */

SPLA_API spla_Status spla_Matrix_make(spla_Matrix* M, spla_uint n_rows, spla_uint n_cols, spla_Type type);
SPLA_API spla_Status spla_Matrix_set_fill_value(spla_Matrix M, spla_Scalar value);
SPLA_API spla_Status spla_Matrix_set_reduce(spla_Matrix M, spla_OpBinary reduce);
SPLA_API spla_Status spla_Matrix_set_int(spla_Matrix M, spla_uint row_id, spla_uint col_id, int value);
Expand All @@ -146,8 +255,21 @@ SPLA_API spla_Status spla_Matrix_set_float(spla_Matrix M, spla_uint row_id, spla
SPLA_API spla_Status spla_Matrix_get_int(spla_Matrix M, spla_uint row_id, spla_uint col_id, int* value);
SPLA_API spla_Status spla_Matrix_get_uint(spla_Matrix M, spla_uint row_id, spla_uint col_id, unsigned int* value);
SPLA_API spla_Status spla_Matrix_get_float(spla_Matrix M, spla_uint row_id, spla_uint col_id, float* value);
SPLA_API spla_Status spla_Matrix_build(spla_Matrix M, spla_Array keys1, spla_Array keys2, spla_Array values);
SPLA_API spla_Status spla_Matrix_read(spla_Matrix M, spla_Array keys1, spla_Array keys2, spla_Array values);
SPLA_API spla_Status spla_Matrix_clear(spla_Matrix M);

//////////////////////////////////////////////////////////////////////////////////////

/* Implemented some common graph algorithms using spla library */

SPLA_API spla_Status spla_Algorithm_bfs(spla_Vector v, spla_Matrix A, spla_uint s, spla_Descriptor descriptor);
SPLA_API spla_Status spla_Algorithm_sssp(spla_Vector v, spla_Matrix A, spla_uint s, spla_Descriptor descriptor);
SPLA_API spla_Status spla_Algorithm_pr(spla_Vector* p, spla_Matrix A, float alpha, float eps, spla_Descriptor descriptor);
SPLA_API spla_Status spla_Algorithm_tc(int* ntrins, spla_Matrix A, spla_Matrix B, spla_Descriptor descriptor);

//////////////////////////////////////////////////////////////////////////////////////

#if defined(__cplusplus)
}
#endif
Expand Down
1 change: 1 addition & 0 deletions include/spla/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ namespace spla {
SPLA_API virtual Status get_uint(uint i, T_UINT& value) = 0;
SPLA_API virtual Status get_float(uint i, T_FLOAT& value) = 0;
SPLA_API virtual Status resize(uint n_values) = 0;
SPLA_API virtual Status clear() = 0;

SPLA_API static ref_ptr<Array> make(uint n_values, const ref_ptr<Type>& type);
};
Expand Down
2 changes: 1 addition & 1 deletion include/spla/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace spla {
SPLA_API virtual Status get_uint(uint row_id, uint col_id, std::uint32_t& value) = 0;
SPLA_API virtual Status get_float(uint row_id, uint col_id, float& value) = 0;
SPLA_API virtual Status build(const ref_ptr<Array>& keys1, const ref_ptr<Array>& keys2, const ref_ptr<Array>& values) = 0;
SPLA_API virtual Status read(ref_ptr<Array>& keys1, ref_ptr<Array>& keys2, ref_ptr<Array>& values) = 0;
SPLA_API virtual Status read(const ref_ptr<Array>& keys1, const ref_ptr<Array>& keys2, const ref_ptr<Array>& values) = 0;
SPLA_API virtual Status clear() = 0;

/**
Expand Down
2 changes: 1 addition & 1 deletion include/spla/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ namespace spla {
SPLA_API virtual Status fill_noize(uint seed) = 0;
SPLA_API virtual Status fill_with(const ref_ptr<Scalar>& value) = 0;
SPLA_API virtual Status build(const ref_ptr<Array>& keys, const ref_ptr<Array>& values) = 0;
SPLA_API virtual Status read(ref_ptr<Array>& keys, ref_ptr<Array>& values) = 0;
SPLA_API virtual Status read(const ref_ptr<Array>& keys, const ref_ptr<Array>& values) = 0;
SPLA_API virtual Status clear() = 0;

/**
Expand Down
56 changes: 28 additions & 28 deletions python/pyspla/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"""

__all__ = [
"_spla_lib",
"_spla",
"_callback_t"
]

Expand All @@ -42,8 +42,8 @@
TARGET_SUFFIX = {'macos': '.dylib', 'linux': '.so', 'windows': '.dll'}[SYSTEM]
TARGET = {'macos': 'libspla', 'linux': 'libspla', 'windows': 'spla'}[SYSTEM] + "_" + ARCH + TARGET_SUFFIX

_spla_lib_path = None
_spla_lib = None
_spla_path = None
_spla = None
_int = None
_enum_t = None
_status_t = None
Expand All @@ -66,15 +66,15 @@


def load_library(lib_path):
global _spla_lib
global _spla
global _int
global _enum_t
global _status_t
global _object_t
global _p_object_t
global _callback_t

_spla_lib = ctypes.cdll.LoadLibrary(str(lib_path))
_spla = ctypes.cdll.LoadLibrary(str(lib_path))
_int = ctypes.c_int
_enum_t = ctypes.c_uint
_status_t = ctypes.c_uint
Expand All @@ -88,26 +88,26 @@ def load_library(lib_path):
ctypes.c_int,
ctypes.c_void_p)

_spla_lib.spla_Library_finalize.restype = _status_t
_spla_lib.spla_Library_finalize.argtypes = []
_spla.spla_Library_finalize.restype = _status_t
_spla.spla_Library_finalize.argtypes = []

_spla_lib.spla_Library_set_accelerator.restype = _status_t
_spla_lib.spla_Library_set_accelerator.argtypes = [_enum_t]
_spla.spla_Library_set_accelerator.restype = _status_t
_spla.spla_Library_set_accelerator.argtypes = [_enum_t]

_spla_lib.spla_Library_set_platform.restype = _status_t
_spla_lib.spla_Library_set_platform.argtypes = [_int]
_spla.spla_Library_set_platform.restype = _status_t
_spla.spla_Library_set_platform.argtypes = [_int]

_spla_lib.spla_Library_set_device.restype = _status_t
_spla_lib.spla_Library_set_device.argtypes = [_int]
_spla.spla_Library_set_device.restype = _status_t
_spla.spla_Library_set_device.argtypes = [_int]

_spla_lib.spla_Library_set_queues_count.restype = _status_t
_spla_lib.spla_Library_set_queues_count.argtypes = [_int]
_spla.spla_Library_set_queues_count.restype = _status_t
_spla.spla_Library_set_queues_count.argtypes = [_int]

_spla_lib.spla_Library_set_message_callback.restype = _status_t
_spla_lib.spla_Library_set_message_callback.argtypes = [_callback_t, ctypes.c_void_p]
_spla.spla_Library_set_message_callback.restype = _status_t
_spla.spla_Library_set_message_callback.argtypes = [_callback_t, ctypes.c_void_p]

_spla_lib.spla_Library_set_default_callback.restype = _status_t
_spla_lib.spla_Library_set_default_callback.argtypes = []
_spla.spla_Library_set_default_callback.restype = _status_t
_spla.spla_Library_set_default_callback.argtypes = []


def default_callback(status, msg, file, function, line, user_data):
Expand All @@ -117,13 +117,13 @@ def default_callback(status, msg, file, function, line, user_data):


def finalize():
if _spla_lib:
_spla_lib.spla_Library_finalize()
if _spla:
_spla.spla_Library_finalize()


def initialize():
global _spla_lib
global _spla_lib_path
global _spla
global _spla_path
global _callback_t
global _default_callback

Expand All @@ -134,26 +134,26 @@ def initialize():
except KeyError:
pass

_spla_lib_path = pathlib.Path(__file__).resolve().parent / TARGET
_spla_path = pathlib.Path(__file__).resolve().parent / TARGET

try:
# Override library path from ENV variable (for debug & custom build)
if os.environ["SPLA_PATH"]:
_spla_lib_path = pathlib.Path(os.environ["SPLA_PATH"])
_spla_path = pathlib.Path(os.environ["SPLA_PATH"])
except KeyError:
pass

if not _spla_lib_path.is_file():
if not _spla_path.is_file():
# Validate file before loading
raise Exception(f"no compiled spla file {TARGET} to load")

load_library(_spla_lib_path)
load_library(_spla_path)
_default_callback = _callback_t(default_callback)

try:
# If debug enable in ENV, setup default callback for messages on init
if int(os.environ["SPLA_DEBUG"]):
_spla_lib.spla_Library_set_message_callback(
_spla.spla_Library_set_message_callback(
_default_callback, ctypes.c_void_p(0))
except KeyError:
pass
Expand Down
Loading

0 comments on commit 4917a24

Please sign in to comment.