From 55eb93b2688de99ada14c71804af99502276ac79 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Mon, 2 Sep 2024 10:14:04 -0700 Subject: [PATCH] [RISCV] Remove RISCVISD::FP_EXTEND_BF16. (#106939) I don't think we need this node. We can isel fp_extend directly. fp_extend to f64 requires two instructions, but we can emit them with an isel pattern. I have not removed RISCVISD::FP_ROUND_BF16 because f64->bf16 needs more work to fix the double rounding. --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 15 --------------- llvm/lib/Target/RISCV/RISCVISelLowering.h | 1 - .../Target/RISCV/RISCVInstrInfoVSDPatterns.td | 11 ++++------- llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td | 16 +++++++--------- 4 files changed, 11 insertions(+), 32 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index d02078372b24a2..250d1c60b9f59e 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -452,8 +452,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, setOperationAction(ISD::BITCAST, MVT::i16, Custom); setOperationAction(ISD::BITCAST, MVT::bf16, Custom); setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom); - setOperationAction(ISD::FP_EXTEND, MVT::f32, Custom); - setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom); setOperationAction(ISD::ConstantFP, MVT::bf16, Expand); setOperationAction(ISD::SELECT_CC, MVT::bf16, Expand); setOperationAction(ISD::BR_CC, MVT::bf16, Expand); @@ -6500,18 +6498,6 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op, return SplitVectorOp(Op, DAG); return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget); case ISD::FP_EXTEND: { - SDLoc DL(Op); - EVT VT = Op.getValueType(); - SDValue Op0 = Op.getOperand(0); - EVT Op0VT = Op0.getValueType(); - if (VT == MVT::f32 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) - return DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0); - if (VT == MVT::f64 && Op0VT == MVT::bf16 && Subtarget.hasStdExtZfbfmin()) { - SDValue FloatVal = - DAG.getNode(RISCVISD::FP_EXTEND_BF16, DL, MVT::f32, Op0); - return DAG.getNode(ISD::FP_EXTEND, DL, MVT::f64, FloatVal); - } - if (!Op.getValueType().isVector()) return Op; return lowerVectorFPExtendOrRoundLike(Op, DAG); @@ -20463,7 +20449,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(STRICT_FCVT_W_RV64) NODE_NAME_CASE(STRICT_FCVT_WU_RV64) NODE_NAME_CASE(FP_ROUND_BF16) - NODE_NAME_CASE(FP_EXTEND_BF16) NODE_NAME_CASE(FROUND) NODE_NAME_CASE(FCLASS) NODE_NAME_CASE(FSGNJX) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 9ae35173ba0cb3..29a16282ed001d 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -117,7 +117,6 @@ enum NodeType : unsigned { FCVT_WU_RV64, FP_ROUND_BF16, - FP_EXTEND_BF16, // Rounds an FP value to its corresponding integer in the same FP format. // First operand is the value to round, the second operand is the largest diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td index 0f435c4ff3d315..f12f82cb159529 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -677,8 +677,7 @@ multiclass VPatWidenBinaryFPSDNode_VV_VF_WV_WF_RM; multiclass VPatWidenFPMulAccSDNode_VV_VF_RM vtiToWtis, - PatFrags extop> { + list vtiToWtis> { foreach vtiToWti = vtiToWtis in { defvar vti = vtiToWti.Vti; defvar wti = vtiToWti.Wti; @@ -702,7 +701,7 @@ multiclass VPatWidenFPMulAccSDNode_VV_VF_RM; def : Pat<(fma (wti.Vector (SplatFPOp - (extop (vti.Scalar vti.ScalarRegClass:$rs1)))), + (fpext_oneuse (vti.Scalar vti.ScalarRegClass:$rs1)))), (wti.Vector (riscv_fpextend_vl_oneuse (vti.Vector vti.RegClass:$rs2), (vti.Mask true_mask), (XLenVT srcvalue))), @@ -1290,11 +1289,9 @@ foreach fvti = AllFloatVectors in { // 13.7. Vector Widening Floating-Point Fused Multiply-Add Instructions defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACC", - AllWidenableFloatVectors, - fpext_oneuse>; + AllWidenableFloatVectors>; defm : VPatWidenFPMulAccSDNode_VV_VF_RM<"PseudoVFWMACCBF16", - AllWidenableBFloatToFloatVectors, - riscv_fpextend_bf16_oneuse>; + AllWidenableBFloatToFloatVectors>; defm : VPatWidenFPNegMulAccSDNode_VV_VF_RM<"PseudoVFWNMACC">; defm : VPatWidenFPMulSacSDNode_VV_VF_RM<"PseudoVFWMSAC">; defm : VPatWidenFPNegMulSacSDNode_VV_VF_RM<"PseudoVFWNMSAC">; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td index 88b66e7fc49aad..bf6272317fda4d 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td @@ -19,17 +19,9 @@ def SDT_RISCVFP_ROUND_BF16 : SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, f32>]>; -def SDT_RISCVFP_EXTEND_BF16 - : SDTypeProfile<1, 1, [SDTCisVT<0, f32>, SDTCisVT<1, bf16>]>; def riscv_fpround_bf16 : SDNode<"RISCVISD::FP_ROUND_BF16", SDT_RISCVFP_ROUND_BF16>; -def riscv_fpextend_bf16 - : SDNode<"RISCVISD::FP_EXTEND_BF16", SDT_RISCVFP_EXTEND_BF16>; -def riscv_fpextend_bf16_oneuse : PatFrag<(ops node:$A), - (riscv_fpextend_bf16 node:$A), [{ - return N->hasOneUse(); -}]>; //===----------------------------------------------------------------------===// // Instructions @@ -57,7 +49,7 @@ def : StPat; // f32 -> bf16, bf16 -> f32 def : Pat<(bf16 (riscv_fpround_bf16 FPR32:$rs1)), (FCVT_BF16_S FPR32:$rs1, FRM_DYN)>; -def : Pat<(riscv_fpextend_bf16 (bf16 FPR16:$rs1)), +def : Pat<(fpextend (bf16 FPR16:$rs1)), (FCVT_S_BF16 FPR16:$rs1, FRM_DYN)>; // Moves (no conversion) @@ -87,3 +79,9 @@ def : Pat<(i64 (any_fp_to_uint (bf16 FPR16:$rs1))), (FCVT_LU_S (FCVT_S_BF16 $rs1 def : Pat<(bf16 (any_sint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_L $rs1, FRM_DYN), FRM_DYN)>; def : Pat<(bf16 (any_uint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_LU $rs1, FRM_DYN), FRM_DYN)>; } + +let Predicates = [HasStdExtZfbfmin, HasStdExtD] in { +// bf16 -> f64 +def : Pat<(fpextend (bf16 FPR16:$rs1)), + (FCVT_D_S (FCVT_S_BF16 FPR16:$rs1, FRM_DYN), FRM_RNE)>; +}