Skip to content

Commit

Permalink
PR #17222: [XLA:CPU] Allow convert natively on supported CPUs
Browse files Browse the repository at this point in the history
Imported from GitHub PR #17222

The performance for some workloads dropped and git bisect points to this [commit](c48011a) on XLA to be causing the drop. The comments indicate that LLVM optimizations are being suppressed when converting from FP32-BF16 and back since it may cause performance degradation on other cpu's. Since, some cpu's can handle BF16 efficiently, this is not required and can be bypassed.
Copybara import of the project:

--
36d2883 by Kanvi Khanna <[email protected]>:

allow convert natively

--
a7f6f71 by Kanvi Khanna <[email protected]>:

Address comments

--
1422583 by Kanvi Khanna <[email protected]>:

Add test

--
9693d9e by Kanvi Khanna <[email protected]>:

address commment

Merging this change closes #17222

FUTURE_COPYBARA_INTEGRATE_REVIEW=#17222 from Intel-tensorflow:kanvi/native_convert_support 9693d9e
PiperOrigin-RevId: 683640752
  • Loading branch information
kanvi-nervana authored and Google-ML-Automation committed Oct 11, 2024
1 parent eeb60b9 commit 2961c38
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 1 deletion.
10 changes: 9 additions & 1 deletion xla/service/cpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,21 @@ using llvm_ir::SetToFirstInsertPoint;

namespace cpu {

bool IsNativeConvertSupportedOnTargetCPU(std::string feature_string) {
return (absl::StrContains(feature_string, "+avxneconvert") ||
absl::StrContains(feature_string, "+amx-bf16"));
}

class IrEmitter::CpuElementalIrEmitter : public ElementalIrEmitter {
public:
CpuElementalIrEmitter(const HloModuleConfig& module_config,
IrEmitter* ir_emitter, llvm::Module* module)
: ElementalIrEmitter(
module, ir_emitter->b(),
Options{/*xla_cpu_use_truncate_f32_to_bf16_conversion=*/true}),
Options{/*xla_cpu_use_truncate_f32_to_bf16_conversion=*/
!IsNativeConvertSupportedOnTargetCPU(
ir_emitter->target_machine_features_
.get_target_feature_string())}),
hlo_module_config_(module_config),
ir_emitter_(ir_emitter) {}

Expand Down
2 changes: 2 additions & 0 deletions xla/service/cpu/ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ namespace cpu {
// Forward declare emitter for XLA:CPU thunks.
class IrEmitter2;

bool IsNativeConvertSupportedOnTargetCPU(std::string feature_string);

// This class is the top-level API for the XLA HLO --> LLVM IR compiler. It
// implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
// functions.
Expand Down
59 changes: 59 additions & 0 deletions xla/service/cpu/ir_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,64 @@ TEST_F(IrEmitterTest, ComputeFuncStack) {
ir_emitter.PopComputeFunction();
}

TEST_F(IrEmitterTest, CheckNativeConvertSupportOnTargetCPU) {
std::string spr_feature_string =
"+prfchw,+cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,+"
"avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,+avx512ifma,+xsave,+sse4.2,+"
"tsxldtrk,-sm3,+ptwrite,-widekl,+invpcid,+64bit,+xsavec,-avx10.1-512,+"
"avx512vpopcntdq,+cmov,-avx512vp2intersect,+avx512cd,+movbe,-avxvnniint8,"
"-ccmp,+amx-int8,-kl,-avx10.1-256,+evex512,+avxvnni,-rtm,+adx,+avx2,-"
"hreset,+movdiri,+serialize,-sha512,+vpclmulqdq,+avx512vl,+uintr,-cf,+"
"clflushopt,-raoint,-cmpccxadd,+bmi,+amx-tile,+sse,-avx10.2-256,+gfni,-"
"avxvnniint16,-amx-fp16,-zu,-ndd,+xsaveopt,+rdrnd,+avx512f,+amx-bf16,+"
"avx512bf16,+avx512vnni,-push2pop2,+cx8,+avx512bw,+sse3,+pku,-nf,+"
"fsgsbase,-clzero,-mwaitx,-lwp,+lzcnt,+sha,+movdir64b,-ppx,+wbnoinvd,+"
"enqcmd,-avx10.2-512,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+"
"cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,+avx512bitalg,-rdpru,+clwb,+mmx,+"
"sse2,+rdseed,+avx512vbmi2,-prefetchi,+rdpid,-fma4,+avx512vbmi,+shstk,+"
"vaes,+waitpkg,-sgx,+fxsr,+avx512dq,-sse4a";

std::string skx_feature_string =
"+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-"
"avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,+sse4.2,-"
"tsxldtrk,-sm3,-ptwrite,-widekl,+invpcid,+64bit,+xsavec,-avx10.1-512,-"
"avx512vpopcntdq,+cmov,-avx512vp2intersect,+avx512cd,+movbe,-avxvnniint8,"
"-ccmp,-amx-int8,-kl,-avx10.1-256,+evex512,-avxvnni,+rtm,+adx,+avx2,-"
"hreset,-movdiri,-serialize,-sha512,-vpclmulqdq,+avx512vl,-uintr,-cf,+"
"clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,-avx10.2-256,-gfni,-"
"avxvnniint16,-amx-fp16,-zu,-ndd,+xsaveopt,+rdrnd,+avx512f,-amx-bf16,-"
"avx512bf16,-avx512vnni,-push2pop2,+cx8,+avx512bw,+sse3,+pku,-nf,+"
"fsgsbase,-clzero,-mwaitx,-lwp,+lzcnt,-sha,-movdir64b,-ppx,-wbnoinvd,-"
"enqcmd,-avx10.2-512,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+"
"cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,-rdpru,+clwb,+mmx,+"
"sse2,+rdseed,-avx512vbmi2,-prefetchi,-rdpid,-fma4,-avx512vbmi,-shstk,-"
"vaes,-waitpkg,-sgx,+fxsr,+avx512dq,-sse4a";

std::string srf_feature_string =
"+prfchw,+cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-"
"avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,+sse4.2,-"
"tsxldtrk,-sm3,+ptwrite,-widekl,+invpcid,+64bit,+xsavec,-avx10.1-512,-"
"avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,+avxvnniint8,"
"-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,+avxvnni,-rtm,+adx,+avx2,-"
"hreset,+movdiri,+serialize,+vpclmulqdq,-avx512vl,+uintr,-cf,+clflushopt,"
"-raoint,+cmpccxadd,+bmi,-amx-tile,+sse,-avx10.2-256,+gfni,-avxvnniint16,"
"-amx-fp16,-zu,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-"
"avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,+pku,-nf,+fsgsbase,-clzero,-"
"mwaitx,-lwp,+lzcnt,+sha,+movdir64b,-ppx,+wbnoinvd,+enqcmd,-avx10.2-512,+"
"avxneconvert,-tbm,+pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,"
"+avxifma,+f16c,-avx512bitalg,-rdpru,+clwb,+mmx,+sse2,+rdseed,-"
"avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,+waitpkg,+"
"sgx,+fxsr,-avx512dq,-sse4a";

// Testing sapphire-rapids target
ASSERT_TRUE(IsNativeConvertSupportedOnTargetCPU(spr_feature_string));

// Testing skylake target
ASSERT_FALSE(IsNativeConvertSupportedOnTargetCPU(skx_feature_string));

// Testing sierra-forest target
ASSERT_TRUE(IsNativeConvertSupportedOnTargetCPU(srf_feature_string));
}

} // namespace
} // namespace xla::cpu
4 changes: 4 additions & 0 deletions xla/service/cpu/target_machine_features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,9 @@ int64_t LLVMTargetMachineFeatures::minimum_alignment_for_allocation(
cpu_function_runtime::MinAlign());
}

std::string LLVMTargetMachineFeatures::get_target_feature_string() const {
return target_machine_->getTargetFeatureString().str();
}

} // namespace cpu
} // namespace xla
4 changes: 4 additions & 0 deletions xla/service/cpu/target_machine_features.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class TargetMachineFeatures {
// this functionality).
virtual int vector_register_count(const llvm::Function& function) const = 0;

virtual std::string get_target_feature_string() const = 0;

// Returns the minimum alignment for a buffer of size size_bytes.
virtual int64_t minimum_alignment_for_allocation(
int64_t size_bytes) const = 0;
Expand Down Expand Up @@ -102,6 +104,8 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures {

int64_t minimum_alignment_for_allocation(int64_t size_bytes) const override;

std::string get_target_feature_string() const override;

private:
llvm::TargetTransformInfo* GetTargetTransformInfoFor(
const llvm::Function& function) const;
Expand Down
4 changes: 4 additions & 0 deletions xla/service/cpu/target_machine_features_fake.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class TargetMachineFeaturesWithFakeAlignmentLogic
return fake_alignment_logic_(size_bytes);
}

std::string get_target_feature_string() const override {
LOG(FATAL) << "Unexpected call to " << __func__;
}

private:
std::function<int64_t(int64_t)> fake_alignment_logic_;
};
Expand Down

0 comments on commit 2961c38

Please sign in to comment.