diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index d1ddbfa300846b..c0671dd1f0087c 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1663,6 +1663,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) { setOperationAction(ISD::BITCAST, VT, Custom); setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); + setOperationAction(ISD::FP_EXTEND, VT, Custom); setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Legal); @@ -4298,8 +4299,28 @@ static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) { SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); - if (VT.isScalableVector()) + if (VT.isScalableVector()) { + SDValue SrcVal = Op.getOperand(0); + + if (SrcVal.getValueType().getScalarType() == MVT::bf16) { + // bf16 and f32 share the same exponent range so the conversion requires + // them to be aligned with the new mantissa bits zero'd. This is just a + // left shift that is best to isel directly. + if (VT == MVT::nxv2f32 || VT == MVT::nxv4f32) + return Op; + + if (VT != MVT::nxv2f64) + return SDValue(); + + // Break other conversions in two with the first part converting to f32 + // and the second using native f32->VT instructions. + SDLoc DL(Op); + return DAG.getNode(ISD::FP_EXTEND, DL, VT, + DAG.getNode(ISD::FP_EXTEND, DL, MVT::nxv2f32, SrcVal)); + } + return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU); + } if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) return LowerFixedLengthFPExtendToSVE(Op, DAG); diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td index 4922fb280333bb..692cd66d38437d 100644 --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -2320,7 +2320,12 @@ let Predicates = [HasSVEorSME] in { def : Pat<(nxv2f16 (AArch64fcvtr_mt (nxv2i1 (SVEAllActive:$Pg)), nxv2f32:$Zs, (i64 timm0_1), nxv2f16:$Zd)), (FCVT_ZPmZ_StoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>; - // Signed integer -> Floating-point + def : Pat<(nxv4f32 (fpextend nxv4bf16:$op)), + (LSL_ZZI_S $op, (i32 16))>; + def : Pat<(nxv2f32 (fpextend nxv2bf16:$op)), + (LSL_ZZI_S $op, (i32 16))>; + + // Signed integer -> Floating-point def : Pat<(nxv2f16 (AArch64scvtf_mt (nxv2i1 (SVEAllActive):$Pg), (sext_inreg nxv2i64:$Zs, nxv2i16), nxv2f16:$Zd)), (SCVTF_ZPmZ_HtoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>; diff --git a/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll new file mode 100644 index 00000000000000..d72f92c1dac1ff --- /dev/null +++ b/llvm/test/CodeGen/AArch64/sve-bf16-converts.ll @@ -0,0 +1,89 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc -mattr=+sve < %s | FileCheck %s +; RUN: llc -mattr=+sme -force-streaming < %s | FileCheck %s + +target triple = "aarch64-unknown-linux-gnu" + +define @fpext_nxv2bf16_to_nxv2f32( %a) { +; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f32: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ret + %res = fpext %a to + ret %res +} + +define @fpext_nxv4bf16_to_nxv4f32( %a) { +; CHECK-LABEL: fpext_nxv4bf16_to_nxv4f32: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ret + %res = fpext %a to + ret %res +} + +define @fpext_nxv8bf16_to_nxv8f32( %a) { +; CHECK-LABEL: fpext_nxv8bf16_to_nxv8f32: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpklo z1.s, z0.h +; CHECK-NEXT: uunpkhi z2.s, z0.h +; CHECK-NEXT: lsl z0.s, z1.s, #16 +; CHECK-NEXT: lsl z1.s, z2.s, #16 +; CHECK-NEXT: ret + %res = fpext %a to + ret %res +} + +define @fpext_nxv2bf16_to_nxv2f64( %a) { +; CHECK-LABEL: fpext_nxv2bf16_to_nxv2f64: +; CHECK: // %bb.0: +; CHECK-NEXT: lsl z0.s, z0.s, #16 +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: fcvt z0.d, p0/m, z0.s +; CHECK-NEXT: ret + %res = fpext %a to + ret %res +} + +define @fpext_nxv4bf16_to_nxv4f64( %a) { +; CHECK-LABEL: fpext_nxv4bf16_to_nxv4f64: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpklo z1.d, z0.s +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: lsl z1.s, z1.s, #16 +; CHECK-NEXT: lsl z2.s, z0.s, #16 +; CHECK-NEXT: movprfx z0, z1 +; CHECK-NEXT: fcvt z0.d, p0/m, z1.s +; CHECK-NEXT: movprfx z1, z2 +; CHECK-NEXT: fcvt z1.d, p0/m, z2.s +; CHECK-NEXT: ret + %res = fpext %a to + ret %res +} + +define @fpext_nxv8bf16_to_nxv8f64( %a) { +; CHECK-LABEL: fpext_nxv8bf16_to_nxv8f64: +; CHECK: // %bb.0: +; CHECK-NEXT: uunpklo z1.s, z0.h +; CHECK-NEXT: uunpkhi z0.s, z0.h +; CHECK-NEXT: ptrue p0.d +; CHECK-NEXT: uunpklo z2.d, z1.s +; CHECK-NEXT: uunpkhi z1.d, z1.s +; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: uunpkhi z0.d, z0.s +; CHECK-NEXT: lsl z1.s, z1.s, #16 +; CHECK-NEXT: lsl z2.s, z2.s, #16 +; CHECK-NEXT: lsl z3.s, z3.s, #16 +; CHECK-NEXT: lsl z4.s, z0.s, #16 +; CHECK-NEXT: fcvt z1.d, p0/m, z1.s +; CHECK-NEXT: movprfx z0, z2 +; CHECK-NEXT: fcvt z0.d, p0/m, z2.s +; CHECK-NEXT: movprfx z2, z3 +; CHECK-NEXT: fcvt z2.d, p0/m, z3.s +; CHECK-NEXT: movprfx z3, z4 +; CHECK-NEXT: fcvt z3.d, p0/m, z4.s +; CHECK-NEXT: ret + %res = fpext %a to + ret %res +}