]> granicus.if.org Git - llvm/commitdiff
Infer lowest bits of an integer Multiply when the low bits of the operands are known
authorSimon Dardis <simon.dardis@mips.com>
Sat, 9 Dec 2017 23:25:57 +0000 (23:25 +0000)
committerSimon Dardis <simon.dardis@mips.com>
Sat, 9 Dec 2017 23:25:57 +0000 (23:25 +0000)
When the lowest bits of the operands to an integer multiply are known, the low bits of the result are deducible.
Code to deduce known-zero bottom bits already existed, but this change improves on that by deducing known-ones.

Patch by: Pedro Ferreira

Reviewers: craig.topper, sanjoy, efriedma

Differential Revision: https://reviews.llvm.org/D34029

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@320269 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Analysis/ValueTracking.cpp
unittests/Analysis/ValueTrackingTest.cpp

index 4f7039c6aa7a09cc6488584d62251fa60c8aedca..e086d27005cc5e1d7edf3e7889513a453a9a6ada 100644 (file)
@@ -336,21 +336,78 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
     }
   }
 
-  // If low bits are zero in either operand, output low known-0 bits.
-  // Also compute a conservative estimate for high known-0 bits.
-  // More trickiness is possible, but this is sufficient for the
-  // interesting case of alignment computation.
-  unsigned TrailZ = Known.countMinTrailingZeros() +
-                    Known2.countMinTrailingZeros();
+  assert(!Known.hasConflict() && !Known2.hasConflict());
+  // Compute a conservative estimate for high known-0 bits.
   unsigned LeadZ =  std::max(Known.countMinLeadingZeros() +
                              Known2.countMinLeadingZeros(),
                              BitWidth) - BitWidth;
-
-  TrailZ = std::min(TrailZ, BitWidth);
   LeadZ = std::min(LeadZ, BitWidth);
+
+  // The result of the bottom bits of an integer multiply can be
+  // inferred by looking at the bottom bits of both operands and
+  // multiplying them together.
+  // We can infer at least the minimum number of known trailing bits
+  // of both operands. Depending on number of trailing zeros, we can
+  // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
+  // a and b are divisible by m and n respectively.
+  // We then calculate how many of those bits are inferrable and set
+  // the output. For example, the i8 mul:
+  //  a = XXXX1100 (12)
+  //  b = XXXX1110 (14)
+  // We know the bottom 3 bits are zero since the first can be divided by
+  // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
+  // Applying the multiplication to the trimmed arguments gets:
+  //    XX11 (3)
+  //    X111 (7)
+  // -------
+  //    XX11
+  //   XX11
+  //  XX11
+  // XX11
+  // -------
+  // XXXXX01
+  // Which allows us to infer the 2 LSBs. Since we're multiplying the result
+  // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
+  // The proof for this can be described as:
+  // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
+  //      (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
+  //                    umin(countTrailingZeros(C2), C6) +
+  //                    umin(C5 - umin(countTrailingZeros(C1), C5),
+  //                         C6 - umin(countTrailingZeros(C2), C6)))) - 1)
+  // %aa = shl i8 %a, C5
+  // %bb = shl i8 %b, C6
+  // %aaa = or i8 %aa, C1
+  // %bbb = or i8 %bb, C2
+  // %mul = mul i8 %aaa, %bbb
+  // %mask = and i8 %mul, C7
+  //   =>
+  // %mask = i8 ((C1*C2)&C7)
+  // Where C5, C6 describe the known bits of %a, %b
+  // C1, C2 describe the known bottom bits of %a, %b.
+  // C7 describes the mask of the known bits of the result.
+  APInt Bottom0 = Known.One;
+  APInt Bottom1 = Known2.One;
+
+  // How many times we'd be able to divide each argument by 2 (shr by 1).
+  // This gives us the number of trailing zeros on the multiplication result.
+  unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes();
+  unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes();
+  unsigned TrailZero0 = Known.countMinTrailingZeros();
+  unsigned TrailZero1 = Known2.countMinTrailingZeros();
+  unsigned TrailZ = TrailZero0 + TrailZero1;
+
+  // Figure out the fewest known-bits operand.
+  unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0,
+                                      TrailBitsKnown1 - TrailZero1);
+  unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
+
+  APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) *
+                      Bottom1.getLoBits(TrailBitsKnown1);
+
   Known.resetAll();
-  Known.Zero.setLowBits(TrailZ);
   Known.Zero.setHighBits(LeadZ);
+  Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
+  Known.One |= BottomKnown.getLoBits(ResultBitsKnown);
 
   // Only make use of no-wrap flags if we failed to compute the sign bit
   // directly.  This matters if the multiplication always overflows, in
index 3c8ecfbe1ee230ec6be847fb8b76c6696ada305f..cfdf264da3104b9f1afc006dca87dea5528bf5f0 100644 (file)
@@ -15,6 +15,7 @@
 #include "llvm/IR/Module.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/KnownBits.h"
 #include "gtest/gtest.h"
 
 using namespace llvm;
@@ -258,3 +259,57 @@ TEST(ValueTracking, ComputeNumSignBits_PR32045) {
       cast<ReturnInst>(F->getEntryBlock().getTerminator())->getOperand(0);
   EXPECT_EQ(ComputeNumSignBits(RVal, M->getDataLayout()), 1u);
 }
+
+TEST(ValueTracking, ComputeKnownBits) {
+  StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { "
+                       "  %ash = mul i32 %a, 8 "
+                       "  %aad = add i32 %ash, 7 "
+                       "  %aan = and i32 %aad, 4095 "
+                       "  %bsh = shl i32 %b, 4 "
+                       "  %bad = or i32 %bsh, 6 "
+                       "  %ban = and i32 %bad, 4095 "
+                       "  %mul = mul i32 %aan, %ban "
+                       "  ret i32 %mul "
+                       "} ";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  assert(M && "Bad assembly?");
+
+  auto *F = M->getFunction("f");
+  assert(F && "Bad assembly?");
+
+  auto *RVal =
+      cast<ReturnInst>(F->getEntryBlock().getTerminator())->getOperand(0);
+  auto Known = computeKnownBits(RVal, M->getDataLayout());
+  ASSERT_FALSE(Known.hasConflict());
+  EXPECT_EQ(Known.One.getZExtValue(), 10u);
+  EXPECT_EQ(Known.Zero.getZExtValue(), 4278190085u);
+}
+
+TEST(ValueTracking, ComputeKnownMulBits) {
+  StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { "
+                       "  %aa = shl i32 %a, 5 "
+                       "  %bb = shl i32 %b, 5 "
+                       "  %aaa = or i32 %aa, 24 "
+                       "  %bbb = or i32 %bb, 28 "
+                       "  %mul = mul i32 %aaa, %bbb "
+                       "  ret i32 %mul "
+                       "} ";
+
+  LLVMContext Context;
+  SMDiagnostic Error;
+  auto M = parseAssemblyString(Assembly, Error, Context);
+  assert(M && "Bad assembly?");
+
+  auto *F = M->getFunction("f");
+  assert(F && "Bad assembly?");
+
+  auto *RVal =
+      cast<ReturnInst>(F->getEntryBlock().getTerminator())->getOperand(0);
+  auto Known = computeKnownBits(RVal, M->getDataLayout());
+  ASSERT_FALSE(Known.hasConflict());
+  EXPECT_EQ(Known.One.getZExtValue(), 32u);
+  EXPECT_EQ(Known.Zero.getZExtValue(), 95u);
+}