[AArch64][GlobalISel] Support sibling calls with mismatched calling conventions
authorJessica Paquette <jpaquette@apple.com>
Tue, 10 Sep 2019 23:25:12 +0000 (23:25 +0000)
committerJessica Paquette <jpaquette@apple.com>
Tue, 10 Sep 2019 23:25:12 +0000 (23:25 +0000)
Add support for sibcalling calls whose calling convention differs from the
caller's.

- Port over `CCState::resultsCombatible` from CallingConvLower.cpp into
  CallLowering. This is used to verify that the way the caller and callee CC
  handle incoming arguments matches up.

- Add `CallLowering::analyzeCallResult`. This is basically a port of
  `CCState::AnalyzeCallResult`, but using `ArgInfo` rather than `ISD::InputArg`.

- Add `AArch64CallLowering::doCallerAndCalleePassArgsTheSameWay`. This checks
  that the calling conventions are compatible, and that the caller and callee
  preserve the same registers.

For testing:

- Update call-translator-tail-call.ll to show that we can now handle this.

- Add a GISel line to tailcall-ccmismatch.ll to show that we will not tail call
  when the regmasks don't line up.

Differential Revision: https://reviews.llvm.org/D67361

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@371570 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/CodeGen/GlobalISel/CallLowering.h
lib/CodeGen/GlobalISel/CallLowering.cpp
lib/Target/AArch64/AArch64CallLowering.cpp
lib/Target/AArch64/AArch64CallLowering.h
test/CodeGen/AArch64/GlobalISel/call-translator-tail-call.ll
test/CodeGen/AArch64/tailcall-ccmismatch.ll

index cfdf3f5bf901d1768379ac7fda01bc0dad19ae93..e2135fc877368ebb0063cabb242743f43a503cc0 100644 (file)
@@ -211,6 +211,24 @@ protected:
                          SmallVectorImpl<ArgInfo> &Args,
                          ValueHandler &Handler) const;
 
+  /// Analyze the return values of a call, incorporating info about the passed
+  /// values into \p CCState.
+  bool analyzeCallResult(CCState &CCState, SmallVectorImpl<ArgInfo> &Args,
+                         CCAssignFn &Fn) const;
+
+  /// \returns True if the calling convention for a callee and its caller pass
+  /// results in the same way. Typically used for tail call eligibility checks.
+  ///
+  /// \p Info is the CallLoweringInfo for the call.
+  /// \p MF is the MachineFunction for the caller.
+  /// \p InArgs contains the results of the call.
+  /// \p CalleeAssignFn is the CCAssignFn to be used for the callee.
+  /// \p CallerAssignFn is the CCAssignFn to be used for the caller.
+  bool resultsCompatible(CallLoweringInfo &Info, MachineFunction &MF,
+                         SmallVectorImpl<ArgInfo> &InArgs,
+                         CCAssignFn &CalleeAssignFn,
+                         CCAssignFn &CallerAssignFn) const;
+
 public:
   CallLowering(const TargetLowering *TLI) : TLI(TLI) {}
   virtual ~CallLowering() = default;
index 1c8e45418176e12f616654949703cb36b272e69a..04aa072bc31131f8f7db25138ab16b0217ceed9e 100644 (file)
@@ -370,6 +370,75 @@ bool CallLowering::handleAssignments(CCState &CCInfo,
   return true;
 }
 
+bool CallLowering::analyzeCallResult(CCState &CCState,
+                                     SmallVectorImpl<ArgInfo> &Args,
+                                     CCAssignFn &Fn) const {
+  for (unsigned i = 0, e = Args.size(); i < e; ++i) {
+    MVT VT = MVT::getVT(Args[i].Ty);
+    if (Fn(i, VT, VT, CCValAssign::Full, Args[i].Flags[0], CCState)) {
+      // Bail out on anything we can't handle.
+      LLVM_DEBUG(dbgs() << "Cannot analyze " << EVT(VT).getEVTString()
+                        << " (arg number = " << i << "\n");
+      return false;
+    }
+  }
+  return true;
+}
+
+bool CallLowering::resultsCompatible(CallLoweringInfo &Info,
+                                     MachineFunction &MF,
+                                     SmallVectorImpl<ArgInfo> &InArgs,
+                                     CCAssignFn &CalleeAssignFn,
+                                     CCAssignFn &CallerAssignFn) const {
+  const Function &F = MF.getFunction();
+  CallingConv::ID CalleeCC = Info.CallConv;
+  CallingConv::ID CallerCC = F.getCallingConv();
+
+  if (CallerCC == CalleeCC)
+    return true;
+
+  SmallVector<CCValAssign, 16> ArgLocs1;
+  CCState CCInfo1(CalleeCC, false, MF, ArgLocs1, F.getContext());
+  if (!analyzeCallResult(CCInfo1, InArgs, CalleeAssignFn))
+    return false;
+
+  SmallVector<CCValAssign, 16> ArgLocs2;
+  CCState CCInfo2(CallerCC, false, MF, ArgLocs2, F.getContext());
+  if (!analyzeCallResult(CCInfo2, InArgs, CallerAssignFn))
+    return false;
+
+  // We need the argument locations to match up exactly. If there's more in
+  // one than the other, then we are done.
+  if (ArgLocs1.size() != ArgLocs2.size())
+    return false;
+
+  // Make sure that each location is passed in exactly the same way.
+  for (unsigned i = 0, e = ArgLocs1.size(); i < e; ++i) {
+    const CCValAssign &Loc1 = ArgLocs1[i];
+    const CCValAssign &Loc2 = ArgLocs2[i];
+
+    // We need both of them to be the same. So if one is a register and one
+    // isn't, we're done.
+    if (Loc1.isRegLoc() != Loc2.isRegLoc())
+      return false;
+
+    if (Loc1.isRegLoc()) {
+      // If they don't have the same register location, we're done.
+      if (Loc1.getLocReg() != Loc2.getLocReg())
+        return false;
+
+      // They matched, so we can move to the next ArgLoc.
+      continue;
+    }
+
+    // Loc1 wasn't a RegLoc, so they both must be MemLocs. Check if they match.
+    if (Loc1.getLocMemOffset() != Loc2.getLocMemOffset())
+      return false;
+  }
+
+  return true;
+}
+
 Register CallLowering::ValueHandler::extendRegister(Register ValReg,
                                                     CCValAssign &VA) {
   LLT LocTy{VA.getLocVT()};
index 03f20a2625951ec6681d08e59a88fa62ed20e7d8..64e1c84d98d61d4d030a9e27167393858772771d 100644 (file)
@@ -431,13 +431,44 @@ static bool mayTailCallThisCC(CallingConv::ID CC) {
   }
 }
 
+bool AArch64CallLowering::doCallerAndCalleePassArgsTheSameWay(
+    CallLoweringInfo &Info, MachineFunction &MF,
+    SmallVectorImpl<ArgInfo> &InArgs) const {
+  const Function &CallerF = MF.getFunction();
+  CallingConv::ID CalleeCC = Info.CallConv;
+  CallingConv::ID CallerCC = CallerF.getCallingConv();
+
+  // If the calling conventions match, then everything must be the same.
+  if (CalleeCC == CallerCC)
+    return true;
+
+  // Check if the caller and callee will handle arguments in the same way.
+  const AArch64TargetLowering &TLI = *getTLI<AArch64TargetLowering>();
+  CCAssignFn *CalleeAssignFn = TLI.CCAssignFnForCall(CalleeCC, Info.IsVarArg);
+  CCAssignFn *CallerAssignFn =
+      TLI.CCAssignFnForCall(CallerCC, CallerF.isVarArg());
+
+  if (!resultsCompatible(Info, MF, InArgs, *CalleeAssignFn, *CallerAssignFn))
+    return false;
+
+  // Make sure that the caller and callee preserve all of the same registers.
+  auto TRI = MF.getSubtarget<AArch64Subtarget>().getRegisterInfo();
+  const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
+  const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC);
+  if (MF.getSubtarget<AArch64Subtarget>().hasCustomCallingConv()) {
+    TRI->UpdateCustomCallPreservedMask(MF, &CallerPreserved);
+    TRI->UpdateCustomCallPreservedMask(MF, &CalleePreserved);
+  }
+
+  return TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved);
+}
+
 bool AArch64CallLowering::isEligibleForTailCallOptimization(
-    MachineIRBuilder &MIRBuilder, CallLoweringInfo &Info) const {
+    MachineIRBuilder &MIRBuilder, CallLoweringInfo &Info,
+    SmallVectorImpl<ArgInfo> &InArgs) const {
   CallingConv::ID CalleeCC = Info.CallConv;
   MachineFunction &MF = MIRBuilder.getMF();
   const Function &CallerF = MF.getFunction();
-  CallingConv::ID CallerCC = CallerF.getCallingConv();
-  bool CCMatch = CallerCC == CalleeCC;
 
   LLVM_DEBUG(dbgs() << "Attempting to lower call as tail call\n");
 
@@ -509,11 +540,11 @@ bool AArch64CallLowering::isEligibleForTailCallOptimization(
   assert((!Info.IsVarArg || CalleeCC == CallingConv::C) &&
          "Unexpected variadic calling convention");
 
-  // For now, only support the case where the calling conventions match.
-  if (!CCMatch) {
+  // Look at the incoming values.
+  if (!doCallerAndCalleePassArgsTheSameWay(Info, MF, InArgs)) {
     LLVM_DEBUG(
         dbgs()
-        << "... Cannot tail call with mismatched calling conventions yet.\n");
+        << "... Caller and callee have incompatible calling conventions.\n");
     return false;
   }
 
@@ -552,6 +583,7 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   const Function &F = MF.getFunction();
   MachineRegisterInfo &MRI = MF.getRegInfo();
   auto &DL = F.getParent()->getDataLayout();
+  const AArch64TargetLowering &TLI = *getTLI<AArch64TargetLowering>();
 
   if (Info.IsMustTailCall) {
     // TODO: Until we lower all tail calls, we should fall back on this.
@@ -573,13 +605,16 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
       SplitArgs.back().Flags[0].setZExt();
   }
 
-  bool IsSibCall =
-      Info.IsTailCall && isEligibleForTailCallOptimization(MIRBuilder, Info);
+  SmallVector<ArgInfo, 8> InArgs;
+  if (!Info.OrigRet.Ty->isVoidTy())
+    splitToValueTypes(Info.OrigRet, InArgs, DL, MRI, F.getCallingConv());
+
+  bool IsSibCall = Info.IsTailCall &&
+                   isEligibleForTailCallOptimization(MIRBuilder, Info, InArgs);
   if (IsSibCall)
     MF.getFrameInfo().setHasTailCall();
 
   // Find out which ABI gets to decide where things go.
-  const AArch64TargetLowering &TLI = *getTLI<AArch64TargetLowering>();
   CCAssignFn *AssignFnFixed =
       TLI.CCAssignFnForCall(Info.CallConv, /*IsVarArg=*/false);
   CCAssignFn *AssignFnVarArg =
@@ -649,14 +684,10 @@ bool AArch64CallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
   // Finally we can copy the returned value back into its virtual-register. In
   // symmetry with the arugments, the physical register must be an
   // implicit-define of the call instruction.
-  CCAssignFn *RetAssignFn = TLI.CCAssignFnForReturn(F.getCallingConv());
   if (!Info.OrigRet.Ty->isVoidTy()) {
-    SplitArgs.clear();
-
-    splitToValueTypes(Info.OrigRet, SplitArgs, DL, MRI, F.getCallingConv());
-
+    CCAssignFn *RetAssignFn = TLI.CCAssignFnForReturn(F.getCallingConv());
     CallReturnHandler Handler(MIRBuilder, MRI, MIB, RetAssignFn);
-    if (!handleAssignments(MIRBuilder, SplitArgs, Handler))
+    if (!handleAssignments(MIRBuilder, InArgs, Handler))
       return false;
   }
 
index 0bf250b85a31c75eec0360ec9875d7d2dae5898a..696d4d8385a4ab7fb41ac06bfa74c804ad45222a 100644 (file)
@@ -44,8 +44,10 @@ public:
                  CallLoweringInfo &Info) const override;
 
   /// Returns true if the call can be lowered as a tail call.
-  bool isEligibleForTailCallOptimization(MachineIRBuilder &MIRBuilder,
-                                         CallLoweringInfo &Info) const;
+  bool
+  isEligibleForTailCallOptimization(MachineIRBuilder &MIRBuilder,
+                                    CallLoweringInfo &Info,
+                                    SmallVectorImpl<ArgInfo> &InArgs) const;
 
   bool supportSwiftError() const override { return true; }
 
@@ -60,6 +62,11 @@ private:
                          SmallVectorImpl<ArgInfo> &SplitArgs,
                          const DataLayout &DL, MachineRegisterInfo &MRI,
                          CallingConv::ID CallConv) const;
+
+  bool
+  doCallerAndCalleePassArgsTheSameWay(CallLoweringInfo &Info,
+                                      MachineFunction &MF,
+                                      SmallVectorImpl<ArgInfo> &InArgs) const;
 };
 
 } // end namespace llvm
index fb937c36b38f8c86f596cc0bf6fe7f9449602037..d253c54c1bb8c0533ed2189adc1923f9f1e8661b 100644 (file)
@@ -175,16 +175,11 @@ define void @test_extern_weak() {
   ret void
 }
 
-; Right now, mismatched calling conventions should not be tail called.
-; TODO: Support this.
 declare fastcc void @fast_fn()
 define void @test_mismatched_caller() {
   ; COMMON-LABEL: name: test_mismatched_caller
   ; COMMON: bb.1 (%ir-block.0):
-  ; COMMON:   ADJCALLSTACKDOWN 0, 0, implicit-def $sp, implicit $sp
-  ; COMMON:   BL @fast_fn, csr_aarch64_aapcs, implicit-def $lr, implicit $sp
-  ; COMMON:   ADJCALLSTACKUP 0, 0, implicit-def $sp, implicit $sp
-  ; COMMON:   RET_ReallyLR
+  ; COMMON:   TCRETURNdi @fast_fn, 0, csr_aarch64_aapcs, implicit $sp
   tail call fastcc void @fast_fn()
   ret void
 }
index ab96e609dd468c36522fe0a87d765c440d5000a5..64a5fad59e978f0b22570f820d2866d5dc503b4e 100644 (file)
@@ -1,4 +1,5 @@
 ; RUN: llc -o - %s | FileCheck %s
+; RUN: llc -global-isel -verify-machineinstrs -o - %s | FileCheck %s
 target triple="aarch64--"
 
 declare void @somefunc()