Skip to content

Commit

Permalink
Arithmetic: fix incorrect adding extra col to col_and_dtype which cau…
Browse files Browse the repository at this point in the history
…se task fail (primihub#751)
  • Loading branch information
phoenix20162016 authored Jan 12, 2024
1 parent 5058795 commit 8c112bb
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 39 deletions.
29 changes: 17 additions & 12 deletions src/primihub/algorithm/arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,9 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &dataset_id) {
int chunk_num = table->column(num_col - 1)->chunks().size();
int64_t array_len = 0;
for (int k = 0; k < chunk_num; k++) {
auto array = std::static_pointer_cast<DoubleArray>(
table->column(num_col - 1)->chunk(k));
// auto array = std::static_pointer_cast<DoubleArray>(
// table->column(num_col - 1)->chunk(k));
auto array = table->column(num_col - 1)->chunk(k);
array_len += array->length();
}

Expand All @@ -500,11 +501,15 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &dataset_id) {
// Force the same value count in every column.

for (int i = 0; i < num_col; i++) {
auto& col_name = col_names[i];
if (col_and_dtype_.find(col_name) == col_and_dtype_.end()) {
continue;
}
int chunk_num = table->column(i)->chunks().size();
if (col_and_dtype_[col_names[i]] == 0) {
auto data_type =
table->schema()->GetFieldByName(col_names[i])->type()->id();
if (data_type != arrow::Type::INT64) {
auto col_data_type =
table->schema()->GetFieldByName(col_name)->type()->id();
if (col_and_dtype_[col_name] == 0) {
if (col_data_type != arrow::Type::INT64) {
RaiseException("Local data type is inconsistent with the demand data "
"type!Demand data type is int,but local data type is "
"double!Please input consistent data type!");
Expand All @@ -523,13 +528,13 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &dataset_id) {
if (tmp_len != array_len) {
std::stringstream ss;
ss << "Party: " << this->party_name() << ", "
<< "Column " << col_names[i] << " has " << tmp_len
<< "Column " << col_name << " has " << tmp_len
<< " value, but other column has " << array_len << " value.";
RaiseException(ss.str());
errors = true;
break;
}
col_and_val_int.insert(std::make_pair(col_names[i], tmp_data));
col_and_val_int.insert(std::make_pair(col_name, tmp_data));
// std::pair<std::string, std::vector<int64_t>>(col_names[i], tmp_data));
// for (auto itr = col_and_val_int.begin(); itr != col_and_val_int.end();
// itr++) {
Expand All @@ -541,7 +546,7 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &dataset_id) {
} else {
std::vector<double> tmp_data;
int64_t tmp_len = 0;
if (table->schema()->GetFieldByName(col_names[i])->type()->id() == 9) {
if (col_data_type == arrow::Type::INT64) {
for (int k = 0; k < chunk_num; k++) {
auto array =
std::static_pointer_cast<Int64Array>(table->column(i)->chunk(k));
Expand All @@ -554,7 +559,7 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &dataset_id) {
if (tmp_len != array_len) {
std::stringstream ss;
ss << "Party: " << this->party_name() << ", "
<< "Column " << col_names[i] << " has " << tmp_len
<< "Column " << col_name << " has " << tmp_len
<< " value, but other column has " << array_len << " value.";
RaiseException(ss.str());
errors = true;
Expand All @@ -573,14 +578,14 @@ int ArithmeticExecutor<Dbit>::_LoadDatasetFromCSV(std::string &dataset_id) {
if (tmp_len != array_len) {
std::stringstream ss;
ss << "Party: " << this->party_name() << ", "
<< "Column " << col_names[i] << " has " << tmp_len
<< "Column " << col_name << " has " << tmp_len
<< " value, but other column has " << array_len << " value.";
RaiseException(ss.str());
errors = true;
break;
}
}
col_and_val_double.insert(std::make_pair(col_names[i], tmp_data));
col_and_val_double.insert(std::make_pair(col_name, tmp_data));
// pair<string, std::vector<double>>(col_names[i], tmp_data));
// for (auto itr = col_and_val_double.begin();
// itr != col_and_val_double.end(); itr++) {
Expand Down
54 changes: 27 additions & 27 deletions src/primihub/executor/express.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ int MPCExpressExecutor<Dbit>::ColumnConfig::importColumnOwner(
const std::string &col_name, const u32 &party_id) {
auto iter = col_owner_.find(col_name);
if (iter != col_owner_.end()) {
LOG(ERROR) << "Column " << col_name
<< "'s owner attr is imported, value is " << iter->second << ".";
return -1;
std::stringstream ss;
ss << "Column " << col_name
<< "'s owner attr is imported, value is " << iter->second << ".";
RaiseException(ss.str());
}

col_owner_.insert(std::make_pair(col_name, party_id));
Expand All @@ -68,8 +69,9 @@ int MPCExpressExecutor<Dbit>::ColumnConfig::getColumnDtype(
const std::string &col_name, ColDtype &dtype) {
auto iter = col_dtype_.find(col_name);
if (iter == col_dtype_.end()) {
LOG(ERROR) << "Can't find dtype attr for column " << col_name << ".";
return -1;
std::stringstream ss;
ss << "Can't find dtype attr for column " << col_name << ".";
RaiseException(ss.str());
}

dtype = iter->second;
Expand All @@ -79,15 +81,19 @@ int MPCExpressExecutor<Dbit>::ColumnConfig::getColumnDtype(
template <Decimal Dbit>
int MPCExpressExecutor<Dbit>::ColumnConfig::resolveLocalColumn(void) {
if (col_dtype_.size() != col_owner_.size()) {
LOG(ERROR) << "Count of owner attr and dtype attr must be the same.";
return -1;
std::stringstream ss;
ss << "Count of owner attr and dtype attr must be the same."
<< "col_dtype size: " << col_dtype_.size() << " "
<< "col_owner size: " << col_owner_.size();
RaiseException(ss.str());
}

for (auto iter = col_owner_.begin(); iter != col_owner_.end(); iter++) {
if (col_dtype_.find(iter->first) == col_dtype_.end()) {
LOG(ERROR) << "Import column " << iter->first
<< "'s owner attr, but don't import it's dtype attr.";
return -2;
std::stringstream ss;
ss << "Import column " << iter->first
<< "'s owner attr, but don't import it's dtype attr.";
RaiseException(ss.str());
}
}

Expand Down Expand Up @@ -124,9 +130,9 @@ int MPCExpressExecutor<Dbit>::ColumnConfig::getColumnLocality(
const std::string &col_name, bool &is_local) {
auto iter = local_col_.find(col_name);
if (iter == local_col_.end()) {
LOG(ERROR) << "Can't find column locality by column name " << col_name
<< ".";
return -1;
std::stringstream ss;
ss << "Can't find column locality by column name " << col_name << ".";
RaiseException(ss.str());
}

is_local = iter->second;
Expand Down Expand Up @@ -166,9 +172,9 @@ int MPCExpressExecutor<Dbit>::FeedDict::checkLocalColumn(
bool is_local = false;
int ret = col_config_->getColumnLocality(col_name, is_local);
if (ret) {
LOG(ERROR) << "Get column locality by column name " << col_name
<< " failed.";
return -1;
std::stringstream ss;
ss << "Get column locality by column name " << col_name << " failed.";
RaiseException(ss.str());
}

if (is_local == false)
Expand Down Expand Up @@ -572,21 +578,15 @@ void MPCExpressExecutor<Dbit>::initColumnConfig(const u32 &party_id) {
template <Decimal Dbit>
int MPCExpressExecutor<Dbit>::importColumnDtype(const std::string &col_name,
bool is_fp64) {
if (is_fp64)
if (is_fp64) {
LOG(INFO) << "Column " << col_name << "'s dtype is FP64.";
return col_config_->importColumnDtype(col_name,
ColumnConfig::ColDtype::FP64);
else
} else {
LOG(INFO) << "Column " << col_name << "'s dtype is I64.";
return col_config_->importColumnDtype(col_name,
ColumnConfig::ColDtype::INT64);

if (is_fp64)
LOG(INFO) << "Column " << col_name << "'s dtype is "
<< " FP64.";
else
LOG(INFO) << "Column " << col_name << "'s dtype is "
<< " I64.";

return 0;
}
}

template <Decimal Dbit>
Expand Down

0 comments on commit 8c112bb

Please sign in to comment.