]> granicus.if.org Git - llvm/commitdiff
[X86][SSE] Move VSRAI sign extend in reg fold into SimplifyDemandedBits
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 18 Dec 2018 09:11:34 +0000 (09:11 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Tue, 18 Dec 2018 09:11:34 +0000 (09:11 +0000)
(VSRAI (VSHLI X, C1), C1) --> X iff NumSignBits(X) > C1

This works better as part of SimplifyDemandedBits than part of the general combine.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@349462 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelLowering.cpp

index 0988fa9dfe3bcca678415d8ded51578e19eaf151..a6bb174f69033b4a830d0deb37e27baed46de670 100644 (file)
@@ -32447,12 +32447,21 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
       if (ShiftImm->getAPIntValue().uge(BitWidth))
         break;
 
+      unsigned ShAmt = ShiftImm->getZExtValue();
+      APInt DemandedMask = OriginalDemandedBits << ShAmt;
+
       // If we just want the sign bit then we don't need to shift it.
       if (OriginalDemandedBits.isSignMask())
         return TLO.CombineTo(Op, Op0);
 
-      unsigned ShAmt = ShiftImm->getZExtValue();
-      APInt DemandedMask = OriginalDemandedBits << ShAmt;
+      // fold (VSRAI (VSHLI X, C1), C1) --> X iff NumSignBits(X) > C1
+      if (Op0.getOpcode() == X86ISD::VSHLI && Op1 == Op0.getOperand(1)) {
+        SDValue Op00 = Op0.getOperand(0);
+        unsigned NumSignBits =
+            TLO.DAG.ComputeNumSignBits(Op00, OriginalDemandedElts);
+        if (ShAmt < NumSignBits)
+          return TLO.CombineTo(Op, Op00);
+      }
 
       // If any of the demanded bits are produced by the sign extension, we also
       // demand the input sign bit.
@@ -35566,15 +35575,6 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG,
   if (ISD::isBuildVectorAllZeros(N0.getNode()))
     return DAG.getConstant(0, SDLoc(N), VT);
 
-  // fold (VSRAI (VSHLI X, C1), C1) --> X iff NumSignBits(X) > C1
-  if (Opcode == X86ISD::VSRAI && N0.getOpcode() == X86ISD::VSHLI &&
-      N1 == N0.getOperand(1)) {
-    SDValue N00 = N0.getOperand(0);
-    unsigned NumSignBits = DAG.ComputeNumSignBits(N00);
-    if (ShiftVal < NumSignBits)
-      return N00;
-  }
-
   // Fold (VSRAI (VSRAI X, C1), C2) --> (VSRAI X, (C1 + C2)) with (C1 + C2)
   // clamped to (NumBitsPerElt - 1).
   if (Opcode == X86ISD::VSRAI && N0.getOpcode() == X86ISD::VSRAI) {