From: Kevin P. Neal Date: Mon, 8 Jul 2019 16:18:18 +0000 (+0000) Subject: Teach the IRBuilder about fadd and friends. X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=fb8a52a58666795bef263deac618fc5fc7f3eca4;p=llvm Teach the IRBuilder about fadd and friends. The IRBuilder has calls to create floating point instructions like fadd. It does not have calls to create constrained versions of them. This patch adds support for constrained creation of fadd, fsub, fmul, fdiv, and frem. Reviewed by: John McCall, Sanjay Patel Approved by: John McCall Differential Revision: https://reviews.llvm.org/D53157 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@365339 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/include/llvm/IR/IRBuilder.h b/include/llvm/IR/IRBuilder.h index 980f7345a40..552ea3ef7a1 100644 --- a/include/llvm/IR/IRBuilder.h +++ b/include/llvm/IR/IRBuilder.h @@ -31,7 +31,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" @@ -96,12 +96,18 @@ protected: MDNode *DefaultFPMathTag; FastMathFlags FMF; + bool IsFPConstrained; + ConstrainedFPIntrinsic::ExceptionBehavior DefaultConstrainedExcept; + ConstrainedFPIntrinsic::RoundingMode DefaultConstrainedRounding; + ArrayRef DefaultOperandBundles; public: IRBuilderBase(LLVMContext &context, MDNode *FPMathTag = nullptr, ArrayRef OpBundles = None) - : Context(context), DefaultFPMathTag(FPMathTag), + : Context(context), DefaultFPMathTag(FPMathTag), IsFPConstrained(false), + DefaultConstrainedExcept(ConstrainedFPIntrinsic::ebStrict), + DefaultConstrainedRounding(ConstrainedFPIntrinsic::rmDynamic), DefaultOperandBundles(OpBundles) { ClearInsertionPoint(); } @@ -218,6 +224,37 @@ public: /// Set the fast-math flags to be used with generated fp-math operators void setFastMathFlags(FastMathFlags NewFMF) { FMF = NewFMF; } + /// Enable/Disable use of constrained floating point math. When + /// enabled the CreateF() calls instead create constrained + /// floating point intrinsic calls. Fast math flags are unaffected + /// by this setting. + void setIsFPConstrained(bool IsCon) { IsFPConstrained = IsCon; } + + /// Query for the use of constrained floating point math + bool getIsFPConstrained() { return IsFPConstrained; } + + /// Set the exception handling to be used with constrained floating point + void setDefaultConstrainedExcept( + ConstrainedFPIntrinsic::ExceptionBehavior NewExcept) { + DefaultConstrainedExcept = NewExcept; + } + + /// Set the rounding mode handling to be used with constrained floating point + void setDefaultConstrainedRounding( + ConstrainedFPIntrinsic::RoundingMode NewRounding) { + DefaultConstrainedRounding = NewRounding; + } + + /// Get the exception handling used with constrained floating point + ConstrainedFPIntrinsic::ExceptionBehavior getDefaultConstrainedExcept() { + return DefaultConstrainedExcept; + } + + /// Get the rounding mode handling used with constrained floating point + ConstrainedFPIntrinsic::RoundingMode getDefaultConstrainedRounding() { + return DefaultConstrainedRounding; + } + //===--------------------------------------------------------------------===// // RAII helpers. //===--------------------------------------------------------------------===// @@ -1045,6 +1082,38 @@ private: return (LC && RC) ? Insert(Folder.CreateBinOp(Opc, LC, RC), Name) : nullptr; } + Value *getConstrainedFPRounding( + Optional Rounding) { + ConstrainedFPIntrinsic::RoundingMode UseRounding = + DefaultConstrainedRounding; + + if (Rounding.hasValue()) + UseRounding = Rounding.getValue(); + + Optional RoundingStr = + ConstrainedFPIntrinsic::RoundingModeToStr(UseRounding); + assert(RoundingStr.hasValue() && "Garbage strict rounding mode!"); + auto *RoundingMDS = MDString::get(Context, RoundingStr.getValue()); + + return MetadataAsValue::get(Context, RoundingMDS); + } + + Value *getConstrainedFPExcept( + Optional Except) { + ConstrainedFPIntrinsic::ExceptionBehavior UseExcept = + DefaultConstrainedExcept; + + if (Except.hasValue()) + UseExcept = Except.getValue(); + + Optional ExceptStr = + ConstrainedFPIntrinsic::ExceptionBehaviorToStr(UseExcept); + assert(ExceptStr.hasValue() && "Garbage strict exception behavior!"); + auto *ExceptMDS = MDString::get(Context, ExceptStr.getValue()); + + return MetadataAsValue::get(Context, ExceptMDS); + } + public: Value *CreateAdd(Value *LHS, Value *RHS, const Twine &Name = "", bool HasNUW = false, bool HasNSW = false) { @@ -1263,6 +1332,10 @@ public: Value *CreateFAdd(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { + if (IsFPConstrained) + return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fadd, + L, R, nullptr, Name, FPMD); + if (Value *V = foldConstant(Instruction::FAdd, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), FPMD, FMF); return Insert(I, Name); @@ -1272,6 +1345,10 @@ public: /// default FMF. Value *CreateFAddFMF(Value *L, Value *R, Instruction *FMFSource, const Twine &Name = "") { + if (IsFPConstrained) + return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fadd, + L, R, FMFSource, Name); + if (Value *V = foldConstant(Instruction::FAdd, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFAdd(L, R), nullptr, FMFSource->getFastMathFlags()); @@ -1280,6 +1357,10 @@ public: Value *CreateFSub(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { + if (IsFPConstrained) + return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fsub, + L, R, nullptr, Name, FPMD); + if (Value *V = foldConstant(Instruction::FSub, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), FPMD, FMF); return Insert(I, Name); @@ -1289,6 +1370,10 @@ public: /// default FMF. Value *CreateFSubFMF(Value *L, Value *R, Instruction *FMFSource, const Twine &Name = "") { + if (IsFPConstrained) + return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fsub, + L, R, FMFSource, Name); + if (Value *V = foldConstant(Instruction::FSub, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFSub(L, R), nullptr, FMFSource->getFastMathFlags()); @@ -1297,6 +1382,10 @@ public: Value *CreateFMul(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { + if (IsFPConstrained) + return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fmul, + L, R, nullptr, Name, FPMD); + if (Value *V = foldConstant(Instruction::FMul, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), FPMD, FMF); return Insert(I, Name); @@ -1306,6 +1395,10 @@ public: /// default FMF. Value *CreateFMulFMF(Value *L, Value *R, Instruction *FMFSource, const Twine &Name = "") { + if (IsFPConstrained) + return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fmul, + L, R, FMFSource, Name); + if (Value *V = foldConstant(Instruction::FMul, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFMul(L, R), nullptr, FMFSource->getFastMathFlags()); @@ -1314,6 +1407,10 @@ public: Value *CreateFDiv(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { + if (IsFPConstrained) + return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fdiv, + L, R, nullptr, Name, FPMD); + if (Value *V = foldConstant(Instruction::FDiv, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), FPMD, FMF); return Insert(I, Name); @@ -1323,6 +1420,10 @@ public: /// default FMF. Value *CreateFDivFMF(Value *L, Value *R, Instruction *FMFSource, const Twine &Name = "") { + if (IsFPConstrained) + return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_fdiv, + L, R, FMFSource, Name); + if (Value *V = foldConstant(Instruction::FDiv, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFDiv(L, R), nullptr, FMFSource->getFastMathFlags()); @@ -1331,6 +1432,10 @@ public: Value *CreateFRem(Value *L, Value *R, const Twine &Name = "", MDNode *FPMD = nullptr) { + if (IsFPConstrained) + return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_frem, + L, R, nullptr, Name, FPMD); + if (Value *V = foldConstant(Instruction::FRem, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), FPMD, FMF); return Insert(I, Name); @@ -1340,6 +1445,10 @@ public: /// default FMF. Value *CreateFRemFMF(Value *L, Value *R, Instruction *FMFSource, const Twine &Name = "") { + if (IsFPConstrained) + return CreateConstrainedFPBinOp(Intrinsic::experimental_constrained_frem, + L, R, FMFSource, Name); + if (Value *V = foldConstant(Instruction::FRem, L, R, Name)) return V; Instruction *I = setFPAttrs(BinaryOperator::CreateFRem(L, R), nullptr, FMFSource->getFastMathFlags()); @@ -1356,6 +1465,23 @@ public: return Insert(BinOp, Name); } + CallInst *CreateConstrainedFPBinOp( + Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource = nullptr, + const Twine &Name = "", MDNode *FPMathTag = nullptr, + Optional Rounding = None, + Optional Except = None) { + Value *RoundingV = getConstrainedFPRounding(Rounding); + Value *ExceptV = getConstrainedFPExcept(Except); + + FastMathFlags UseFMF = FMF; + if (FMFSource) + UseFMF = FMFSource->getFastMathFlags(); + + CallInst *C = CreateIntrinsic(ID, {L->getType()}, + {L, R, RoundingV, ExceptV}, nullptr, Name); + return cast(setFPAttrs(C, FPMathTag, UseFMF)); + } + Value *CreateNeg(Value *V, const Twine &Name = "", bool HasNUW = false, bool HasNSW = false) { if (auto *VC = dyn_cast(V)) diff --git a/include/llvm/IR/IntrinsicInst.h b/include/llvm/IR/IntrinsicInst.h index 9b816b0a224..438bdb29b70 100644 --- a/include/llvm/IR/IntrinsicInst.h +++ b/include/llvm/IR/IntrinsicInst.h @@ -208,26 +208,47 @@ namespace llvm { /// This is the common base class for constrained floating point intrinsics. class ConstrainedFPIntrinsic : public IntrinsicInst { public: - enum RoundingMode { - rmInvalid, - rmDynamic, - rmToNearest, - rmDownward, - rmUpward, - rmTowardZero + /// Specifies the rounding mode to be assumed. This is only used when + /// when constrained floating point is enabled. See the LLVM Language + /// Reference Manual for details. + enum RoundingMode : uint8_t { + rmDynamic, ///< This corresponds to "fpround.dynamic". + rmToNearest, ///< This corresponds to "fpround.tonearest". + rmDownward, ///< This corresponds to "fpround.downward". + rmUpward, ///< This corresponds to "fpround.upward". + rmTowardZero ///< This corresponds to "fpround.tozero". }; - enum ExceptionBehavior { - ebInvalid, - ebIgnore, - ebMayTrap, - ebStrict + /// Specifies the required exception behavior. This is only used when + /// when constrained floating point is used. See the LLVM Language + /// Reference Manual for details. + enum ExceptionBehavior : uint8_t { + ebIgnore, ///< This corresponds to "fpexcept.ignore". + ebMayTrap, ///< This corresponds to "fpexcept.maytrap". + ebStrict ///< This corresponds to "fpexcept.strict". }; bool isUnaryOp() const; bool isTernaryOp() const; - RoundingMode getRoundingMode() const; - ExceptionBehavior getExceptionBehavior() const; + Optional getRoundingMode() const; + Optional getExceptionBehavior() const; + + /// Returns a valid RoundingMode enumerator when given a string + /// that is valid as input in constrained intrinsic rounding mode + /// metadata. + static Optional StrToRoundingMode(StringRef); + + /// For any RoundingMode enumerator, returns a string valid as input in + /// constrained intrinsic rounding mode metadata. + static Optional RoundingModeToStr(RoundingMode); + + /// Returns a valid ExceptionBehavior enumerator when given a string + /// valid as input in constrained intrinsic exception behavior metadata. + static Optional StrToExceptionBehavior(StringRef); + + /// For any ExceptionBehavior enumerator, returns a string valid as + /// input in constrained intrinsic exception behavior metadata. + static Optional ExceptionBehaviorToStr(ExceptionBehavior); // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const IntrinsicInst *I) { diff --git a/lib/IR/IntrinsicInst.cpp b/lib/IR/IntrinsicInst.cpp index 793e2895dce..7a042326f67 100644 --- a/lib/IR/IntrinsicInst.cpp +++ b/lib/IR/IntrinsicInst.cpp @@ -103,39 +103,86 @@ Value *InstrProfIncrementInst::getStep() const { return ConstantInt::get(Type::getInt64Ty(Context), 1); } -ConstrainedFPIntrinsic::RoundingMode +Optional ConstrainedFPIntrinsic::getRoundingMode() const { unsigned NumOperands = getNumArgOperands(); Metadata *MD = dyn_cast(getArgOperand(NumOperands - 2))->getMetadata(); if (!MD || !isa(MD)) - return rmInvalid; - StringRef RoundingArg = cast(MD)->getString(); + return None; + return StrToRoundingMode(cast(MD)->getString()); +} +Optional +ConstrainedFPIntrinsic::StrToRoundingMode(StringRef RoundingArg) { // For dynamic rounding mode, we use round to nearest but we will set the // 'exact' SDNodeFlag so that the value will not be rounded. - return StringSwitch(RoundingArg) + return StringSwitch>(RoundingArg) .Case("round.dynamic", rmDynamic) .Case("round.tonearest", rmToNearest) .Case("round.downward", rmDownward) .Case("round.upward", rmUpward) .Case("round.towardzero", rmTowardZero) - .Default(rmInvalid); + .Default(None); } -ConstrainedFPIntrinsic::ExceptionBehavior +Optional +ConstrainedFPIntrinsic::RoundingModeToStr(RoundingMode UseRounding) { + Optional RoundingStr = None; + switch (UseRounding) { + case ConstrainedFPIntrinsic::rmDynamic: + RoundingStr = "round.dynamic"; + break; + case ConstrainedFPIntrinsic::rmToNearest: + RoundingStr = "round.tonearest"; + break; + case ConstrainedFPIntrinsic::rmDownward: + RoundingStr = "round.downward"; + break; + case ConstrainedFPIntrinsic::rmUpward: + RoundingStr = "round.upward"; + break; + case ConstrainedFPIntrinsic::rmTowardZero: + RoundingStr = "round.tozero"; + break; + } + return RoundingStr; +} + +Optional ConstrainedFPIntrinsic::getExceptionBehavior() const { unsigned NumOperands = getNumArgOperands(); Metadata *MD = dyn_cast(getArgOperand(NumOperands - 1))->getMetadata(); if (!MD || !isa(MD)) - return ebInvalid; - StringRef ExceptionArg = cast(MD)->getString(); - return StringSwitch(ExceptionArg) + return None; + return StrToExceptionBehavior(cast(MD)->getString()); +} + +Optional +ConstrainedFPIntrinsic::StrToExceptionBehavior(StringRef ExceptionArg) { + return StringSwitch>(ExceptionArg) .Case("fpexcept.ignore", ebIgnore) .Case("fpexcept.maytrap", ebMayTrap) .Case("fpexcept.strict", ebStrict) - .Default(ebInvalid); + .Default(None); +} + +Optional +ConstrainedFPIntrinsic::ExceptionBehaviorToStr(ExceptionBehavior UseExcept) { + Optional ExceptStr = None; + switch (UseExcept) { + case ConstrainedFPIntrinsic::ebStrict: + ExceptStr = "fpexcept.strict"; + break; + case ConstrainedFPIntrinsic::ebIgnore: + ExceptStr = "fpexcept.ignore"; + break; + case ConstrainedFPIntrinsic::ebMayTrap: + ExceptStr = "fpexcept.maytrap"; + break; + } + return ExceptStr; } bool ConstrainedFPIntrinsic::isUnaryOp() const { diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp index 2655e3ce81e..744707afa3c 100644 --- a/lib/IR/Verifier.cpp +++ b/lib/IR/Verifier.cpp @@ -4776,11 +4776,11 @@ void Verifier::visitConstrainedFPIntrinsic(ConstrainedFPIntrinsic &FPI) { // argument type check is needed here. if (HasExceptionMD) { - Assert(FPI.getExceptionBehavior() != ConstrainedFPIntrinsic::ebInvalid, + Assert(FPI.getExceptionBehavior().hasValue(), "invalid exception behavior argument", &FPI); } if (HasRoundingMD) { - Assert(FPI.getRoundingMode() != ConstrainedFPIntrinsic::rmInvalid, + Assert(FPI.getRoundingMode().hasValue(), "invalid rounding mode argument", &FPI); } } diff --git a/unittests/IR/IRBuilderTest.cpp b/unittests/IR/IRBuilderTest.cpp index c7368e2038b..c43f00cb8da 100644 --- a/unittests/IR/IRBuilderTest.cpp +++ b/unittests/IR/IRBuilderTest.cpp @@ -122,6 +122,70 @@ TEST_F(IRBuilderTest, Intrinsics) { EXPECT_FALSE(II->hasNoNaNs()); } +TEST_F(IRBuilderTest, ConstrainedFP) { + IRBuilder<> Builder(BB); + Value *V; + CallInst *Call; + IntrinsicInst *II; + + V = Builder.CreateLoad(GV); + + // See if we get constrained intrinsics instead of non-constrained + // instructions. + Builder.setIsFPConstrained(true); + + V = Builder.CreateFAdd(V, V); + ASSERT_TRUE(isa(V)); + II = cast(V); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fadd); + + V = Builder.CreateFSub(V, V); + ASSERT_TRUE(isa(V)); + II = cast(V); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fsub); + + V = Builder.CreateFMul(V, V); + ASSERT_TRUE(isa(V)); + II = cast(V); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fmul); + + V = Builder.CreateFDiv(V, V); + ASSERT_TRUE(isa(V)); + II = cast(V); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_fdiv); + + V = Builder.CreateFRem(V, V); + ASSERT_TRUE(isa(V)); + II = cast(V); + EXPECT_EQ(II->getIntrinsicID(), Intrinsic::experimental_constrained_frem); + + // Verify the codepaths for setting and overriding the default metadata. + V = Builder.CreateFAdd(V, V); + ASSERT_TRUE(isa(V)); + auto *CII = cast(V); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebStrict); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDynamic); + + Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebIgnore); + Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmUpward); + V = Builder.CreateFAdd(V, V); + CII = cast(V); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebIgnore); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmUpward); + + // Now override the defaults. + Call = Builder.CreateConstrainedFPBinOp( + Intrinsic::experimental_constrained_fadd, V, V, nullptr, "", nullptr, + ConstrainedFPIntrinsic::rmDownward, ConstrainedFPIntrinsic::ebMayTrap); + CII = cast(Call); + EXPECT_EQ(CII->getIntrinsicID(), Intrinsic::experimental_constrained_fadd); + ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebMayTrap); + ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDownward); + + Builder.CreateRetVoid(); + EXPECT_FALSE(verifyModule(*M)); +} + TEST_F(IRBuilderTest, Lifetime) { IRBuilder<> Builder(BB); AllocaInst *Var1 = Builder.CreateAlloca(Builder.getInt8Ty());