From: David Stuttard Date: Thu, 9 May 2019 15:02:10 +0000 (+0000) Subject: [CodeGenPrepare] Limit recursion depth for collectBitParts X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=f617e867aed8bbc4c46c872597cc011d85936678;p=llvm [CodeGenPrepare] Limit recursion depth for collectBitParts Summary: Seeing some issues for windows debug pathological cases with collectBitParts recursion (1525 levels of recursion!) Setting the limit to 64 as this should be sufficient - passes all lit cases Subscribers: llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D61728 Change-Id: I7f44cdc6c1badf1c2ccbf1b0c4b6afe27ecb39a1 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@360347 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index f9dc09e88e9..bf57fcdbdee 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -91,6 +91,10 @@ using namespace llvm::PatternMatch; STATISTIC(NumRemoved, "Number of unreachable basic blocks removed"); +// Max recursion depth for collectBitParts used when detecting bswap and +// bitreverse idioms +static const unsigned BitPartRecursionMaxDepth = 64; + //===----------------------------------------------------------------------===// // Local constant propagation. // @@ -2619,7 +2623,7 @@ struct BitPart { /// does not invalidate internal references (std::map instead of DenseMap). static const Optional & collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, - std::map> &BPS) { + std::map> &BPS, int Depth) { auto I = BPS.find(V); if (I != BPS.end()) return I->second; @@ -2627,13 +2631,19 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, auto &Result = BPS[V] = None; auto BitWidth = cast(V->getType())->getBitWidth(); + // Prevent stack overflow by limiting the recursion depth + if (Depth == BitPartRecursionMaxDepth) { + LLVM_DEBUG(dbgs() << "collectBitParts max recursion depth reached.\n"); + return Result; + } + if (Instruction *I = dyn_cast(V)) { // If this is an or instruction, it may be an inner node of the bswap. if (I->getOpcode() == Instruction::Or) { auto &A = collectBitParts(I->getOperand(0), MatchBSwaps, - MatchBitReversals, BPS); + MatchBitReversals, BPS, Depth + 1); auto &B = collectBitParts(I->getOperand(1), MatchBSwaps, - MatchBitReversals, BPS); + MatchBitReversals, BPS, Depth + 1); if (!A || !B) return Result; @@ -2666,7 +2676,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, return Result; auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps, - MatchBitReversals, BPS); + MatchBitReversals, BPS, Depth + 1); if (!Res) return Result; Result = Res; @@ -2698,7 +2708,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, return Result; auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps, - MatchBitReversals, BPS); + MatchBitReversals, BPS, Depth + 1); if (!Res) return Result; Result = Res; @@ -2713,7 +2723,7 @@ collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, // If this is a zext instruction zero extend the result. if (I->getOpcode() == Instruction::ZExt) { auto &Res = collectBitParts(I->getOperand(0), MatchBSwaps, - MatchBitReversals, BPS); + MatchBitReversals, BPS, Depth + 1); if (!Res) return Result; @@ -2775,7 +2785,7 @@ bool llvm::recognizeBSwapOrBitReverseIdiom( // Try to find all the pieces corresponding to the bswap. std::map> BPS; - auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS); + auto Res = collectBitParts(I, MatchBSwaps, MatchBitReversals, BPS, 0); if (!Res) return false; auto &BitProvenance = Res->Provenance;