Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(interactive): Cut off the string if it exceeds max_length. #4359

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
30 changes: 30 additions & 0 deletions flex/interactive/sdk/python/gs_interactive/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,36 @@ def create_graph_with_custom_pk_name(interactive_session):
delete_running_graph(interactive_session, graph_id)


@pytest.fixture(scope="function")
def create_graph_with_var_char_property(interactive_session):
modern_graph_custom_pk_name = copy.deepcopy(modern_graph_full)
for vertex_type in modern_graph_custom_pk_name["schema"]["vertex_types"]:
# replace each string property with var_char
for prop in vertex_type["properties"]:
if prop["property_type"]:
if "string" in prop["property_type"]:
prop["property_type"]["string"] = {"var_char": {"max_length": 2}}
create_graph_request = CreateGraphRequest.from_dict(modern_graph_custom_pk_name)
resp = interactive_session.create_graph(create_graph_request)
assert resp.is_ok()
graph_id = resp.get_value().graph_id
yield graph_id
delete_running_graph(interactive_session, graph_id)


@pytest.fixture(scope="function")
def create_graph_with_x_csr_params(interactive_session):
modern_graph_x_csr_params = modern_graph_full.copy()
for vertex_type in modern_graph_x_csr_params["schema"]["vertex_types"]:
vertex_type["x_csr_params"] = {"max_vertex_num": 1}
create_graph_request = CreateGraphRequest.from_dict(modern_graph_x_csr_params)
resp = interactive_session.create_graph(create_graph_request)
assert resp.is_ok()
graph_id = resp.get_value().graph_id
yield graph_id
delete_running_graph(interactive_session, graph_id)


def wait_job_finish(sess: Session, job_id: str):
assert job_id is not None
while True:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,33 @@ def test_custom_pk_name(
)
records = result.fetch(1)
assert len(records) == 1 and records[0]["$f0"] == 2


def test_x_csr_params(
interactive_session, neo4j_session, create_graph_with_x_csr_params
):
print("[Test x csr params]")
import_data_to_full_modern_graph(
interactive_session, create_graph_with_x_csr_params
)
start_service_on_graph(interactive_session, create_graph_with_x_csr_params)
result = neo4j_session.run("MATCH (n: person) return count(n);")
# expect return value 0
records = result.fetch(1)
assert len(records) == 1 and records[0]["$f0"] > 1


def test_var_char_property(
interactive_session, neo4j_session, create_graph_with_var_char_property
):
print("[Test var char property]")
import_data_to_full_modern_graph(
interactive_session, create_graph_with_var_char_property
)
start_service_on_graph(interactive_session, create_graph_with_var_char_property)
result = neo4j_session.run("MATCH (n: person) return n.name AS personName;")
records = result.fetch(10)
assert len(records) == 4
for record in records:
# all string property in this graph is var char with max_length 2
assert len(record["personName"]) == 2
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,8 @@ class AbstractArrowFragmentLoader : public IFragmentLoader {
void addVertexBatchFromArray(
label_t v_label_id, IdIndexer<KEY_T, vid_t>& indexer,
std::shared_ptr<arrow::Array>& primary_key_col,
const std::vector<std::shared_ptr<arrow::Array>>& property_cols) {
const std::vector<std::shared_ptr<arrow::Array>>& property_cols,
std::shared_mutex& rw_mutex) {
size_t row_num = primary_key_col->length();
auto col_num = property_cols.size();
for (size_t i = 0; i < col_num; ++i) {
Expand All @@ -391,12 +392,15 @@ class AbstractArrowFragmentLoader : public IFragmentLoader {
std::unique_lock<std::mutex> lock(mtxs_[v_label_id]);
_add_vertex<KEY_T>()(primary_key_col, indexer, vids);
}
for (size_t j = 0; j < property_cols.size(); ++j) {
auto array = property_cols[j];
auto chunked_array = std::make_shared<arrow::ChunkedArray>(array);
set_properties_column(
basic_fragment_loader_.GetVertexTable(v_label_id).column_ptrs()[j],
chunked_array, vids);
{
std::shared_lock<std::shared_mutex> lock(rw_mutex);
for (size_t j = 0; j < property_cols.size(); ++j) {
auto array = property_cols[j];
auto chunked_array = std::make_shared<arrow::ChunkedArray>(array);
set_properties_column(
basic_fragment_loader_.GetVertexTable(v_label_id).column_ptrs()[j],
chunked_array, vids);
}
}

VLOG(10) << "Insert rows: " << row_num;
Expand Down Expand Up @@ -455,13 +459,20 @@ class AbstractArrowFragmentLoader : public IFragmentLoader {
},
idx);
}

std::atomic<size_t> offset(0);
std::shared_mutex rw_mutex;
for (unsigned idx = 0;
idx <
std::min(static_cast<unsigned>(8 * record_batch_supplier_vec.size()),
std::thread::hardware_concurrency());
++idx) {
work_threads.emplace_back(
[&](int i) {
// It is possible that the inserted data will exceed the size of
// the table, so we need to resize the table.
// basic_fragment_loader_.GetVertexTable(v_label_id).resize(vids.size());
auto& vtable = basic_fragment_loader_.GetVertexTable(v_label_id);
while (true) {
std::shared_ptr<arrow::RecordBatch> batch{nullptr};
auto ret = queue.Get(batch);
Expand All @@ -478,8 +489,25 @@ class AbstractArrowFragmentLoader : public IFragmentLoader {
other_columns_array.erase(other_columns_array.begin() +
primary_key_ind);

offset.fetch_add(primary_key_column->length());
size_t local_offset = offset.load();
size_t cur_row_num = std::max(vtable.row_num(), 1ul);
while (cur_row_num <
local_offset + primary_key_column->length()) {
cur_row_num *= 2;
}
if (cur_row_num > vtable.row_num()) {
std::unique_lock<std::shared_mutex> lock(rw_mutex);
if (cur_row_num > vtable.row_num()) {
LOG(INFO) << "Resize vertex table from " << vtable.row_num()
<< " to " << cur_row_num
<< ", local_offset: " << local_offset;
vtable.resize(cur_row_num);
}
}

addVertexBatchFromArray(v_label_id, indexer, primary_key_column,
other_columns_array);
other_columns_array, rw_mutex);
}
},
idx);
Expand Down Expand Up @@ -587,6 +615,22 @@ class AbstractArrowFragmentLoader : public IFragmentLoader {
basic_fragment_loader_.FinishAddingVertex(v_label_id, indexer_builder);
const auto& indexer = basic_fragment_loader_.GetLFIndexer(v_label_id);

auto& vtable = basic_fragment_loader_.GetVertexTable(v_label_id);
size_t total_row_num = 0;
for (auto& batch : batchs) {
for (auto& b : batch) {
total_row_num += b->num_rows();
}
}
if (total_row_num > vtable.row_num()) {
std::unique_lock<std::mutex> lock(mtxs_[v_label_id]);
if (total_row_num > vtable.row_num()) {
LOG(INFO) << "Resize vertex table from " << vtable.row_num() << " to "
<< total_row_num;
vtable.resize(total_row_num);
}
}

std::atomic<size_t> cur_batch_id(0);
for (unsigned i = 0; i < std::thread::hardware_concurrency(); ++i) {
work_threads.emplace_back(
Expand Down Expand Up @@ -635,10 +679,8 @@ class AbstractArrowFragmentLoader : public IFragmentLoader {
auto array = other_columns_array[j];
auto chunked_array =
std::make_shared<arrow::ChunkedArray>(array);
set_properties_column(
basic_fragment_loader_.GetVertexTable(v_label_id)
.column_ptrs()[j],
chunked_array, vids);
set_properties_column(vtable.column_ptrs()[j], chunked_array,
vids);
}
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ void BasicFragmentLoader::AddVertexBatch(
<< ", props[i] size: " << props[i].size();
}
auto dst_columns = table.column_ptrs();
// The table.row_num() is initialized to x_csr_params.max_vertex_num, we
// should resize it to the actual number of vertices.
if (table.row_num() < vids.size()) {
VLOG(10) << "Resize vertex table from " << table.row_num() << " to "
<< vids.size();
table.resize(vids.size());
}
for (size_t j = 0; j < props.size(); ++j) {
auto& cur_vec = props[j];
for (size_t i = 0; i < vids.size(); ++i) {
Expand Down
34 changes: 32 additions & 2 deletions flex/utils/property/column.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,33 @@

namespace gs {

std::string_view truncate_utf8(std::string_view str, size_t length) {
if (str.size() <= length) {
return str;
}
size_t byte_count = 0;

for (const char* p = str.data(); *p && byte_count < length;) {
unsigned char ch = *p;
size_t char_length = 0;
if ((ch & 0x80) == 0) {
char_length = 1;
} else if ((ch & 0xE0) == 0xC0) {
char_length = 2;
} else if ((ch & 0xF0) == 0xE0) {
char_length = 3;
} else if ((ch & 0xF8) == 0xF0) {
char_length = 4;
}
if (byte_count + char_length > length) {
break;
}
p += char_length;
byte_count += char_length;
}
return str.substr(0, byte_count);
}

template <typename T>
class TypedEmptyColumn : public ColumnBase {
public:
Expand Down Expand Up @@ -169,11 +196,14 @@ std::shared_ptr<ColumnBase> CreateColumn(
return std::make_shared<DayColumn>(strategy);
} else if (type == PropertyType::kStringMap) {
return std::make_shared<DefaultStringMapColumn>(strategy);
} else if (type == PropertyType::kStringView) {
return std::make_shared<StringColumn>(strategy);
} else if (type.type_enum == impl::PropertyTypeImpl::kVarChar) {
// We must check is varchar first, because in implementation of
// PropertyType::operator==(const PropertyType& other), we string_view is
// equal to varchar.
return std::make_shared<StringColumn>(
strategy, type.additional_type_info.max_length);
} else if (type == PropertyType::kStringView) {
return std::make_shared<StringColumn>(strategy);
} else if (type.type_enum == impl::PropertyTypeImpl::kRecordView) {
return std::make_shared<RecordViewColumn>(sub_types);
} else {
Expand Down
16 changes: 12 additions & 4 deletions flex/utils/property/column.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

namespace gs {

std::string_view truncate_utf8(std::string_view str, size_t length);

class ColumnBase {
public:
virtual ~ColumnBase() {}
Expand Down Expand Up @@ -499,12 +501,18 @@ class TypedColumn<std::string_view> : public ColumnBase {
PropertyType type() const override { return PropertyType::Varchar(width_); }

void set_value(size_t idx, const std::string_view& val) {
auto copied_val = val;
if (copied_val.size() >= width_) {
VLOG(1) << "String length" << copied_val.size()
<< " exceeds the maximum length: " << width_ << ", cut off.";
copied_val = truncate_utf8(copied_val, width_);
}
if (idx >= basic_size_ && idx < basic_size_ + extra_size_) {
size_t offset = pos_.fetch_add(val.size());
extra_buffer_.set(idx - basic_size_, offset, val);
size_t offset = pos_.fetch_add(copied_val.size());
extra_buffer_.set(idx - basic_size_, offset, copied_val);
} else if (idx < basic_size_) {
size_t offset = basic_pos_.fetch_add(val.size());
basic_buffer_.set(idx, offset, val);
size_t offset = basic_pos_.fetch_add(copied_val.size());
basic_buffer_.set(idx, offset, copied_val);
} else {
LOG(FATAL) << "Index out of range";
}
Expand Down
5 changes: 3 additions & 2 deletions flex/utils/property/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1408,9 +1408,10 @@ struct convert<gs::PropertyType> {
if (config["string"]["var_char"]["max_length"]) {
property_type = gs::PropertyType::Varchar(
config["string"]["var_char"]["max_length"].as<int32_t>());
} else {
property_type = gs::PropertyType::Varchar(
gs::PropertyType::STRING_DEFAULT_MAX_LENGTH);
}
property_type = gs::PropertyType::Varchar(
gs::PropertyType::STRING_DEFAULT_MAX_LENGTH);
} else {
LOG(ERROR) << "Unrecognized string type";
}
Expand Down
Loading