Skip to content

Commit

Permalink
CP for Find-2.0 and 3D-Unet fix (#1448)
Browse files Browse the repository at this point in the history
* Find2.0 changes for the Quant  and De-Convolution (#1408)
* fix for 3d-UNet
  • Loading branch information
umangyadav authored Dec 6, 2022
1 parent 788ce62 commit fe19455
Show file tree
Hide file tree
Showing 12 changed files with 398 additions and 878 deletions.
18 changes: 14 additions & 4 deletions src/include/migraphx/reflect.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ auto reflect_impl(rank<0>, T&, Selector)
}

template <class T>
auto reflectable_impl(rank<1>, T&& x)
auto reflectable_impl(rank<1>, const T& x)
-> decltype(T::reflect(x, reflect_placeholder{}), std::true_type{});

template <class T>
auto reflectable_impl(rank<0>, T &&) -> decltype(std::false_type{});
auto reflectable_impl(rank<0>, const T&) -> decltype(std::false_type{});

template <class T>
struct remove_rvalue_reference
Expand Down Expand Up @@ -111,8 +111,18 @@ auto reflect(T& x, Selector f)
template <class T>
auto reflect_tie(T& x)
{
return reflect(x, [](auto&& y, auto&&...) { return detail::wrap<decltype(y)>(y); })(
[](auto&&... xs) { return detail::auto_tuple(xs.get()...); });
return reflect(x, [](auto&& y, auto&&...) {
// cppcheck-suppress UnnecessaryElseStatement
if constexpr(is_reflectable<decltype(y)>{})
{
auto t = reflect_tie(y);
return detail::wrap<decltype(t)>(t);
}
else
{
return detail::wrap<decltype(y)>(y);
}
})([](auto&&... xs) { return detail::auto_tuple(xs.get()...); });
}

template <class T, class F>
Expand Down
16 changes: 16 additions & 0 deletions src/include/migraphx/streamutils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@

#include <ostream>
#include <algorithm>
#include <migraphx/reflect.hpp>
#include <migraphx/rank.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/config.hpp>
#include <vector>

Expand Down Expand Up @@ -83,6 +85,20 @@ auto stream_write_value_impl(rank<0>, std::ostream& os, const Range& r)
os << "}";
}

template <class T, MIGRAPHX_REQUIRES(is_reflectable<T>{})>
void stream_write_value_impl(rank<0>, std::ostream& os, const T& x)
{
char delim = '{';
reflect_each(x, [&](auto&& y, auto name) {
os << delim;
os << name << "=";
stream_write_value_impl(rank<2>{}, os, y);
delim = ',';
});
if(delim == ',')
os << "}";
}

} // namespace detail

template <class T>
Expand Down
15 changes: 7 additions & 8 deletions src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ add_library(migraphx_gpu
compile_hip.cpp
compile_hip_code_object.cpp
compiler.cpp
convolution.cpp
deconvolution.cpp
device_name.cpp
elu.cpp
fuse_mlir.cpp
Expand All @@ -110,7 +108,6 @@ add_library(migraphx_gpu
pad.cpp
perfdb.cpp
pooling.cpp
quant_convolution.cpp
reverse.cpp
rnn_variable_seq_lens.cpp
rocblas.cpp
Expand Down Expand Up @@ -146,14 +143,11 @@ register_migraphx_gpu_ops(hip_
register_migraphx_gpu_ops(miopen_
abs
contiguous
convolution
deconvolution
elu
int8_conv_pack
leaky_relu
lrn
pooling
quant_convolution
)
register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
Expand All @@ -167,6 +161,9 @@ register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::deconvolution> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp)
rocm_set_soversion(migraphx_gpu ${MIGRAPHX_SO_VERSION})
rocm_clang_tidy_check(migraphx_gpu)

Expand Down Expand Up @@ -239,11 +236,13 @@ get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)

if(HAS_FIND_2_API)
set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")

if(MIGRAPHX_USE_FIND_2_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_HAS_FIND_2_API)
message(STATUS "MIGraphx is using Find-2.0 API of MIOpen")
else()
message(STATUS "MIOpen does not have Find-2.0 API")
message(STATUS "MIGraphx is using legacy Find API in MIOpen")
endif()

if(HAS_FIND_MODE_API)
Expand Down
271 changes: 0 additions & 271 deletions src/targets/gpu/convolution.cpp

This file was deleted.

Loading

0 comments on commit fe19455

Please sign in to comment.