]> granicus.if.org Git - llvm/commitdiff
Address Eli's post-commit comments
authorDavid Majnemer <david.majnemer@gmail.com>
Sun, 19 Jun 2016 21:36:35 +0000 (21:36 +0000)
committerDavid Majnemer <david.majnemer@gmail.com>
Sun, 19 Jun 2016 21:36:35 +0000 (21:36 +0000)
Use an APInt to handle pointers of arbitrary width, let
accumulateConstantOffset handle overflow issues.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@273126 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/LoadCombine.cpp

index 9e457a191807d3062fffa233f26ac2e13262262b..dfe51a4ce44c5fd49693f8a198e7fe42fa5d75a9 100644 (file)
@@ -40,7 +40,7 @@ STATISTIC(NumLoadsCombined, "Number of loads combined");
 namespace {
 struct PointerOffsetPair {
   Value *Pointer;
-  int64_t Offset;
+  APInt Offset;
 };
 
 struct LoadPOPPair {
@@ -93,22 +93,25 @@ bool LoadCombine::doInitialization(Function &F) {
 }
 
 PointerOffsetPair LoadCombine::getPointerOffsetPair(LoadInst &LI) {
+  auto &DL = LI.getModule()->getDataLayout();
+
   PointerOffsetPair POP;
   POP.Pointer = LI.getPointerOperand();
-  POP.Offset = 0;
+  unsigned BitWidth = DL.getPointerSizeInBits(LI.getPointerAddressSpace());
+  POP.Offset = APInt(BitWidth, 0);
+
   while (isa<BitCastInst>(POP.Pointer) || isa<GetElementPtrInst>(POP.Pointer)) {
     if (auto *GEP = dyn_cast<GetElementPtrInst>(POP.Pointer)) {
-      auto &DL = LI.getModule()->getDataLayout();
-      unsigned BitWidth = DL.getPointerTypeSizeInBits(GEP->getType());
-      APInt Offset(BitWidth, 0);
-      if (GEP->accumulateConstantOffset(DL, Offset))
-        POP.Offset += Offset.getSExtValue();
-      else
+      APInt LastOffset = POP.Offset;
+      if (!GEP->accumulateConstantOffset(DL, POP.Offset)) {
         // Can't handle GEPs with variable indices.
+        POP.Offset = LastOffset;
         return POP;
+      }
       POP.Pointer = GEP->getPointerOperand();
-    } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer))
+    } else if (auto *BC = dyn_cast<BitCastInst>(POP.Pointer)) {
       POP.Pointer = BC->getOperand(0);
+    }
   }
   return POP;
 }
@@ -121,8 +124,8 @@ bool LoadCombine::combineLoads(
       continue;
     std::sort(Loads.second.begin(), Loads.second.end(),
               [](const LoadPOPPair &A, const LoadPOPPair &B) {
-      return A.POP.Offset < B.POP.Offset;
-    });
+                return A.POP.Offset.slt(B.POP.Offset);
+              });
     if (aggregateLoads(Loads.second))
       Combined = true;
   }
@@ -139,7 +142,7 @@ bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
   SmallVector<LoadPOPPair, 8> AggregateLoads;
   bool Combined = false;
   bool ValidPrevOffset = false;
-  int64_t PrevOffset = 0;
+  APInt PrevOffset;
   uint64_t PrevSize = 0;
   for (auto &L : Loads) {
     if (ValidPrevOffset == false) {
@@ -153,8 +156,8 @@ bool LoadCombine::aggregateLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
     }
     if (L.Load->getAlignment() > BaseLoad->getAlignment())
       continue;
-    int64_t PrevEnd = PrevOffset + PrevSize;
-    if (L.POP.Offset > PrevEnd) {
+    APInt PrevEnd = PrevOffset + PrevSize;
+    if (L.POP.Offset.sgt(PrevEnd)) {
       // No other load will be combinable
       if (combineLoads(AggregateLoads))
         Combined = true;
@@ -208,7 +211,7 @@ bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
   Value *Ptr = Builder->CreateConstGEP1_64(
       Builder->CreatePointerCast(Loads[0].POP.Pointer,
                                  Builder->getInt8PtrTy(AddressSpace)),
-      Loads[0].POP.Offset);
+      Loads[0].POP.Offset.getSExtValue());
   LoadInst *NewLoad = new LoadInst(
       Builder->CreatePointerCast(
           Ptr, PointerType::get(IntegerType::get(Ptr->getContext(), TotalSize),
@@ -221,7 +224,7 @@ bool LoadCombine::combineLoads(SmallVectorImpl<LoadPOPPair> &Loads) {
     Value *V = Builder->CreateExtractInteger(
         L.Load->getModule()->getDataLayout(), NewLoad,
         cast<IntegerType>(L.Load->getType()),
-        L.POP.Offset - Loads[0].POP.Offset, "combine.extract");
+        (L.POP.Offset - Loads[0].POP.Offset).getZExtValue(), "combine.extract");
     L.Load->replaceAllUsesWith(V);
   }