// Recursively descend the def-use lists from V to find non-bitcast users of
// bitcasts of V.
static void FindUses(Value *V, Function &F,
- SmallVectorImpl<std::pair<Use *, Function *>> &Uses) {
+ SmallVectorImpl<std::pair<Use *, Function *>> &Uses,
+ SmallPtrSetImpl<Constant *> &ConstantBCs) {
for (Use &U : V->uses()) {
if (BitCastOperator *BC = dyn_cast<BitCastOperator>(U.getUser()))
- FindUses(BC, F, Uses);
- else if (U.get()->getType() != F.getType())
+ FindUses(BC, F, Uses, ConstantBCs);
+ else if (U.get()->getType() != F.getType()) {
+ if (isa<Constant>(U.get())) {
+ // Only add constant bitcasts to the list once; they get RAUW'd
+ auto c = ConstantBCs.insert(cast<Constant>(U.get()));
+ if (!c.second) continue;
+ }
Uses.push_back(std::make_pair(&U, &F));
+ }
}
}
bool FixFunctionBitcasts::runOnModule(Module &M) {
SmallVector<std::pair<Use *, Function *>, 0> Uses;
+ SmallPtrSet<Constant *, 2> ConstantBCs;
// Collect all the places that need wrappers.
- for (Function &F : M)
- FindUses(&F, F, Uses);
+ for (Function &F : M) FindUses(&F, F, Uses, ConstantBCs);
DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;
; CHECK-LABEL: test:
; CHECK-NEXT: call .Lbitcast@FUNCTION{{$}}
+; CHECK-NEXT: call .Lbitcast@FUNCTION{{$}}
; CHECK-NEXT: call .Lbitcast.1@FUNCTION{{$}}
; CHECK-NEXT: i32.const $push[[L0:[0-9]+]]=, 0
; CHECK-NEXT: call .Lbitcast.2@FUNCTION, $pop[[L0]]{{$}}
+; CHECK-NEXT: i32.const $push[[L1:[0-9]+]]=, 0
+; CHECK-NEXT: call .Lbitcast.2@FUNCTION, $pop[[L1]]{{$}}
+; CHECK-NEXT: i32.const $push[[L2:[0-9]+]]=, 0
+; CHECK-NEXT: call .Lbitcast.2@FUNCTION, $pop[[L2]]{{$}}
+; CHECK-NEXT: call foo0@FUNCTION
; CHECK-NEXT: i32.call $drop=, .Lbitcast.3@FUNCTION{{$}}
; CHECK-NEXT: call foo2@FUNCTION{{$}}
+; CHECK-NEXT: call foo1@FUNCTION{{$}}
; CHECK-NEXT: call foo3@FUNCTION{{$}}
; CHECK-NEXT: .endfunc
define void @test() {
entry:
+ call void bitcast (void (i32)* @has_i32_arg to void ()*)()
call void bitcast (void (i32)* @has_i32_arg to void ()*)()
call void bitcast (i32 ()* @has_i32_ret to void ()*)()
call void bitcast (void ()* @foo0 to void (i32)*)(i32 0)
+ %p = bitcast void ()* @foo0 to void (i32)*
+ call void %p(i32 0)
+ %q = bitcast void ()* @foo0 to void (i32)*
+ call void %q(i32 0)
+ %r = bitcast void (i32)* %q to void ()*
+ call void %r()
%t = call i32 bitcast (void ()* @foo1 to i32 ()*)()
call void bitcast (void ()* @foo2 to void ()*)()
+ call void @foo1()
call void @foo3()
+
ret void
}