]> granicus.if.org Git - llvm/commitdiff
[ConstantFold] Support vector index when factoring out GEP index into preceding dimen...
authorHaicheng Wu <haicheng@codeaurora.org>
Mon, 4 Dec 2017 19:56:33 +0000 (19:56 +0000)
committerHaicheng Wu <haicheng@codeaurora.org>
Mon, 4 Dec 2017 19:56:33 +0000 (19:56 +0000)
Follow-up of r316824. This patch supports the vector type for both current and
previous index when factoring out the current one into the previous one.

Differential Revision: https://reviews.llvm.org/D39556

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@319683 91177308-0d34-0410-b5e6-96231b3b80d8

lib/IR/ConstantFold.cpp
test/Assembler/getelementptr_vec_ce.ll
test/Transforms/InstCombine/gep-vector.ll
test/Transforms/InstSimplify/vector_gep.ll

index c826f757e6dde6e1afd2f1c34a022d4a19036ad0..90b10309b58bdf07f167dc7218399d4d9ba70c78 100644 (file)
@@ -2210,17 +2210,17 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
   SmallVector<Constant *, 8> NewIdxs;
   Type *Ty = PointeeTy;
   Type *Prev = C->getType();
-  bool Unknown = !isa<ConstantInt>(Idxs[0]);
+  bool Unknown =
+      !isa<ConstantInt>(Idxs[0]) && !isa<ConstantDataVector>(Idxs[0]);
   for (unsigned i = 1, e = Idxs.size(); i != e;
        Prev = Ty, Ty = cast<CompositeType>(Ty)->getTypeAtIndex(Idxs[i]), ++i) {
-    auto *CI = dyn_cast<ConstantInt>(Idxs[i]);
-    if (!CI) {
+    if (!isa<ConstantInt>(Idxs[i]) && !isa<ConstantDataVector>(Idxs[i])) {
       // We don't know if it's in range or not.
       Unknown = true;
       continue;
     }
-    if (!isa<ConstantInt>(Idxs[i - 1]))
-      // FIXME: add the support of cosntant vector index.
+    if (!isa<ConstantInt>(Idxs[i - 1]) && !isa<ConstantDataVector>(Idxs[i - 1]))
+      // Skip if the type of the previous index is not supported.
       continue;
     if (InRangeIndex && i == *InRangeIndex + 1) {
       // If an index is marked inrange, we cannot apply this canonicalization to
@@ -2238,46 +2238,91 @@ Constant *llvm::ConstantFoldGetElementPtr(Type *PointeeTy, Constant *C,
       Unknown = true;
       continue;
     }
-    if (isIndexInRangeOfArrayType(STy->getNumElements(), CI))
-      // It's in range, skip to the next index.
-      continue;
+    if (ConstantInt *CI = dyn_cast<ConstantInt>(Idxs[i])) {
+      if (isIndexInRangeOfArrayType(STy->getNumElements(), CI))
+        // It's in range, skip to the next index.
+        continue;
+      if (CI->getSExtValue() < 0) {
+        // It's out of range and negative, don't try to factor it.
+        Unknown = true;
+        continue;
+      }
+    } else {
+      auto *CV = cast<ConstantDataVector>(Idxs[i]);
+      bool InRange = true;
+      for (unsigned I = 0, E = CV->getNumElements(); I != E; ++I) {
+        auto *CI = cast<ConstantInt>(CV->getElementAsConstant(I));
+        InRange &= isIndexInRangeOfArrayType(STy->getNumElements(), CI);
+        if (CI->getSExtValue() < 0) {
+          Unknown = true;
+          break;
+        }
+      }
+      if (InRange || Unknown)
+        // It's in range, skip to the next index.
+        // It's out of range and negative, don't try to factor it.
+        continue;
+    }
     if (isa<StructType>(Prev)) {
       // It's out of range, but the prior dimension is a struct
       // so we can't do anything about it.
       Unknown = true;
       continue;
     }
-    if (CI->getSExtValue() < 0) {
-      // It's out of range and negative, don't try to factor it.
-      Unknown = true;
-      continue;
-    }
     // It's out of range, but we can factor it into the prior
     // dimension.
     NewIdxs.resize(Idxs.size());
     // Determine the number of elements in our sequential type.
     uint64_t NumElements = STy->getArrayNumElements();
 
-    ConstantInt *Factor = ConstantInt::get(CI->getType(), NumElements);
-    NewIdxs[i] = ConstantExpr::getSRem(CI, Factor);
+    // Expand the current index or the previous index to a vector from a scalar
+    // if necessary.
+    Constant *CurrIdx = cast<Constant>(Idxs[i]);
+    auto *PrevIdx =
+        NewIdxs[i - 1] ? NewIdxs[i - 1] : cast<Constant>(Idxs[i - 1]);
+    bool IsCurrIdxVector = CurrIdx->getType()->isVectorTy();
+    bool IsPrevIdxVector = PrevIdx->getType()->isVectorTy();
+    bool UseVector = IsCurrIdxVector || IsPrevIdxVector;
+
+    if (!IsCurrIdxVector && IsPrevIdxVector)
+      CurrIdx = ConstantDataVector::getSplat(
+          PrevIdx->getType()->getVectorNumElements(), CurrIdx);
+
+    if (!IsPrevIdxVector && IsCurrIdxVector)
+      PrevIdx = ConstantDataVector::getSplat(
+          CurrIdx->getType()->getVectorNumElements(), PrevIdx);
+
+    Constant *Factor =
+        ConstantInt::get(CurrIdx->getType()->getScalarType(), NumElements);
+    if (UseVector)
+      Factor = ConstantDataVector::getSplat(
+          IsPrevIdxVector ? PrevIdx->getType()->getVectorNumElements()
+                          : CurrIdx->getType()->getVectorNumElements(),
+          Factor);
+
+    NewIdxs[i] = ConstantExpr::getSRem(CurrIdx, Factor);
 
-    Constant *PrevIdx = NewIdxs[i-1] ? NewIdxs[i-1] :
-                           cast<Constant>(Idxs[i - 1]);
-    Constant *Div = ConstantExpr::getSDiv(CI, Factor);
+    Constant *Div = ConstantExpr::getSDiv(CurrIdx, Factor);
 
     unsigned CommonExtendedWidth =
-        std::max(PrevIdx->getType()->getIntegerBitWidth(),
-                 Div->getType()->getIntegerBitWidth());
+        std::max(PrevIdx->getType()->getScalarSizeInBits(),
+                 Div->getType()->getScalarSizeInBits());
     CommonExtendedWidth = std::max(CommonExtendedWidth, 64U);
 
     // Before adding, extend both operands to i64 to avoid
     // overflow trouble.
-    if (!PrevIdx->getType()->isIntegerTy(CommonExtendedWidth))
-      PrevIdx = ConstantExpr::getSExt(
-          PrevIdx, Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
-    if (!Div->getType()->isIntegerTy(CommonExtendedWidth))
-      Div = ConstantExpr::getSExt(
-          Div, Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
+    Type *ExtendedTy = Type::getIntNTy(Div->getContext(), CommonExtendedWidth);
+    if (UseVector)
+      ExtendedTy = VectorType::get(
+          ExtendedTy, IsPrevIdxVector
+                          ? PrevIdx->getType()->getVectorNumElements()
+                          : CurrIdx->getType()->getVectorNumElements());
+
+    if (!PrevIdx->getType()->isIntOrIntVectorTy(CommonExtendedWidth))
+      PrevIdx = ConstantExpr::getSExt(PrevIdx, ExtendedTy);
+
+    if (!Div->getType()->isIntOrIntVectorTy(CommonExtendedWidth))
+      Div = ConstantExpr::getSExt(Div, ExtendedTy);
 
     NewIdxs[i - 1] = ConstantExpr::getAdd(PrevIdx, Div);
   }
index 4cf2964a57f7f47bf7c07a0d6c5e7b25d659ae28..67029698bfc52382ae994d3c3fbe3ae0eaeddacb 100644 (file)
@@ -3,7 +3,7 @@
 @G = global [4 x i32] zeroinitializer
 
 ; CHECK-LABEL: @foo
-; CHECK: ret <4 x i32*> getelementptr ([4 x i32], [4 x i32]* @G, <4 x i32> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>)
+; CHECK: ret <4 x i32*> getelementptr inbounds ([4 x i32], [4 x i32]* @G, <4 x i32> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>)
 define <4 x i32*> @foo() {
   ret <4 x i32*> getelementptr ([4 x i32], [4 x i32]* @G, i32 0, <4 x i32> <i32 0, i32 1, i32 2, i32 3>)
 }
index f7ed1a776f5309432ce97dfbef67519454d0109a..9f55981ae450498a61989938f136bb3d55797a9c 100644 (file)
@@ -16,9 +16,23 @@ define <8 x i64*> @patatino2() {
 
 @block = global [64 x [8192 x i8]] zeroinitializer, align 1
 
-; CHECK-LABEL:vectorindex
-; CHECK-NEXT: ret <2 x i8*> getelementptr inbounds ([64 x [8192 x i8]], [64 x [8192 x i8]]* @block, <2 x i64> zeroinitializer, <2 x i64> <i64 0, i64 1>, <2 x i64> <i64 8192, i64 8192>)
-define <2 x i8*> @vectorindex() {
+; CHECK-LABEL:vectorindex1
+; CHECK-NEXT: ret <2 x i8*> getelementptr inbounds ([64 x [8192 x i8]], [64 x [8192 x i8]]* @block, <2 x i64> zeroinitializer, <2 x i64> <i64 1, i64 2>, <2 x i64> zeroinitializer)
+define <2 x i8*> @vectorindex1() {
   %1 = getelementptr inbounds [64 x [8192 x i8]], [64 x [8192 x i8]]* @block, i64 0, <2 x i64> <i64 0, i64 1>, i64 8192
   ret <2 x i8*> %1
 }
+
+; CHECK-LABEL:vectorindex2
+; CHECK-NEXT: ret <2 x i8*> getelementptr inbounds ([64 x [8192 x i8]], [64 x [8192 x i8]]* @block, <2 x i64> zeroinitializer, <2 x i64> <i64 1, i64 2>, <2 x i64> <i64 8191, i64 1>)
+define <2 x i8*> @vectorindex2() {
+  %1 = getelementptr inbounds [64 x [8192 x i8]], [64 x [8192 x i8]]* @block, i64 0, i64 1, <2 x i64> <i64 8191, i64 8193>
+  ret <2 x i8*> %1
+}
+
+; CHECK-LABEL:vectorindex3
+; CHECK-NEXT: ret <2 x i8*> getelementptr inbounds ([64 x [8192 x i8]], [64 x [8192 x i8]]* @block, <2 x i64> zeroinitializer, <2 x i64> <i64 0, i64 2>, <2 x i64> <i64 8191, i64 1>)
+define <2 x i8*> @vectorindex3() {
+  %1 = getelementptr inbounds [64 x [8192 x i8]], [64 x [8192 x i8]]* @block, i64 0, <2 x i64> <i64 0, i64 1>, <2 x i64> <i64 8191, i64 8193>
+  ret <2 x i8*> %1
+}
index cdf4732d4b5eb5c890269c2325b238bb2befe589..25f2255a2a7ca9e12ca6bb9c7b6f35dbbcf42be9 100644 (file)
@@ -58,7 +58,7 @@ define <4 x i8*> @test5() {
 
 define <16 x i32*> @test6() {
 ; CHECK-LABEL: @test6
-; CHECK-NEXT: ret <16 x i32*> getelementptr ([24 x [42 x [3 x i32]]], [24 x [42 x [3 x i32]]]* @v, <16 x i64> zeroinitializer, <16 x i64> zeroinitializer, <16 x i64> <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15>, <16 x i64> zeroinitializer)
+; CHECK-NEXT: ret <16 x i32*> getelementptr inbounds ([24 x [42 x [3 x i32]]], [24 x [42 x [3 x i32]]]* @v, <16 x i64> zeroinitializer, <16 x i64> zeroinitializer, <16 x i64> <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15>, <16 x i64> zeroinitializer)
   %VectorGep = getelementptr [24 x [42 x [3 x i32]]], [24 x [42 x [3 x i32]]]* @v, i64 0, i64 0, <16 x i64> <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15>, i64 0
   ret <16 x i32*> %VectorGep
 }