Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 2D forward convolution using rocMLIR #2507

Merged
merged 128 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
df7f8a3
changes for the FP8 ref implementation
umangyadav Nov 9, 2023
9bc1828
cppcheck fixes
umangyadav Nov 9, 2023
155a2b1
move FNUZ as template parameter
umangyadav Nov 10, 2023
d9f11e3
Fix numeric limits
umangyadav Nov 10, 2023
4e9d51f
Working FNUZ and FN
umangyadav Nov 10, 2023
7639c28
use float equal
umangyadav Nov 10, 2023
a6372c5
add test for fp8e5m2
umangyadav Nov 10, 2023
439ea40
add test for fp8e5m2fnuz
umangyadav Nov 10, 2023
183db78
refactor add some comments
umangyadav Nov 10, 2023
ab653af
Review updates
umangyadav Nov 13, 2023
8319e01
Fix tidy
umangyadav Nov 14, 2023
9ee0418
Fix test failure
umangyadav Nov 14, 2023
355e4f6
fix isfinite
umangyadav Nov 14, 2023
ba471f4
Merge remote-tracking branch 'origin/develop' into ref_fp8
umangyadav Nov 14, 2023
6aec703
fix test for neg inf
umangyadav Nov 14, 2023
12aac37
fix warning
umangyadav Nov 14, 2023
6009232
add tests
umangyadav Nov 14, 2023
03f7139
Fix tests
umangyadav Nov 14, 2023
1e220c0
add stringstream tests
umangyadav Nov 14, 2023
a83e9dc
Remove clang diagnostics
umangyadav Nov 15, 2023
dfb35a6
Merge remote-tracking branch 'origin/develop' into ref_fp8
umangyadav Nov 15, 2023
26956f1
Remove NOLINTS
umangyadav Nov 15, 2023
269ce6d
Bugfixes and additional tests
umangyadav Nov 16, 2023
6414ee3
Fix undoing
umangyadav Nov 16, 2023
cd26ada
Handle underflow case separately to avoid sanitization errors
umangyadav Nov 16, 2023
1cf87ef
use std::min to avoid sanitization errors
umangyadav Nov 16, 2023
e7e5ba2
Merge branch 'develop' into ref_fp8
umangyadav Nov 16, 2023
98a838f
formatting
umangyadav Nov 16, 2023
61e4e1d
use 31 for min value
umangyadav Nov 16, 2023
a5c38eb
add note
umangyadav Nov 16, 2023
61775ea
Merge branch 'ref_fp8' of github.com:ROCmSoftwarePlatform/AMDMIGraphX…
umangyadav Nov 16, 2023
3806427
Merge branch 'develop' into ref_fp8
umangyadav Nov 16, 2023
017d67e
add some more comments
umangyadav Nov 17, 2023
9e6d866
Merge branch 'ref_fp8' of github.com:ROCmSoftwarePlatform/AMDMIGraphX…
umangyadav Nov 17, 2023
a9dd42f
port gpu changes
umangyadav Nov 17, 2023
d7339e8
use bit cast
umangyadav Nov 17, 2023
6094234
Make FNUZ template param and add numeric limits
umangyadav Nov 17, 2023
78ec77e
only compile for device
umangyadav Nov 17, 2023
3411649
remove non-JIT related code
umangyadav Nov 17, 2023
d2c25a0
Remove FP8_Lowest/Max
umangyadav Nov 17, 2023
5da68df
remove using for dtypes
umangyadav Nov 17, 2023
b36f72d
Update float8_impl
umangyadav Nov 17, 2023
85ba819
constructor from float works with constexpr
umangyadav Nov 17, 2023
aed1922
Remove unnecessary pragmas
umangyadav Nov 17, 2023
f975c63
Remove clang diagnostics
umangyadav Nov 17, 2023
32033d8
Add back floatequal
umangyadav Nov 17, 2023
e88d46a
disable DPP For FP8
umangyadav Nov 17, 2023
3ae93ca
Merge remote-tracking branch 'origin/develop' into gpu_fp8
umangyadav Nov 17, 2023
60dd1f4
formatting
umangyadav Nov 17, 2023
ef425d0
revert unwanted changes
umangyadav Nov 17, 2023
76f0318
Merge branch 'gpu_fp8' of https://github.com/ROCmSoftwarePlatform/AMD…
umangyadav Nov 17, 2023
bd0ae5f
add some more tests
umangyadav Nov 17, 2023
91cc9c7
Add math and reduce tests
umangyadav Nov 18, 2023
e2b0c40
Fix tidy and other errors
umangyadav Nov 18, 2023
9f50051
fixes
umangyadav Nov 18, 2023
249464c
add nolint
umangyadav Nov 18, 2023
1be9587
tidy fix
umangyadav Nov 18, 2023
13403ab
roialign, softmax, pow, acosh, atanh,pad tests are enabled now
umangyadav Nov 20, 2023
f550f81
add layernorm, remove constexpr for 1/r
umangyadav Nov 20, 2023
7e3444c
tidy fixes
umangyadav Nov 20, 2023
6155c78
use __builtin_is_constant_evaluated
umangyadav Nov 20, 2023
13ef414
add test for rsqrt and remove old-styple-cast
umangyadav Nov 20, 2023
8660572
add comment about c++20 extensions
umangyadav Nov 20, 2023
6fbd997
Remove old cast
umangyadav Nov 20, 2023
2acd265
Remove DPP
umangyadav Nov 20, 2023
836e201
Remove MIN max overloads
umangyadav Nov 20, 2023
f9542d5
Put numeric_max and numeeric lowest into float8
umangyadav Nov 20, 2023
480288f
use void for highest to match template candidates
umangyadav Nov 21, 2023
a6c5772
add float8 for tensorview
umangyadav Nov 21, 2023
3aa465f
compiles all right
umangyadav Nov 26, 2023
037205c
Works now
umangyadav Nov 26, 2023
87548b5
add ifdef to compile
umangyadav Nov 26, 2023
d473b80
add tests and fix cmake
umangyadav Nov 26, 2023
4604f2e
add tests
umangyadav Nov 26, 2023
ad9c25e
add eliminate_fp8 pass
umangyadav Nov 26, 2023
8734ffa
remove convert from lowering
umangyadav Nov 26, 2023
f014fb9
Fix eliminate_fp8 pass
umangyadav Nov 26, 2023
83ce487
Move pass before optimize module
umangyadav Nov 26, 2023
9a9e964
formatting
umangyadav Nov 26, 2023
c40a39c
fix cppcheck
umangyadav Nov 26, 2023
c4cee34
Merge branch 'develop' into rocblas_fp8
umangyadav Dec 1, 2023
f155b0e
merge changes
umangyadav Dec 1, 2023
38218ed
few changes
umangyadav Dec 1, 2023
379692f
few more cosmetic changes
umangyadav Dec 1, 2023
381b2d9
add half tests
umangyadav Dec 2, 2023
ce61ea6
add quant_dot support for fp8
umangyadav Dec 2, 2023
575fc04
mlir fp8
umangyadav Nov 26, 2023
afb55bd
add some MLIR fp8 tests for convolutions
umangyadav Nov 27, 2023
f293193
small example for fp8 fail case
umangyadav Nov 28, 2023
a8ef912
add test for conv_bn with 1e-1f
umangyadav Nov 28, 2023
32e0855
fix conv_bn eps
umangyadav Nov 28, 2023
f18418b
add pooling to unsupported ops
umangyadav Nov 28, 2023
9acd36a
update eps
umangyadav Nov 28, 2023
c314119
update eps
umangyadav Nov 28, 2023
88eb355
add conv tests supported by MLIR
umangyadav Nov 28, 2023
3f21332
remove half test and add it as template
umangyadav Dec 3, 2023
050184c
revert some changes
umangyadav Dec 3, 2023
4e07dfc
revert some changes
umangyadav Dec 3, 2023
370d18c
add quant_conv tests
umangyadav Dec 3, 2023
24c63d7
add comment for 1d convs
umangyadav Dec 3, 2023
c522d47
I dont' know why this test was disabled for the PGpu but enabling it …
umangyadav Dec 3, 2023
fe585d4
Disable FP8 tests for the non-gfx940 arches
umangyadav Dec 3, 2023
994d24b
use helper function to determine gfx940
umangyadav Dec 3, 2023
51ac4fd
fix naming
umangyadav Dec 3, 2023
d06dd8d
use generale_type
umangyadav Dec 3, 2023
40e7698
do not use brackets
umangyadav Dec 3, 2023
119a6b8
Try removing fusing converts
umangyadav Dec 3, 2023
fc093b0
formatting
umangyadav Dec 3, 2023
b6a436f
update MLIR commit hasH
umangyadav Dec 3, 2023
c60a4a4
Merge branch 'rocblas_mlir_fp8' of github.com:ROCmSoftwarePlatform/AM…
umangyadav Dec 4, 2023
5423577
use updated eliminate_fp8 pass
umangyadav Dec 4, 2023
402c66a
use eliminate_data_type pass instead of eliminate_fp8 pass
umangyadav Dec 5, 2023
8738f3b
Merge branch 'develop' into rocblas_fp8
umangyadav Dec 5, 2023
4ca90ec
remove older files
umangyadav Dec 5, 2023
b099a7d
remove header
umangyadav Dec 5, 2023
7d6e6ad
fix typo
umangyadav Dec 5, 2023
cf91c2b
add changes for the eliminate_data_type pass
umangyadav Dec 5, 2023
82f9847
add comments
umangyadav Dec 5, 2023
a9db2bf
fix typo
umangyadav Dec 5, 2023
aeaac20
remove else
umangyadav Dec 5, 2023
a196e90
disable tests that uses CK
umangyadav Dec 5, 2023
7e80f62
formatting
umangyadav Dec 5, 2023
a3d4b01
use same SHA as develop branch
umangyadav Dec 5, 2023
a98d86d
Merge branch 'rocblas_fp8' into rocblas_mlir_fp8
umangyadav Dec 5, 2023
de27b91
use angled brackets
umangyadav Dec 5, 2023
b6250a4
add comment
umangyadav Dec 6, 2023
b254223
formatting
umangyadav Dec 6, 2023
acd9bd3
Merge branch 'develop' into rocblas_mlir_fp8
umangyadav Dec 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/include/migraphx/op/quant_convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <migraphx/op/common.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <migraphx/convolution.hpp>
#include <migraphx/value.hpp>
Expand Down Expand Up @@ -87,11 +88,13 @@ struct quant_convolution
}

// all input type must be int8_type and output is float_type
if(t != shape::int8_type)
std::set<migraphx::shape::type_t> supported_types = {shape::int8_type,
shape::fp8e4m3fnuz_type};
if(not contains(supported_types, t))
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t");
MIGRAPHX_THROW("QUANT_CONVOLUTION: only accept input and weights of type int8_t or "
"fp8e4m3fnuz_type");
}
t = shape::int32_type;

std::vector<size_t> output_lens{input.lens()[0], weights.lens()[0]};
auto padding_size = padding.size();
Expand All @@ -107,8 +110,11 @@ struct quant_convolution
stride[i] +
1)));
}

return inputs[0].with_lens(t, output_lens);
if(t == shape::int8_type)
{
return inputs[0].with_lens(shape::int32_type, output_lens);
} // else fp8 conv
return inputs[0].with_lens(shape::float_type, output_lens);
}

size_t kdims() const
Expand Down
6 changes: 6 additions & 0 deletions src/targets/gpu/device_name.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ std::string get_device_name()
return props.gcnArchName;
}

bool gfx_has_fp8_intrinsics()
{
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
}

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
24 changes: 18 additions & 6 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,18 @@ auto is_mlir_conv(mlir_mode mode)
return false;
if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false;
auto input_arg_t = ins->inputs().front()->get_shape().type();
value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>();
if(group != 1)
return false;
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if(ins->get_shape().lens().size() != 4)
return false;
if(ins->get_shape().type() == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::float_type and input_arg_t == shape::fp8e4m3fnuz_type)
return true;
if(ins->get_shape().type() == shape::int8_type)
return true;
if(mode == mlir_mode::int8)
Expand Down Expand Up @@ -292,6 +297,7 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
const auto result_type = i.get_shape().type();
const std::initializer_list<type_t> allowed_types = {type_t::float_type,
type_t::half_type,
type_t::fp8e4m3fnuz_type,
type_t::int8_type,
type_t::int32_type,
type_t::bool_type};
Expand Down Expand Up @@ -331,7 +337,8 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
"softmax",
"tanh",
};
bool is_float = contains({type_t::float_type, type_t::half_type}, result_type);
bool is_float =
contains({type_t::float_type, type_t::half_type, type_t::fp8e4m3fnuz_type}, result_type);
if(contains(any_type_ops, name))
return true;
if(result_type != type_t::bool_type and contains(no_bool_ops, name))
Expand All @@ -342,6 +349,10 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
// supported.
if(is_float and name == "convert")
{
if(result_type == shape::fp8e4m3fnuz_type)
{
return false;
} // else
return std::all_of(i.inputs().begin(), i.inputs().end(), [](const auto& arg) {
return contains({type_t::float_type, type_t::half_type}, arg->get_shape().type());
});
Expand Down Expand Up @@ -404,12 +415,13 @@ struct find_mlir_standalone_op
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto gemm_based_op = r.result;
//
// enable only for fp32/fp16/i8 types
// enable only for fp32/fp16/i8/fp8 types
if(std::any_of(gemm_based_op->inputs().begin(), gemm_based_op->inputs().end(), [&](auto i) {
return not contains(
{shape::type_t::float_type, shape::type_t::half_type, shape::type_t::int8_type},
i->get_shape().type());
return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::fp8e4m3fnuz_type},
i->get_shape().type());
}))
return;
static size_t counter = 0;
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/include/migraphx/gpu/device_name.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();

MIGRAPHX_GPU_EXPORT int get_device_id();

MIGRAPHX_GPU_EXPORT bool gfx_has_fp8_intrinsics();

} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
Expand Down
2 changes: 2 additions & 0 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ struct mlir_program
result = mlirF32TypeGet(ctx.get());
else if(as.type_enum() == shape::half_type)
result = mlirF16TypeGet(ctx.get());
else if(as.type_enum() == shape::fp8e4m3fnuz_type)
result = mlirFloat8E4M3FNUZTypeGet(ctx.get());
else if(as.type_enum() == shape::double_type)
result = mlirF64TypeGet(ctx.get());
else if(as.is_integral())
Expand Down
3 changes: 1 addition & 2 deletions src/targets/gpu/rocblas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ bool rocblas_fp8_available()
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
return false;
#else
const auto device_name = trim(split_string(get_device_name(), ':').front());
return (starts_with(device_name, "gfx9") and device_name >= "gfx940");
return gfx_has_fp8_intrinsics();
#endif
}

Expand Down
8 changes: 8 additions & 0 deletions src/targets/gpu/target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,19 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
unsupported_types.erase(shape::type_t::uint8_type);
unsupported_types.erase(shape::type_t::int32_type);
unsupported_types.erase(shape::type_t::tuple_type);
// whiltelist supported Ops for the FP8
std::set<std::string> unsupported_fp8_ops = {};
if(not gpu::rocblas_fp8_available())
{
unsupported_fp8_ops.insert("dot");
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops.insert("pooling");
if(not gpu::gfx_has_fp8_intrinsics())
{
unsupported_fp8_ops.insert("convolution");
unsupported_fp8_ops.insert("quant_convolution");
}
// add all device kernels
unsupported_fp8_ops.insert("logsoftmax");
unsupported_fp8_ops.insert("nonzero");
Expand Down
1 change: 0 additions & 1 deletion test/verify/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,5 @@ int main(int argc, const char* argv[])
"test_split_single_dyn_dim",
"test_instancenorm_large_3d<migraphx::shape::float_type>",
"test_instancenorm_large_3d<migraphx::shape::half_type>"});
rv.disable_test_for("gpu", {"test_conv_bn_add"});
rv.run(argc, argv);
}
10 changes: 7 additions & 3 deletions test/verify/quant_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct quant_conv : verify_program<quant_conv>
template <migraphx::shape::type_t DType>
struct quant_conv : verify_program<quant_conv<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::make_op("quant_convolution"), pa, pc);
return p;
}
};

template struct quant_conv<migraphx::shape::int8_type>;
template struct quant_conv<migraphx::shape::fp8e4m3fnuz_type>;
10 changes: 7 additions & 3 deletions test/verify/quant_conv_1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>

struct quant_conv_1 : verify_program<quant_conv_1>
template <migraphx::shape::type_t DType>
struct quant_conv_1 : verify_program<quant_conv_1<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};

template struct quant_conv_1<migraphx::shape::int8_type>;
template struct quant_conv_1<migraphx::shape::fp8e4m3fnuz_type>;
11 changes: 8 additions & 3 deletions test/verify/quant_conv_1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct quant_conv_1d : verify_program<quant_conv_1d>
template <migraphx::shape::type_t DType>
struct quant_conv_1d : verify_program<quant_conv_1d<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4}};
migraphx::shape a_shape{DType, {2, 3, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::make_op("quant_convolution",
Expand All @@ -45,3 +46,7 @@ struct quant_conv_1d : verify_program<quant_conv_1d>
return p;
}
};

template struct quant_conv_1d<migraphx::shape::int8_type>;
// MLIR 1D convolution is not supported in MIGraphX yet. Enable this through MIOpen route later.
// template struct quant_conv_1d<migraphx::shape::fp8e4m3fnuz_type>;
10 changes: 7 additions & 3 deletions test/verify/quant_conv_2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,21 @@
#include <migraphx/generate.hpp>
#include <migraphx/op/quant_convolution.hpp>

struct quant_conv_2 : verify_program<quant_conv_2>
template <migraphx::shape::type_t DType>
struct quant_conv_2 : verify_program<quant_conv_2<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {16, 16, 4, 4}};
migraphx::shape a_shape{DType, {16, 16, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {16, 16, 3, 3}};
migraphx::shape c_shape{DType, {16, 16, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(migraphx::op::quant_convolution{{{0, 0}}, {{1, 1}}, {{1, 1}}}, pa, pc);
return p;
}
};

template struct quant_conv_2<migraphx::shape::int8_type>;
template struct quant_conv_2<migraphx::shape::fp8e4m3fnuz_type>;
10 changes: 7 additions & 3 deletions test/verify/quant_conv_padding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct quant_conv_padding : verify_program<quant_conv_padding>
template <migraphx::shape::type_t DType>
struct quant_conv_padding : verify_program<quant_conv_padding<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {1, 1}}}),
Expand All @@ -44,3 +45,6 @@ struct quant_conv_padding : verify_program<quant_conv_padding>
return p;
}
};

template struct quant_conv_padding<migraphx::shape::int8_type>;
template struct quant_conv_padding<migraphx::shape::fp8e4m3fnuz_type>;
9 changes: 6 additions & 3 deletions test/verify/quant_conv_padding_stride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,16 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
template <migraphx::shape::type_t DType>
struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::int8_type, {2, 3, 4, 4}};
migraphx::shape a_shape{DType, {2, 3, 4, 4}};
auto pa = mm->add_parameter("a", a_shape);
migraphx::shape c_shape{migraphx::shape::int8_type, {2, 3, 3, 3}};
migraphx::shape c_shape{DType, {2, 3, 3, 3}};
auto pc = mm->add_parameter("c", c_shape);
mm->add_instruction(
migraphx::make_op("quant_convolution", {{"padding", {1, 1}}, {"stride", {2, 2}}}),
Expand All @@ -45,3 +46,5 @@ struct quant_conv_padding_stride : verify_program<quant_conv_padding_stride>
return p;
}
};
template struct quant_conv_padding_stride<migraphx::shape::int8_type>;
template struct quant_conv_padding_stride<migraphx::shape::fp8e4m3fnuz_type>;
3 changes: 2 additions & 1 deletion test/verify/run_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ std::vector<migraphx::argument> run_verify::run_ref(migraphx::program p,
{
migraphx::target t = migraphx::make_target("ref");
auto_print pp{p, t.name()};
compile_check(p, t, c_opts);
auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, c_opts, (trace_target == "ref"));
return p.eval(std::move(inputs));
}
std::pair<migraphx::program, std::vector<migraphx::argument>>
Expand Down
12 changes: 7 additions & 5 deletions test/verify/test_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,19 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct test_conv : verify_program<test_conv>
template <migraphx::shape::type_t DType>
struct test_conv : verify_program<test_conv<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("convolution"), input, weights);
return p;
}
};

template struct test_conv<migraphx::shape::float_type>;
template struct test_conv<migraphx::shape::fp8e4m3fnuz_type>;
11 changes: 6 additions & 5 deletions test/verify/test_conv2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,15 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>

struct test_conv2 : verify_program<test_conv2>
template <migraphx::shape::type_t DType>
struct test_conv2 : verify_program<test_conv2<DType>>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto input =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 512, 28, 28}});
auto weights =
mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {256, 512, 1, 1}});
auto input = mm->add_parameter("x", migraphx::shape{DType, {1, 512, 28, 28}});
auto weights = mm->add_parameter("w", migraphx::shape{DType, {256, 512, 1, 1}});
mm->add_instruction(
migraphx::make_op("convolution",
{{"padding", {0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}}),
Expand All @@ -45,3 +44,5 @@ struct test_conv2 : verify_program<test_conv2>
return p;
}
};
template struct test_conv2<migraphx::shape::float_type>;
template struct test_conv2<migraphx::shape::fp8e4m3fnuz_type>;
Loading