Skip to content

Commit

Permalink
xe: jit: gemm: fixup f32->f16 scale downconversion logic
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad committed Oct 27, 2024
1 parent e89c66f commit 210c689
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
17 changes: 8 additions & 9 deletions src/gpu/intel/jit/gemm/generator/pieces/quantization.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ bool BLASKernelGenerator<hw>::gemmMake2DQuantizationLayouts(bool isA, const GEMM
auto &Tx_scaleOp = isA ? state.Ta_scaleOp : state.Tb_scaleOp;
auto &lateScale = isA ? state.lateScale2DA : state.lateScale2DB;

bool downScale = isA ? problem.downconvertAScales() : problem.downconvertBScales();

bool Tx_bf = problem.Ta_ext == Type::bf16 || problem.Tb_ext == Type::bf16;
Tx_scaleOp = (Tx_bf ? Type(Tx_ext.isInt4() ? Type::f16 : Type::f32) : Txs);
Txo_int = Txo.isInteger() ? Tx.asSignedInt() : Tx;
Expand All @@ -71,14 +73,11 @@ bool BLASKernelGenerator<hw>::gemmMake2DQuantizationLayouts(bool isA, const GEMM
int cpoDiv = 1;
if (Txo_int.isInt8()) Txo_int = Type::s16, cpoDiv = 2;

if (xs2D && (Txs.paddedSize() > Tx.paddedSize())) {
if (problem.aOffset == ABOffset::Calc || problem.bOffset == ABOffset::Calc){
lateScale = true;
Txs_int = Tx_scaleOp = problem.Tc;
} else if (!Tx_bf){
lateScale = false;
Txs_int = Tx_scaleOp = Type::f16;
}
if (downScale)
Tx_scaleOp = Tx;
else if (xs2D && (Txs.paddedSize() > Tx.paddedSize())) {
lateScale = true;
Txs_int = Tx_scaleOp = problem.Tc;
}

bool int4SpecialPath = Tx_ext.isInt4() && one_of(Tx, Type::f16, Type::bf16, Type::f32);
Expand Down Expand Up @@ -174,7 +173,7 @@ bool BLASKernelGenerator<hw>::gemmMake2DQuantizationLayouts(bool isA, const GEMM
int cpo = div_up(crosspack, cpoDiv);

auto makeQRepack = [&](Type Txq, Type Txq_int, vector<RegisterBlock> &repack, vector<RegisterBlock> &src, int m, int n, int cp) {
if (cp > 1 || (cColMajor && (cp != src[0].crosspack)) || Txq != Txq_int || lateScale)
if (cp > 1 || (cColMajor && (cp != src[0].crosspack)) || Txq != Txq_int)
makeUnbackedRegLayout(Txq_int, repack, m, n, wantCM, cp, tileR, tileC, false);
};

Expand Down
13 changes: 11 additions & 2 deletions src/gpu/intel/jit/gemm/include/problem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,17 @@ struct GEMMProblem : public CommonProblem {
bool quantized2DA() const { return (aoPtrDims == 2) || aScale2D; }
bool quantized2DB() const { return (boPtrDims == 2) || bScale2D; }

bool earlyDequantizeA() const { return (aOffset == ABOffset::Calc && Tao.asSigned().isSubsetOf(Ta)) || (aScale2D && ((aOffset != ABOffset::Calc && bOffset != ABOffset::Calc) || Ta_scale.isSubsetOf(Ta))); }
bool earlyDequantizeB() const { return (bOffset == ABOffset::Calc && Tbo.asSigned().isSubsetOf(Tb)) || (bScale2D && ((aOffset != ABOffset::Calc && bOffset != ABOffset::Calc) || Tb_scale.isSubsetOf(Tb))); }
bool downconvertAScales() const { return Ta == Type::f16 && Ta_scale == Type::f32; }
bool downconvertBScales() const { return Tb == Type::f16 && Tb_scale == Type::f32; }

bool earlyDequantizeA() const {
return (aOffset == ABOffset::Calc && Tao.asSigned().isSubsetOf(Ta))
|| (aScale2D && (Ta_scale.isSubsetOf(Ta) || downconvertAScales()));
}
bool earlyDequantizeB() const {
return (bOffset == ABOffset::Calc && Tbo.asSigned().isSubsetOf(Tb))
|| (bScale2D && (Tb_scale.isSubsetOf(Tb) || downconvertBScales()));
}

Type Tc_compute() const {
if (Ta.isInteger() && Tb.isInteger() && Tc == Type::f32)
Expand Down

0 comments on commit 210c689

Please sign in to comment.