From: Sanjay Patel Date: Tue, 20 Aug 2019 18:15:17 +0000 (+0000) Subject: [InstCombine] add helper function for icmp+zext/sext; NFC X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=e687343b4cc8a7f7b3903788ecdd628239772eba;p=llvm [InstCombine] add helper function for icmp+zext/sext; NFC git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@369421 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index dc6dcc36061..b318d0b2de9 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -4026,103 +4026,66 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { return nullptr; } -/// Handle icmp (cast x), (cast or constant). -Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) { - auto *CastOp0 = dyn_cast(ICmp.getOperand(0)); - if (!CastOp0) - return nullptr; - if (!isa(ICmp.getOperand(1)) && !isa(ICmp.getOperand(1))) - return nullptr; - - Value *Op0Src = CastOp0->getOperand(0); - Type *SrcTy = CastOp0->getSrcTy(); - Type *DestTy = CastOp0->getDestTy(); - - // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the - // integer type is the same size as the pointer type. - auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) { - if (isa(SrcTy)) { - SrcTy = cast(SrcTy)->getElementType(); - DestTy = cast(DestTy)->getElementType(); - } - return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); - }; - if (CastOp0->getOpcode() == Instruction::PtrToInt && - CompatibleSizes(SrcTy, DestTy)) { - Value *NewOp1 = nullptr; - if (auto *PtrToIntOp1 = dyn_cast(ICmp.getOperand(1))) { - Value *PtrSrc = PtrToIntOp1->getOperand(0); - if (PtrSrc->getType()->getPointerAddressSpace() == - Op0Src->getType()->getPointerAddressSpace()) { - NewOp1 = PtrToIntOp1->getOperand(0); - // If the pointer types don't match, insert a bitcast. - if (Op0Src->getType() != NewOp1->getType()) - NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType()); - } - } else if (auto *RHSC = dyn_cast(ICmp.getOperand(1))) { - NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy); - } - - if (NewOp1) - return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); - } - - // The code below only handles extension cast instructions, so far. - // Enforce this. - if (CastOp0->getOpcode() != Instruction::ZExt && - CastOp0->getOpcode() != Instruction::SExt) +static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp) { + assert(isa(ICmp.getOperand(0)) && "Expected cast for operand 0"); + auto *CastOp0 = cast(ICmp.getOperand(0)); + Value *X; + if (!match(CastOp0, m_ZExtOrSExt(m_Value(X)))) return nullptr; - bool isSignedExt = CastOp0->getOpcode() == Instruction::SExt; - bool isSignedCmp = ICmp.isSigned(); - + bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt; + bool IsSignedCmp = ICmp.isSigned(); if (auto *CastOp1 = dyn_cast(ICmp.getOperand(1))) { - // Not an extension from the same type? - Value *Op1Src = CastOp1->getOperand(0); - if (Op1Src->getType() != Op0Src->getType()) - return nullptr; - // If the signedness of the two casts doesn't agree (i.e. one is a sext // and the other is a zext), then we can't handle this. - if (CastOp1->getOpcode() != CastOp0->getOpcode()) + if (CastOp0->getOpcode() != CastOp1->getOpcode()) + return nullptr; + + // Not an extension from the same type? + // TODO: Handle this by extending the narrower operand to the type of + // the wider operand. + Value *Y = CastOp1->getOperand(0); + if (X->getType() != Y->getType()) return nullptr; - // Deal with equality cases early. + // (zext X) == (zext Y) --> X == Y + // (sext X) == (sext Y) --> X == Y if (ICmp.isEquality()) - return new ICmpInst(ICmp.getPredicate(), Op0Src, Op1Src); + return new ICmpInst(ICmp.getPredicate(), X, Y); // A signed comparison of sign extended values simplifies into a // signed comparison. - if (isSignedCmp && isSignedExt) - return new ICmpInst(ICmp.getPredicate(), Op0Src, Op1Src); + if (IsSignedCmp && IsSignedExt) + return new ICmpInst(ICmp.getPredicate(), X, Y); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICmp.getUnsignedPredicate(), Op0Src, Op1Src); + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Y); } - // If we aren't dealing with a constant on the RHS, exit early. + // Below here, we are only folding a compare with constant. auto *C = dyn_cast(ICmp.getOperand(1)); if (!C) return nullptr; // Compute the constant that would happen if we truncated to SrcTy then // re-extended to DestTy. + Type *SrcTy = CastOp0->getSrcTy(); + Type *DestTy = CastOp0->getDestTy(); Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy); Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy); // If the re-extended constant didn't change... if (Res2 == C) { - // Deal with equality cases early. if (ICmp.isEquality()) - return new ICmpInst(ICmp.getPredicate(), Op0Src, Res1); + return new ICmpInst(ICmp.getPredicate(), X, Res1); // A signed comparison of sign extended values simplifies into a // signed comparison. - if (isSignedExt && isSignedCmp) - return new ICmpInst(ICmp.getPredicate(), Op0Src, Res1); + if (IsSignedExt && IsSignedCmp) + return new ICmpInst(ICmp.getPredicate(), X, Res1); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICmp.getUnsignedPredicate(), Op0Src, Res1); + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1); } // The re-extended constant changed, partly changed (in the case of a vector), @@ -4130,19 +4093,62 @@ Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) { // expression), so the constant cannot be represented in the shorter type. // All the cases that fold to true or false will have already been handled // by SimplifyICmpInst, so only deal with the tricky case. - if (isSignedCmp || !isSignedExt || !isa(C)) + if (IsSignedCmp || !IsSignedExt || !isa(C)) return nullptr; // Is source op positive? // icmp ult (sext X), C --> icmp sgt X, -1 if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) - return new ICmpInst(CmpInst::ICMP_SGT, Op0Src, - Constant::getAllOnesValue(SrcTy)); + return new ICmpInst(CmpInst::ICMP_SGT, X, Constant::getAllOnesValue(SrcTy)); // Is source op negative? // icmp ugt (sext X), C --> icmp slt X, 0 assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); - return new ICmpInst(CmpInst::ICMP_SLT, Op0Src, Constant::getNullValue(SrcTy)); + return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy)); +} + +/// Handle icmp (cast x), (cast or constant). +Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) { + auto *CastOp0 = dyn_cast(ICmp.getOperand(0)); + if (!CastOp0) + return nullptr; + if (!isa(ICmp.getOperand(1)) && !isa(ICmp.getOperand(1))) + return nullptr; + + Value *Op0Src = CastOp0->getOperand(0); + Type *SrcTy = CastOp0->getSrcTy(); + Type *DestTy = CastOp0->getDestTy(); + + // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the + // integer type is the same size as the pointer type. + auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) { + if (isa(SrcTy)) { + SrcTy = cast(SrcTy)->getElementType(); + DestTy = cast(DestTy)->getElementType(); + } + return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); + }; + if (CastOp0->getOpcode() == Instruction::PtrToInt && + CompatibleSizes(SrcTy, DestTy)) { + Value *NewOp1 = nullptr; + if (auto *PtrToIntOp1 = dyn_cast(ICmp.getOperand(1))) { + Value *PtrSrc = PtrToIntOp1->getOperand(0); + if (PtrSrc->getType()->getPointerAddressSpace() == + Op0Src->getType()->getPointerAddressSpace()) { + NewOp1 = PtrToIntOp1->getOperand(0); + // If the pointer types don't match, insert a bitcast. + if (Op0Src->getType() != NewOp1->getType()) + NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType()); + } + } else if (auto *RHSC = dyn_cast(ICmp.getOperand(1))) { + NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy); + } + + if (NewOp1) + return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); + } + + return foldICmpWithZextOrSext(ICmp); } static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) {