Skip to content

Commit

Permalink
Remove rocblas on windows (#2966)
Browse files Browse the repository at this point in the history
  • Loading branch information
tvukovic-amd authored May 9, 2024
1 parent 17ec660 commit 69e3981
Show file tree
Hide file tree
Showing 15 changed files with 89 additions and 47 deletions.
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ else()
option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON)
endif()

option(MIGRAPHX_USE_ROCBLAS "Enable MIGraphX to use rocBLAS" ON)

# By default build shared libraries
option(BUILD_SHARED_LIBS "Create shared libraries" ON)

Expand Down Expand Up @@ -334,11 +336,15 @@ else()
set(DEPENDS_HIP_RUNTIME "hip-runtime-amd" )
endif()

if(MIGRAPHX_USE_ROCBLAS)
list(APPEND PACKAGE_DEPENDS rocblas)
endif()

rocm_create_package(
NAME MIGraphX
DESCRIPTION "AMD's graph optimizer"
MAINTAINER "AMDMIGraphX Maintainer <[email protected]>"
LDCONFIG
PTH
DEPENDS miopen-hip rocblas ${DEPENDS_HIP_RUNTIME} hip-base half ${PACKAGE_DEPENDS}
DEPENDS miopen-hip ${DEPENDS_HIP_RUNTIME} hip-base half ${PACKAGE_DEPENDS}
)
5 changes: 4 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,10 @@ target_link_libraries(migraphx_all_targets INTERFACE migraphx_cpu)
target_compile_definitions(migraphx_all_targets INTERFACE -DHAVE_CPU)
endif()
if(MIGRAPHX_ENABLE_GPU)
list(APPEND MIGRAPHX_CONFIG_DEPENDS PACKAGE MIOpen PACKAGE rocblas)
if(MIGRAPHX_USE_ROCBLAS)
list(APPEND MIGRAPHX_CONFIG_DEPENDS PACKAGE rocblas)
endif()
list(APPEND MIGRAPHX_CONFIG_DEPENDS PACKAGE MIOpen)
add_subdirectory(targets/gpu)
target_link_libraries(migraphx_all_targets INTERFACE migraphx_gpu)
target_compile_definitions(migraphx_all_targets INTERFACE -DHAVE_GPU)
Expand Down
68 changes: 43 additions & 25 deletions src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@ endif()
find_package(miopen REQUIRED)
message(STATUS "MIGraphX is using MIOpen")

# rocblas
find_package(rocblas REQUIRED)
message(STATUS "MIGraphX build with rocBLAS")
if(MIGRAPHX_USE_ROCBLAS)
# rocblas
find_package(rocblas REQUIRED)
message(STATUS "MIGraphX build with rocBLAS")
else()
message(STATUS "MIGraphX build without rocBLAS")
endif()

if(MIGRAPHX_USE_COMPOSABLEKERNEL)
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
Expand Down Expand Up @@ -189,10 +193,12 @@ register_op(migraphx_gpu
HEADER migraphx/gpu/rnn_variable_seq_lens.hpp
OPERATORS gpu::hip_rnn_var_sl_shift_sequence gpu::hip_rnn_var_sl_shift_output gpu::hip_rnn_var_sl_last_output
INCLUDES migraphx/gpu/context.hpp)
register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp)
if(MIGRAPHX_USE_ROCBLAS)
register_op(migraphx_gpu
HEADER migraphx/gpu/gemm.hpp
OPERATORS gpu::rocblas_gemm<op::dot> gpu::rocblas_gemm<op::quant_dot>
INCLUDES migraphx/gpu/context.hpp)
endif()
register_op(migraphx_gpu HEADER migraphx/gpu/convolution.hpp
OPERATORS gpu::miopen_convolution<op::convolution> gpu::miopen_convolution<op::convolution_backwards> gpu::miopen_convolution<op::quant_convolution>
INCLUDES migraphx/gpu/context.hpp)
Expand Down Expand Up @@ -260,13 +266,19 @@ target_compile_definitions(migraphx_gpu PUBLIC MIGRAPHX_CXX_COMPILER="${CMAKE_CX

include(CheckLibraryExists)
get_target_property(MIOPEN_LOCATION MIOpen LOCATION)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)
check_library_exists(MIOpen "miopenFindSolutions" "${MIOPEN_LOCATION}" HAS_FIND_2_API)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
check_library_exists(MIOpen "miopenHiddenSetConvolutionFindMode" "${MIOPEN_LOCATION}" HAS_FIND_MODE_API)

if(MIGRAPHX_USE_ROCBLAS)
get_target_property(ROCBLAS_LOCATION roc::rocblas LOCATION)
target_compile_definitions(migraphx_gpu PUBLIC MIGRAPHX_USE_ROCBLAS=1)
# Beta API for automated GEMM tuning
check_library_exists(roc::rocblas "rocblas_gemm_ex_get_solutions" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
# rocblas FP8 API
check_library_exists(roc::rocblas "rocblas_gemm_strided_batched_ex3" "${ROCBLAS_LOCATION}" HAS_ROCBLAS_FP8_BETA_API)
else()
target_compile_definitions(migraphx_gpu PUBLIC MIGRAPHX_USE_ROCBLAS=0)
endif()

set(MIGRAPHX_USE_FIND_2_API "${HAS_FIND_2_API}" CACHE BOOL "")

Expand All @@ -289,21 +301,27 @@ else()
message(STATUS "MIOpen does not have find mode api")
endif()

if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphx is using Beta API of rocBLAS")
else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()
if(MIGRAPHX_USE_ROCBLAS)
if(HAS_ROCBLAS_TUNING_BETA_FEATURE_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphx is using Beta API of rocBLAS")
else()
message(STATUS "rocBLAS does not have User Tuning Beta API")
endif()

if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have Fp8 Beta API")
if(HAS_ROCBLAS_FP8_BETA_API)
target_compile_definitions(migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_FP8_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS)
message(STATUS "MIGraphX is using Beta API of rocBLAS for FP8 computations")
else()
message(STATUS "rocBLAS does not have Fp8 Beta API")
endif()


target_link_libraries(migraphx_gpu PUBLIC roc::rocblas)
endif()

target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas)
target_link_libraries(migraphx_gpu PUBLIC migraphx MIOpen)

target_link_libraries(migraphx_gpu PRIVATE migraphx_device migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
Expand Down
3 changes: 1 addition & 2 deletions src/targets/gpu/compile_miopen.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -29,7 +29,6 @@
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/rocblas.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down
2 changes: 1 addition & 1 deletion src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
{
#ifdef MIGRAPHX_MLIR
const auto& device_name = ctx == nullptr ? "" : ctx->get_current_device().get_gfx_name();
const bool is_navi = starts_with(device_name, "gfx110");
const bool is_navi = starts_with(device_name, "gfx11");

auto get_mode = [&](std::string_view option, mlir_mode m1, mlir_mode m2 = mlir_mode::fast) {
if(specific_op<rejected>(option))
Expand Down
4 changes: 4 additions & 0 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ struct find_conv_pointwise
}
};

#if MIGRAPHX_USE_ROCBLAS
struct find_gemm_pointwise
{
auto matcher() const
Expand Down Expand Up @@ -675,6 +676,7 @@ struct find_gemm_pointwise
m.replace_instruction(ins, gemm, inputs);
}
};
#endif

struct find_contiguous_tranpose_gemm
{
Expand Down Expand Up @@ -893,7 +895,9 @@ void fuse_ops::apply(module& m) const
match::find_matches(m, find_conv_pointwise{ctx}, find_conv_bias_relu{ctx}, find_conv_bias{ctx});
run_passes(m, {dead_code_elimination{}});
match::find_matches(m,
#if MIGRAPHX_USE_ROCBLAS
find_gemm_pointwise{},
#endif
find_layernorm_pointwise{},
find_concat_pointwise{},
find_contiguous_tranpose_gemm{},
Expand Down
4 changes: 2 additions & 2 deletions src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using microseconds = std::chrono::duration<double, std::micro>;
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

#if MIGRAPHX_USE_ROCBLAS
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
Expand Down Expand Up @@ -678,7 +678,7 @@ int32_t gemm_finalize(context& ctx,
return gemm_finalize_impl(
ctx, output_shape, input_shapes, alpha, beta, compute_fp32, solution_idx);
}

#endif
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
2 changes: 1 addition & 1 deletion src/targets/gpu/include/migraphx/gpu/compile_miopen.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down
11 changes: 7 additions & 4 deletions src/targets/gpu/include/migraphx/gpu/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct hip_device
assert(mihandle.get() != nullptr);
return mihandle.get();
}

#if MIGRAPHX_USE_ROCBLAS
auto get_rocblas()
{
setup();
Expand All @@ -116,6 +116,7 @@ struct hip_device
assert(rbhandle.get() != nullptr);
return rbhandle.get();
}
#endif

void wait() const
{
Expand Down Expand Up @@ -144,10 +145,12 @@ struct hip_device
}

private:
std::size_t id = 0;
shared<hip_stream_ptr> s = nullptr;
shared<miopen_handle> mihandle = nullptr;
std::size_t id = 0;
shared<hip_stream_ptr> s = nullptr;
shared<miopen_handle> mihandle = nullptr;
#if MIGRAPHX_USE_ROCBLAS
shared<rocblas_handle_ptr> rbhandle = nullptr;
#endif
};

void add_stream() { streams.emplace_back(device_id); }
Expand Down
5 changes: 1 addition & 4 deletions src/targets/gpu/include/migraphx/gpu/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ struct rocblas_gemm
bool compute_fp32 = false;
unsigned trans_batch = 0;
int32_t solution_idx = 0;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
Expand Down Expand Up @@ -158,9 +157,7 @@ struct rocblas_gemm
#endif
}
};

#endif
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
7 changes: 5 additions & 2 deletions src/targets/gpu/include/migraphx/gpu/rocblas.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -25,11 +25,14 @@
#define MIGRAPHX_GUARD_MIGRAPHLIB_ROCBLAS_HPP
#include <migraphx/manage_ptr.hpp>
#include <migraphx/gpu/config.hpp>
#if MIGRAPHX_USE_ROCBLAS
#include <rocblas/rocblas.h>
#endif

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
#if MIGRAPHX_USE_ROCBLAS

using rocblas_handle_ptr = MIGRAPHX_MANAGE_PTR(rocblas_handle, rocblas_destroy_handle);

Expand All @@ -41,7 +44,7 @@ struct context;
MIGRAPHX_GPU_EXPORT bool get_compute_fp32_flag();

MIGRAPHX_GPU_EXPORT bool rocblas_fp8_available();

#endif
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
7 changes: 6 additions & 1 deletion src/targets/gpu/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ struct miopen_apply
{
assert(mod != nullptr);
assert(pass != nullptr);

#if MIGRAPHX_USE_ROCBLAS
compute_fp32 = get_compute_fp32_flag();
#endif
offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;

add_generic_op("contiguous");
Expand All @@ -104,8 +105,10 @@ struct miopen_apply
add_convolution_op("convolution");
add_convolution_op("convolution_backwards");
add_convolution_op("quant_convolution");
#if MIGRAPHX_USE_ROCBLAS
add_gemm_op<op::dot>("dot");
add_gemm_op<op::quant_dot>("quant_dot");
#endif
add_if_op();
add_loop_op();
add_neg_op();
Expand Down Expand Up @@ -232,6 +235,7 @@ struct miopen_apply
return mod->insert_instruction(ins, make_op("allocate", {{"shape", to_value(s)}}));
}

#if MIGRAPHX_USE_ROCBLAS
template <typename Op>
void add_gemm_op(const std::string& name)
{
Expand All @@ -243,6 +247,7 @@ struct miopen_apply
return mod->replace_instruction(ins, rocblas_gemm<Op>{Op{}, 1, 0, compute_fp32}, refs);
});
}
#endif

void add_convolution_op(const std::string& name)
{
Expand Down
4 changes: 2 additions & 2 deletions src/targets/gpu/rocblas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {

#if MIGRAPHX_USE_ROCBLAS
rocblas_handle_ptr create_rocblas_handle_ptr()
{
// add a call to rocblas_initialize() to workaround a rocblas bug SWDEV-438929
Expand Down Expand Up @@ -63,7 +63,7 @@ bool rocblas_fp8_available()
return gfx_has_fp8_intrinsics();
#endif
}

#endif
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
2 changes: 2 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::tuple_type);
// whiltelist supported Ops for the FP8
std::set<std::string> unsupported_fp8_ops = {};
#if MIGRAPHX_USE_ROCBLAS
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
unsupported_fp8_ops.insert("quant_dot");
}
#endif
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling");
if(not gpu::gfx_has_fp8_intrinsics())
Expand Down
4 changes: 3 additions & 1 deletion test/gpu/gemm_tune.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -47,6 +47,7 @@ void run_lowering(migraphx::program& p, bool offload_copy = false)
{migraphx::auto_contiguous{}, migraphx::gpu::lowering{&ctx, offload_copy}});
}

#if MIGRAPHX_USE_ROCBLAS
/**
* Tests the automatic GEMM tuning feature. In the finalize() method of the gemm op,
* rocBLAS API functions are called to quickly benchmark all the GEMM solutions
Expand Down Expand Up @@ -181,6 +182,7 @@ TEST_CASE(gemm_tune_strided_lowered)
EXPECT(0 == solution_idx.to<std::size_t>());
#endif
}
#endif

TEST_CASE(gemm_tune_invalid_sol_index)
{
Expand Down

0 comments on commit 69e3981

Please sign in to comment.