From 776facdd01004dea5b8c493aaa57aab8e5ccce7e Mon Sep 17 00:00:00 2001 From: Egor Orachev Date: Sun, 3 Sep 2023 13:56:54 +0300 Subject: [PATCH] gh-225: add matrix row / col extraction --- include/spla.h | 2 + include/spla/exec.hpp | 44 ++++++++ python/example.py | 12 ++- python/pyspla/bridge.py | 6 ++ python/pyspla/matrix.py | 124 ++++++++++++++++++++++ src/binding/c_exec.cpp | 6 ++ src/cpu/cpu_algo_registry.cpp | 12 +++ src/cpu/cpu_format_coo_vec.hpp | 20 ++++ src/cpu/cpu_m_extract_column.hpp | 177 +++++++++++++++++++++++++++++++ src/cpu/cpu_m_extract_row.hpp | 176 ++++++++++++++++++++++++++++++ src/cpu/cpu_m_transpose.hpp | 24 ++--- src/cpu/cpu_v_eadd.hpp | 62 ++++++++++- src/cpu/cpu_v_emult.hpp | 3 + src/exec.cpp | 32 ++++++ src/schedule/schedule_tasks.cpp | 42 ++++++++ src/schedule/schedule_tasks.hpp | 38 +++++++ 16 files changed, 763 insertions(+), 17 deletions(-) create mode 100644 src/cpu/cpu_m_extract_column.hpp create mode 100644 src/cpu/cpu_m_extract_row.hpp diff --git a/include/spla.h b/include/spla.h index 7175bb4a3..907865fab 100644 --- a/include/spla.h +++ b/include/spla.h @@ -377,6 +377,8 @@ SPLA_API spla_Status spla_Exec_m_reduce_by_row(spla_Vector r, spla_Matrix M, spl SPLA_API spla_Status spla_Exec_m_reduce_by_column(spla_Vector r, spla_Matrix M, spla_OpBinary op_reduce, spla_Scalar init, spla_Descriptor desc, spla_ScheduleTask* task); SPLA_API spla_Status spla_Exec_m_reduce(spla_Scalar r, spla_Scalar s, spla_Matrix M, spla_OpBinary op_reduce, spla_Descriptor desc, spla_ScheduleTask* task); SPLA_API spla_Status spla_Exec_m_transpose(spla_Matrix R, spla_Matrix M, spla_OpUnary op_apply, spla_Descriptor desc, spla_ScheduleTask* task); +SPLA_API spla_Status spla_Exec_m_extract_row(spla_Vector r, spla_Matrix M, spla_uint index, spla_OpUnary op_apply, spla_Descriptor desc, spla_ScheduleTask* task); +SPLA_API spla_Status spla_Exec_m_extract_column(spla_Vector r, spla_Matrix M, spla_uint index, spla_OpUnary op_apply, spla_Descriptor desc, spla_ScheduleTask* task); SPLA_API spla_Status spla_Exec_v_eadd(spla_Vector r, spla_Vector u, spla_Vector v, spla_OpBinary op, spla_Descriptor desc, spla_ScheduleTask* task); SPLA_API spla_Status spla_Exec_v_emult(spla_Vector r, spla_Vector u, spla_Vector v, spla_OpBinary op, spla_Descriptor desc, spla_ScheduleTask* task); SPLA_API spla_Status spla_Exec_v_eadd_fdb(spla_Vector r, spla_Vector v, spla_Vector fdb, spla_OpBinary op, spla_Descriptor desc, spla_ScheduleTask* task); diff --git a/include/spla/exec.hpp b/include/spla/exec.hpp index 1f9e87aeb..d24029ffb 100644 --- a/include/spla/exec.hpp +++ b/include/spla/exec.hpp @@ -322,6 +322,50 @@ namespace spla { ref_ptr desc = ref_ptr(), ref_ptr* task_hnd = nullptr); + /** + * @brief Execute (schedule) matrix row extract + * + * @note Pass valid `task_hnd` to store as a task, rather then execute immediately. + * + * @param r Result vector + * @param M Source matrix + * @param index index of row + * @param op_apply Unary op to transform value + * @param desc Scheduled task descriptor; default is null + * @param task_hnd Optional task hnd; pass not-null pointer to store task + * + * @return Status on task execution or status on hnd creation + */ + SPLA_API Status exec_m_extract_row( + ref_ptr r, + ref_ptr M, + uint index, + ref_ptr op_apply, + ref_ptr desc = ref_ptr(), + ref_ptr* task_hnd = nullptr); + + /** + * @brief Execute (schedule) matrix column extract + * + * @note Pass valid `task_hnd` to store as a task, rather then execute immediately. + * + * @param r Result vector + * @param M Source matrix + * @param index Index of column + * @param op_apply Unary op to transform value + * @param desc Scheduled task descriptor; default is null + * @param task_hnd Optional task hnd; pass not-null pointer to store task + * + * @return Status on task execution or status on hnd creation + */ + SPLA_API Status exec_m_extract_column( + ref_ptr r, + ref_ptr M, + uint index, + ref_ptr op_apply, + ref_ptr desc = ref_ptr(), + ref_ptr* task_hnd = nullptr); + /** * @brief Execute (schedule) element-wise addition by structure of two vectors * diff --git a/python/example.py b/python/example.py index e794d4a4e..d8ffc0844 100644 --- a/python/example.py +++ b/python/example.py @@ -1,7 +1,11 @@ from pyspla import * -M = Matrix.from_lists([0, 0, 1], [0, 1, 1], [1, 2, 3], (2, 2), INT) -print(M.eadd(INT.MULT, M.transpose())) +M = Matrix.from_lists([0, 0, 1, 2], [1, 2, 3, 0], [-1, 1, 2, 3], (3, 4), INT) +print(M) +print(M.extract_row(0)) +print(M.extract_row(0, op_apply=INT.AINV)) -M = Matrix.from_lists([0, 0, 1], [0, 1, 1], [1, 2, 3], (2, 2), INT) -print(M.emult(INT.MULT, M.transpose())) +M = Matrix.from_lists([0, 1, 1, 2], [1, 0, 3, 1], [-1, 1, 2, 3], (3, 4), INT) +print(M) +print(M.extract_column(1)) +print(M.extract_column(1, op_apply=INT.AINV)) diff --git a/python/pyspla/bridge.py b/python/pyspla/bridge.py index dee94519a..808f6ff17 100644 --- a/python/pyspla/bridge.py +++ b/python/pyspla/bridge.py @@ -558,6 +558,8 @@ def load_library(lib_path): _spla.spla_Exec_m_reduce_by_column.restype = _status_t _spla.spla_Exec_m_reduce.restype = _status_t _spla.spla_Exec_m_transpose.restype = _status_t + _spla.spla_Exec_m_extract_row.restype = _status_t + _spla.spla_Exec_m_extract_column.restype = _status_t _spla.spla_Exec_v_eadd.restype = _status_t _spla.spla_Exec_v_emult.restype = _status_t _spla.spla_Exec_v_eadd_fdb.restype = _status_t @@ -588,6 +590,10 @@ def load_library(lib_path): [_object_t, _object_t, _object_t, _object_t, _object_t, _p_object_t] _spla.spla_Exec_m_transpose.argtypes = \ [_object_t, _object_t, _object_t, _object_t, _p_object_t] + _spla.spla_Exec_m_extract_row.argtypes = \ + [_object_t, _object_t, _uint, _object_t, _object_t, _p_object_t] + _spla.spla_Exec_m_extract_column.argtypes = \ + [_object_t, _object_t, _uint, _object_t, _object_t, _p_object_t] _spla.spla_Exec_v_eadd.argtypes = \ [_object_t, _object_t, _object_t, _object_t, _object_t, _p_object_t] _spla.spla_Exec_v_emult.argtypes = \ diff --git a/python/pyspla/matrix.py b/python/pyspla/matrix.py index 05e7d22d2..5756c48df 100644 --- a/python/pyspla/matrix.py +++ b/python/pyspla/matrix.py @@ -1338,6 +1338,130 @@ def transpose(self, out=None, op_apply=None, desc=None): return out + def extract_row(self, index, out=None, op_apply=None, desc=None): + """ + Extract matrix row. + + >>> M = Matrix.from_lists([0, 0, 1, 2], [1, 2, 3, 0], [-1, 1, 2, 3], (3, 4), INT) + >>> print(M) + ' + 0 1 2 3 + 0| .-1 1 .| 0 + 1| . . . 2| 1 + 2| 3 . . .| 2 + 0 1 2 3 + ' + + >>> print(M.extract_row(0)) + ' + 0| . + 1|-1 + 2| 1 + 3| . + ' + + >>> print(M.extract_row(0, op_apply=INT.AINV)) + ' + 0| . + 1| 1 + 2|-1 + 3| . + ' + + :param index: int. + Index of row to extract. + + :param out: optional: Vector: default: none. + Optional vector to store result. + + :param op_apply: optional: OpUnary. default: None. + Optional unary function to apply on extraction. + + :param desc: optional: Descriptor. default: None. + Optional descriptor object to configure the execution. + + :return: Vector. + """ + + from .vector import Vector + + if out is None: + out = Vector(shape=self.n_cols, dtype=self.dtype) + if op_apply is None: + op_apply = self.dtype.IDENTITY + + assert out + assert op_apply + assert out.dtype == self.dtype + assert out.n_rows == self.n_cols + assert 0 <= index < self.n_rows + + check(backend().spla_Exec_m_extract_row(out.hnd, self.hnd, ctypes.c_uint(index), op_apply.hnd, + self._get_desc(desc), self._get_task(None))) + + return out + + def extract_column(self, index, out=None, op_apply=None, desc=None): + """ + Extract matrix column. + + >>> M = Matrix.from_lists([0, 1, 1, 2], [1, 0, 3, 1], [-1, 1, 2, 3], (3, 4), INT) + >>> print(M) + ' + 0 1 2 3 + 0| .-1 . .| 0 + 1| 1 . . 2| 1 + 2| . 3 . .| 2 + 0 1 2 3 + ' + + >>> print(M.extract_column(1)) + ' + 0|-1 + 1| . + 2| 3 + ' + + >>> print(M.extract_column(1, op_apply=INT.AINV)) + ' + 0| 1 + 1| . + 2|-3 + ' + + :param index: int. + Index of column to extract. + + :param out: optional: Vector: default: none. + Optional vector to store result. + + :param op_apply: optional: OpUnary. default: None. + Optional unary function to apply on extraction. + + :param desc: optional: Descriptor. default: None. + Optional descriptor object to configure the execution. + + :return: Vector. + """ + + from .vector import Vector + + if out is None: + out = Vector(shape=self.n_rows, dtype=self.dtype) + if op_apply is None: + op_apply = self.dtype.IDENTITY + + assert out + assert op_apply + assert out.dtype == self.dtype + assert out.n_rows == self.n_rows + assert 0 <= index < self.n_cols + + check(backend().spla_Exec_m_extract_column(out.hnd, self.hnd, ctypes.c_uint(index), op_apply.hnd, + self._get_desc(desc), self._get_task(None))) + + return out + def __str__(self): return self.to_string() diff --git a/src/binding/c_exec.cpp b/src/binding/c_exec.cpp index fcc370791..cda0ffe27 100644 --- a/src/binding/c_exec.cpp +++ b/src/binding/c_exec.cpp @@ -76,6 +76,12 @@ spla_Status spla_Exec_m_reduce(spla_Scalar r, spla_Scalar s, spla_Matrix M, spla spla_Status spla_Exec_m_transpose(spla_Matrix R, spla_Matrix M, spla_OpUnary op_apply, spla_Descriptor desc, spla_ScheduleTask* task) { SPLA_WRAP_EXEC(exec_m_transpose, AS_M(R), AS_M(M), AS_OU(op_apply)); } +spla_Status spla_Exec_m_extract_row(spla_Vector r, spla_Matrix M, spla_uint index, spla_OpUnary op_apply, spla_Descriptor desc, spla_ScheduleTask* task) { + SPLA_WRAP_EXEC(exec_m_extract_row, AS_V(r), AS_M(M), index, AS_OU(op_apply)); +} +spla_Status spla_Exec_m_extract_column(spla_Vector r, spla_Matrix M, spla_uint index, spla_OpUnary op_apply, spla_Descriptor desc, spla_ScheduleTask* task) { + SPLA_WRAP_EXEC(exec_m_extract_column, AS_V(r), AS_M(M), index, AS_OU(op_apply)); +} spla_Status spla_Exec_v_eadd(spla_Vector r, spla_Vector u, spla_Vector v, spla_OpBinary op, spla_Descriptor desc, spla_ScheduleTask* task) { SPLA_WRAP_EXEC(exec_v_eadd, AS_V(r), AS_V(u), AS_V(v), AS_OB(op)); } diff --git a/src/cpu/cpu_algo_registry.cpp b/src/cpu/cpu_algo_registry.cpp index cf06035b3..4944eb756 100644 --- a/src/cpu/cpu_algo_registry.cpp +++ b/src/cpu/cpu_algo_registry.cpp @@ -34,6 +34,8 @@ #include #include #include +#include +#include #include #include #include @@ -121,6 +123,16 @@ namespace spla { g_registry->add(MAKE_KEY_CPU_0("m_transpose", UINT), std::make_shared>()); g_registry->add(MAKE_KEY_CPU_0("m_transpose", FLOAT), std::make_shared>()); + // algorthm m_extract_row + g_registry->add(MAKE_KEY_CPU_0("m_extract_row", INT), std::make_shared>()); + g_registry->add(MAKE_KEY_CPU_0("m_extract_row", UINT), std::make_shared>()); + g_registry->add(MAKE_KEY_CPU_0("m_extract_row", FLOAT), std::make_shared>()); + + // algorthm m_extract_column + g_registry->add(MAKE_KEY_CPU_0("m_extract_column", INT), std::make_shared>()); + g_registry->add(MAKE_KEY_CPU_0("m_extract_column", UINT), std::make_shared>()); + g_registry->add(MAKE_KEY_CPU_0("m_extract_column", FLOAT), std::make_shared>()); + // algorthm mxv_masked g_registry->add(MAKE_KEY_CPU_0("mxv_masked", INT), std::make_shared>()); g_registry->add(MAKE_KEY_CPU_0("mxv_masked", UINT), std::make_shared>()); diff --git a/src/cpu/cpu_format_coo_vec.hpp b/src/cpu/cpu_format_coo_vec.hpp index 921788d3f..1cf86bf31 100644 --- a/src/cpu/cpu_format_coo_vec.hpp +++ b/src/cpu/cpu_format_coo_vec.hpp @@ -30,6 +30,9 @@ #include +#include +#include + namespace spla { /** @@ -37,6 +40,23 @@ namespace spla { * @{ */ + template + void cpu_coo_vec_sort(CpuCooVec& vec) { + std::vector> buffer; + buffer.reserve(vec.values); + + for (uint i = 0; i < vec.values; i++) { + buffer.emplace_back(vec.Ai[i], vec.Ax[i]); + } + + std::sort(buffer.begin(), buffer.end(), [](auto& a, auto& b) { return a.first < b.first; }); + + for (uint i = 0; i < vec.values; i++) { + vec.Ai[i] = buffer[i].first; + vec.Ax[i] = buffer[i].second; + } + } + template void cpu_coo_vec_resize(const uint n_values, CpuCooVec& vec) { diff --git a/src/cpu/cpu_m_extract_column.hpp b/src/cpu/cpu_m_extract_column.hpp new file mode 100644 index 000000000..ddf9c53e8 --- /dev/null +++ b/src/cpu/cpu_m_extract_column.hpp @@ -0,0 +1,177 @@ +/**********************************************************************************/ +/* This file is part of spla project */ +/* https://github.com/JetBrains-Research/spla */ +/**********************************************************************************/ +/* MIT License */ +/* */ +/* Copyright (c) 2023 SparseLinearAlgebra */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy */ +/* of this software and associated documentation files (the "Software"), to deal */ +/* in the Software without restriction, including without limitation the rights */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be included in all */ +/* copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ +/* SOFTWARE. */ +/**********************************************************************************/ + +#ifndef SPLA_CPU_M_EXTRACT_COLUMN_HPP +#define SPLA_CPU_M_EXTRACT_COLUMN_HPP + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace spla { + + template + class Algo_m_extract_column_cpu final : public RegistryAlgo { + public: + ~Algo_m_extract_column_cpu() override = default; + + std::string get_name() override { + return "m_extract_column"; + } + + std::string get_description() override { + return "extract matrix column on cpu sequentially"; + } + + Status execute(const DispatchContext& ctx) override { + auto t = ctx.task.template cast_safe(); + auto M = t->M.template cast_safe>(); + + if (M->is_valid(FormatMatrix::CpuCsr)) { + return execute_csr(ctx); + } + if (M->is_valid(FormatMatrix::CpuLil)) { + return execute_lil(ctx); + } + if (M->is_valid(FormatMatrix::CpuDok)) { + return execute_dok(ctx); + } + + return execute_csr(ctx); + } + + private: + Status execute_dok(const DispatchContext& ctx) { + TIME_PROFILE_SCOPE("cpu/m_extract_column_dok"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> op_apply = t->op_apply.template cast_safe>(); + uint index = t->index; + + r->validate_wd(FormatVector::CpuCoo); + M->validate_rw(FormatMatrix::CpuDok); + + CpuCooVec* p_coo_r = r->template get>(); + const CpuDok* p_dok_M = M->template get>(); + auto& func_apply = op_apply->function; + + for (const auto [key, value] : p_dok_M->Ax) { + if (key.second == index) { + p_coo_r->values += 1; + p_coo_r->Ai.push_back(key.first); + p_coo_r->Ax.push_back(func_apply(value)); + } + } + + cpu_coo_vec_sort(*p_coo_r); + + return Status::Ok; + } + + Status execute_lil(const DispatchContext& ctx) { + TIME_PROFILE_SCOPE("cpu/m_extract_column_lil"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> op_apply = t->op_apply.template cast_safe>(); + uint index = t->index; + + r->validate_wd(FormatVector::CpuCoo); + M->validate_rw(FormatMatrix::CpuLil); + + CpuCooVec* p_coo_r = r->template get>(); + const CpuLil* p_lil_M = M->template get>(); + auto& func_apply = op_apply->function; + + for (uint i = 0; i < M->get_n_rows(); i++) { + const auto& row = p_lil_M->Ar[i]; + + typename CpuLil::Entry fake{index, T()}; + auto query = std::lower_bound(row.begin(), row.end(), fake, [](auto& a, auto& b) { return a.first < b.first; }); + + if (query != row.end() && query->first == index) { + p_coo_r->values += 1; + p_coo_r->Ai.push_back(i); + p_coo_r->Ax.push_back(func_apply(query->second)); + } + } + + return Status::Ok; + } + + Status execute_csr(const DispatchContext& ctx) { + TIME_PROFILE_SCOPE("cpu/m_extract_column_csr"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> op_apply = t->op_apply.template cast_safe>(); + uint index = t->index; + + r->validate_wd(FormatVector::CpuCoo); + M->validate_rw(FormatMatrix::CpuCsr); + + CpuCooVec* p_coo_r = r->template get>(); + const CpuCsr* p_csr_M = M->template get>(); + auto& func_apply = op_apply->function; + + for (uint i = 0; i < M->get_n_rows(); i++) { + const auto row_begin = p_csr_M->Aj.begin() + p_csr_M->Ap[i]; + const auto row_end = p_csr_M->Aj.begin() + p_csr_M->Ap[i + 1]; + + auto query = std::lower_bound(row_begin, row_end, index); + + if (query != row_end && *query == index) { + p_coo_r->values += 1; + p_coo_r->Ai.push_back(i); + p_coo_r->Ax.push_back(func_apply(p_csr_M->Ax[std::distance(p_csr_M->Aj.begin(), query)])); + } + } + + return Status::Ok; + } + }; + +}// namespace spla + +#endif//SPLA_CPU_M_EXTRACT_COLUMN_HPP diff --git a/src/cpu/cpu_m_extract_row.hpp b/src/cpu/cpu_m_extract_row.hpp new file mode 100644 index 000000000..9e9d202a0 --- /dev/null +++ b/src/cpu/cpu_m_extract_row.hpp @@ -0,0 +1,176 @@ +/**********************************************************************************/ +/* This file is part of spla project */ +/* https://github.com/JetBrains-Research/spla */ +/**********************************************************************************/ +/* MIT License */ +/* */ +/* Copyright (c) 2023 SparseLinearAlgebra */ +/* */ +/* Permission is hereby granted, free of charge, to any person obtaining a copy */ +/* of this software and associated documentation files (the "Software"), to deal */ +/* in the Software without restriction, including without limitation the rights */ +/* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell */ +/* copies of the Software, and to permit persons to whom the Software is */ +/* furnished to do so, subject to the following conditions: */ +/* */ +/* The above copyright notice and this permission notice shall be included in all */ +/* copies or substantial portions of the Software. */ +/* */ +/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR */ +/* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, */ +/* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE */ +/* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER */ +/* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, */ +/* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE */ +/* SOFTWARE. */ +/**********************************************************************************/ + +#ifndef SPLA_CPU_M_EXTRACT_ROW_HPP +#define SPLA_CPU_M_EXTRACT_ROW_HPP + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace spla { + + template + class Algo_m_extract_row_cpu final : public RegistryAlgo { + public: + ~Algo_m_extract_row_cpu() override = default; + + std::string get_name() override { + return "m_extract_row"; + } + + std::string get_description() override { + return "extract matrix row on cpu sequentially"; + } + + Status execute(const DispatchContext& ctx) override { + auto t = ctx.task.template cast_safe(); + auto M = t->M.template cast_safe>(); + + if (M->is_valid(FormatMatrix::CpuCsr)) { + return execute_csr(ctx); + } + if (M->is_valid(FormatMatrix::CpuLil)) { + return execute_lil(ctx); + } + if (M->is_valid(FormatMatrix::CpuDok)) { + return execute_dok(ctx); + } + + return execute_csr(ctx); + } + + private: + Status execute_dok(const DispatchContext& ctx) { + TIME_PROFILE_SCOPE("cpu/m_extract_row_dok"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> op_apply = t->op_apply.template cast_safe>(); + uint index = t->index; + + r->validate_wd(FormatVector::CpuCoo); + M->validate_rw(FormatMatrix::CpuDok); + + CpuCooVec* p_coo_r = r->template get>(); + const CpuDok* p_dok_M = M->template get>(); + auto& func_apply = op_apply->function; + + for (const auto [key, value] : p_dok_M->Ax) { + if (key.first == index) { + p_coo_r->values += 1; + p_coo_r->Ai.push_back(key.second); + p_coo_r->Ax.push_back(func_apply(value)); + } + } + + cpu_coo_vec_sort(*p_coo_r); + + return Status::Ok; + } + + Status execute_lil(const DispatchContext& ctx) { + TIME_PROFILE_SCOPE("cpu/m_extract_row_lil"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> op_apply = t->op_apply.template cast_safe>(); + uint index = t->index; + + r->validate_wd(FormatVector::CpuCoo); + M->validate_rw(FormatMatrix::CpuLil); + + CpuCooVec* p_coo_r = r->template get>(); + const CpuLil* p_lil_M = M->template get>(); + auto& func_apply = op_apply->function; + + assert(index < M->get_n_rows()); + + p_coo_r->Ai.reserve(p_lil_M->Ar[index].size()); + p_coo_r->Ax.reserve(p_lil_M->Ar[index].size()); + + for (const auto [key, value] : p_lil_M->Ar[index]) { + p_coo_r->values += 1; + p_coo_r->Ai.push_back(key); + p_coo_r->Ax.push_back(func_apply(value)); + } + + return Status::Ok; + } + + Status execute_csr(const DispatchContext& ctx) { + TIME_PROFILE_SCOPE("cpu/m_extract_row_csr"); + + auto t = ctx.task.template cast_safe(); + + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> op_apply = t->op_apply.template cast_safe>(); + uint index = t->index; + + r->validate_wd(FormatVector::CpuCoo); + M->validate_rw(FormatMatrix::CpuCsr); + + CpuCooVec* p_coo_r = r->template get>(); + const CpuCsr* p_csr_M = M->template get>(); + auto& func_apply = op_apply->function; + + assert(index < M->get_n_rows()); + + const uint start = p_csr_M->Ap[index]; + const uint end = p_csr_M->Ap[index + 1]; + const uint count = end - start; + + p_coo_r->Ai.reserve(count); + p_coo_r->Ax.reserve(count); + + for (uint k = start; k < end; k++) { + p_coo_r->values += 1; + p_coo_r->Ai.push_back(p_csr_M->Aj[k]); + p_coo_r->Ax.push_back(func_apply(p_csr_M->Ax[k])); + } + + return Status::Ok; + } + }; + +}// namespace spla + +#endif//SPLA_CPU_M_EXTRACT_ROW_HPP diff --git a/src/cpu/cpu_m_transpose.hpp b/src/cpu/cpu_m_transpose.hpp index 9a6f05205..da4bfba07 100644 --- a/src/cpu/cpu_m_transpose.hpp +++ b/src/cpu/cpu_m_transpose.hpp @@ -79,16 +79,16 @@ namespace spla { auto t = ctx.task.template cast_safe(); - ref_ptr> R = t->R.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> op_reduce = t->op_apply.template cast_safe>(); + ref_ptr> R = t->R.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> op_apply = t->op_apply.template cast_safe>(); R->validate_wd(FormatMatrix::CpuDok); M->validate_rw(FormatMatrix::CpuDok); CpuDok* p_dok_R = R->template get>(); const CpuDok* p_dok_M = M->template get>(); - auto& func_apply = op_reduce->function; + auto& func_apply = op_apply->function; assert(p_dok_R->Ax.empty()); @@ -108,16 +108,16 @@ namespace spla { auto t = ctx.task.template cast_safe(); - ref_ptr> R = t->R.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> op_reduce = t->op_apply.template cast_safe>(); + ref_ptr> R = t->R.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> op_apply = t->op_apply.template cast_safe>(); R->validate_wd(FormatMatrix::CpuLil); M->validate_rw(FormatMatrix::CpuLil); CpuLil* p_lil_R = R->template get>(); const CpuLil* p_lil_M = M->template get>(); - auto& func_apply = op_reduce->function; + auto& func_apply = op_apply->function; const uint DM = M->get_n_rows(); const uint DN = M->get_n_cols(); @@ -143,16 +143,16 @@ namespace spla { auto t = ctx.task.template cast_safe(); - ref_ptr> R = t->R.template cast_safe>(); - ref_ptr> M = t->M.template cast_safe>(); - ref_ptr> op_reduce = t->op_apply.template cast_safe>(); + ref_ptr> R = t->R.template cast_safe>(); + ref_ptr> M = t->M.template cast_safe>(); + ref_ptr> op_apply = t->op_apply.template cast_safe>(); R->validate_wd(FormatMatrix::CpuCsr); M->validate_rw(FormatMatrix::CpuCsr); CpuCsr* p_csr_R = R->template get>(); const CpuCsr* p_csr_M = M->template get>(); - auto& func_apply = op_reduce->function; + auto& func_apply = op_apply->function; const uint DM = M->get_n_rows(); const uint DN = M->get_n_cols(); diff --git a/src/cpu/cpu_v_eadd.hpp b/src/cpu/cpu_v_eadd.hpp index 0bbf24003..64bb855ee 100644 --- a/src/cpu/cpu_v_eadd.hpp +++ b/src/cpu/cpu_v_eadd.hpp @@ -57,14 +57,74 @@ namespace spla { ref_ptr> u = t->u.template cast_safe>(); ref_ptr> v = t->v.template cast_safe>(); + if (u->is_valid(FormatVector::CpuCoo) && v->is_valid(FormatVector::CpuCoo)) { + return execute_spNsp(ctx); + } if (u->is_valid(FormatVector::CpuDense) && v->is_valid(FormatVector::CpuDense)) { return execute_dnNdn(ctx); } - return execute_dnNdn(ctx); + return execute_spNsp(ctx); } private: + Status execute_spNsp(const DispatchContext& ctx) { + TIME_PROFILE_SCOPE("cpu/vector_eadd_spNsp"); + + auto t = ctx.task.template cast_safe(); + ref_ptr> r = t->r.template cast_safe>(); + ref_ptr> u = t->u.template cast_safe>(); + ref_ptr> v = t->v.template cast_safe>(); + ref_ptr> op = t->op.template cast_safe>(); + + r->validate_wd(FormatVector::CpuCoo); + u->validate_rw(FormatVector::CpuCoo); + v->validate_rw(FormatVector::CpuCoo); + + CpuCooVec* p_r = r->template get>(); + const CpuCooVec* p_u = u->template get>(); + const CpuCooVec* p_v = v->template get>(); + const auto& function = op->function; + + assert(p_r->Ax.empty()); + + const auto u_count = p_u->values; + const auto v_count = p_v->values; + uint u_iter = 0; + uint v_iter = 0; + + while (u_iter < u_count && v_iter < v_count) { + if (p_u->Ai[u_iter] < p_v->Ai[v_iter]) { + p_r->Ai.push_back(p_u->Ai[u_iter]); + p_r->Ax.push_back(p_u->Ax[u_iter]); + u_iter += 1; + } else if (p_v->Ai[v_iter] < p_u->Ai[u_iter]) { + p_r->Ai.push_back(p_v->Ai[v_iter]); + p_r->Ax.push_back(p_v->Ax[v_iter]); + v_iter += 1; + } else { + p_r->Ai.push_back(p_u->Ai[u_iter]); + p_r->Ax.push_back(function(p_u->Ax[u_iter], p_v->Ax[v_iter])); + u_iter += 1; + v_iter += 1; + } + p_r->values += 1; + } + while (u_iter < u_count) { + p_r->Ai.push_back(p_u->Ai[u_iter]); + p_r->Ax.push_back(p_u->Ax[u_iter]); + u_iter += 1; + p_r->values += 1; + } + while (v_iter < v_count) { + p_r->Ai.push_back(p_v->Ai[v_iter]); + p_r->Ax.push_back(p_v->Ax[v_iter]); + v_iter += 1; + p_r->values += 1; + } + + return Status::Ok; + } Status execute_dnNdn(const DispatchContext& ctx) { TIME_PROFILE_SCOPE("cpu/vector_eadd_dnNdn"); diff --git a/src/cpu/cpu_v_emult.hpp b/src/cpu/cpu_v_emult.hpp index ceeadf49f..a414ccbf5 100644 --- a/src/cpu/cpu_v_emult.hpp +++ b/src/cpu/cpu_v_emult.hpp @@ -103,6 +103,7 @@ namespace spla { } else if (p_v->Ai[v_iter] < p_u->Ai[u_iter]) { v_iter += 1; } else { + p_r->values += 1; p_r->Ai.push_back(p_u->Ai[u_iter]); p_r->Ax.push_back(function(p_u->Ax[u_iter], p_v->Ax[v_iter])); u_iter += 1; @@ -138,6 +139,7 @@ namespace spla { const uint i = p_u->Ai[k]; if (p_v->Ax[i] != skip) { + p_r->values += 1; p_r->Ai.push_back(i); p_r->Ax.push_back(function(p_u->Ax[k], p_v->Ax[i])); } @@ -171,6 +173,7 @@ namespace spla { const uint i = p_v->Ai[k]; if (p_u->Ax[i] != skip) { + p_r->values += 1; p_r->Ai.push_back(i); p_r->Ax.push_back(function(p_u->Ax[i], p_v->Ax[k])); } diff --git a/src/exec.cpp b/src/exec.cpp index 7759a0f56..d62647fba 100644 --- a/src/exec.cpp +++ b/src/exec.cpp @@ -265,6 +265,38 @@ namespace spla { EXEC_OR_MAKE_TASK } + Status exec_m_extract_row( + ref_ptr r, + ref_ptr M, + uint index, + ref_ptr op_apply, + ref_ptr desc, + ref_ptr* task_hnd) { + auto task = make_ref(); + task->r = std::move(r); + task->M = std::move(M); + task->index = index; + task->op_apply = std::move(op_apply); + task->desc = std::move(desc); + EXEC_OR_MAKE_TASK + } + + Status exec_m_extract_column( + ref_ptr r, + ref_ptr M, + uint index, + ref_ptr op_apply, + ref_ptr desc, + ref_ptr* task_hnd) { + auto task = make_ref(); + task->r = std::move(r); + task->M = std::move(M); + task->index = index; + task->op_apply = std::move(op_apply); + task->desc = std::move(desc); + EXEC_OR_MAKE_TASK + } + Status exec_v_eadd( ref_ptr r, ref_ptr u, diff --git a/src/schedule/schedule_tasks.cpp b/src/schedule/schedule_tasks.cpp index 42e5c9f83..a65ef2934 100644 --- a/src/schedule/schedule_tasks.cpp +++ b/src/schedule/schedule_tasks.cpp @@ -300,6 +300,48 @@ namespace spla { return {R.as(), M.as(), op_apply.as()}; } + std::string ScheduleTask_m_extract_row::get_name() { + return "m_extract_row"; + } + std::string ScheduleTask_m_extract_row::get_key() { + std::stringstream key; + key << get_name() + << TYPE_KEY(r->get_type()); + + return key.str(); + } + std::string ScheduleTask_m_extract_row::get_key_full() { + std::stringstream key; + key << get_name() + << OP_KEY(op_apply); + + return key.str(); + } + std::vector> ScheduleTask_m_extract_row::get_args() { + return {r.as(), M.as(), op_apply.as()}; + } + + std::string ScheduleTask_m_extract_column::get_name() { + return "m_extract_column"; + } + std::string ScheduleTask_m_extract_column::get_key() { + std::stringstream key; + key << get_name() + << TYPE_KEY(r->get_type()); + + return key.str(); + } + std::string ScheduleTask_m_extract_column::get_key_full() { + std::stringstream key; + key << get_name() + << OP_KEY(op_apply); + + return key.str(); + } + std::vector> ScheduleTask_m_extract_column::get_args() { + return {r.as(), M.as(), op_apply.as()}; + } + std::string ScheduleTask_v_eadd::get_name() { return "v_eadd"; } diff --git a/src/schedule/schedule_tasks.hpp b/src/schedule/schedule_tasks.hpp index 5f5f66641..49b0d87a5 100644 --- a/src/schedule/schedule_tasks.hpp +++ b/src/schedule/schedule_tasks.hpp @@ -295,6 +295,44 @@ namespace spla { ref_ptr op_apply; }; + /** + * @class ScheduleTask_m_extract_row + * @brief Matrix extract vector + */ + class ScheduleTask_m_extract_row final : public ScheduleTaskBase { + public: + ~ScheduleTask_m_extract_row() override = default; + + std::string get_name() override; + std::string get_key() override; + std::string get_key_full() override; + std::vector> get_args() override; + + ref_ptr r; + ref_ptr M; + uint index; + ref_ptr op_apply; + }; + + /** + * @class ScheduleTask_m_extract_column + * @brief Matrix extract vector + */ + class ScheduleTask_m_extract_column final : public ScheduleTaskBase { + public: + ~ScheduleTask_m_extract_column() override = default; + + std::string get_name() override; + std::string get_key() override; + std::string get_key_full() override; + std::vector> get_args() override; + + ref_ptr r; + ref_ptr M; + uint index; + ref_ptr op_apply; + }; + /** * @class ScheduleTask_v_eadd * @brief Vector ewise add