Skip to content

Commit

Permalink
Merge branch 'develop' into onnxruntime-sync-2024-04-26
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored May 3, 2024
2 parents ac0a968 + 2bdd02d commit f223caa
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/[email protected] -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@87a55290b7f89f9f8b97c611e0cf929399a9e0c5 -DBUILD_FAT_LIBROCKCOMPILER=On
ROCm/rocMLIR@ac15e0103315414b3cc99d999e2c714469caca03 -DBUILD_FAT_LIBROCKCOMPILER=On
9 changes: 6 additions & 3 deletions src/include/migraphx/op/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,18 @@ struct dot
auto s1 = b.to_dynamic();
std::vector<shape::dynamic_dimension> out_dyn_dims;

// check outer dynamic dimensions are the same
// Check outer dynamic dimensions are compatible.
// Must allow for intersection because of how simplify_dyn_ops
// simplifies each broadcast_for_dot individually.
bool same_outers = std::equal(s0.dyn_dims().begin(),
s0.dyn_dims().end() - 2,
s1.dyn_dims().begin(),
s1.dyn_dims().end() - 2,
[&](auto x, auto y) {
if(x == y)
auto intersect = x.intersection(y);
if(intersect.has_value())
{
out_dyn_dims.push_back(x);
out_dyn_dims.push_back(intersect.value());
return true;
}
return false;
Expand Down
20 changes: 15 additions & 5 deletions test/op_shape_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,16 @@ TEST_CASE(dot_dyn_static_test1)
s_m2);
}

TEST_CASE(dot_dyn_static_test2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {3, 3}, {5, 5}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 3, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{2, 2}, {3, 3}, {5, 5}, {8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}

TEST_CASE(dot_dyn_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}}};
Expand All @@ -832,7 +842,7 @@ TEST_CASE(dot_dyn_test1)

TEST_CASE(dot_dyn_test2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 1}, {5, 5}, {5, 5}}};
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 20}, {5, 5}, {5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1}, {5, 5}, {8, 8}}},
migraphx::make_op("dot"),
Expand All @@ -853,10 +863,10 @@ TEST_CASE(dot_dyn_test3)

TEST_CASE(dot_dyn_test4)
{
// Note how the inner dimensions have an intersection in range
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}, {4, 8}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{1, 4}, {5, 9}, {8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {5, 5}, {8, 8}}},
std::size_t max_val = std::numeric_limits<std::size_t>::max();
migraphx::shape s_m1{migraphx::shape::float_type, {{0, max_val}, {5, 5}, {0, max_val}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{4, 8}, {5, 5}, {8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{4, 8}, {5, 5}, {8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
Expand Down
2 changes: 1 addition & 1 deletion tools/docker/sles.docker
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ RUN sh -c 'echo -e "\
name=rocm\n\
baseurl=https://repo.radeon.com/rocm/zyp/6.0.2/main\n\
enabled=1\n\
gpgcheck=1\n\
gpgcheck=0\n\
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n\
" > /etc/zypp/repos.d/rocm.repo'

Expand Down

0 comments on commit f223caa

Please sign in to comment.