return nullptr;
}
-/// Given the operands for an FMul, see if we can fold the result
-static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
- const SimplifyQuery &Q, unsigned MaxRecurse) {
- if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
- return C;
-
- if (Constant *C = simplifyFPBinop(Op0, Op1))
- return C;
-
+static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
// fmul X, 1.0 ==> X
if (match(Op1, m_FPOne()))
return Op0;
return nullptr;
}
+/// Given the operands for an FMul, see if we can fold the result
+static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))
+ return C;
+
+ if (Constant *C = simplifyFPBinop(Op0, Op1))
+ return C;
+
+ // Now apply simplifications that do not require rounding.
+ return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse);
+}
+
Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
const SimplifyQuery &Q) {
return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit);
return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit);
}
+Value *llvm::SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
+ const SimplifyQuery &Q) {
+ return ::SimplifyFMAFMul(Op0, Op1, FMF, Q, RecursionLimit);
+}
+
static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
const SimplifyQuery &Q, unsigned) {
if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))
return replaceInstUsesWith(*II, Add);
}
+ // Try to simplify the underlying FMul.
+ if (Value *V = SimplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1),
+ II->getFastMathFlags(),
+ SQ.getWithInstruction(II))) {
+ auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
+ FAdd->copyFastMathFlags(II);
+ return FAdd;
+ }
+
LLVM_FALLTHROUGH;
}
case Intrinsic::fma: {
return II;
}
- // fma x, 1, z -> fadd x, z
- if (match(Src1, m_FPOne())) {
- auto *FAdd = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2));
+ // Try to simplify the underlying FMul. We can only apply simplifications
+ // that do not require rounding.
+ if (Value *V = SimplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1),
+ II->getFastMathFlags(),
+ SQ.getWithInstruction(II))) {
+ auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
FAdd->copyFastMathFlags(II);
return FAdd;
}
define <2 x double> @fmuladd_a_0_b(<2 x double> %a, <2 x double> %b) {
; CHECK-LABEL: @fmuladd_a_0_b(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
-; CHECK-NEXT: ret <2 x double> [[RES]]
+; CHECK-NEXT: ret <2 x double> [[B:%.*]]
;
entry:
%res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> %a, <2 x double> zeroinitializer, <2 x double> %b)
define <2 x double> @fmuladd_0_a_b(<2 x double> %a, <2 x double> %b) {
; CHECK-LABEL: @fmuladd_0_a_b(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
-; CHECK-NEXT: ret <2 x double> [[RES]]
+; CHECK-NEXT: ret <2 x double> [[B:%.*]]
;
entry:
%res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> zeroinitializer, <2 x double> %a, <2 x double> %b)
define <2 x double> @fma_a_0_b(<2 x double> %a, <2 x double> %b) {
; CHECK-LABEL: @fma_a_0_b(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
-; CHECK-NEXT: ret <2 x double> [[RES]]
+; CHECK-NEXT: ret <2 x double> [[B:%.*]]
;
entry:
%res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> %a, <2 x double> zeroinitializer, <2 x double> %b)
define <2 x double> @fma_0_a_b(<2 x double> %a, <2 x double> %b) {
; CHECK-LABEL: @fma_0_a_b(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> [[A:%.*]], <2 x double> zeroinitializer, <2 x double> [[B:%.*]])
-; CHECK-NEXT: ret <2 x double> [[RES]]
+; CHECK-NEXT: ret <2 x double> [[B:%.*]]
;
entry:
%res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> zeroinitializer, <2 x double> %a, <2 x double> %b)
define <2 x double> @fma_sqrt(<2 x double> %a, <2 x double> %b) {
; CHECK-LABEL: @fma_sqrt(
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[SQRT:%.*]] = call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> [[A:%.*]])
-; CHECK-NEXT: [[RES:%.*]] = call fast <2 x double> @llvm.fma.v2f64(<2 x double> [[SQRT]], <2 x double> [[SQRT]], <2 x double> [[B:%.*]])
+; CHECK-NEXT: [[RES:%.*]] = fadd fast <2 x double> [[A:%.*]], [[B:%.*]]
; CHECK-NEXT: ret <2 x double> [[RES]]
;
entry:
ret <2 x double> %res
}
+; We do not fold constant multiplies in FMAs, as they could require rounding, unless either constant is 0.0 or 1.0.
+define <2 x double> @fma_const_fmul(<2 x double> %b) {
+; CHECK-LABEL: @fma_const_fmul(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 0x4131233302898702, double 0x40C387800000D6C0>, <2 x double> <double 1.291820e-08, double 9.123000e-06>, <2 x double> [[B:%.*]])
+; CHECK-NEXT: ret <2 x double> [[RES]]
+;
+entry:
+ %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 0.0000000129182, double 0.000009123>, <2 x double> %b)
+ ret <2 x double> %res
+}
+
+define <2 x double> @fma_const_fmul_zero(<2 x double> %b) {
+; CHECK-LABEL: @fma_const_fmul_zero(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> zeroinitializer, <2 x double> <double 0x4131233302898702, double 0x40C387800000D6C0>, <2 x double> [[B:%.*]])
+; CHECK-NEXT: ret <2 x double> [[RES]]
+;
+entry:
+ %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 0.0, double 0.0>, <2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> %b)
+ ret <2 x double> %res
+}
+
+define <2 x double> @fma_const_fmul_zero2(<2 x double> %b) {
+; CHECK-LABEL: @fma_const_fmul_zero2(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: ret <2 x double> [[B:%.*]]
+;
+entry:
+ %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 0.0, double 0.0>, <2 x double> %b)
+ ret <2 x double> %res
+}
+
+define <2 x double> @fma_const_fmul_one(<2 x double> %b) {
+; CHECK-LABEL: @fma_const_fmul_one(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[RES:%.*]] = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1.000000e+00, double 1.000000e+00>, <2 x double> <double 0x4131233302898702, double 0x40C387800000D6C0>, <2 x double> [[B:%.*]])
+; CHECK-NEXT: ret <2 x double> [[RES]]
+;
+entry:
+ %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1.0, double 1.0>, <2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> %b)
+ ret <2 x double> %res
+}
+
+define <2 x double> @fma_const_fmul_one2(<2 x double> %b) {
+; CHECK-LABEL: @fma_const_fmul_one2(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], <double 0x4131233302898702, double 0x40C387800000D6C0>
+; CHECK-NEXT: ret <2 x double> [[RES]]
+;
+entry:
+ %res = call nnan nsz <2 x double> @llvm.fma.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 1.0, double 1.0>, <2 x double> %b)
+ ret <2 x double> %res
+}
+
+define <2 x double> @fmuladd_const_fmul(<2 x double> %b) {
+; CHECK-LABEL: @fmuladd_const_fmul(
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[RES:%.*]] = fadd nnan nsz <2 x double> [[B:%.*]], <double 0x3F8DB6C076AD949B, double 0x3FB75A405B6E6D69>
+; CHECK-NEXT: ret <2 x double> [[RES]]
+;
+entry:
+ %res = call nnan nsz <2 x double> @llvm.fmuladd.v2f64(<2 x double> <double 1123123.0099110012314, double 9999.0000001>, <2 x double> <double 0.0000000129182, double 0.000009123>, <2 x double> %b)
+ ret <2 x double> %res
+}
declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>)
declare <2 x double> @llvm.sqrt.v2f64(<2 x double>)