]> granicus.if.org Git - llvm/commitdiff
[APInt] Add APInt::setBits() method to set all bits in range
authorSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 24 Feb 2017 10:15:29 +0000 (10:15 +0000)
committerSimon Pilgrim <llvm-dev@redking.me.uk>
Fri, 24 Feb 2017 10:15:29 +0000 (10:15 +0000)
The current pattern for setting bits in range is typically:

Mask |= APInt::getBitsSet(MaskSizeInBits, LoPos, HiPos);

Which can be particularly slow for large APInts (MaskSizeInBits > 64) as they require the allocation memory for the temporary variable.

This is one of the key compile time issues identified in PR32037.

This patch adds the APInt::setBits() helper method which avoids the temporary memory allocation completely, this first implementation uses setBit() internally instead but already significantly reduces the regression in PR32037 (~10% drop). Additional optimization may be possible.

I investigated whether there is need for APInt::clearBits() and APInt::flipBits() equivalents but haven't seen these patterns to be particularly common, but reusing the code would be trivial.

Differential Revision: https://reviews.llvm.org/D30265

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@296102 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/ADT/APInt.h
lib/Support/APInt.cpp
lib/Target/X86/X86ISelLowering.cpp
lib/Target/X86/X86ShuffleDecodeConstantPool.cpp
unittests/ADT/APIntTest.cpp

index 77f18f3e5bf82018b958d96a5f0738f047e95a59..32fed77098cee807e41762bcd5ffb2ccf6681325 100644 (file)
@@ -1239,6 +1239,9 @@ public:
   /// Set the given bit to 1 whose position is given as "bitPosition".
   void setBit(unsigned bitPosition);
 
+  /// Set the bits from loBit (inclusive) to hiBit (exclusive) to 1.
+  void setBits(unsigned loBit, unsigned hiBit);
+
   /// \brief Set every bit to 0.
   void clearAllBits() {
     if (isSingleWord())
index 35c15e04bf0186acbd29441ecec90939e5ccbc24..8ddbbe3a70d2a8f21023fa5d7f95c50b6e7a486e 100644 (file)
@@ -565,6 +565,39 @@ void APInt::setBit(unsigned bitPosition) {
     pVal[whichWord(bitPosition)] |= maskBit(bitPosition);
 }
 
+void APInt::setBits(unsigned loBit, unsigned hiBit) {
+  assert(hiBit <= BitWidth && "hiBit out of range");
+  assert(loBit <= hiBit && loBit <= BitWidth && "loBit out of range");
+
+  if (loBit == hiBit)
+    return;
+
+  if (isSingleWord())
+    *this |= APInt::getBitsSet(BitWidth, loBit, hiBit);
+  else {
+    unsigned hiBit1 = hiBit - 1;
+    unsigned loWord = whichWord(loBit);
+    unsigned hiWord = whichWord(hiBit1);
+    if (loWord == hiWord) {
+      // Set bits are all within the same word, create a [loBit,hiBit) mask.
+      uint64_t mask = UINT64_MAX;
+      mask >>= (APINT_BITS_PER_WORD - (hiBit - loBit));
+      mask <<= whichBit(loBit);
+      pVal[loWord] |= mask;
+    } else {
+      // Set bits span multiple words, create a lo mask with set bits starting
+      // at loBit, a hi mask with set bits below hiBit and set all bits of the
+      // words in between.
+      uint64_t loMask = UINT64_MAX << whichBit(loBit);
+      uint64_t hiMask = UINT64_MAX >> (64 - whichBit(hiBit1) - 1);
+      pVal[loWord] |= loMask;
+      pVal[hiWord] |= hiMask;
+      for (unsigned word = loWord + 1; word < hiWord; ++word)
+        pVal[word] = UINT64_MAX;
+    }
+  }
+}
+
 /// Set the given bit to 0 whose position is given as "bitPosition".
 /// @brief Set a given bit to 0.
 void APInt::clearBit(unsigned bitPosition) {
index 720f93a4fbe8cd1684e515cd5e66a0643b1c3c4d..6b404a498e36c233a105d0308182a32e8297f9fb 100644 (file)
@@ -5236,8 +5236,7 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
       return false;
     unsigned CstSizeInBits = Cst->getType()->getPrimitiveSizeInBits();
     if (isa<UndefValue>(Cst)) {
-      unsigned HiBits = BitOffset + CstSizeInBits;
-      Undefs |= APInt::getBitsSet(SizeInBits, BitOffset, HiBits);
+      Undefs.setBits(BitOffset, BitOffset + CstSizeInBits);
       return true;
     }
     if (auto *CInt = dyn_cast<ConstantInt>(Cst)) {
@@ -5258,8 +5257,7 @@ static bool getTargetConstantBitsFromNode(SDValue Op, unsigned EltSizeInBits,
       const SDValue &Src = Op.getOperand(i);
       unsigned BitOffset = i * SrcEltSizeInBits;
       if (Src.isUndef()) {
-        unsigned HiBits = BitOffset + SrcEltSizeInBits;
-        UndefBits |= APInt::getBitsSet(SizeInBits, BitOffset, HiBits);
+        UndefBits.setBits(BitOffset, BitOffset + SrcEltSizeInBits);
         continue;
       }
       auto *Cst = cast<ConstantSDNode>(Src);
@@ -26353,11 +26351,11 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
       break;
     LLVM_FALLTHROUGH;
   case X86ISD::SETCC:
-    KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - 1);
+    KnownZero.setBits(1, BitWidth);
     break;
   case X86ISD::MOVMSK: {
     unsigned NumLoBits = Op.getOperand(0).getValueType().getVectorNumElements();
-    KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - NumLoBits);
+    KnownZero.setBits(NumLoBits, BitWidth);
     break;
   }
   case X86ISD::VZEXT: {
@@ -26374,7 +26372,7 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
     DAG.computeKnownBits(N0, KnownZero, KnownOne, DemandedSrcElts, Depth + 1);
     KnownOne = KnownOne.zext(BitWidth);
     KnownZero = KnownZero.zext(BitWidth);
-    KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - InBitWidth);
+    KnownZero.setBits(InBitWidth, BitWidth);
     break;
   }
   }
index 4d278d4b05aefa5ea379d0e3d411994c0af8858d..41ad097146110c60504c040ea0509d04832409d2 100644 (file)
@@ -60,8 +60,7 @@ static bool extractConstantMask(const Constant *C, unsigned MaskEltSizeInBits,
     unsigned BitOffset = i * CstEltSizeInBits;
 
     if (isa<UndefValue>(COp)) {
-      unsigned HiBits = BitOffset + CstEltSizeInBits;
-      UndefBits |= APInt::getBitsSet(CstSizeInBits, BitOffset, HiBits);
+      UndefBits.setBits(BitOffset, BitOffset + CstEltSizeInBits);
       continue;
     }
 
index c939c3942d7e80996a2e2c3e623423743dab9b88..3d0f6f57dc0733dcaccfcc2f8e647244d72d5cb4 100644 (file)
@@ -63,6 +63,26 @@ TEST(APIntTest, i33_Count) {
   EXPECT_EQ(((uint64_t)-2)&((1ull<<33) -1), i33minus2.getZExtValue());
 }
 
+TEST(APIntTest, i61_Count) {
+  APInt i61(61, 1 << 15);
+  EXPECT_EQ(45u, i61.countLeadingZeros());
+  EXPECT_EQ(0u, i61.countLeadingOnes());
+  EXPECT_EQ(16u, i61.getActiveBits());
+  EXPECT_EQ(15u, i61.countTrailingZeros());
+  EXPECT_EQ(1u, i61.countPopulation());
+  EXPECT_EQ((1 << 15), i61.getSExtValue());
+  EXPECT_EQ((1 << 15), i61.getZExtValue());
+
+  i61.setBits(8, 19);
+  EXPECT_EQ(42u, i61.countLeadingZeros());
+  EXPECT_EQ(0u, i61.countLeadingOnes());
+  EXPECT_EQ(19u, i61.getActiveBits());
+  EXPECT_EQ(8u, i61.countTrailingZeros());
+  EXPECT_EQ(11u, i61.countPopulation());
+  EXPECT_EQ((1 << 19) - (1 << 8), i61.getSExtValue());
+  EXPECT_EQ((1 << 19) - (1 << 8), i61.getZExtValue());
+}
+
 TEST(APIntTest, i65_Count) {
   APInt i65(65, 0, true);
   EXPECT_EQ(65u, i65.countLeadingZeros());
@@ -118,6 +138,80 @@ TEST(APIntTest, i128_PositiveCount) {
   EXPECT_EQ(1u, one.countPopulation());
   EXPECT_EQ(1, one.getSExtValue());
   EXPECT_EQ(1u, one.getZExtValue());
+
+  APInt s128(128, 2, true);
+  EXPECT_EQ(126u, s128.countLeadingZeros());
+  EXPECT_EQ(0u, s128.countLeadingOnes());
+  EXPECT_EQ(2u, s128.getActiveBits());
+  EXPECT_EQ(1u, s128.countTrailingZeros());
+  EXPECT_EQ(0u, s128.countTrailingOnes());
+  EXPECT_EQ(1u, s128.countPopulation());
+  EXPECT_EQ(2, s128.getSExtValue());
+  EXPECT_EQ(2u, s128.getZExtValue());
+
+  // NOP Test
+  s128.setBits(42, 42);
+  EXPECT_EQ(126u, s128.countLeadingZeros());
+  EXPECT_EQ(0u, s128.countLeadingOnes());
+  EXPECT_EQ(2u, s128.getActiveBits());
+  EXPECT_EQ(1u, s128.countTrailingZeros());
+  EXPECT_EQ(0u, s128.countTrailingOnes());
+  EXPECT_EQ(1u, s128.countPopulation());
+  EXPECT_EQ(2, s128.getSExtValue());
+  EXPECT_EQ(2u, s128.getZExtValue());
+
+  s128.setBits(3, 32);
+  EXPECT_EQ(96u, s128.countLeadingZeros());
+  EXPECT_EQ(0u, s128.countLeadingOnes());
+  EXPECT_EQ(32u, s128.getActiveBits());
+  EXPECT_EQ(33u, s128.getMinSignedBits());
+  EXPECT_EQ(1u, s128.countTrailingZeros());
+  EXPECT_EQ(0u, s128.countTrailingOnes());
+  EXPECT_EQ(30u, s128.countPopulation());
+  EXPECT_EQ(static_cast<uint32_t>((~0u << 3) | 2), s128.getZExtValue());
+
+  s128.setBits(62, 128);
+  EXPECT_EQ(0u, s128.countLeadingZeros());
+  EXPECT_EQ(66u, s128.countLeadingOnes());
+  EXPECT_EQ(128u, s128.getActiveBits());
+  EXPECT_EQ(63u, s128.getMinSignedBits());
+  EXPECT_EQ(1u, s128.countTrailingZeros());
+  EXPECT_EQ(0u, s128.countTrailingOnes());
+  EXPECT_EQ(96u, s128.countPopulation());
+  EXPECT_EQ(static_cast<int64_t>((3ull << 62) |
+                                 static_cast<uint32_t>((~0u << 3) | 2)),
+            s128.getSExtValue());
+}
+
+TEST(APIntTest, i256) {
+  APInt s256(256, 15, true);
+  EXPECT_EQ(252u, s256.countLeadingZeros());
+  EXPECT_EQ(0u, s256.countLeadingOnes());
+  EXPECT_EQ(4u, s256.getActiveBits());
+  EXPECT_EQ(0u, s256.countTrailingZeros());
+  EXPECT_EQ(4u, s256.countTrailingOnes());
+  EXPECT_EQ(4u, s256.countPopulation());
+  EXPECT_EQ(15, s256.getSExtValue());
+  EXPECT_EQ(15u, s256.getZExtValue());
+
+  s256.setBits(62, 66);
+  EXPECT_EQ(190u, s256.countLeadingZeros());
+  EXPECT_EQ(0u, s256.countLeadingOnes());
+  EXPECT_EQ(66u, s256.getActiveBits());
+  EXPECT_EQ(67u, s256.getMinSignedBits());
+  EXPECT_EQ(0u, s256.countTrailingZeros());
+  EXPECT_EQ(4u, s256.countTrailingOnes());
+  EXPECT_EQ(8u, s256.countPopulation());
+
+  s256.setBits(60, 256);
+  EXPECT_EQ(0u, s256.countLeadingZeros());
+  EXPECT_EQ(196u, s256.countLeadingOnes());
+  EXPECT_EQ(256u, s256.getActiveBits());
+  EXPECT_EQ(61u, s256.getMinSignedBits());
+  EXPECT_EQ(0u, s256.countTrailingZeros());
+  EXPECT_EQ(4u, s256.countTrailingOnes());
+  EXPECT_EQ(200u, s256.countPopulation());
+  EXPECT_EQ(static_cast<int64_t>((~0ull << 60) | 15), s256.getSExtValue());
 }
 
 TEST(APIntTest, i1) {