Skip to content

Commit

Permalink
Merge branch 'develop' into onnxruntime-sync-2023-10-13
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Oct 19, 2023
2 parents e18497c + 07848b2 commit 87a127e
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 54 deletions.
2 changes: 1 addition & 1 deletion src/include/migraphx/argument.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct MIGRAPHX_EXPORT argument : raw_data<argument>
{
argument() = default;

argument(const shape& s);
explicit argument(const shape& s);

template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})>
argument(shape s, F d)
Expand Down
4 changes: 2 additions & 2 deletions src/include/migraphx/op/allocate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ struct allocate
{
if(args.empty())
{
return {output_shape};
return argument{output_shape};
}
else
{
std::vector<std::size_t> output_dims(output_shape.ndim());
args.at(0).visit([&](auto a) { output_dims.assign(a.begin(), a.end()); });
return {shape{buf_type, output_dims}};
return argument{shape{buf_type, output_dims}};
}
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/include/migraphx/op/pooling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ struct pooling
// for dynamic GlobalPooling, there's no padding
kernel_dims.insert(kernel_dims.end(), input_lens.begin() + 2, input_lens.end());
output_shape = dyn_out.computed_shape;
result = dyn_out.computed_shape;
result = argument{dyn_out.computed_shape};
}
else if((padding_mode != op::padding_mode_t::default_))
{
Expand Down Expand Up @@ -439,7 +439,7 @@ struct pooling
{
kernel_dims = this->lengths;
output_shape = dyn_out.computed_shape;
result = dyn_out.computed_shape;
result = argument{dyn_out.computed_shape};
}

// Perform the computation and populate result
Expand Down
26 changes: 19 additions & 7 deletions src/targets/gpu/compile_hip_code_object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,27 @@ void hip_compile_options::set_launch_params(
global = compute_global(local);
}

static bool hip_accept_non_uniform_wg()
{
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
return non_uniform_wg;
}

std::function<std::size_t(std::size_t local)>
compute_global_for(context& ctx, std::size_t n, std::size_t over)
{
assert(over > 0);
std::size_t max_global = ctx.get_current_device().get_cu_count() *
ctx.get_current_device().get_max_workitems_per_cu();
return [n, over, max_global](std::size_t local) {
// hip require global workitems multiple of local workitems. It may degrade performance.
// [TODO]: consider adding "fno-hip-uniform-block" flag when it becomes available.
// https://reviews.llvm.org/D155213
std::size_t num_elements = ((n + local - 1) / local) * local;
std::size_t groups = (num_elements + local - 1) / local;
std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
std::size_t num_elements = n;
if(not hip_accept_non_uniform_wg())
{
num_elements = (1 + (n - 1) / local) * local;
}
std::size_t groups = 1 + (num_elements - 1) / local;
std::size_t max_blocks = max_global / local;
std::size_t nglobal = std::min(max_blocks * over, groups) * local;
return std::min(nglobal, num_elements);
};
}
Expand Down Expand Up @@ -183,6 +190,11 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.emplace_back("args.hpp", args_hpp);

if(options.global % options.local != 0 and hip_accept_non_uniform_wg())
options.params += " -fno-offload-uniform-block";
else
assert(options.global % options.local == 0);

options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " ");
Expand Down
6 changes: 3 additions & 3 deletions src/targets/gpu/include/migraphx/gpu/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ struct miopen_convolution
// MIOpen has APIs to pass pre-allocated buffers starting from rocm-5.6
preallocate = true;
#endif
auto x = preallocate ? to_gpu(generate_argument(x_shape)) : inputs[0];
auto w = preallocate ? to_gpu(generate_argument(w_shape)) : inputs[1];
auto y = preallocate ? allocate_gpu(output_shape) : inputs[2];
auto x = preallocate ? to_gpu(generate_argument(x_shape)) : argument{inputs[0]};
auto w = preallocate ? to_gpu(generate_argument(w_shape)) : argument{inputs[1]};
auto y = preallocate ? allocate_gpu(output_shape) : argument{inputs[2]};
auto workspace =
preallocate ? allocate_gpu(workspace_shape) : migraphx::argument(workspace_shape);

Expand Down
63 changes: 33 additions & 30 deletions src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/functional.hpp>

#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
extern "C" __device__ size_t __ockl_get_enqueued_local_size(uint); // NOLINT
extern "C" __device__ size_t __ockl_get_local_size(uint); // NOLINT
#pragma clang diagnostic pop
#endif

namespace migraphx {

#if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL)
Expand All @@ -45,43 +53,37 @@ inline __device__ __attribute__((const)) index_int compute_global_size()
// This actualy works even when global is not divisible by local size.
// This doesnt actually do a multiplicatiosn. Instead it calls a device
// function to get the global size, which is why it works.
return blockDim.x * gridDim.x; // NOLINT
return blockDim.x * gridDim.x; // NOLINT
#endif
}

// We cant just use blockDim.x to get the local size since its broken on hip
// when global is not divisible by local size. In this case, we calulate the
// size for the last group.
#ifdef MIGRAPHX_NGROUP
// If global is divisible by local then local can be a const
#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_HAS_CONST_LOCAL 1
#endif
#endif

inline __device__ __attribute__((const)) index_int compute_local_size()
{
#ifdef MIGRAPHX_NLOCAL
const auto nlocal = MIGRAPHX_NLOCAL;
#else
const auto nlocal = blockDim.x; // NOLINT
#endif
#ifdef MIGRAPHX_NGROUP
const auto ngroup = MIGRAPHX_NGROUP;
#ifdef MIGRAPHX_HAS_CONST_LOCAL
return MIGRAPHX_NLOCAL;
#else
const auto ngroup = gridDim.x; // NOLINT
// Returns block size. For the non-uniform block it returns the size of the non-uniform block.
return __ockl_get_local_size(0); // NOLINT
#endif
const auto group_id = blockIdx.x; // NOLINT
const auto nglobal = compute_global_size();
if(group_id == ngroup - 1)
{
return 1 + (nglobal - 1) % nlocal;
}
else
{
return nlocal; // NOLINT
}
}

#ifdef MIGRAPHX_NGROUP
// If global is divisible by local then local can be a const
#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_HAS_CONST_LOCAL 1
#endif
inline __device__ __attribute__((const)) index_int compute_max_local_size()
{
#ifdef MIGRAPHX_LOCAL
return MIGRAPHX_NLOCAL;
#else
// Returns the block size. When workgrop has non-uniform block, this returns size of the uniform
// block.
return __ockl_get_enqueued_local_size(0); // NOLINT
#endif
}

struct index
{
Expand Down Expand Up @@ -126,8 +128,8 @@ struct index
#else
__device__ index_int max_nlocal() const
{
MIGRAPHX_ASSERT(blockDim.x > 0);
return blockDim.x;
MIGRAPHX_ASSERT(compute_max_local_size() > 0);
return compute_max_local_size();
}
#endif

Expand Down Expand Up @@ -249,7 +251,8 @@ struct index
#endif
inline __device__ __attribute__((const)) index make_index()
{
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
return index{
blockIdx.x * compute_max_local_size() + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
}

} // namespace migraphx
Expand Down
2 changes: 1 addition & 1 deletion test/eliminate_allocation_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct allocate
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};

Expand Down
4 changes: 2 additions & 2 deletions test/eliminate_concat_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct concat
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};

Expand Down Expand Up @@ -104,7 +104,7 @@ struct allocate
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};

Expand Down
2 changes: 1 addition & 1 deletion test/memory_coloring_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct allocate
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};

Expand Down
2 changes: 1 addition & 1 deletion test/normalize_ops_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct normalize_test_op
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};

Expand Down
4 changes: 2 additions & 2 deletions test/replace_allocate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct allocate_no_out : migraphx::auto_register_op<allocate_no_out>
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};

Expand All @@ -78,7 +78,7 @@ struct allocate_with_out : migraphx::auto_register_op<allocate_with_out>
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};

Expand Down
10 changes: 8 additions & 2 deletions tools/accuracy/accuracy_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,16 @@ def main():
else:
test_input = np.zeros(in_shape).astype(get_np_datatype(in_type))
test_inputs[name] = test_input
params[name] = migraphx.argument(test_input)
migraphx_arg = migraphx.argument(test_input)
if not args.offload_copy:
migraphx_arg = migraphx.to_gpu(migraphx_arg)
params[name] = migraphx_arg

if not args.ort_run:
pred_migx = np.array(model.run(params)[-1])
if not args.offload_copy:
pred_migx = np.array(migraphx.from_gpu(model.run(params)[-1]))
else:
pred_migx = np.array(model.run(params)[-1])

if use_onnx:
sess_op = ort.SessionOptions()
Expand Down

0 comments on commit 87a127e

Please sign in to comment.