]> granicus.if.org Git - llvm/commitdiff
[InstCombine] add helper function for icmp+zext/sext; NFC
authorSanjay Patel <spatel@rotateright.com>
Tue, 20 Aug 2019 18:15:17 +0000 (18:15 +0000)
committerSanjay Patel <spatel@rotateright.com>
Tue, 20 Aug 2019 18:15:17 +0000 (18:15 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@369421 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineCompares.cpp

index dc6dcc36061332c2849f98f763afb9cdcb894e0d..b318d0b2de950b00fe1b7df9c7502d85383d3854 100644 (file)
@@ -4026,103 +4026,66 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {
   return nullptr;
 }
 
-/// Handle icmp (cast x), (cast or constant).
-Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) {
-  auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0));
-  if (!CastOp0)
-    return nullptr;
-  if (!isa<Constant>(ICmp.getOperand(1)) && !isa<CastInst>(ICmp.getOperand(1)))
-    return nullptr;
-
-  Value *Op0Src = CastOp0->getOperand(0);
-  Type *SrcTy = CastOp0->getSrcTy();
-  Type *DestTy = CastOp0->getDestTy();
-
-  // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the
-  // integer type is the same size as the pointer type.
-  auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) {
-    if (isa<VectorType>(SrcTy)) {
-      SrcTy = cast<VectorType>(SrcTy)->getElementType();
-      DestTy = cast<VectorType>(DestTy)->getElementType();
-    }
-    return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth();
-  };
-  if (CastOp0->getOpcode() == Instruction::PtrToInt &&
-      CompatibleSizes(SrcTy, DestTy)) {
-    Value *NewOp1 = nullptr;
-    if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) {
-      Value *PtrSrc = PtrToIntOp1->getOperand(0);
-      if (PtrSrc->getType()->getPointerAddressSpace() ==
-          Op0Src->getType()->getPointerAddressSpace()) {
-        NewOp1 = PtrToIntOp1->getOperand(0);
-        // If the pointer types don't match, insert a bitcast.
-        if (Op0Src->getType() != NewOp1->getType())
-          NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType());
-      }
-    } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) {
-      NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy);
-    }
-
-    if (NewOp1)
-      return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1);
-  }
-
-  // The code below only handles extension cast instructions, so far.
-  // Enforce this.
-  if (CastOp0->getOpcode() != Instruction::ZExt &&
-      CastOp0->getOpcode() != Instruction::SExt)
+static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp) {
+  assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0");
+  auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0));
+  Value *X;
+  if (!match(CastOp0, m_ZExtOrSExt(m_Value(X))))
     return nullptr;
 
-  bool isSignedExt = CastOp0->getOpcode() == Instruction::SExt;
-  bool isSignedCmp = ICmp.isSigned();
-
+  bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt;
+  bool IsSignedCmp = ICmp.isSigned();
   if (auto *CastOp1 = dyn_cast<CastInst>(ICmp.getOperand(1))) {
-    // Not an extension from the same type?
-    Value *Op1Src = CastOp1->getOperand(0);
-    if (Op1Src->getType() != Op0Src->getType())
-      return nullptr;
-
     // If the signedness of the two casts doesn't agree (i.e. one is a sext
     // and the other is a zext), then we can't handle this.
-    if (CastOp1->getOpcode() != CastOp0->getOpcode())
+    if (CastOp0->getOpcode() != CastOp1->getOpcode())
+      return nullptr;
+
+    // Not an extension from the same type?
+    // TODO: Handle this by extending the narrower operand to the type of
+    //       the wider operand.
+    Value *Y = CastOp1->getOperand(0);
+    if (X->getType() != Y->getType())
       return nullptr;
 
-    // Deal with equality cases early.
+    // (zext X) == (zext Y) --> X == Y
+    // (sext X) == (sext Y) --> X == Y
     if (ICmp.isEquality())
-      return new ICmpInst(ICmp.getPredicate(), Op0Src, Op1Src);
+      return new ICmpInst(ICmp.getPredicate(), X, Y);
 
     // A signed comparison of sign extended values simplifies into a
     // signed comparison.
-    if (isSignedCmp && isSignedExt)
-      return new ICmpInst(ICmp.getPredicate(), Op0Src, Op1Src);
+    if (IsSignedCmp && IsSignedExt)
+      return new ICmpInst(ICmp.getPredicate(), X, Y);
 
     // The other three cases all fold into an unsigned comparison.
-    return new ICmpInst(ICmp.getUnsignedPredicate(), Op0Src, Op1Src);
+    return new ICmpInst(ICmp.getUnsignedPredicate(), X, Y);
   }
 
-  // If we aren't dealing with a constant on the RHS, exit early.
+  // Below here, we are only folding a compare with constant.
   auto *C = dyn_cast<Constant>(ICmp.getOperand(1));
   if (!C)
     return nullptr;
 
   // Compute the constant that would happen if we truncated to SrcTy then
   // re-extended to DestTy.
+  Type *SrcTy = CastOp0->getSrcTy();
+  Type *DestTy = CastOp0->getDestTy();
   Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy);
   Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy);
 
   // If the re-extended constant didn't change...
   if (Res2 == C) {
-    // Deal with equality cases early.
     if (ICmp.isEquality())
-      return new ICmpInst(ICmp.getPredicate(), Op0Src, Res1);
+      return new ICmpInst(ICmp.getPredicate(), X, Res1);
 
     // A signed comparison of sign extended values simplifies into a
     // signed comparison.
-    if (isSignedExt && isSignedCmp)
-      return new ICmpInst(ICmp.getPredicate(), Op0Src, Res1);
+    if (IsSignedExt && IsSignedCmp)
+      return new ICmpInst(ICmp.getPredicate(), X, Res1);
 
     // The other three cases all fold into an unsigned comparison.
-    return new ICmpInst(ICmp.getUnsignedPredicate(), Op0Src, Res1);
+    return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1);
   }
 
   // The re-extended constant changed, partly changed (in the case of a vector),
@@ -4130,19 +4093,62 @@ Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) {
   // expression), so the constant cannot be represented in the shorter type.
   // All the cases that fold to true or false will have already been handled
   // by SimplifyICmpInst, so only deal with the tricky case.
-  if (isSignedCmp || !isSignedExt || !isa<ConstantInt>(C))
+  if (IsSignedCmp || !IsSignedExt || !isa<ConstantInt>(C))
     return nullptr;
 
   // Is source op positive?
   // icmp ult (sext X), C --> icmp sgt X, -1
   if (ICmp.getPredicate() == ICmpInst::ICMP_ULT)
-    return new ICmpInst(CmpInst::ICMP_SGT, Op0Src,
-                        Constant::getAllOnesValue(SrcTy));
+    return new ICmpInst(CmpInst::ICMP_SGT, X, Constant::getAllOnesValue(SrcTy));
 
   // Is source op negative?
   // icmp ugt (sext X), C --> icmp slt X, 0
   assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!");
-  return new ICmpInst(CmpInst::ICMP_SLT, Op0Src, Constant::getNullValue(SrcTy));
+  return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy));
+}
+
+/// Handle icmp (cast x), (cast or constant).
+Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) {
+  auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0));
+  if (!CastOp0)
+    return nullptr;
+  if (!isa<Constant>(ICmp.getOperand(1)) && !isa<CastInst>(ICmp.getOperand(1)))
+    return nullptr;
+
+  Value *Op0Src = CastOp0->getOperand(0);
+  Type *SrcTy = CastOp0->getSrcTy();
+  Type *DestTy = CastOp0->getDestTy();
+
+  // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the
+  // integer type is the same size as the pointer type.
+  auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) {
+    if (isa<VectorType>(SrcTy)) {
+      SrcTy = cast<VectorType>(SrcTy)->getElementType();
+      DestTy = cast<VectorType>(DestTy)->getElementType();
+    }
+    return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth();
+  };
+  if (CastOp0->getOpcode() == Instruction::PtrToInt &&
+      CompatibleSizes(SrcTy, DestTy)) {
+    Value *NewOp1 = nullptr;
+    if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) {
+      Value *PtrSrc = PtrToIntOp1->getOperand(0);
+      if (PtrSrc->getType()->getPointerAddressSpace() ==
+          Op0Src->getType()->getPointerAddressSpace()) {
+        NewOp1 = PtrToIntOp1->getOperand(0);
+        // If the pointer types don't match, insert a bitcast.
+        if (Op0Src->getType() != NewOp1->getType())
+          NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType());
+      }
+    } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) {
+      NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy);
+    }
+
+    if (NewOp1)
+      return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1);
+  }
+
+  return foldICmpWithZextOrSext(ICmp);
 }
 
 static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) {