]> granicus.if.org Git - llvm/commitdiff
[InstCombine] Dropping redundant masking before left-shift [2/5] (PR42563)
authorRoman Lebedev <lebedev.ri@gmail.com>
Fri, 19 Jul 2019 08:26:25 +0000 (08:26 +0000)
committerRoman Lebedev <lebedev.ri@gmail.com>
Fri, 19 Jul 2019 08:26:25 +0000 (08:26 +0000)
Summary:
If we have some pattern that leaves only some low bits set, and then performs
left-shift of those bits, if none of the bits that are left after the final
shift are modified by the mask, we can omit the mask.

There are many variants to this pattern:
c. `(x & (-1 >> MaskShAmt)) << ShiftShAmt`
All these patterns can be simplified to just:
`x << ShiftShAmt`
iff:
c. `(ShiftShAmt-MaskShAmt) s>= 0` (i.e. `ShiftShAmt u>= MaskShAmt`)

alive proofs:
c: https://rise4fun.com/Alive/RgJh

For now let's start with patterns where both shift amounts are variable,
with trivial constant "offset" between them, since i believe this is
both simplest to handle and i think this is most common.
But again, there are likely other variants where we could use
ValueTracking/ConstantRange to handle more cases.

https://bugs.llvm.org/show_bug.cgi?id=42563

Differential Revision: https://reviews.llvm.org/D64517

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@366537 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/InstCombine/InstCombineShifts.cpp
test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-c.ll

index 8ffdb661e32fa74a88fb0d9996c02c1eaba7371f..b94febf786edcf7875dbae80995127f338dfc0d1 100644 (file)
@@ -72,10 +72,12 @@ reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0,
 // There are many variants to this pattern:
 //   a)  (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
 //   b)  (x & (~(-1 << MaskShAmt))) << ShiftShAmt
+//   c)  (x & (-1 >> MaskShAmt)) << ShiftShAmt
 // All these patterns can be simplified to just:
 //   x << ShiftShAmt
 // iff:
 //   a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
+//   c)   (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt)
 static Instruction *
 dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
                                      const SimplifyQuery &SQ) {
@@ -91,24 +93,38 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
   auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
   // (~(-1 << maskNbits))
   auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes());
+  // (-1 >> MaskShAmt)
+  auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt));
 
   Value *X;
-  if (!match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X))))
-    return nullptr;
-
-  // Can we simplify (MaskShAmt+ShiftShAmt) ?
-  Value *SumOfShAmts =
-      SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
-                      SQ.getWithInstruction(OuterShift));
-  if (!SumOfShAmts)
-    return nullptr; // Did not simplify.
-  // Is the total shift amount *not* smaller than the bit width?
-  // FIXME: could also rely on ConstantRange.
-  unsigned BitWidth = X->getType()->getScalarSizeInBits();
-  if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
-                                             APInt(BitWidth, BitWidth))))
-    return nullptr;
-  // All good, we can do this fold.
+  if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) {
+    // Can we simplify (MaskShAmt+ShiftShAmt) ?
+    Value *SumOfShAmts =
+        SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
+                        SQ.getWithInstruction(OuterShift));
+    if (!SumOfShAmts)
+      return nullptr; // Did not simplify.
+    // Is the total shift amount *not* smaller than the bit width?
+    // FIXME: could also rely on ConstantRange.
+    unsigned BitWidth = X->getType()->getScalarSizeInBits();
+    if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
+                                               APInt(BitWidth, BitWidth))))
+      return nullptr;
+    // All good, we can do this fold.
+  } else if (match(Masked, m_c_And(MaskC, m_Value(X)))) {
+    // Can we simplify (ShiftShAmt-MaskShAmt) ?
+    Value *ShAmtsDiff =
+        SimplifySubInst(ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
+                        SQ.getWithInstruction(OuterShift));
+    if (!ShAmtsDiff)
+      return nullptr; // Did not simplify.
+    // Is the difference non-negative? (is ShiftShAmt u>= MaskShAmt ?)
+    // FIXME: could also rely on ConstantRange.
+    if (!match(ShAmtsDiff, m_NonNegative()))
+      return nullptr;
+    // All good, we can do this fold.
+  } else
+    return nullptr; // Don't know anything about this pattern.
 
   // No 'NUW'/'NSW'!
   // We no longer know that we won't shift-out non-0 bits.
index 45345819735eb7ae4c4a940d93472db008e4f0e4..54bd16f082bc3bc40a147bc9e7b3079430d28d1b 100644 (file)
@@ -21,7 +21,7 @@ define i32 @t0_basic(i32 %x, i32 %nbits) {
 ; CHECK-NEXT:    [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
 ; CHECK-NEXT:    call void @use32(i32 [[T0]])
 ; CHECK-NEXT:    call void @use32(i32 [[T1]])
-; CHECK-NEXT:    [[T2:%.*]] = shl i32 [[T1]], [[NBITS]]
+; CHECK-NEXT:    [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
 ; CHECK-NEXT:    ret i32 [[T2]]
 ;
   %t0 = lshr i32 -1, %nbits
@@ -40,7 +40,7 @@ define i32 @t1_bigger_shift(i32 %x, i32 %nbits) {
 ; CHECK-NEXT:    call void @use32(i32 [[T0]])
 ; CHECK-NEXT:    call void @use32(i32 [[T1]])
 ; CHECK-NEXT:    call void @use32(i32 [[T2]])
-; CHECK-NEXT:    [[T3:%.*]] = shl i32 [[T1]], [[T2]]
+; CHECK-NEXT:    [[T3:%.*]] = shl i32 [[X]], [[T2]]
 ; CHECK-NEXT:    ret i32 [[T3]]
 ;
   %t0 = lshr i32 -1, %nbits
@@ -65,7 +65,7 @@ define <3 x i32> @t2_vec_splat(<3 x i32> %x, <3 x i32> %nbits) {
 ; CHECK-NEXT:    call void @use3xi32(<3 x i32> [[T0]])
 ; CHECK-NEXT:    call void @use3xi32(<3 x i32> [[T1]])
 ; CHECK-NEXT:    call void @use3xi32(<3 x i32> [[T2]])
-; CHECK-NEXT:    [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
+; CHECK-NEXT:    [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
 ; CHECK-NEXT:    ret <3 x i32> [[T3]]
 ;
   %t0 = lshr <3 x i32> <i32 -1, i32 -1, i32 -1>, %nbits
@@ -86,7 +86,7 @@ define <3 x i32> @t3_vec_nonsplat(<3 x i32> %x, <3 x i32> %nbits) {
 ; CHECK-NEXT:    call void @use3xi32(<3 x i32> [[T0]])
 ; CHECK-NEXT:    call void @use3xi32(<3 x i32> [[T1]])
 ; CHECK-NEXT:    call void @use3xi32(<3 x i32> [[T2]])
-; CHECK-NEXT:    [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
+; CHECK-NEXT:    [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
 ; CHECK-NEXT:    ret <3 x i32> [[T3]]
 ;
   %t0 = lshr <3 x i32> <i32 -1, i32 -1, i32 -1>, %nbits
@@ -107,7 +107,7 @@ define <3 x i32> @t4_vec_undef(<3 x i32> %x, <3 x i32> %nbits) {
 ; CHECK-NEXT:    call void @use3xi32(<3 x i32> [[T0]])
 ; CHECK-NEXT:    call void @use3xi32(<3 x i32> [[T1]])
 ; CHECK-NEXT:    call void @use3xi32(<3 x i32> [[T2]])
-; CHECK-NEXT:    [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
+; CHECK-NEXT:    [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
 ; CHECK-NEXT:    ret <3 x i32> [[T3]]
 ;
   %t0 = lshr <3 x i32> <i32 -1, i32 undef, i32 -1>, %nbits
@@ -131,7 +131,7 @@ define i32 @t5_commutativity0(i32 %nbits) {
 ; CHECK-NEXT:    [[T1:%.*]] = and i32 [[X]], [[T0]]
 ; CHECK-NEXT:    call void @use32(i32 [[T0]])
 ; CHECK-NEXT:    call void @use32(i32 [[T1]])
-; CHECK-NEXT:    [[T2:%.*]] = shl i32 [[T1]], [[NBITS]]
+; CHECK-NEXT:    [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
 ; CHECK-NEXT:    ret i32 [[T2]]
 ;
   %x = call i32 @gen32()
@@ -151,7 +151,7 @@ define i32 @t6_commutativity1(i32 %nbits0, i32 %nbits1) {
 ; CHECK-NEXT:    call void @use32(i32 [[T0]])
 ; CHECK-NEXT:    call void @use32(i32 [[T1]])
 ; CHECK-NEXT:    call void @use32(i32 [[T2]])
-; CHECK-NEXT:    [[T3:%.*]] = shl i32 [[T2]], [[NBITS0]]
+; CHECK-NEXT:    [[T3:%.*]] = shl i32 [[T1]], [[NBITS0]]
 ; CHECK-NEXT:    ret i32 [[T3]]
 ;
   %t0 = lshr i32 -1, %nbits0
@@ -192,7 +192,7 @@ define i32 @t8_nuw(i32 %x, i32 %nbits) {
 ; CHECK-NEXT:    [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
 ; CHECK-NEXT:    call void @use32(i32 [[T0]])
 ; CHECK-NEXT:    call void @use32(i32 [[T1]])
-; CHECK-NEXT:    [[T2:%.*]] = shl nuw i32 [[T1]], [[NBITS]]
+; CHECK-NEXT:    [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
 ; CHECK-NEXT:    ret i32 [[T2]]
 ;
   %t0 = lshr i32 -1, %nbits
@@ -209,7 +209,7 @@ define i32 @t9_nsw(i32 %x, i32 %nbits) {
 ; CHECK-NEXT:    [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
 ; CHECK-NEXT:    call void @use32(i32 [[T0]])
 ; CHECK-NEXT:    call void @use32(i32 [[T1]])
-; CHECK-NEXT:    [[T2:%.*]] = shl nsw i32 [[T1]], [[NBITS]]
+; CHECK-NEXT:    [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
 ; CHECK-NEXT:    ret i32 [[T2]]
 ;
   %t0 = lshr i32 -1, %nbits
@@ -226,7 +226,7 @@ define i32 @t10_nuw_nsw(i32 %x, i32 %nbits) {
 ; CHECK-NEXT:    [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
 ; CHECK-NEXT:    call void @use32(i32 [[T0]])
 ; CHECK-NEXT:    call void @use32(i32 [[T1]])
-; CHECK-NEXT:    [[T2:%.*]] = shl nuw nsw i32 [[T1]], [[NBITS]]
+; CHECK-NEXT:    [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
 ; CHECK-NEXT:    ret i32 [[T2]]
 ;
   %t0 = lshr i32 -1, %nbits