From f8fc02fbd4ab33383c010d33675acf9763d0bd44 Mon Sep 17 00:00:00 2001 From: Martin Bohme Date: Tue, 24 Oct 2017 20:40:02 +0000 Subject: [PATCH] Revert "[CodeGen][ExpandMemcmp][NFC] Allow memcmp to expand to vector loads (1)" This reverts commit r316417, which causes internal compiles to OOM. I don't unfortunately have a self-contained test case but will follow up with courbet. git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@316497 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/CodeGen/CodeGenPrepare.cpp | 397 ++++++++++++++++----------------- 1 file changed, 196 insertions(+), 201 deletions(-) diff --git a/lib/CodeGen/CodeGenPrepare.cpp b/lib/CodeGen/CodeGenPrepare.cpp index 68e53d55b12..9f0c1f7fb1a 100644 --- a/lib/CodeGen/CodeGenPrepare.cpp +++ b/lib/CodeGen/CodeGenPrepare.cpp @@ -1710,61 +1710,43 @@ class MemCmpExpansion { ResultBlock() = default; }; - CallInst *const CI; + CallInst *CI; ResultBlock ResBlock; - const unsigned Size; unsigned MaxLoadSize; - unsigned NumLoadsNonOneByte; - const unsigned NumLoadsPerBlock; + unsigned NumBlocks; + unsigned NumBlocksNonOneByte; + unsigned NumLoadsPerBlock; std::vector LoadCmpBlocks; BasicBlock *EndBlock; PHINode *PhiRes; - const bool IsUsedForZeroCmp; + bool IsUsedForZeroCmp; const DataLayout &DL; IRBuilder<> Builder; - // Represents the decomposition in blocks of the expansion. For example, - // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and - // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}. - // TODO(courbet): Involve the target more in this computation. On X86, 7 - // bytes can be done more efficiently with two overlaping 4-byte loads than - // covering the interval with [{4, 0},{2, 4},{1, 6}}. - struct LoadEntry { - LoadEntry(unsigned LoadSize, uint64_t Offset) - : LoadSize(LoadSize), Offset(Offset) { - assert(Offset % LoadSize == 0 && "invalid load entry"); - } - - uint64_t getGEPIndex() const { return Offset / LoadSize; } - - // The size of the load for this block, in bytes. - const unsigned LoadSize; - // The offset of this load WRT the base pointer, in bytes. - const uint64_t Offset; - }; - SmallVector LoadSequence; + unsigned calculateNumBlocks(unsigned Size); void createLoadCmpBlocks(); void createResultBlock(); void setupResultBlockPHINodes(); void setupEndBlockPHINodes(); - Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); - void emitLoadCompareBlock(unsigned BlockIndex); - void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, - unsigned &LoadIndex); - void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex); + void emitLoadCompareBlock(unsigned Index, unsigned LoadSize, + unsigned GEPIndex); + Value *getCompareLoadPairs(unsigned Index, unsigned Size, + unsigned &NumBytesProcessed); + void emitLoadCompareBlockMultipleLoads(unsigned Index, unsigned Size, + unsigned &NumBytesProcessed); + void emitLoadCompareByteBlock(unsigned Index, unsigned GEPIndex); void emitMemCmpResultBlock(); - Value *getMemCmpExpansionZeroCase(); - Value *getMemCmpEqZeroOneBlock(); - Value *getMemCmpOneBlock(); + Value *getMemCmpExpansionZeroCase(unsigned Size); + Value *getMemCmpEqZeroOneBlock(unsigned Size); + Value *getMemCmpOneBlock(unsigned Size); + unsigned getLoadSize(unsigned Size); + unsigned getNumLoads(unsigned Size); - public: +public: MemCmpExpansion(CallInst *CI, uint64_t Size, unsigned MaxLoadSize, unsigned NumLoadsPerBlock, const DataLayout &DL); - unsigned getNumBlocks(); - unsigned getNumLoads() const { return LoadSequence.size(); } - - Value *getMemCmpExpansion(); + Value *getMemCmpExpansion(uint64_t Size); }; } // end anonymous namespace @@ -1777,56 +1759,43 @@ class MemCmpExpansion { // return from. // 3. ResultBlock, block to branch to for early exit when a // LoadCmpBlock finds a difference. -MemCmpExpansion::MemCmpExpansion(CallInst *const CI, uint64_t Size, - const unsigned MaxLoadSize, - const unsigned LoadsPerBlock, +MemCmpExpansion::MemCmpExpansion(CallInst *CI, uint64_t Size, + unsigned MaxLoadSize, unsigned LoadsPerBlock, const DataLayout &TheDataLayout) - : CI(CI), - Size(Size), - MaxLoadSize(MaxLoadSize), - NumLoadsNonOneByte(0), - NumLoadsPerBlock(LoadsPerBlock), - IsUsedForZeroCmp(isOnlyUsedInZeroEqualityComparison(CI)), - DL(TheDataLayout), - Builder(CI) { - // Scale the max size down if the target can load more bytes than we need. - while (this->MaxLoadSize > Size) { - this->MaxLoadSize /= 2; - } - // Compute the decomposition. - unsigned LoadSize = this->MaxLoadSize; - assert(Size > 0 && "zero blocks"); - uint64_t Offset = 0; - while (Size) { - assert(LoadSize > 0 && "zero load size"); - const uint64_t NumLoadsForThisSize = Size / LoadSize; - if (NumLoadsForThisSize > 0) { - for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { - LoadSequence.push_back({LoadSize, Offset}); - Offset += LoadSize; - } - if (LoadSize > 1) { - ++NumLoadsNonOneByte; - } - Size = Size % LoadSize; - } - // FIXME: This can result in a non-native load size (e.g. X86-32+SSE can - // load 16 and 4 but not 8), which throws the load count off (e.g. in the - // aforementioned case, 16 bytes will count for 2 loads but will generate - // 4). - LoadSize /= 2; + : CI(CI), MaxLoadSize(MaxLoadSize), NumLoadsPerBlock(LoadsPerBlock), + DL(TheDataLayout), Builder(CI) { + // A memcmp with zero-comparison with only one block of load and compare does + // not need to set up any extra blocks. This case could be handled in the DAG, + // but since we have all of the machinery to flexibly expand any memcpy here, + // we choose to handle this case too to avoid fragmented lowering. + IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); + NumBlocks = calculateNumBlocks(Size); + if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || NumBlocks != 1) { + BasicBlock *StartBlock = CI->getParent(); + EndBlock = StartBlock->splitBasicBlock(CI, "endblock"); + setupEndBlockPHINodes(); + createResultBlock(); + + // If return value of memcmp is not used in a zero equality, we need to + // calculate which source was larger. The calculation requires the + // two loaded source values of each load compare block. + // These will be saved in the phi nodes created by setupResultBlockPHINodes. + if (!IsUsedForZeroCmp) + setupResultBlockPHINodes(); + + // Create the number of required load compare basic blocks. + createLoadCmpBlocks(); + + // Update the terminator added by splitBasicBlock to branch to the first + // LoadCmpBlock. + StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); } -} -unsigned MemCmpExpansion::getNumBlocks() { - if (IsUsedForZeroCmp) - return getNumLoads() / NumLoadsPerBlock + - (getNumLoads() % NumLoadsPerBlock != 0 ? 1 : 0); - return getNumLoads(); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); } void MemCmpExpansion::createLoadCmpBlocks() { - for (unsigned i = 0; i < getNumBlocks(); i++) { + for (unsigned i = 0; i < NumBlocks; i++) { BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", EndBlock->getParent(), EndBlock); LoadCmpBlocks.push_back(BB); @@ -1842,12 +1811,12 @@ void MemCmpExpansion::createResultBlock() { // It loads 1 byte from each source of the memcmp parameters with the given // GEPIndex. It then subtracts the two loaded values and adds this result to the // final phi node for selecting the memcmp result. -void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, +void MemCmpExpansion::emitLoadCompareByteBlock(unsigned Index, unsigned GEPIndex) { Value *Source1 = CI->getArgOperand(0); Value *Source2 = CI->getArgOperand(1); - Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); + Builder.SetInsertPoint(LoadCmpBlocks[Index]); Type *LoadSizeType = Type::getInt8Ty(CI->getContext()); // Cast source to LoadSizeType*. if (Source1->getType() != LoadSizeType) @@ -1870,15 +1839,15 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext())); Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2); - PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]); + PhiRes->addIncoming(Diff, LoadCmpBlocks[Index]); - if (BlockIndex < (LoadCmpBlocks.size() - 1)) { + if (Index < (LoadCmpBlocks.size() - 1)) { // Early exit branch if difference found to EndBlock. Otherwise, continue to // next LoadCmpBlock, Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, ConstantInt::get(Diff->getType(), 0)); BranchInst *CmpBr = - BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); + BranchInst::Create(EndBlock, LoadCmpBlocks[Index + 1], Cmp); Builder.Insert(CmpBr); } else { // The last block has an unconditional branch to EndBlock. @@ -1887,37 +1856,42 @@ void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, } } +unsigned MemCmpExpansion::getNumLoads(unsigned Size) { + return (Size / MaxLoadSize) + countPopulation(Size % MaxLoadSize); +} + +unsigned MemCmpExpansion::getLoadSize(unsigned Size) { + return MinAlign(PowerOf2Floor(Size), MaxLoadSize); +} + /// Generate an equality comparison for one or more pairs of loaded values. /// This is used in the case where the memcmp() call is compared equal or not /// equal to zero. -Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, - unsigned &LoadIndex) { - assert(LoadIndex < getNumLoads() && - "getCompareLoadPairs() called with no remaining loads"); +Value *MemCmpExpansion::getCompareLoadPairs(unsigned Index, unsigned Size, + unsigned &NumBytesProcessed) { std::vector XorList, OrList; Value *Diff; - const unsigned NumLoads = - std::min(getNumLoads() - LoadIndex, NumLoadsPerBlock); + unsigned RemainingBytes = Size - NumBytesProcessed; + unsigned NumLoadsRemaining = getNumLoads(RemainingBytes); + unsigned NumLoads = std::min(NumLoadsRemaining, NumLoadsPerBlock); // For a single-block expansion, start inserting before the memcmp call. if (LoadCmpBlocks.empty()) Builder.SetInsertPoint(CI); else - Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); + Builder.SetInsertPoint(LoadCmpBlocks[Index]); Value *Cmp = nullptr; - // If we have multiple loads per block, we need to generate a composite - // comparison using xor+or. The type for the combinations is the largest load - // type. - IntegerType *const MaxLoadType = - NumLoads == 1 ? nullptr - : IntegerType::get(CI->getContext(), MaxLoadSize * 8); - for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { - const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; - - IntegerType *LoadSizeType = - IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); + for (unsigned i = 0; i < NumLoads; ++i) { + unsigned LoadSize = getLoadSize(RemainingBytes); + unsigned GEPIndex = NumBytesProcessed / LoadSize; + NumBytesProcessed += LoadSize; + RemainingBytes -= LoadSize; + + Type *LoadSizeType = IntegerType::get(CI->getContext(), LoadSize * 8); + Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); + assert(LoadSize <= MaxLoadSize && "Unexpected load type"); Value *Source1 = CI->getArgOperand(0); Value *Source2 = CI->getArgOperand(1); @@ -1928,14 +1902,12 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, if (Source2->getType() != LoadSizeType) Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); - // Get the base address using a GEP. - if (CurLoadEntry.Offset != 0) { - Source1 = Builder.CreateGEP( - LoadSizeType, Source1, - ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); - Source2 = Builder.CreateGEP( - LoadSizeType, Source2, - ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); + // Get the base address using the GEPIndex. + if (GEPIndex != 0) { + Source1 = Builder.CreateGEP(LoadSizeType, Source1, + ConstantInt::get(LoadSizeType, GEPIndex)); + Source2 = Builder.CreateGEP(LoadSizeType, Source2, + ConstantInt::get(LoadSizeType, GEPIndex)); } // Get a constant or load a value for each source address. @@ -1992,13 +1964,13 @@ Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, return Cmp; } -void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, - unsigned &LoadIndex) { - Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); +void MemCmpExpansion::emitLoadCompareBlockMultipleLoads( + unsigned Index, unsigned Size, unsigned &NumBytesProcessed) { + Value *Cmp = getCompareLoadPairs(Index, Size, NumBytesProcessed); - BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) + BasicBlock *NextBB = (Index == (LoadCmpBlocks.size() - 1)) ? EndBlock - : LoadCmpBlocks[BlockIndex + 1]; + : LoadCmpBlocks[Index + 1]; // Early exit branch if difference found to ResultBlock. Otherwise, // continue to next LoadCmpBlock or EndBlock. BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); @@ -2007,9 +1979,9 @@ void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 // since early exit to ResultBlock was not taken (no difference was found in // any of the bytes). - if (BlockIndex == LoadCmpBlocks.size() - 1) { + if (Index == LoadCmpBlocks.size() - 1) { Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); - PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); + PhiRes->addIncoming(Zero, LoadCmpBlocks[Index]); } } @@ -2022,39 +1994,33 @@ void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with // a special case through emitLoadCompareByteBlock. The special handling can // simply subtract the loaded values and add it to the result phi node. -void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { - // There is one load per block in this case, BlockIndex == LoadIndex. - const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; - - if (CurLoadEntry.LoadSize == 1) { - MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, - CurLoadEntry.getGEPIndex()); +void MemCmpExpansion::emitLoadCompareBlock(unsigned Index, unsigned LoadSize, + unsigned GEPIndex) { + if (LoadSize == 1) { + MemCmpExpansion::emitLoadCompareByteBlock(Index, GEPIndex); return; } - Type *LoadSizeType = - IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); + Type *LoadSizeType = IntegerType::get(CI->getContext(), LoadSize * 8); Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); - assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); + assert(LoadSize <= MaxLoadSize && "Unexpected load type"); Value *Source1 = CI->getArgOperand(0); Value *Source2 = CI->getArgOperand(1); - Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); + Builder.SetInsertPoint(LoadCmpBlocks[Index]); // Cast source to LoadSizeType*. if (Source1->getType() != LoadSizeType) Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); if (Source2->getType() != LoadSizeType) Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); - // Get the base address using a GEP. - if (CurLoadEntry.Offset != 0) { - Source1 = Builder.CreateGEP( - LoadSizeType, Source1, - ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); - Source2 = Builder.CreateGEP( - LoadSizeType, Source2, - ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); + // Get the base address using the GEPIndex. + if (GEPIndex != 0) { + Source1 = Builder.CreateGEP(LoadSizeType, Source1, + ConstantInt::get(LoadSizeType, GEPIndex)); + Source2 = Builder.CreateGEP(LoadSizeType, Source2, + ConstantInt::get(LoadSizeType, GEPIndex)); } // Load LoadSizeType from the base address. @@ -2076,14 +2042,14 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { // Add the loaded values to the phi nodes for calculating memcmp result only // if result is not used in a zero equality. if (!IsUsedForZeroCmp) { - ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]); - ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]); + ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[Index]); + ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[Index]); } Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2); - BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) + BasicBlock *NextBB = (Index == (LoadCmpBlocks.size() - 1)) ? EndBlock - : LoadCmpBlocks[BlockIndex + 1]; + : LoadCmpBlocks[Index + 1]; // Early exit branch if difference found to ResultBlock. Otherwise, continue // to next LoadCmpBlock or EndBlock. BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); @@ -2092,9 +2058,9 @@ void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 // since early exit to ResultBlock was not taken (no difference was found in // any of the bytes). - if (BlockIndex == LoadCmpBlocks.size() - 1) { + if (Index == LoadCmpBlocks.size() - 1) { Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); - PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); + PhiRes->addIncoming(Zero, LoadCmpBlocks[Index]); } } @@ -2128,14 +2094,34 @@ void MemCmpExpansion::emitMemCmpResultBlock() { PhiRes->addIncoming(Res, ResBlock.BB); } +unsigned MemCmpExpansion::calculateNumBlocks(unsigned Size) { + unsigned NumBlocks = 0; + bool HaveOneByteLoad = false; + unsigned RemainingSize = Size; + unsigned LoadSize = MaxLoadSize; + while (RemainingSize) { + if (LoadSize == 1) + HaveOneByteLoad = true; + NumBlocks += RemainingSize / LoadSize; + RemainingSize = RemainingSize % LoadSize; + LoadSize = LoadSize / 2; + } + NumBlocksNonOneByte = HaveOneByteLoad ? (NumBlocks - 1) : NumBlocks; + + if (IsUsedForZeroCmp) + NumBlocks = NumBlocks / NumLoadsPerBlock + + (NumBlocks % NumLoadsPerBlock != 0 ? 1 : 0); + + return NumBlocks; +} + void MemCmpExpansion::setupResultBlockPHINodes() { Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); Builder.SetInsertPoint(ResBlock.BB); - // Note: this assumes one load per block. ResBlock.PhiSrc1 = - Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); + Builder.CreatePHI(MaxLoadType, NumBlocksNonOneByte, "phi.src1"); ResBlock.PhiSrc2 = - Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); + Builder.CreatePHI(MaxLoadType, NumBlocksNonOneByte, "phi.src2"); } void MemCmpExpansion::setupEndBlockPHINodes() { @@ -2143,13 +2129,12 @@ void MemCmpExpansion::setupEndBlockPHINodes() { PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); } -Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { - unsigned LoadIndex = 0; +Value *MemCmpExpansion::getMemCmpExpansionZeroCase(unsigned Size) { + unsigned NumBytesProcessed = 0; // This loop populates each of the LoadCmpBlocks with the IR sequence to // handle multiple loads per block. - for (unsigned I = 0; I < getNumBlocks(); ++I) { - emitLoadCompareBlockMultipleLoads(I, LoadIndex); - } + for (unsigned i = 0; i < NumBlocks; ++i) + emitLoadCompareBlockMultipleLoads(i, Size, NumBytesProcessed); emitMemCmpResultBlock(); return PhiRes; @@ -2158,16 +2143,15 @@ Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { /// A memcmp expansion that compares equality with 0 and only has one block of /// load and compare can bypass the compare, branch, and phi IR that is required /// in the general case. -Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { - unsigned LoadIndex = 0; - Value *Cmp = getCompareLoadPairs(0, LoadIndex); - assert(LoadIndex == getNumLoads() && "some entries were not consumed"); +Value *MemCmpExpansion::getMemCmpEqZeroOneBlock(unsigned Size) { + unsigned NumBytesProcessed = 0; + Value *Cmp = getCompareLoadPairs(0, Size, NumBytesProcessed); return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); } /// A memcmp expansion that only has one block of load and compare can bypass /// the compare, branch, and phi IR that is required in the general case. -Value *MemCmpExpansion::getMemCmpOneBlock() { +Value *MemCmpExpansion::getMemCmpOneBlock(unsigned Size) { assert(NumLoadsPerBlock == 1 && "Only handles one load pair per block"); Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); @@ -2214,42 +2198,37 @@ Value *MemCmpExpansion::getMemCmpOneBlock() { // This function expands the memcmp call into an inline expansion and returns // the memcmp result. -Value *MemCmpExpansion::getMemCmpExpansion() { - // A memcmp with zero-comparison with only one block of load and compare does - // not need to set up any extra blocks. This case could be handled in the DAG, - // but since we have all of the machinery to flexibly expand any memcpy here, - // we choose to handle this case too to avoid fragmented lowering. - if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || getNumBlocks() != 1) { - BasicBlock *StartBlock = CI->getParent(); - EndBlock = StartBlock->splitBasicBlock(CI, "endblock"); - setupEndBlockPHINodes(); - createResultBlock(); - - // If return value of memcmp is not used in a zero equality, we need to - // calculate which source was larger. The calculation requires the - // two loaded source values of each load compare block. - // These will be saved in the phi nodes created by setupResultBlockPHINodes. - if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); - - // Create the number of required load compare basic blocks. - createLoadCmpBlocks(); - - // Update the terminator added by splitBasicBlock to branch to the first - // LoadCmpBlock. - StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); - } - - Builder.SetCurrentDebugLocation(CI->getDebugLoc()); - +Value *MemCmpExpansion::getMemCmpExpansion(uint64_t Size) { if (IsUsedForZeroCmp) - return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() - : getMemCmpExpansionZeroCase(); + return NumBlocks == 1 ? getMemCmpEqZeroOneBlock(Size) : + getMemCmpExpansionZeroCase(Size); // TODO: Handle more than one load pair per block in getMemCmpOneBlock(). - if (getNumBlocks() == 1 && NumLoadsPerBlock == 1) return getMemCmpOneBlock(); - - for (unsigned I = 0; I < getNumBlocks(); ++I) { - emitLoadCompareBlock(I); + if (NumBlocks == 1 && NumLoadsPerBlock == 1) + return getMemCmpOneBlock(Size); + + // This loop calls emitLoadCompareBlock for comparing Size bytes of the two + // memcmp sources. It starts with loading using the maximum load size set by + // the target. It processes any remaining bytes using a load size which is the + // next smallest power of 2. + unsigned LoadSize = MaxLoadSize; + unsigned NumBytesToBeProcessed = Size; + unsigned Index = 0; + while (NumBytesToBeProcessed) { + // Calculate how many blocks we can create with the current load size. + unsigned NumBlocks = NumBytesToBeProcessed / LoadSize; + unsigned GEPIndex = (Size - NumBytesToBeProcessed) / LoadSize; + NumBytesToBeProcessed = NumBytesToBeProcessed % LoadSize; + + // For each NumBlocks, populate the instruction sequence for loading and + // comparing LoadSize bytes. + while (NumBlocks--) { + emitLoadCompareBlock(Index, LoadSize, GEPIndex); + Index++; + GEPIndex++; + } + // Get the next LoadSize to use. + LoadSize = LoadSize / 2; } emitMemCmpResultBlock(); @@ -2333,6 +2312,12 @@ static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, const TargetLowering *TLI, const DataLayout *DL) { NumMemCmpCalls++; + // TTI call to check if target would like to expand memcmp. Also, get the + // MaxLoadSize. + unsigned MaxLoadSize; + if (!TTI->enableMemCmpExpansion(MaxLoadSize)) + return false; + // Early exit from expansion if -Oz. if (CI->getFunction()->optForMinSize()) return false; @@ -2343,26 +2328,36 @@ static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, NumMemCmpNotConstant++; return false; } - const uint64_t SizeVal = SizeCast->getZExtValue(); - // TTI call to check if target would like to expand memcmp. Also, get the - // max LoadSize. - unsigned MaxLoadSize; - if (!TTI->enableMemCmpExpansion(MaxLoadSize)) return false; + // Scale the max size down if the target can load more bytes than we need. + uint64_t SizeVal = SizeCast->getZExtValue(); + if (MaxLoadSize > SizeVal) + MaxLoadSize = 1 << SizeCast->getValue().logBase2(); - MemCmpExpansion Expansion(CI, SizeVal, MaxLoadSize, MemCmpNumLoadsPerBlock, - *DL); + // Calculate how many load pairs are needed for the constant size. + unsigned NumLoads = 0; + unsigned RemainingSize = SizeVal; + unsigned LoadSize = MaxLoadSize; + while (RemainingSize) { + NumLoads += RemainingSize / LoadSize; + RemainingSize = RemainingSize % LoadSize; + LoadSize = LoadSize / 2; + } // Don't expand if this will require more loads than desired by the target. - if (Expansion.getNumLoads() > - TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize())) { + if (NumLoads > TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize())) { NumMemCmpGreaterThanMax++; return false; } NumMemCmpInlined++; - Value *Res = Expansion.getMemCmpExpansion(); + // MemCmpHelper object creates and sets up basic blocks required for + // expanding memcmp with size SizeVal. + unsigned NumLoadsPerBlock = MemCmpNumLoadsPerBlock; + MemCmpExpansion MemCmpHelper(CI, SizeVal, MaxLoadSize, NumLoadsPerBlock, *DL); + + Value *Res = MemCmpHelper.getMemCmpExpansion(SizeVal); // Replace call with result of expansion and erase call. CI->replaceAllUsesWith(Res); -- 2.40.0