Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BF16 fused_reduce compile fail #3728

Open
shivadbhavsar opened this issue Dec 20, 2024 · 1 comment
Open

BF16 fused_reduce compile fail #3728

shivadbhavsar opened this issue Dec 20, 2024 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@shivadbhavsar
Copy link
Contributor

Repro

# fuse_reduce.py

import numpy as np
import migraphx

p = migraphx.program()
m = p.get_main_module()

s1 = migraphx.shape(type="float_type", lens=[1, 24, 4608, 128])
x0 = m.add_parameter("x0", s1)
x1 = m.add_parameter("x1", s1)

c1 = m.add_literal(np.array(0.0078125, dtype=np.float32))
c1 = m.add_instruction(migraphx.op("multibroadcast", out_lens=s1.lens()), [c1])
c2 = m.add_literal(np.array(2, dtype=np.float32))
c2 = m.add_instruction(migraphx.op("multibroadcast", out_lens=s1.lens()), [c2])
c4 = m.add_literal(np.array(9.98378e-07, dtype=np.float32))

pow = m.add_instruction(migraphx.op("pow"), [x0, c2])
mul = m.add_instruction(migraphx.op("mul"), [pow, c1])
red = m.add_instruction(migraphx.op("reduce_sum", axes=[3]), [mul])
c4 = m.add_instruction(migraphx.op("multibroadcast", out_lens=red.shape().lens()), [c4])
add = m.add_instruction(migraphx.op("add"), [red, c4])
rsqrt = m.add_instruction(migraphx.op("rsqrt"), [add])
rsqrt_mb = m.add_instruction(migraphx.op("multibroadcast", out_lens=s1.lens()), [rsqrt])
mul2 = m.add_instruction(migraphx.op("mul"), [x0, rsqrt_mb])
mul3 = m.add_instruction(migraphx.op("mul"), [mul2, x1])

Run:

migraphx-driver compile fuse_reduce.py --bf16

Error:

/long_pathname_so_that_rpms_can_package_the_debug_info/src/llvm-project/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp:1143: void VerifySDNode(llvm::SDNode*, const llvm::TargetLowering*): Assertion `(Op.getValueType() == EltVT || (EltVT.isInteger() && Op.getValueType().isInteger() && EltVT.bitsLE(Op.getValueType()))) && "Wrong operand type!"' failed.
@shivadbhavsar shivadbhavsar self-assigned this Dec 20, 2024
@shivadbhavsar shivadbhavsar added the bug Something isn't working label Dec 20, 2024
@shivadbhavsar
Copy link
Contributor Author

It turns out that removing the last mul doesnt cause this issue. For reference, here are the produces sources for both cases:

No mul3 (compile works)

#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>

namespace migraphx {

template<class Tx0>
__device__ __attribute__((const)) auto pointwise0(Tx0 x0) {
// @param:x0 -> float_type, {1}, {0}
// convert[target_type=15] -> bf16_type, {1}, {0}
auto zz1 = migraphx::convert<bf16>(migraphx::convert<bf16>(x0));
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz1;
return zzreturn;

}
template<class Tx0>
__device__ __attribute__((const)) auto pointwise1(Tx0 x0) {
// @literal -> bf16_type, {1}, {0}
auto zz0 = bf16(0.0078125);
// @literal -> bf16_type, {1}, {0}
auto zz1 = bf16(2);
// @param:x0 -> bf16_type, {1}, {0}
// pow -> bf16_type, {1}, {0}
auto zz3 = migraphx::convert<bf16>(migraphx::pow(x0, zz1));
// mul -> bf16_type, {1}, {0}
auto zz4 = migraphx::convert<bf16>(zz3 * zz0);
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz4;
return zzreturn;

}
template<class Tx0>
__device__ __attribute__((const)) auto pointwise2(Tx0 x0) {
// @literal -> bf16_type, {1}, {0}
auto zz0 = bf16(9.9837779998779297e-07);
// @param:x0 -> bf16_type, {1}, {0}
// add -> bf16_type, {1}, {0}
auto zz2 = migraphx::convert<bf16>(x0 + zz0);
// rsqrt -> bf16_type, {1}, {0}
auto zz3 = migraphx::convert<bf16>(migraphx::rsqrt(zz2));
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz3;
return zzreturn;

}
template<class Tx0, class Tx1>
__device__ __attribute__((const)) auto pointwise3(Tx0 x0,Tx1 x1) {
// @param:x1 -> bf16_type, {1}, {0}
// @param:x0 -> bf16_type, {1}, {0}
// mul -> bf16_type, {1}, {0}
auto zz2 = migraphx::convert<bf16>(x0 * x1);
// convert[target_type=2] -> float_type, {1}, {0}
auto zz3 = migraphx::convert<float>(migraphx::convert<float>(zz2));
// @return -> float_type, {1}, {0}
auto zzreturn = zz3;
return zzreturn;

}
template<class Tx0, class Tr, class Tout_idx>
__device__ __attribute__((const)) auto fused_reduce_op(Tx0 x0,Tr r,Tout_idx out_idx) {
(void)out_idx;
// @param:x0 -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
// pointwise -> bf16_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz1 = r.inner([=](auto x0_lambda_param) { return pointwise0(x0_lambda_param); })(x0);
// pointwise -> bf16_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz2 = r.lazy_inner([=](auto zz1_lambda_param) { return pointwise1(zz1_lambda_param); })(zz1);
// reduce_sum[axes={3}] -> bf16_type, {1, 24, 4608, 1}, {110592, 4608, 1, 1}
auto zz3 = op::id{}(r.reduce(op::sum{}, 0, op::id{})(zz2));
// pointwise -> bf16_type, {1, 24, 4608, 1}, {110592, 4608, 1, 1}
auto zz4 = pointwise2(zz3);
// multibroadcast[out_lens={1, 24, 4608, 128},out_dyn_dims={}] -> bf16_type, {1, 24, 4608, 128}, {110592, 4608, 1, 0}
auto zz5 = zz4;
// pointwise -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz6 = r.inner([=](auto zz1_lambda_param) { return pointwise3(zz1_lambda_param, zz5); })(zz1);
// @return -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zzreturn = make_tuple(zz6);
return zzreturn;

}


extern "C" {
MIGRAPHX_GLOBAL void convert_pow_mul_reduce_sum_add_rsqrt_mul_convert_kernel(void * private_p0,void * private_p1)
{
    transform_args(make_tensors(), vectorize<4, 1>(), rotate_and_pack_last<1>())(private_p0,private_p1)([](auto y, auto... xs) {
        fused_reduce<reduce::block, decltype(make_shape(index_ints<110592, 1>{}, index_ints<1, 1>{}))>(y, assign_none{}, partial(MIGRAPHX_LIFT(fused_reduce_op))(xs...));
    });
}
    
}

} // namespace migraphx

With mul3 (compile fails):

#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>

namespace migraphx {

template<class Tx0>
__device__ __attribute__((const)) auto pointwise0(Tx0 x0) {
// @param:x0 -> float_type, {1}, {0}
// convert[target_type=15] -> bf16_type, {1}, {0}
auto zz1 = migraphx::convert<bf16>(migraphx::convert<bf16>(x0));
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz1;
return zzreturn;

}
template<class Tx0>
__device__ __attribute__((const)) auto pointwise1(Tx0 x0) {
// @literal -> bf16_type, {1}, {0}
auto zz0 = bf16(0.0078125);
// @literal -> bf16_type, {1}, {0}
auto zz1 = bf16(2);
// @param:x0 -> bf16_type, {1}, {0}
// pow -> bf16_type, {1}, {0}
auto zz3 = migraphx::convert<bf16>(migraphx::pow(x0, zz1));
// mul -> bf16_type, {1}, {0}
auto zz4 = migraphx::convert<bf16>(zz3 * zz0);
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz4;
return zzreturn;

}
template<class Tx0>
__device__ __attribute__((const)) auto pointwise2(Tx0 x0) {
// @literal -> bf16_type, {1}, {0}
auto zz0 = bf16(9.9837779998779297e-07);
// @param:x0 -> bf16_type, {1}, {0}
// add -> bf16_type, {1}, {0}
auto zz2 = migraphx::convert<bf16>(x0 + zz0);
// rsqrt -> bf16_type, {1}, {0}
auto zz3 = migraphx::convert<bf16>(migraphx::rsqrt(zz2));
// @return -> bf16_type, {1}, {0}
auto zzreturn = zz3;
return zzreturn;

}
template<class Tx0, class Tx1, class Tx2>
__device__ __attribute__((const)) auto pointwise3(Tx0 x0,Tx1 x1,Tx2 x2) {
// @param:x0 -> float_type, {1}, {0}
// convert[target_type=15] -> bf16_type, {1}, {0}
auto zz1 = migraphx::convert<bf16>(migraphx::convert<bf16>(x0));
// @param:x2 -> bf16_type, {1}, {0}
// @param:x1 -> bf16_type, {1}, {0}
// mul -> bf16_type, {1}, {0}
auto zz4 = migraphx::convert<bf16>(x1 * x2);
// mul -> bf16_type, {1}, {0}
auto zz5 = migraphx::convert<bf16>(zz4 * zz1);
// convert[target_type=2] -> float_type, {1}, {0}
auto zz6 = migraphx::convert<float>(migraphx::convert<float>(zz5));
// @return -> float_type, {1}, {0}
auto zzreturn = zz6;
return zzreturn;

}
template<class Tx0, class Tx1, class Tr, class Tout_idx>
__device__ __attribute__((const)) auto fused_reduce_op(Tx0 x0,Tx1 x1,Tr r,Tout_idx out_idx) {
(void)out_idx;
// @param:x0 -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
// pointwise -> bf16_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz1 = r.inner([=](auto x0_lambda_param) { return pointwise0(x0_lambda_param); })(x0);
// pointwise -> bf16_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz2 = r.lazy_inner([=](auto zz1_lambda_param) { return pointwise1(zz1_lambda_param); })(zz1);
// reduce_sum[axes={3}] -> bf16_type, {1, 24, 4608, 1}, {110592, 4608, 1, 1}
auto zz3 = op::id{}(r.reduce(op::sum{}, 0, op::id{})(zz2));
// pointwise -> bf16_type, {1, 24, 4608, 1}, {110592, 4608, 1, 1}
auto zz4 = pointwise2(zz3);
// multibroadcast[out_lens={1, 24, 4608, 128},out_dyn_dims={}] -> bf16_type, {1, 24, 4608, 128}, {110592, 4608, 1, 0}
auto zz5 = zz4;
// @param:x1 -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
// pointwise -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zz7 = r.inner([=](auto x1_lambda_param, auto zz1_lambda_param) { return pointwise3(x1_lambda_param, zz1_lambda_param, zz5); })(x1, zz1);
// @return -> float_type, {1, 24, 4608, 128}, {14155776, 589824, 128, 1}
auto zzreturn = make_tuple(zz7);
return zzreturn;

}


extern "C" {
MIGRAPHX_GLOBAL void convert_pow_mul_reduce_sum_add_rsqrt_convert_mul_mul_convert_kernel(void * private_p0,void * private_p1,void * private_p2)
{
    transform_args(make_tensors(), vectorize<4, 1>(), rotate_and_pack_last<1>())(private_p0,private_p1,private_p2)([](auto y, auto... xs) {
        fused_reduce<reduce::block, decltype(make_shape(index_ints<110592, 1>{}, index_ints<1, 1>{}))>(y, assign_none{}, partial(MIGRAPHX_LIFT(fused_reduce_op))(xs...));
    });
}
    
}

} // namespace migraphx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant