diff --git a/requirements.txt b/requirements.txt index 1ff157d6d27..558c833d0a9 100755 --- a/requirements.txt +++ b/requirements.txt @@ -28,4 +28,4 @@ pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@87a55290b7f89f9f8b97c611e0cf929399a9e0c5 -DBUILD_FAT_LIBROCKCOMPILER=On +ROCm/rocMLIR@ac15e0103315414b3cc99d999e2c714469caca03 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/include/migraphx/op/dot.hpp b/src/include/migraphx/op/dot.hpp index 0ef3787b38c..37e38f76264 100644 --- a/src/include/migraphx/op/dot.hpp +++ b/src/include/migraphx/op/dot.hpp @@ -57,15 +57,18 @@ struct dot auto s1 = b.to_dynamic(); std::vector out_dyn_dims; - // check outer dynamic dimensions are the same + // Check outer dynamic dimensions are compatible. + // Must allow for intersection because of how simplify_dyn_ops + // simplifies each broadcast_for_dot individually. bool same_outers = std::equal(s0.dyn_dims().begin(), s0.dyn_dims().end() - 2, s1.dyn_dims().begin(), s1.dyn_dims().end() - 2, [&](auto x, auto y) { - if(x == y) + auto intersect = x.intersection(y); + if(intersect.has_value()) { - out_dyn_dims.push_back(x); + out_dyn_dims.push_back(intersect.value()); return true; } return false; diff --git a/test/op_shape_test.cpp b/test/op_shape_test.cpp index 333a7f9ef63..9e4d833b448 100644 --- a/test/op_shape_test.cpp +++ b/test/op_shape_test.cpp @@ -810,6 +810,16 @@ TEST_CASE(dot_dyn_static_test1) s_m2); } +TEST_CASE(dot_dyn_static_test2) +{ + migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {3, 3}, {5, 5}, {5, 5}}}; + migraphx::shape s_m2{migraphx::shape::float_type, {2, 3, 5, 8}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {{2, 2}, {3, 3}, {5, 5}, {8, 8}}}, + migraphx::make_op("dot"), + s_m1, + s_m2); +} + TEST_CASE(dot_dyn_test0) { migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}}}; @@ -832,7 +842,7 @@ TEST_CASE(dot_dyn_test1) TEST_CASE(dot_dyn_test2) { - migraphx::shape s_m1{migraphx::shape::float_type, {{1, 1}, {5, 5}, {5, 5}}}; + migraphx::shape s_m1{migraphx::shape::float_type, {{1, 20}, {5, 5}, {5, 5}}}; migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}}; expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 1}, {5, 5}, {8, 8}}}, migraphx::make_op("dot"), @@ -853,10 +863,10 @@ TEST_CASE(dot_dyn_test3) TEST_CASE(dot_dyn_test4) { - // Note how the inner dimensions have an intersection in range - migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4}, {5, 5}, {4, 8}}}; - migraphx::shape s_m2{migraphx::shape::float_type, {{1, 4}, {5, 9}, {8, 8}}}; - expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4}, {5, 5}, {8, 8}}}, + std::size_t max_val = std::numeric_limits::max(); + migraphx::shape s_m1{migraphx::shape::float_type, {{0, max_val}, {5, 5}, {0, max_val}}}; + migraphx::shape s_m2{migraphx::shape::float_type, {{4, 8}, {5, 5}, {8, 8}}}; + expect_shape(migraphx::shape{migraphx::shape::float_type, {{4, 8}, {5, 5}, {8, 8}}}, migraphx::make_op("dot"), s_m1, s_m2); diff --git a/tools/docker/sles.docker b/tools/docker/sles.docker index 5006593a500..b38645a2d95 100644 --- a/tools/docker/sles.docker +++ b/tools/docker/sles.docker @@ -5,7 +5,7 @@ RUN sh -c 'echo -e "\ name=rocm\n\ baseurl=https://repo.radeon.com/rocm/zyp/6.0.2/main\n\ enabled=1\n\ -gpgcheck=1\n\ +gpgcheck=0\n\ gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key\n\ " > /etc/zypp/repos.d/rocm.repo'