]> granicus.if.org Git - llvm/commitdiff
[TargetTransformInfo, API] Add a list of operands to TTI::getUserCost
authorEvgeny Astigeevich <evgeny.astigeevich@arm.com>
Thu, 29 Jun 2017 13:42:12 +0000 (13:42 +0000)
committerEvgeny Astigeevich <evgeny.astigeevich@arm.com>
Thu, 29 Jun 2017 13:42:12 +0000 (13:42 +0000)
The changes are a result of discussion of https://reviews.llvm.org/D33685.
It solves the following problem:

1. We can inform getGEPCost about simplified indices to help it with
   calculating the cost. But getGEPCost does not take into account the
   context which GEPs are used in.
2. We have getUserCost which can take the context into account but we cannot
   inform about simplified indices.

With the changes getUserCost will have access to additional information
as getGEPCost has.

The one parameter getUserCost is also provided.

Differential Revision: https://reviews.llvm.org/D34057

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@306674 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/Analysis/TargetTransformInfo.h
include/llvm/Analysis/TargetTransformInfoImpl.h
lib/Analysis/TargetTransformInfo.cpp
lib/Target/Hexagon/HexagonTargetTransformInfo.cpp
lib/Target/Hexagon/HexagonTargetTransformInfo.h

index f2a07006da463e966447942aa2bfc5d93ffba1a7..68fbf640994c2351aba0e41414474de0d536e591 100644 (file)
@@ -216,9 +216,23 @@ public:
   /// other context they may not be folded. This routine can distinguish such
   /// cases.
   ///
+  /// \p Operands is a list of operands which can be a result of transformations
+  /// of the current operands. The number of the operands on the list must equal
+  /// to the number of the current operands the IR user has. Their order on the
+  /// list must be the same as the order of the current operands the IR user
+  /// has.
+  ///
   /// The returned cost is defined in terms of \c TargetCostConstants, see its
   /// comments for a detailed explanation of the cost values.
-  int getUserCost(const User *U) const;
+  int getUserCost(const User *U, ArrayRef<const Value *> Operands) const;
+
+  /// \brief This is a helper function which calls the two-argument getUserCost
+  /// with \p Operands which are the current operands U has.
+  int getUserCost(const User *U) const {
+    SmallVector<const Value *, 4> Operands(U->value_op_begin(),
+                                           U->value_op_end());
+    return getUserCost(U, Operands);
+  }
 
   /// \brief Return true if branch divergence exists.
   ///
@@ -824,7 +838,8 @@ public:
                                ArrayRef<const Value *> Arguments) = 0;
   virtual unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
                                                     unsigned &JTSize) = 0;
-  virtual int getUserCost(const User *U) = 0;
+  virtual int
+  getUserCost(const User *U, ArrayRef<const Value *> Operands) = 0;
   virtual bool hasBranchDivergence() = 0;
   virtual bool isSourceOfDivergence(const Value *V) = 0;
   virtual bool isAlwaysUniform(const Value *V) = 0;
@@ -1000,7 +1015,9 @@ public:
                        ArrayRef<const Value *> Arguments) override {
     return Impl.getIntrinsicCost(IID, RetTy, Arguments);
   }
-  int getUserCost(const User *U) override { return Impl.getUserCost(U); }
+  int getUserCost(const User *U, ArrayRef<const Value *> Operands) override {
+    return Impl.getUserCost(U, Operands);
+  }
   bool hasBranchDivergence() override { return Impl.hasBranchDivergence(); }
   bool isSourceOfDivergence(const Value *V) override {
     return Impl.isSourceOfDivergence(V);
index defd9e73bc055902128adb1fa8569006f5dac26f..0246fc1c02cced14427808f7f997c02faca751f7 100644 (file)
@@ -685,14 +685,14 @@ public:
     return static_cast<T *>(this)->getIntrinsicCost(IID, RetTy, ParamTys);
   }
 
-  unsigned getUserCost(const User *U) {
+  unsigned getUserCost(const User *U, ArrayRef<const Value *> Operands) {
     if (isa<PHINode>(U))
       return TTI::TCC_Free; // Model all PHI nodes as free.
 
     if (const GEPOperator *GEP = dyn_cast<GEPOperator>(U)) {
-      SmallVector<Value *, 4> Indices(GEP->idx_begin(), GEP->idx_end());
-      return static_cast<T *>(this)->getGEPCost(
-          GEP->getSourceElementType(), GEP->getPointerOperand(), Indices);
+      return static_cast<T *>(this)->getGEPCost(GEP->getSourceElementType(),
+                                                GEP->getPointerOperand(),
+                                                Operands.drop_front());
     }
 
     if (auto CS = ImmutableCallSite(U)) {
index 91160954bd997e3f141e05a3f276cceb0030666c..f938a9a52065013339f88f6601be9cc6744762b8 100644 (file)
@@ -89,8 +89,9 @@ TargetTransformInfo::getEstimatedNumberOfCaseClusters(const SwitchInst &SI,
   return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize);
 }
 
-int TargetTransformInfo::getUserCost(const User *U) const {
-  int Cost = TTIImpl->getUserCost(U);
+int TargetTransformInfo::getUserCost(const User *U,
+    ArrayRef<const Value *> Operands) const {
+  int Cost = TTIImpl->getUserCost(U, Operands);
   assert(Cost >= 0 && "TTI should not produce negative costs!");
   return Cost;
 }
index d3848104d2ba1b8088e2a8ea5c562e128140d824..4135d0cec70327aafa787f911c19c9b22e39fe4a 100644 (file)
@@ -46,8 +46,9 @@ unsigned HexagonTTIImpl::getCacheLineSize() const {
   return getST()->getL1CacheLineSize();
 }
 
-int HexagonTTIImpl::getUserCost(const User *U) {
-  auto isCastFoldedIntoLoad = [] (const CastInst *CI) -> bool {
+int HexagonTTIImpl::getUserCost(const User *U,
+                                ArrayRef<const Value *> Operands) {
+  auto isCastFoldedIntoLoad = [](const CastInst *CI) -> bool {
     if (!CI->isIntegerCast())
       return false;
     const LoadInst *LI = dyn_cast<const LoadInst>(CI->getOperand(0));
@@ -67,5 +68,5 @@ int HexagonTTIImpl::getUserCost(const User *U) {
   if (const CastInst *CI = dyn_cast<const CastInst>(U))
     if (isCastFoldedIntoLoad(CI))
       return TargetTransformInfo::TCC_Free;
-  return BaseT::getUserCost(U);
+  return BaseT::getUserCost(U, Operands);
 }
index 08faa2acd90ba339f2be3d0fdebfa03fe0e15ae3..f08a27310577e16afef8be6e2ec7f5e5f816b2a7 100644 (file)
@@ -62,7 +62,7 @@ public:
 
   /// @}
 
-  int getUserCost(const User *U);
+  int getUserCost(const User *U, ArrayRef<const Value *> Operands);
 };
 
 } // end namespace llvm