Skip to content

Commit

Permalink
matmul: x64: added support for bf16,f16 bias dt
Browse files Browse the repository at this point in the history
  • Loading branch information
amakarev authored and tprimak committed Oct 22, 2024
1 parent bf58e72 commit 188ae7f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 8 deletions.
10 changes: 6 additions & 4 deletions src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,15 @@ status_t brgemm_desc_set_postops(brgemm_desc_t *brg,
data_type::f16)))
return status::unimplemented;
const auto bias_f8_e5m2_compatible
= one_of(dt_d, data_type::f32, data_type::f16, data_type::f8_e5m2)
= one_of(dt_d, data_type::f32, data_type::f16, data_type::bf16,
data_type::f8_e5m2)
&& one_of(dt_bias, data_type::undef, data_type::f32, data_type::f16,
data_type::f8_e5m2, data_type::f8_e4m3);
data_type::bf16, data_type::f8_e5m2, data_type::f8_e4m3);
const auto bias_f8_e4m3_compatible
= one_of(dt_d, data_type::f32, data_type::f16, data_type::f8_e4m3)
= one_of(dt_d, data_type::f32, data_type::f16, data_type::bf16,
data_type::f8_e4m3)
&& one_of(dt_bias, data_type::undef, data_type::f32, data_type::f16,
data_type::f8_e4m3, data_type::f8_e5m2);
data_type::bf16, data_type::f8_e4m3, data_type::f8_e5m2);
if (!IMPLICATION(brg->is_fp8,
bias_f8_e5m2_compatible || bias_f8_e4m3_compatible))
return status::unimplemented;
Expand Down
5 changes: 4 additions & 1 deletion src/cpu/x64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
const bool is_bia_dt_correct
= IMPLICATION(is_int8 == true,
one_of(bia_dt, f32, s32, s8, u8, bf16))
&& IMPLICATION(!is_int8, one_of(bia_dt, f32, src_dt));
&& IMPLICATION(
is_f8 == true, one_of(bia_dt, f32, f16, bf16, src_dt))
&& IMPLICATION(
!(is_int8 || is_f8), one_of(bia_dt, f32, src_dt));
return IMPLICATION(with_bias(), is_bia_dt_correct && is_bias_1xN());
};

Expand Down
6 changes: 3 additions & 3 deletions tests/benchdnn/inputs/matmul/test_matmul_fp8
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
--dt=f8_e4m3:f8_e4m3:f32,f8_e4m3,f8_e5m2:f8_e5m2:f32,f8_e5m2
--stag=ab,ba --wtag=ab,ba --dtag=ab
--runtime_dims_masks=0,2:1,1:0,3:1
--bia_dt=undef,f32 --bia_mask=2
--bia_dt=undef,f32,f16,bf16 --bia_mask=2

--attr-scales=
--attr-post-ops=
Expand All @@ -21,8 +21,8 @@

--stag=ba --wtag=ab,ba --dtag=ab
--runtime_dims_masks=3:1,3:3
--bia_dt=f8_e4m3,f8_e5m2 --bia_mask=1,2,3
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:2.25
--bia_dt=f8_e4m3,f8_e5m2,f16,bf16 --bia_mask=1,2,3
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:4
--attr-post-ops=add:f32,sum+mul:s32:per_oc+linear:2:-1
--batch=shapes_2d

Expand Down

0 comments on commit 188ae7f

Please sign in to comment.