// There are many variants to this pattern:
// a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
// b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt
+// c) (x & (-1 >> MaskShAmt)) << ShiftShAmt
// All these patterns can be simplified to just:
// x << ShiftShAmt
// iff:
// a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
+// c) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt)
static Instruction *
dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
const SimplifyQuery &SQ) {
auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
// (~(-1 << maskNbits))
auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes());
+ // (-1 >> MaskShAmt)
+ auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt));
Value *X;
- if (!match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X))))
- return nullptr;
-
- // Can we simplify (MaskShAmt+ShiftShAmt) ?
- Value *SumOfShAmts =
- SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
- SQ.getWithInstruction(OuterShift));
- if (!SumOfShAmts)
- return nullptr; // Did not simplify.
- // Is the total shift amount *not* smaller than the bit width?
- // FIXME: could also rely on ConstantRange.
- unsigned BitWidth = X->getType()->getScalarSizeInBits();
- if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
- APInt(BitWidth, BitWidth))))
- return nullptr;
- // All good, we can do this fold.
+ if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) {
+ // Can we simplify (MaskShAmt+ShiftShAmt) ?
+ Value *SumOfShAmts =
+ SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
+ SQ.getWithInstruction(OuterShift));
+ if (!SumOfShAmts)
+ return nullptr; // Did not simplify.
+ // Is the total shift amount *not* smaller than the bit width?
+ // FIXME: could also rely on ConstantRange.
+ unsigned BitWidth = X->getType()->getScalarSizeInBits();
+ if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
+ APInt(BitWidth, BitWidth))))
+ return nullptr;
+ // All good, we can do this fold.
+ } else if (match(Masked, m_c_And(MaskC, m_Value(X)))) {
+ // Can we simplify (ShiftShAmt-MaskShAmt) ?
+ Value *ShAmtsDiff =
+ SimplifySubInst(ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
+ SQ.getWithInstruction(OuterShift));
+ if (!ShAmtsDiff)
+ return nullptr; // Did not simplify.
+ // Is the difference non-negative? (is ShiftShAmt u>= MaskShAmt ?)
+ // FIXME: could also rely on ConstantRange.
+ if (!match(ShAmtsDiff, m_NonNegative()))
+ return nullptr;
+ // All good, we can do this fold.
+ } else
+ return nullptr; // Don't know anything about this pattern.
// No 'NUW'/'NSW'!
// We no longer know that we won't shift-out non-0 bits.
; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
; CHECK-NEXT: call void @use32(i32 [[T0]])
; CHECK-NEXT: call void @use32(i32 [[T1]])
-; CHECK-NEXT: [[T2:%.*]] = shl i32 [[T1]], [[NBITS]]
+; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
; CHECK-NEXT: ret i32 [[T2]]
;
%t0 = lshr i32 -1, %nbits
; CHECK-NEXT: call void @use32(i32 [[T0]])
; CHECK-NEXT: call void @use32(i32 [[T1]])
; CHECK-NEXT: call void @use32(i32 [[T2]])
-; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T1]], [[T2]]
+; CHECK-NEXT: [[T3:%.*]] = shl i32 [[X]], [[T2]]
; CHECK-NEXT: ret i32 [[T3]]
;
%t0 = lshr i32 -1, %nbits
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
-; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
+; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
; CHECK-NEXT: ret <3 x i32> [[T3]]
;
%t0 = lshr <3 x i32> <i32 -1, i32 -1, i32 -1>, %nbits
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
-; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
+; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
; CHECK-NEXT: ret <3 x i32> [[T3]]
;
%t0 = lshr <3 x i32> <i32 -1, i32 -1, i32 -1>, %nbits
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]])
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
-; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
+; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
; CHECK-NEXT: ret <3 x i32> [[T3]]
;
%t0 = lshr <3 x i32> <i32 -1, i32 undef, i32 -1>, %nbits
; CHECK-NEXT: [[T1:%.*]] = and i32 [[X]], [[T0]]
; CHECK-NEXT: call void @use32(i32 [[T0]])
; CHECK-NEXT: call void @use32(i32 [[T1]])
-; CHECK-NEXT: [[T2:%.*]] = shl i32 [[T1]], [[NBITS]]
+; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
; CHECK-NEXT: ret i32 [[T2]]
;
%x = call i32 @gen32()
; CHECK-NEXT: call void @use32(i32 [[T0]])
; CHECK-NEXT: call void @use32(i32 [[T1]])
; CHECK-NEXT: call void @use32(i32 [[T2]])
-; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T2]], [[NBITS0]]
+; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T1]], [[NBITS0]]
; CHECK-NEXT: ret i32 [[T3]]
;
%t0 = lshr i32 -1, %nbits0
; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
; CHECK-NEXT: call void @use32(i32 [[T0]])
; CHECK-NEXT: call void @use32(i32 [[T1]])
-; CHECK-NEXT: [[T2:%.*]] = shl nuw i32 [[T1]], [[NBITS]]
+; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
; CHECK-NEXT: ret i32 [[T2]]
;
%t0 = lshr i32 -1, %nbits
; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
; CHECK-NEXT: call void @use32(i32 [[T0]])
; CHECK-NEXT: call void @use32(i32 [[T1]])
-; CHECK-NEXT: [[T2:%.*]] = shl nsw i32 [[T1]], [[NBITS]]
+; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
; CHECK-NEXT: ret i32 [[T2]]
;
%t0 = lshr i32 -1, %nbits
; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
; CHECK-NEXT: call void @use32(i32 [[T0]])
; CHECK-NEXT: call void @use32(i32 [[T1]])
-; CHECK-NEXT: [[T2:%.*]] = shl nuw nsw i32 [[T1]], [[NBITS]]
+; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
; CHECK-NEXT: ret i32 [[T2]]
;
%t0 = lshr i32 -1, %nbits