Skip to content

Commit

Permalink
Merge branch 'add_int4_decompression_example' of https://github.com/r…
Browse files Browse the repository at this point in the history
…upakroyintel/oneDNN into add_int4_decompression_example
  • Loading branch information
rupakroyintel committed Oct 31, 2024
2 parents 53ecaea + 79c307d commit 1abe160
Show file tree
Hide file tree
Showing 48 changed files with 435 additions and 138 deletions.
2 changes: 1 addition & 1 deletion cmake/configuring_primitive_list.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ if (DNNL_ENABLE_PRIMITIVE_GPU_ISA STREQUAL "ALL")
else()
foreach(isa ${DNNL_ENABLE_PRIMITIVE_GPU_ISA})
string(TOUPPER ${isa} uisa)
if(NOT "${uisa}" MATCHES "^(GEN9|GEN11|XELP|XEHP|XEHPG|XEHPC|XE2)$")
if(NOT "${uisa}" MATCHES "^(GEN9|GEN11|XELP|XEHP|XEHPG|XEHPC|XE2|XE3)$")
message(FATAL_ERROR "Unsupported primitive GPU ISA: ${uisa}")
endif()
set(BUILD_${uisa} TRUE)
Expand Down
2 changes: 1 addition & 1 deletion cmake/options.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ set(DNNL_ENABLE_PRIMITIVE_GPU_ISA "ALL" CACHE STRING
implementations will always be available. Valid values:
- ALL (the default). Includes all ISA to be enabled.
- <ISA_NAME>;<ISA_NAME>;... Includes only selected ISA to be enabled.
Possible values are: GEN9, GEN11, XELP, XEHP, XEHPG, XEHPC, XE2.")
Possible values are: GEN9, GEN11, XELP, XEHP, XEHPG, XEHPC, XE2, XE3.")

set(ONEDNN_ENABLE_GEMM_KERNELS_ISA "ALL" CACHE STRING
"Specifies an ISA set of GeMM kernels residing in x64/gemm folder to be
Expand Down
2 changes: 1 addition & 1 deletion doc/build/build_options.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Example that enables SSE41 and AVX2 sets:
#### ONEDNN_ENABLE_PRIMITIVE_GPU_ISA
This option supports several values: `ALL` (the default) which enables all
ISA implementations or any set of `GEN9`, `GEN11`, `XELP`, `XEHP`, `XEHPG`,
`XEHPC`, and `XE2`. Selected ISA will enable correspondent parts in
`XEHPC`, `XE2`, and `XE3`. Selected ISA will enable correspondent parts in
just-in-time kernel generation based implementations. OpenCL based kernels and
implementations will always be available. Example that enables XeLP and XeHP
set:
Expand Down
20 changes: 13 additions & 7 deletions doc/graph/programming_model/graph_basic_concepts.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,19 @@ tensor as the edge between them.
## Graph

`Graph` (@ref dnnl::graph::graph) contains a set of operations. A graph object
is associated to a specific engine kind (@ref dnnl::engine::kind). Multiple
operations can be added (@ref dnnl::graph::graph::add_op) along with input and
output logical tensors to a graph. After finishing adding operations,
finalization API (@ref dnnl::graph::graph::finalize) can be called to indicate
that the graph is ready for partitioning. By calling partitioning API (@ref
dnnl::graph::graph::get_partitions), a group of partitions from the graph will
be returned .
is associated to a specific engine kind (@ref dnnl::engine::kind). In addition,
you can set the graph-level floating-point math mode through the setter API
( @ref dnnl::graph::graph::set_fpmath_mode ) or in the constructor. The API
accepts two paramters, the given floating point math mode and a optional boolean
flag to indicate whether to use floating-point arithmetic for integral
operations.

Multiple operations can be added (@ref dnnl::graph::graph::add_op) along with
input and output logical tensors to a graph. After finishing adding the
operations, finalization API (@ref dnnl::graph::graph::finalize) can be called
to indicate that the graph is ready for partitioning. By calling partitioning
API (@ref dnnl::graph::graph::get_partitions), a group of partitions from the
graph will be returned.

## Partition

Expand Down
22 changes: 22 additions & 0 deletions include/oneapi/dnnl/dnnl_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,28 @@ dnnl_status_t DNNL_API dnnl_graph_graph_create_with_fpmath_mode(
/// otherwise.
dnnl_status_t DNNL_API dnnl_graph_graph_destroy(dnnl_graph_graph_t graph);

/// Set the floating point math mode for a graph.
///
/// @param graph The target graph.
/// @param mode The floating-point math mode.
/// @param apply_to_int The flag that controls whether to use floating-point
/// arithmetic for integral operations.
/// @returns #dnnl_success on success or a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_graph_graph_set_fpmath_mode(
dnnl_graph_graph_t graph, dnnl_fpmath_mode_t mode, int apply_to_int);

/// Get the floating point math mode for a graph.
///
/// @param graph The target graph.
/// @param mode The floating-point math mode.
/// @param apply_to_int The flag that controls whether to use floating-point
/// arithmetic for integral operations.
/// @returns #dnnl_success on success or a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_graph_graph_get_fpmath_mode(
dnnl_graph_graph_t graph, dnnl_fpmath_mode_t *mode, int *apply_to_int);

/// Adds an operation into a graph. The API will return failure if the operator
/// has already been added to the graph or the operation cannot pass the schema
/// check in the library (eg. input and output numbers and data types, the
Expand Down
35 changes: 35 additions & 0 deletions include/oneapi/dnnl/dnnl_graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,10 @@ class graph : public graph_handle {
/// mode. All partitions returned from the graph will inherit the engine
/// kind and floating-point math mode.
///
/// Setting the floating-point math mode enables automatic down-conversion
/// of inputs for the given graph, promoting speedup by using
/// lower-precision data types when available.
///
/// @param engine_kind Engine kind.
/// @param mode Floating-point math mode.
graph(engine::kind engine_kind, fpmath_mode mode) {
Expand All @@ -1384,6 +1388,37 @@ class graph : public graph_handle {
reset(g);
}

/// Set the floating point math mode for a graph. Users can enforce the
/// graph to comply with the mode by specifying a boolean flag with the
/// setter function.
///
/// @param mode The floating-point math mode.
/// @param apply_to_int The flag that controls whether to use
/// floating-point arithmetic for integral operations.
void set_fpmath_mode(fpmath_mode mode, bool apply_to_int = false) {
error::wrap_c_api(dnnl_graph_graph_set_fpmath_mode(
get(), convert_to_c(mode), apply_to_int),
"could not set fpmath mode graph attribute");
}

/// Get the floating point math mode and the boolean flag that specifies
/// whether the graph will be enforced to comply the mode.
///
/// @param mode The floating-point math mode.
/// @param apply_to_int The flag that controls whether to use
/// floating-point arithmetic for integral operations.
void get_fpmath_mode(fpmath_mode &mode, bool &apply_to_int) const {
dnnl_fpmath_mode_t c_mode;
int c_apply_to_int;

error::wrap_c_api(dnnl_graph_graph_get_fpmath_mode(
get(), &c_mode, &c_apply_to_int),
"could not get fpmath mode graph attribute");

mode = fpmath_mode(c_mode);
apply_to_int = static_cast<bool>(c_apply_to_int);
}

/// Adds an op into the graph to construct a computational DAG. The API will
/// return failure if the operator has already been added to the graph or
/// the operation cannot pass the schema check in the library (eg. input and
Expand Down
16 changes: 16 additions & 0 deletions src/cpu/aarch64/matmul/acl_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,26 @@ status_t init_conf_matmul(acl_matmul_conf_t &amp, memory_desc_t &src_md,
// for e.g when ab in abcd is 1x1
bool batch_ok = IMPLICATION(src_batch > 1, wei_batch == 1)
&& IMPLICATION(wei_batch > 1, src_batch == 1);

ACL_CHECK_SUPPORT(src_d.ndims() == 4 && src_batch != wei_batch && !batch_ok,
"matmul broadcast supported only for 3D shapes and 4D shapes when "
"ab is 1x1");

if (src_d.ndims() == 4 && src_batch == wei_batch
&& src_d.dims()[0] != wei_d.dims()[0]) { // 4D broadcast occurred
if (src_d.dims()[0] == 1 && wei_d.dims()[0] != 1) { // Broadcast src
ACL_CHECK_SUPPORT(
IMPLICATION(src_d.dims()[1] != 1, wei_d.dims()[1] == 1),
"acl only broadcasts one of src or wei at once");
}

if (wei_d.dims()[0] == 1 && src_d.dims()[0] != 1) { // Broadcast wei
ACL_CHECK_SUPPORT(
IMPLICATION(src_d.dims()[1] == 1, wei_d.dims()[1] != 1),
"acl only broadcasts one of src or wei at once");
}
}

// ACL does not support bias
bool with_bias = md.bias_desc.format_kind != format_kind::undef;
ACL_CHECK_SUPPORT(with_bias, "ACL does not support bias for matmul");
Expand Down
7 changes: 7 additions & 0 deletions src/gpu/intel/jit/gemm/generator/pieces/allocators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ FlagRegister VirtualFlagAllocator::assignPhysical(VirtualFlag vflag)
return pflag.toPhysical();
}

bool VirtualFlagAllocator::lock(VirtualFlag vflag, bool allowAlreadyLocked) {
bool wasLocked = isLocked(vflag);
if (wasLocked && !allowAlreadyLocked) stub("Illegally locking an already-locked flag register");
locked |= mask(vflag);
return wasLocked;
}

bool VirtualFlagAllocator::canLock(int n) const
{
uint8_t unlocked = ~locked & ((1 << nflag) - 1);
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/jit/gemm/generator/pieces/allocators.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class VirtualFlagAllocator {

bool isVirtual(VirtualFlag vflag) { return (vflag.idx >= nflag); }

bool lock(VirtualFlag vflag) { bool wasLocked = isLocked(vflag); locked |= mask(vflag); return wasLocked; }
bool lock(VirtualFlag vflag, bool allowAlreadyLocked = false);
void unlock(VirtualFlag vflag) { locked &= ~mask(vflag); }
bool isLocked(VirtualFlag vflag) const { return !(~locked & mask(vflag)); }
bool canLock(int n = 1) const;
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/jit/gemm/generator/pieces/copy.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ void BLASKernelGenerator<hw>::copyExecute(CopyPlan &&plan, CommonState &state)
if (!state.vflagsEnabled())
for (int i = 0; i < nflag; i++)
if (!raVFlag0.isFree(VirtualFlag{i}))
raVFlag0.lock(VirtualFlag{i});
raVFlag0.lock(VirtualFlag{i}, true);
auto raVFlag = raVFlag0;

// If we have enough free flags, use those.
Expand Down
20 changes: 12 additions & 8 deletions src/gpu/intel/jit/gemm/generator/pieces/copy_plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,10 +733,12 @@ void CopyPlan::planTypeConversions()
} else
planEmulatedHalveFloat(i);
}
} else if (st == DataType::hf8 && dt == DataType::hf && hw < HW::Xe3) {
planEmulatedHF8ToHF(i);
} else if (st == DataType::hf && dt == DataType::hf8 && hw < HW::Xe3) {
planEmulatedHFToHF8(i);
} else if (st == DataType::hf8 && dt == DataType::hf) {
if (hw < HW::Xe3)
planEmulatedHF8ToHF(i);
} else if (st == DataType::hf && dt == DataType::hf8) {
if (hw < HW::Xe3)
planEmulatedHFToHF8(i);
} else if (st != dt && (isFP8(st) || isFP8(dt))) {
copyThrough(i, DataType::hf, 1);
rerun = true;
Expand Down Expand Up @@ -1328,11 +1330,13 @@ void CopyPlan::legalizeRegions()
if (!i.dst) continue;

/* Check for special packed conversion cases */
if (i.op == Opcode::mov && s0t == DataType::hf && dt == DataType::bf8) {
// hf -> bf8: src0/dst must be packed unit stride, zero offset
if (i.src0.offset != 0 || i.src0.stride != 1)
if (i.op == Opcode::mov && ((s0t == DataType::hf && isFP8(dt))
|| (dt == DataType::hf && isFP8(s0t)))) {
// hf <-> bf8/hf8: src0/dst must be packed unit stride, zero offset
if (i.src0.offset != 0 || i.src0.stride != 1) {
repositionSrc(i, 0, 1, 0);
if (i.dst.offset != 0 || i.dst.stride != 1)
rerun = true;
} else if (i.dst.offset != 0 || i.dst.stride != 1)
repositionDst(i, 1, 0);
if (i.simd == 1) hw_unsupported();
continue;
Expand Down
4 changes: 3 additions & 1 deletion src/gpu/intel/jit/gemm/generator/pieces/matrix_access.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,8 @@ void BLASKernelGenerator<hw>::prepareSeriesRegisterBlockMasking(const vector<Reg
for (int startPreload = start; startPreload < nblocks; startPreload++) {
auto &block = layout[startPreload];

if (!block.isLoadBlock()) continue;

bool plFlag[2];
for (int i = 0; i <= 1; i++)
plFlag[i] = block.flag[i] && (block.flag[i] != state.blockEMask);
Expand All @@ -630,7 +632,7 @@ void BLASKernelGenerator<hw>::prepareSeriesRegisterBlockMasking(const vector<Reg

auto &flag = block.flag[plFlag[0] ? 0 : 1];
if (!state.raVFlag.canLock(flag.n)) break;
state.raVFlag.lock(getPhysicalFlag(flag, state));
state.raVFlag.lock(getPhysicalFlag(flag, state), true);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/jit/gemm/generator/pieces/quantization.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ bool BLASKernelGenerator<hw>::gemmMake2DQuantizationLayouts(bool isA, const GEMM

bool int4SpecialPath = Tx_ext.isInt4() && one_of(Tx, Type::f16, Type::bf16, Type::f32);
if (int4SpecialPath)
Txo_int = Txs_int = Type::f16;
Txo_int = Txs_int = Tx_scaleOp = Type::f16;

// Get tile sizes, depending on whether A/B are copied to SLM.
// For late scaling (after compute), scales are always applied to the whole tile.
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/jit/gemm/include/type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class Type {
constexpr Type baseType() const { return *this; }

template <typename U> constexpr friend int operator*(U a, Type t) {
return t.isInt4() ? int((a + 1) / 2) : int(a * (U(1) << t.log2Size()));
return t.isInt4() ? int((unsigned(a) + 1) >> 1) : int(a * (U(1) << t.log2Size()));
}
template <typename U> constexpr friend int operator*(Type t, U a) { return a * t; }
template <typename U> friend int operator*=(U &a, Type t) { a = a * t; return a; }
Expand Down
1 change: 1 addition & 0 deletions src/gpu/intel/jit/gemm/selector/db/kernel.db
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ auto _CATALOG_ = kcatalog::toFlatCatalog({
{{'C', "gemm", {"H", "H", "H"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, 8, -1}, {1, 1, 1}, ""}, "ab2x2 as16 ab l4 int", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 2048}, {4096, 4096, 2048}, {32, 8, 16}, {2, 8, 1}, 1, (WGType) 0, 1, 0, 0, {2, 2, 2}, {true, true, true}}, {'W', 1, {256}}},
{{'C', "gemm", {"H", "H", "H"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "qxy"}, "sb2/1 su8x2 ab l4 ca1 wg 2x8 int", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 2048}, {4096, 4096, 2048}, {64, 16, 16}, {2, 8, 1}, 1, (WGType) 1, 1, 4096, 0, {2, 2, 2}, {false, false, true}}, {'W', 1, {1024}}},
{{'C', "gemm", {"H", "H", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "ab4 as8 ab l4 ca1 wg 2x8 int", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 2048}, {4096, 4096, 2048}, {32, 16, 8}, {2, 8, 1}, 1, (WGType) 1, 1, 2048, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {512}}},
{{'C', "gemm", {"H", "H", "S"}, {"N", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, 16, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "ab4x2 ab8 ab wg 2x1x8 ikr kc4 acb ar sb32 bk0 np", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 2}, {4096, 4096, 16777216}, {4096, 4096, 16777216}, {32, 1, 8}, {2, 1, 8}, 1, (WGType) 0, 4101, 0, 256, {2, 2, 4}, {true, true, true}}, {'W', 1, {32}}},
{{'C', "gemm", {"H", "H", "H"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "ab2x2 ab2x2 ab k4 l4 vnc", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 2048}, {4096, 4096, 2048}, {32, 32, 4}, {2, 8, 1}, 1, (WGType) 0, 1, 0, 0, {2, 2, 2}, {true, true, true}}, {'W', 1, {1024}}},
{{'C', "gemm", {"H", "H", "S"}, {"N", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "ab2 ab8 ab l4 cab1 wg 4x4 int", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 2048}, {4096, 4096, 2048}, {32, 16, 8}, {4, 4, 1}, 1, (WGType) 1, 1, 6144, 0, {2, 2, 4}, {true, true, true}}, {'W', 1, {512}}},
{{'C', "gemm", {"H", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, ""}, "as4 as8 ab k8 l4 vnc", {8, (LoopType) 0, 128, {(LoopType) 0, (LoopType) 1, (LoopType) 255}, {4096, 4096, 1024}, {4096, 4096, 1024}, {32, 32, 8}, {2, 8, 1}, 1, (WGType) 0, 1, 0, 0, {2, 2, 2}, {true, true, true}}, {'W', 1, {1024}}},
Expand Down
56 changes: 30 additions & 26 deletions src/gpu/intel/jit/ngen/ngen_gen12.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,36 +635,40 @@ static inline constexpr14 TernaryOperand12 encodeTernaryOperand12(const Extended

static inline void encodeCommon12(Instruction12 &i, Opcode opcode, const InstructionModifier &mod, const RegData &dst, EncodingTag12 tag)
{
i.common.opcode = static_cast<unsigned>(opcode) | (mod.parts.autoSWSB << 7);
i.common.swsb = SWSBInfo12(mod.getSWSB(), opcode).raw();
i.common.execSize = mod.parts.eSizeField;
i.common.execOffset = mod.parts.chanOff;
i.common.flagReg = (mod.parts.flagRegNum << 1) | mod.parts.flagSubRegNum;
i.common.predCtrl = mod.parts.predCtrl;
i.common.predInv = mod.parts.predInv;
i.common.cmptCtrl = mod.parts.cmptCtrl;
i.common.debugCtrl = mod.parts.debugCtrl;
i.common.maskCtrl = mod.parts.maskCtrl;
i.common.atomicCtrl = mod.parts.threadCtrl;
i.common.accWrCtrl = mod.parts.accWrCtrl;
i.common.saturate = mod.parts.saturate;
Instruction12 i2; /* separate variable to avoid gcc13 bug */
i2.common.opcode = static_cast<unsigned>(opcode) | (mod.parts.autoSWSB << 7);
i2.common.swsb = SWSBInfo12(mod.getSWSB(), opcode).raw();
i2.common.execSize = mod.parts.eSizeField;
i2.common.execOffset = mod.parts.chanOff;
i2.common.flagReg = (mod.parts.flagRegNum << 1) | mod.parts.flagSubRegNum;
i2.common.predCtrl = mod.parts.predCtrl;
i2.common.predInv = mod.parts.predInv;
i2.common.cmptCtrl = mod.parts.cmptCtrl;
i2.common.debugCtrl = mod.parts.debugCtrl;
i2.common.maskCtrl = mod.parts.maskCtrl;
i2.common.atomicCtrl = mod.parts.threadCtrl;
i2.common.accWrCtrl = mod.parts.accWrCtrl;
i2.common.saturate = mod.parts.saturate;
i.common = i2.common;
}

static inline void encodeCommon12(Instruction12 &i, Opcode opcode, const InstructionModifier &mod, const RegData &dst, EncodingTagXeHPC tag)
{
i.common.opcode = static_cast<unsigned>(opcode) | (mod.parts.autoSWSB << 7);
i.commonXeHPC.swsb = SWSBInfoXeHPC(mod.getSWSB(), opcode).raw();
i.commonXeHPC.execSize = mod.parts.eSizeField;
i.commonXeHPC.flagReg = (mod.parts.flagRegNum1 << 2) | (mod.parts.flagRegNum << 1) | mod.parts.flagSubRegNum;
i.commonXeHPC.execOffset = mod.parts.chanOff >> 1;
i.commonXeHPC.predCtrl = mod.parts.predCtrl;
i.common.predInv = mod.parts.predInv;
i.common.cmptCtrl = mod.parts.cmptCtrl;
i.common.debugCtrl = mod.parts.debugCtrl;
i.common.maskCtrl = mod.parts.maskCtrl;
i.common.atomicCtrl = mod.parts.threadCtrl;
i.commonXeHPC.dstExt = (dst.isIndirect() ? dst.getOffset() : dst.getByteOffset()) & 1;
i.common.saturate = mod.parts.saturate;
Instruction12 i2; /* separate variable to avoid gcc13 bug */
i2.common.opcode = static_cast<unsigned>(opcode) | (mod.parts.autoSWSB << 7);
i2.commonXeHPC.swsb = SWSBInfoXeHPC(mod.getSWSB(), opcode).raw();
i2.commonXeHPC.execSize = mod.parts.eSizeField;
i2.commonXeHPC.flagReg = (mod.parts.flagRegNum1 << 2) | (mod.parts.flagRegNum << 1) | mod.parts.flagSubRegNum;
i2.commonXeHPC.execOffset = mod.parts.chanOff >> 1;
i2.commonXeHPC.predCtrl = mod.parts.predCtrl;
i2.common.predInv = mod.parts.predInv;
i2.common.cmptCtrl = mod.parts.cmptCtrl;
i2.common.debugCtrl = mod.parts.debugCtrl;
i2.common.maskCtrl = mod.parts.maskCtrl;
i2.common.atomicCtrl = mod.parts.threadCtrl;
i2.commonXeHPC.dstExt = (dst.isIndirect() ? dst.getOffset() : dst.getByteOffset()) & 1;
i2.common.saturate = mod.parts.saturate;
i.common = i2.common;
}

template <typename Tag>
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/ocl/ref_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct ref_matmul_t : public gpu_primitive_t {
&& dst_dt_ == f32;
const bool is_f16 = src_dt_ == f16
&& utils::one_of(wei_dt_, f16, s8, u8, s4, u4)
&& utils::one_of(dst_dt_, u8, s8, f16);
&& utils::one_of(dst_dt_, u8, s8, f16, f32);
const bool is_f8
= (utils::one_of(src_dt_, f8_e5m2, f8_e4m3)
|| utils::one_of(wei_dt_, f8_e5m2, f8_e4m3))
Expand Down
4 changes: 3 additions & 1 deletion src/graph/backend/dnnl/dnnl_partition_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ status_t dnnl_partition_impl_t::compile(

// Dispatch to fake kernel if one of the output dimensions is zero.
const std::vector<std::shared_ptr<op_t>> &fused_op = part->get_ops();
auto agraph = graph_t(fused_op, get_engine_kind(), get_fpmath_mode());
auto fpm = get_fpmath_mode();
auto agraph = graph_t(fused_op, get_engine_kind());
agraph.set_fpmath_mode(fpm.mode_, fpm.apply_to_int_);
agraph.set_user_inputs_outputs(inputs, outputs);
agraph.infer_shape();
for (const auto &val : agraph.get_output_values()) {
Expand Down
Loading

0 comments on commit 1abe160

Please sign in to comment.