}
}
- // If low bits are zero in either operand, output low known-0 bits.
- // Also compute a conservative estimate for high known-0 bits.
- // More trickiness is possible, but this is sufficient for the
- // interesting case of alignment computation.
- unsigned TrailZ = Known.countMinTrailingZeros() +
- Known2.countMinTrailingZeros();
+ assert(!Known.hasConflict() && !Known2.hasConflict());
+ // Compute a conservative estimate for high known-0 bits.
unsigned LeadZ = std::max(Known.countMinLeadingZeros() +
Known2.countMinLeadingZeros(),
BitWidth) - BitWidth;
-
- TrailZ = std::min(TrailZ, BitWidth);
LeadZ = std::min(LeadZ, BitWidth);
+
+ // The result of the bottom bits of an integer multiply can be
+ // inferred by looking at the bottom bits of both operands and
+ // multiplying them together.
+ // We can infer at least the minimum number of known trailing bits
+ // of both operands. Depending on number of trailing zeros, we can
+ // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
+ // a and b are divisible by m and n respectively.
+ // We then calculate how many of those bits are inferrable and set
+ // the output. For example, the i8 mul:
+ // a = XXXX1100 (12)
+ // b = XXXX1110 (14)
+ // We know the bottom 3 bits are zero since the first can be divided by
+ // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
+ // Applying the multiplication to the trimmed arguments gets:
+ // XX11 (3)
+ // X111 (7)
+ // -------
+ // XX11
+ // XX11
+ // XX11
+ // XX11
+ // -------
+ // XXXXX01
+ // Which allows us to infer the 2 LSBs. Since we're multiplying the result
+ // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
+ // The proof for this can be described as:
+ // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
+ // (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
+ // umin(countTrailingZeros(C2), C6) +
+ // umin(C5 - umin(countTrailingZeros(C1), C5),
+ // C6 - umin(countTrailingZeros(C2), C6)))) - 1)
+ // %aa = shl i8 %a, C5
+ // %bb = shl i8 %b, C6
+ // %aaa = or i8 %aa, C1
+ // %bbb = or i8 %bb, C2
+ // %mul = mul i8 %aaa, %bbb
+ // %mask = and i8 %mul, C7
+ // =>
+ // %mask = i8 ((C1*C2)&C7)
+ // Where C5, C6 describe the known bits of %a, %b
+ // C1, C2 describe the known bottom bits of %a, %b.
+ // C7 describes the mask of the known bits of the result.
+ APInt Bottom0 = Known.One;
+ APInt Bottom1 = Known2.One;
+
+ // How many times we'd be able to divide each argument by 2 (shr by 1).
+ // This gives us the number of trailing zeros on the multiplication result.
+ unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes();
+ unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes();
+ unsigned TrailZero0 = Known.countMinTrailingZeros();
+ unsigned TrailZero1 = Known2.countMinTrailingZeros();
+ unsigned TrailZ = TrailZero0 + TrailZero1;
+
+ // Figure out the fewest known-bits operand.
+ unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0,
+ TrailBitsKnown1 - TrailZero1);
+ unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth);
+
+ APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) *
+ Bottom1.getLoBits(TrailBitsKnown1);
+
Known.resetAll();
- Known.Zero.setLowBits(TrailZ);
Known.Zero.setHighBits(LeadZ);
+ Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown);
+ Known.One |= BottomKnown.getLoBits(ResultBitsKnown);
// Only make use of no-wrap flags if we failed to compute the sign bit
// directly. This matters if the multiplication always overflows, in
#include "llvm/IR/Module.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/KnownBits.h"
#include "gtest/gtest.h"
using namespace llvm;
cast<ReturnInst>(F->getEntryBlock().getTerminator())->getOperand(0);
EXPECT_EQ(ComputeNumSignBits(RVal, M->getDataLayout()), 1u);
}
+
+TEST(ValueTracking, ComputeKnownBits) {
+ StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { "
+ " %ash = mul i32 %a, 8 "
+ " %aad = add i32 %ash, 7 "
+ " %aan = and i32 %aad, 4095 "
+ " %bsh = shl i32 %b, 4 "
+ " %bad = or i32 %bsh, 6 "
+ " %ban = and i32 %bad, 4095 "
+ " %mul = mul i32 %aan, %ban "
+ " ret i32 %mul "
+ "} ";
+
+ LLVMContext Context;
+ SMDiagnostic Error;
+ auto M = parseAssemblyString(Assembly, Error, Context);
+ assert(M && "Bad assembly?");
+
+ auto *F = M->getFunction("f");
+ assert(F && "Bad assembly?");
+
+ auto *RVal =
+ cast<ReturnInst>(F->getEntryBlock().getTerminator())->getOperand(0);
+ auto Known = computeKnownBits(RVal, M->getDataLayout());
+ ASSERT_FALSE(Known.hasConflict());
+ EXPECT_EQ(Known.One.getZExtValue(), 10u);
+ EXPECT_EQ(Known.Zero.getZExtValue(), 4278190085u);
+}
+
+TEST(ValueTracking, ComputeKnownMulBits) {
+ StringRef Assembly = "define i32 @f(i32 %a, i32 %b) { "
+ " %aa = shl i32 %a, 5 "
+ " %bb = shl i32 %b, 5 "
+ " %aaa = or i32 %aa, 24 "
+ " %bbb = or i32 %bb, 28 "
+ " %mul = mul i32 %aaa, %bbb "
+ " ret i32 %mul "
+ "} ";
+
+ LLVMContext Context;
+ SMDiagnostic Error;
+ auto M = parseAssemblyString(Assembly, Error, Context);
+ assert(M && "Bad assembly?");
+
+ auto *F = M->getFunction("f");
+ assert(F && "Bad assembly?");
+
+ auto *RVal =
+ cast<ReturnInst>(F->getEntryBlock().getTerminator())->getOperand(0);
+ auto Known = computeKnownBits(RVal, M->getDataLayout());
+ ASSERT_FALSE(Known.hasConflict());
+ EXPECT_EQ(Known.One.getZExtValue(), 32u);
+ EXPECT_EQ(Known.Zero.getZExtValue(), 95u);
+}