Skip to content

Commit

Permalink
gh-218: add format control for mat & vec
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorOrachyov committed Aug 30, 2023
1 parent b12b632 commit 5d747e2
Show file tree
Hide file tree
Showing 15 changed files with 234 additions and 24 deletions.
23 changes: 23 additions & 0 deletions include/spla.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,27 @@ typedef enum spla_AcceleratorType {
SPLA_ACCELERATOR_TYPE_OPENCL = 1
} spla_AcceleratorType;

typedef enum spla_FormatMatrix {
SPLA_FORMAT_MATRIX_CPU_LIL = 0,
SPLA_FORMAT_MATRIX_CPU_DOK = 1,
SPLA_FORMAT_MATRIX_CPU_COO = 2,
SPLA_FORMAT_MATRIX_CPU_CSR = 3,
SPLA_FORMAT_MATRIX_CPU_CSC = 4,
SPLA_FORMAT_MATRIX_ACC_COO = 5,
SPLA_FORMAT_MATRIX_ACC_CSR = 6,
SPLA_FORMAT_MATRIX_ACC_CSC = 7,
SPLA_FORMAT_MATRIX_COUNT = 8
} spla_FormatMatrix;

typedef enum spla_FormatVector {
SPLA_FORMAT_VECTOR_CPU_DOK = 0,
SPLA_FORMAT_VECTOR_CPU_DENSE = 1,
SPLA_FORMAT_VECTOR_CPU_COO = 2,
SPLA_FORMAT_VECTOR_ACC_DENSE = 3,
SPLA_FORMAT_VECTOR_ACC_COO = 4,
SPLA_FORMAT_VECTOR_COUNT = 5
} spla_FormatVector;

#define SPLA_NULL NULL

typedef int32_t spla_bool;
Expand Down Expand Up @@ -257,6 +278,7 @@ SPLA_API spla_Status spla_Array_clear(spla_Array a);
/* Vector container creation and manipulation */

SPLA_API spla_Status spla_Vector_make(spla_Vector* v, spla_uint n_rows, spla_Type type);
SPLA_API spla_Status spla_Vector_set_format(spla_Vector v, int format);
SPLA_API spla_Status spla_Vector_set_fill_value(spla_Vector v, spla_Scalar value);
SPLA_API spla_Status spla_Vector_set_reduce(spla_Vector v, spla_OpBinary reduce);
SPLA_API spla_Status spla_Vector_set_int(spla_Vector v, spla_uint row_id, int value);
Expand All @@ -274,6 +296,7 @@ SPLA_API spla_Status spla_Vector_clear(spla_Vector v);
/* Matrix container creation and manipulation */

SPLA_API spla_Status spla_Matrix_make(spla_Matrix* M, spla_uint n_rows, spla_uint n_cols, spla_Type type);
SPLA_API spla_Status spla_Matrix_set_format(spla_Matrix M, int format);
SPLA_API spla_Status spla_Matrix_set_fill_value(spla_Matrix M, spla_Scalar value);
SPLA_API spla_Status spla_Matrix_set_reduce(spla_Matrix M, spla_OpBinary reduce);
SPLA_API spla_Status spla_Matrix_set_int(spla_Matrix M, spla_uint row_id, spla_uint col_id, int value);
Expand Down
15 changes: 0 additions & 15 deletions include/spla/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,21 +141,6 @@ namespace spla {
Count = 5
};

/**
* @class FormatArray
* @brief Named storage formats of array
*
* @warning Do not change order of values
*/
enum class FormatArray : uint {
/** CPU side data allocation */
Cpu = 0,
/** Acc side allocation */
Acc = 1,
/** Total formats count */
Count = 2
};

/**
* @class MessageCallback
* @brief Callback function called on library message event
Expand Down
10 changes: 6 additions & 4 deletions python/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
t.assign(m, pyspla.Scalar(pyspla.INT, 10), pyspla.INT.SECOND, pyspla.INT.GEZERO)
print(t.to_list())

M = pyspla.Matrix((10, 10), pyspla.INT)
G = pyspla.Matrix.generate((10, 10), pyspla.INT, density=0.1, dist=[0, 10])

print(M.to_list())
print(G.to_lists())

print(G.reduce(pyspla.INT.PLUS))

M = pyspla.Matrix.from_lists([1, 2, 3], [1, 2, 3], [-1, 5, 10], (10, 10), pyspla.INT)
M.set_format(pyspla.FormatMatrix.ACC_CSR)
print(M.get(1, 0))
print(M.get(1, 1))
print(M.to_list())
2 changes: 2 additions & 0 deletions python/pyspla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@
"INT",
"UINT",
"FLOAT",
"FormatMatrix",
"FormatVector",
"Descriptor",
"Op",
"OpUnary",
Expand Down
1 change: 0 additions & 1 deletion python/pyspla/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ def clear(self):
"""
Clears array removing all elements, so it has 0 values.
"""

check(backend().spla_Array_clear(self._hnd))

def to_list(self):
Expand Down
38 changes: 37 additions & 1 deletion python/pyspla/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@
__all__ = [
"backend",
"check",
"is_docs"
"is_docs",
"FormatMatrix",
"FormatVector"
]

import os
import ctypes
import pathlib
import platform
import atexit
import enum

ARCH = {'AMD64': 'x64', 'x86_64': 'x64', 'arm64': 'arm64'}[platform.machine()]
SYSTEM = {'Darwin': 'macos', 'Linux': 'linux', 'Windows': 'windows'}[platform.system()]
Expand Down Expand Up @@ -95,6 +98,35 @@ class SplaNotImplemented(SplaError):
pass


class FormatMatrix(enum.Enum):
"""
Mapping for spla supported matrix storage formats enumeration.
"""

CPU_LIL = 0
CPU_DOK = 1
CPU_COO = 2
CPU_CSR = 3
CPU_CSC = 4
ACC_COO = 5
ACC_CSR = 6
ACC_CSC = 7
COUNT = 8


class FormatVector(enum.Enum):
"""
Mapping for spla supported vector storage formats enumeration.
"""

CPU_DOK = 0
CPU_DENSE = 1
CPU_COO = 2
ACC_DENSE = 3
ACC_COO = 4
COUNT = 5


_status_mapping = {
1: SplaError,
2: SplaNoAcceleration,
Expand Down Expand Up @@ -350,6 +382,7 @@ def load_library(lib_path):
_spla.spla_Array_clear.argtypes = [_object_t]

_spla.spla_Vector_make.restype = _status_t
_spla.spla_Vector_set_format.restype = _status_t
_spla.spla_Vector_set_fill_value.restype = _status_t
_spla.spla_Vector_set_reduce.restype = _status_t
_spla.spla_Vector_set_int.restype = _status_t
Expand All @@ -363,6 +396,7 @@ def load_library(lib_path):
_spla.spla_Vector_clear.restype = _status_t

_spla.spla_Vector_make.argtypes = [_p_object_t, _uint, _object_t]
_spla.spla_Vector_set_format.argtypes = [_object_t, ctypes.c_int]
_spla.spla_Vector_set_fill_value.argtypes = [_object_t, _object_t]
_spla.spla_Vector_set_reduce.argtypes = [_object_t, _object_t]
_spla.spla_Vector_set_int.argtypes = [_object_t, _uint, _int]
Expand All @@ -376,6 +410,7 @@ def load_library(lib_path):
_spla.spla_Vector_clear.argtypes = [_object_t]

_spla.spla_Matrix_make.restype = _status_t
_spla.spla_Matrix_set_format.restype = _status_t
_spla.spla_Matrix_set_fill_value.restype = _status_t
_spla.spla_Matrix_set_reduce.restype = _status_t
_spla.spla_Matrix_set_int.restype = _status_t
Expand All @@ -389,6 +424,7 @@ def load_library(lib_path):
_spla.spla_Matrix_clear.restype = _status_t

_spla.spla_Matrix_make.argtypes = [_p_object_t, _uint, _uint, _object_t]
_spla.spla_Matrix_set_format.argtypes = [_object_t, ctypes.c_int]
_spla.spla_Matrix_set_fill_value.argtypes = [_object_t, _object_t]
_spla.spla_Matrix_set_reduce.argtypes = [_object_t, _object_t]
_spla.spla_Matrix_set_int.argtypes = [_object_t, _uint, _uint, _int]
Expand Down
54 changes: 54 additions & 0 deletions python/pyspla/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,53 @@ def shape(self):

return self._shape

def set_format(self, fmt):
"""
Instruct container to format internal data with desired storage format.
Multiple different formats may be set at same time, data will be duplicated in different formats.
If selected data already in a selected format, then nothing to do.
See `FormatMatrix` enumeration for all supported formats.
:param fmt: FormatMatrix.
One of built-in storage formats to set.
"""

check(backend().spla_Matrix_set_format(self.hnd, ctypes.c_int(fmt.value)))

def set(self, i, j, v):
"""
Set value at specified index
:param i: uint.
Row index to set.
:param j: uint.
Column index to set.
:param v: any.
Value to set.
"""

check(self._dtype._matrix_set(self.hnd, ctypes.c_uint(i), ctypes.c_uint(j), self._dtype._c_type(v)))

def get(self, i, j):
"""
Get value at specified index.
:param i: uint.
Row index of value to get.
:param j: uint.
Column index of value to get.
:return: Value.
"""

c_value = self._dtype._c_type(0)
check(self._dtype._matrix_get(self.hnd, ctypes.c_uint(i), ctypes.c_uint(j), ctypes.byref(c_value)))
return self._dtype.cast_value(c_value)

def build(self, view_I: MemView, view_J: MemView, view_V: MemView):
"""
Builds matrix content from a raw memory view resources.
Expand Down Expand Up @@ -198,6 +245,13 @@ def to_lists(self):

return list(buffer_I), list(buffer_J), list(buffer_V)

def clear(self):
"""
Clears matrix removing all elements, so it has 0 values.
"""

check(backend().spla_Vector_clear(self.hnd))

def to_list(self):
"""
Read matrix data as a python lists of tuples where key and value stored together.
Expand Down
48 changes: 48 additions & 0 deletions python/pyspla/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,47 @@ def shape(self):

return self._shape

def set_format(self, fmt):
"""
Instruct container to format internal data with desired storage format.
Multiple different formats may be set at same time, data will be duplicated in different formats.
If selected data already in a selected format, then nothing to do.
See `FormatVector` enumeration for all supported formats.
:param fmt: FormatVector.
One of built-in storage formats to set.
"""

check(backend().spla_Vector_set_format(self.hnd, ctypes.c_int(fmt.value)))

def set(self, i, v):
"""
Set value at specified index
:param i: uint.
Row index to set.
:param v: any.
Value to set.
"""

check(self._dtype._vector_set(self.hnd, ctypes.c_uint(i), self._dtype._c_type(v)))

def get(self, i):
"""
Get value at specified index.
:param i: uint.
Row index of value to get.
:return: Value.
"""

c_value = self._dtype._c_type(0)
check(self._dtype._vector_get(self.hnd, ctypes.c_uint(i), ctypes.byref(c_value)))
return self._dtype.cast_value(c_value)

def build(self, view_I: MemView, view_V: MemView):
"""
Builds vector content from a raw memory view resources.
Expand Down Expand Up @@ -155,6 +196,13 @@ def read(self):
check(backend().spla_Vector_read(self.hnd, ctypes.byref(keys_view_hnd), ctypes.byref(values_view_hnd)))
return MemView(hnd=keys_view_hnd, owner=self), MemView(hnd=values_view_hnd, owner=self)

def clear(self):
"""
Clears vector removing all elements, so it has 0 values.
"""

check(backend().spla_Vector_clear(self.hnd))

def to_lists(self):
"""
Read vector data as a python lists of keys and values.
Expand Down
3 changes: 3 additions & 0 deletions src/binding/c_matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ spla_Status spla_Matrix_make(spla_Matrix* M, spla_uint n_rows, spla_uint n_cols,
*M = as_ptr<spla_Matrix_t>(matrix.release());
return SPLA_STATUS_OK;
}
spla_Status spla_Matrix_set_format(spla_Matrix M, int format) {
return to_c_status(as_ptr<spla::Matrix>(M)->set_format(static_cast<spla::FormatMatrix>(format)));
}
spla_Status spla_Matrix_set_fill_value(spla_Matrix M, spla_Scalar value) {
return to_c_status(as_ptr<spla::Matrix>(M)->set_fill_value(as_ref<spla::Scalar>(value)));
}
Expand Down
3 changes: 3 additions & 0 deletions src/binding/c_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ spla_Status spla_Vector_make(spla_Vector* v, spla_uint n_rows, spla_Type type) {
*v = as_ptr<spla_Vector_t>(vector.release());
return SPLA_STATUS_OK;
}
spla_Status spla_Vector_set_format(spla_Vector v, int format) {
return to_c_status(as_ptr<spla::Vector>(v)->set_format(static_cast<spla::FormatVector>(format)));
}
spla_Status spla_Vector_set_fill_value(spla_Vector v, spla_Scalar value) {
return to_c_status(as_ptr<spla::Vector>(v)->set_fill_value(as_ref<spla::Scalar>(value)));
}
Expand Down
6 changes: 3 additions & 3 deletions src/core/tmatrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,19 @@ namespace spla {

template<typename T>
Status TMatrix<T>::set_int(uint row_id, uint col_id, std::int32_t value) {
validate_rw(FormatMatrix::CpuLil);
validate_rwd(FormatMatrix::CpuLil);
cpu_lil_add_element(row_id, col_id, static_cast<T>(value), *get<CpuLil<T>>());
return Status::Ok;
}
template<typename T>
Status TMatrix<T>::set_uint(uint row_id, uint col_id, std::uint32_t value) {
validate_rw(FormatMatrix::CpuLil);
validate_rwd(FormatMatrix::CpuLil);
cpu_lil_add_element(row_id, col_id, static_cast<T>(value), *get<CpuLil<T>>());
return Status::Ok;
}
template<typename T>
Status TMatrix<T>::set_float(uint row_id, uint col_id, float value) {
validate_rw(FormatMatrix::CpuLil);
validate_rwd(FormatMatrix::CpuLil);
cpu_lil_add_element(row_id, col_id, static_cast<T>(value), *get<CpuLil<T>>());
return Status::Ok;
}
Expand Down
3 changes: 3 additions & 0 deletions src/core/tvector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ namespace spla {
template<typename T>
Status TVector<T>::set_int(uint row_id, std::int32_t value) {
if (is_valid(FormatVector::CpuDense)) {
validate_rwd(FormatVector::CpuDense);
get<CpuDenseVec<T>>()->Ax[row_id] = static_cast<T>(value);
return Status::Ok;
}
Expand All @@ -169,6 +170,7 @@ namespace spla {
template<typename T>
Status TVector<T>::set_uint(uint row_id, std::uint32_t value) {
if (is_valid(FormatVector::CpuDense)) {
validate_rwd(FormatVector::CpuDense);
get<CpuDenseVec<T>>()->Ax[row_id] = static_cast<T>(value);
return Status::Ok;
}
Expand All @@ -180,6 +182,7 @@ namespace spla {
template<typename T>
Status TVector<T>::set_float(uint row_id, float value) {
if (is_valid(FormatVector::CpuDense)) {
validate_rwd(FormatVector::CpuDense);
get<CpuDenseVec<T>>()->Ax[row_id] = static_cast<T>(value);
return Status::Ok;
}
Expand Down
Loading

0 comments on commit 5d747e2

Please sign in to comment.