Skip to content

Commit

Permalink
Fix bug with last split
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieL7 committed May 30, 2024
1 parent f293d88 commit 4499a60
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
18 changes: 12 additions & 6 deletions src/onnx/parse_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,18 @@ auto parse_static_split(onnx_parser::node_info info,
std::to_string(info.num_outputs) + "!");
}
}

std::size_t chunk_size = tuned_axis_len / num_outputs;
vec_splits.resize(num_outputs, chunk_size);
auto last = tuned_axis_len % num_outputs;
if(last)
vec_splits.back() = last;
if(tuned_axis_len % num_outputs == 0)
{
std::size_t chunk_size = tuned_axis_len / num_outputs;
vec_splits.resize(num_outputs, chunk_size);
}
else
{
std::size_t chunk_size = tuned_axis_len / num_outputs + 1;
std::size_t last_chunk_size = tuned_axis_len - chunk_size * (num_outputs - 1);
vec_splits.resize(num_outputs - 1, chunk_size);
vec_splits.push_back(last_chunk_size);
}
}

if(std::accumulate(vec_splits.begin(), vec_splits.end(), int64_t(0)) !=
Expand Down
6 changes: 3 additions & 3 deletions test/onnx/parse/split_minus_axis_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ TEST_CASE(split_minus_axis_test)
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {10, 15}});
auto r1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {0}}, {"ends", {5}}}), input);
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), input);
auto r2 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {5}}, {"ends", {10}}}), input);
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {5}}, {"ends", {10}}}), input);
auto r3 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {-1}}, {"starts", {10}}, {"ends", {15}}}), input);
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {10}}, {"ends", {15}}}), input);
mm->add_return({r1, r2, r3});

auto prog = read_onnx("split_minus_axis_test.onnx");
Expand Down

0 comments on commit 4499a60

Please sign in to comment.