Skip to content

Commit

Permalink
throw exception if task encountes error (primihub#714)
Browse files Browse the repository at this point in the history
  • Loading branch information
phoenix20162016 authored Nov 27, 2023
1 parent 6397e2d commit 73718d6
Show file tree
Hide file tree
Showing 12 changed files with 343 additions and 237 deletions.
92 changes: 58 additions & 34 deletions src/primihub/algorithm/arithmetic.cc
Original file line number Diff line number Diff line change
@@ -1,21 +1,40 @@
/*
* Copyright (c) 2023 by PrimiHub
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/primihub/algorithm/arithmetic.h"

#include <arrow/api.h>
#include <arrow/array.h>
#include <arrow/result.h>

#include <string>
#include <memory>
#include <utility>

#include "src/primihub/data_store/csv/csv_driver.h"
#include "src/primihub/data_store/factory.h"
#include "src/primihub/util/file_util.h"
#include "src/primihub/util/network/message_interface.h"
#include "src/primihub/util/util.h"
#include "src/primihub/common/value_check_util.h"

using arrow::Array;
using arrow::DoubleArray;
using arrow::Int64Array;
using arrow::Table;
using Array = arrow::Array;
using DoubleArray = arrow::DoubleArray;
using Int64Array = arrow::Int64Array;
using Table = arrow::Table;

namespace primihub {

Expand Down Expand Up @@ -103,7 +122,7 @@ int ArithmeticExecutor<Dbit>::loadParams(primihub::rpc::Task &task) {
// LOG(INFO) << col << ":" << dtype;
}
// LOG(INFO) << col_and_dtype;

expr_ = param_map["Expr"].value_string();
int comma_index = expr_.find(",");
cmp_col1 = expr_.substr(4, comma_index - 4);
Expand Down Expand Up @@ -170,17 +189,15 @@ int ArithmeticExecutor<Dbit>::loadDataset() {
auto iter1 = col_and_dtype_.find(cmp_col1);
if (iter1 == col_and_dtype_.end()) {
std::stringstream ss;
ss << "Can't find dtype of column " << cmp_col1;
LOG(ERROR) << ss.str() << ".";
throw std::runtime_error(ss.str());
ss << "Can't find dtype of column " << cmp_col1 << ".";
RaiseException(ss.str());
}

auto iter2 = col_and_dtype_.find(cmp_col2);
if (iter2 == col_and_dtype_.end()) {
std::stringstream ss;
ss << "Can't find dtype of column " << cmp_col1;
LOG(ERROR) << ss.str() << ".";
throw std::runtime_error(ss.str());
ss << "Can't find dtype of column " << cmp_col1 << ".";
RaiseException(ss.str());
}

if (iter1->second == iter2->second) {
Expand Down Expand Up @@ -212,9 +229,10 @@ int ArithmeticExecutor<Dbit>::loadDataset() {
auto iter3 = col_and_owner_.find(convert_col);
if (iter3 == col_and_owner_.end()) {
std::stringstream ss;
ss << "Can't find the party that column " << convert_col << " belong to.";
LOG(ERROR) << ss.str();
throw std::runtime_error(ss.str());
ss << "Party: " << this->party_name()
<< ", Can't find the party that column "
<< convert_col << " belong to.";
RaiseException(ss.str());
}

if (party_id_ != iter3->second) {
Expand All @@ -226,9 +244,9 @@ int ArithmeticExecutor<Dbit>::loadDataset() {
auto col_iter = col_and_val_int.find(convert_col);
if (col_iter == col_and_val_int.end()) {
std::stringstream ss;
ss << "Can't find column value with column name " << convert_col << ".";
LOG(ERROR) << ss.str();
throw std::runtime_error(ss.str());
ss << "Party: " << this->party_name() << ", "
<< "Can't find column value with column name " << convert_col << ".";
RaiseException(ss.str());
}

std::vector<int64_t> &col_src = col_iter->second;
Expand Down Expand Up @@ -345,8 +363,7 @@ int ArithmeticExecutor<Dbit>::execute() {
} catch (std::exception &e) {
std::stringstream ss;
ss << "Error occurs during MPC Compare, " << e.what();
LOG(ERROR) << ss.str();
throw std::runtime_error(ss.str());
RaiseException(ss.str());
}

return 0;
Expand All @@ -368,8 +385,7 @@ int ArithmeticExecutor<Dbit>::execute() {
} catch (const std::exception &e) {
std::stringstream ss;
ss << "Error occurs during MPC run, " << e.what() << ".";
LOG(ERROR) << ss.str();
throw std::runtime_error(ss.str());
RaiseException(ss.str());
}

return 0;
Expand Down Expand Up @@ -486,11 +502,12 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &dataset_id) {
for (int i = 0; i < num_col; i++) {
int chunk_num = table->column(i)->chunks().size();
if (col_and_dtype_[col_names[i]] == 0) {
if (table->schema()->GetFieldByName(col_names[i])->type()->id() != 9) {
LOG(ERROR) << "Local data type is inconsistent with the demand data "
"type!Demand data type is int,but local data type is "
"double!Please input consistent data type!";
return -1;
auto data_type =
table->schema()->GetFieldByName(col_names[i])->type()->id();
if (data_type != arrow::Type::INT64) {
RaiseException("Local data type is inconsistent with the demand data "
"type!Demand data type is int,but local data type is "
"double!Please input consistent data type!");
}
std::vector<int64_t> tmp_data;
int64_t tmp_len = 0;
Expand All @@ -504,8 +521,11 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &dataset_id) {
}
}
if (tmp_len != array_len) {
LOG(ERROR) << "Column " << col_names[i] << " has " << tmp_len
<< " value, but other column has " << array_len << " value.";
std::stringstream ss;
ss << "Party: " << this->party_name() << ", "
<< "Column " << col_names[i] << " has " << tmp_len
<< " value, but other column has " << array_len << " value.";
RaiseException(ss.str());
errors = true;
break;
}
Expand All @@ -532,9 +552,11 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &dataset_id) {
}
}
if (tmp_len != array_len) {
LOG(ERROR) << "Column " << col_names[i] << " has " << tmp_len
<< " value, but other column has " << array_len
<< " value.";
std::stringstream ss;
ss << "Party: " << this->party_name() << ", "
<< "Column " << col_names[i] << " has " << tmp_len
<< " value, but other column has " << array_len << " value.";
RaiseException(ss.str());
errors = true;
break;
}
Expand All @@ -549,9 +571,11 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &dataset_id) {
}
}
if (tmp_len != array_len) {
LOG(ERROR) << "Column " << col_names[i] << " has " << tmp_len
<< " value, but other column has " << array_len
<< " value.";
std::stringstream ss;
ss << "Party: " << this->party_name() << ", "
<< "Column " << col_names[i] << " has " << tmp_len
<< " value, but other column has " << array_len << " value.";
RaiseException(ss.str());
errors = true;
break;
}
Expand Down
33 changes: 26 additions & 7 deletions src/primihub/algorithm/arithmetic.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
#include "src/primihub/algorithm/base.h"

/*
* Copyright (c) 2023 by PrimiHub
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SRC_PRIMIHUB_ALGORITHM_ARITHMETIC_H_
#define SRC_PRIMIHUB_ALGORITHM_ARITHMETIC_H_
#include <stdlib.h>
#include <time.h>
#include <algorithm>
Expand All @@ -8,7 +23,10 @@
#include <sstream>
#include <string>
#include <vector>
#include <memory>
#include <map>

#include "src/primihub/algorithm/base.h"
#include "src/primihub/data_store/driver.h"
#include "src/primihub/executor/express.h"
// #include "src/primihub/service/dataset/service.h"
Expand All @@ -18,7 +36,7 @@ namespace primihub {

template <Decimal Dbit>
class ArithmeticExecutor : public AlgorithmBase {
public:
public:
explicit ArithmeticExecutor(PartyConfig &config,
std::shared_ptr<DatasetService> dataset_service);
int loadParams(primihub::rpc::Task &task) override;
Expand All @@ -27,7 +45,7 @@ class ArithmeticExecutor : public AlgorithmBase {
int saveModel(void);
retcode InitEngine() override;

private:
private:
int _LoadDatasetFromCSV(std::string &filename);

std::string res_name_;
Expand Down Expand Up @@ -67,9 +85,9 @@ class ArithmeticExecutor : public AlgorithmBase {
// class MPCSendRecvExecutor : public AlgorithmBase {
// public:
// explicit MPCSendRecvExecutor(PartyConfig &config,
// std::shared_ptr<DatasetService> dataset_service);
// std::shared_ptr<DatasetService> dataset_service); // NOLINT
// using TaskGetChannelFunc =
// std::function<std::shared_ptr<network::IChannel>(primihub::Node &node)>;
// std::function<std::shared_ptr<network::IChannel>(primihub::Node &node)>; // NOLINT
// using TaskGetRecvQueueFunc =
// std::function<ThreadSafeQueue<std::string> &(const std::string &key)>;

Expand Down Expand Up @@ -102,4 +120,5 @@ class ArithmeticExecutor : public AlgorithmBase {
// };
// #endif

} // namespace primihub
} // namespace primihub
#endif // SRC_PRIMIHUB_ALGORITHM_ARITHMETIC_H_
Loading

0 comments on commit 73718d6

Please sign in to comment.