From: Joey Gouly Date: Tue, 7 Feb 2017 11:58:22 +0000 (+0000) Subject: [APInt] Fix rotl/rotr when the shift amount is greater than the total bit width. X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=f83a9ee2dfc4808fdbf4dcd4555e2fb3776cedee;p=llvm [APInt] Fix rotl/rotr when the shift amount is greater than the total bit width. Review: https://reviews.llvm.org/D27749 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@294295 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/lib/Support/APInt.cpp b/lib/Support/APInt.cpp index 7b3be916f31..7e17e42dc91 100644 --- a/lib/Support/APInt.cpp +++ b/lib/Support/APInt.cpp @@ -1252,8 +1252,21 @@ APInt APInt::shlSlowCase(unsigned shiftAmt) const { return Result; } +// Calculate the rotate amount modulo the bit width. +static unsigned rotateModulo(unsigned BitWidth, const APInt &rotateAmt) { + unsigned rotBitWidth = rotateAmt.getBitWidth(); + APInt rot = rotateAmt; + if (rotBitWidth < BitWidth) { + // Extend the rotate APInt, so that the urem doesn't divide by 0. + // e.g. APInt(1, 32) would give APInt(1, 0). + rot = rotateAmt.zext(BitWidth); + } + rot = rot.urem(APInt(rot.getBitWidth(), BitWidth)); + return rot.getLimitedValue(BitWidth); +} + APInt APInt::rotl(const APInt &rotateAmt) const { - return rotl((unsigned)rotateAmt.getLimitedValue(BitWidth)); + return rotl(rotateModulo(BitWidth, rotateAmt)); } APInt APInt::rotl(unsigned rotateAmt) const { @@ -1264,7 +1277,7 @@ APInt APInt::rotl(unsigned rotateAmt) const { } APInt APInt::rotr(const APInt &rotateAmt) const { - return rotr((unsigned)rotateAmt.getLimitedValue(BitWidth)); + return rotr(rotateModulo(BitWidth, rotateAmt)); } APInt APInt::rotr(unsigned rotateAmt) const { diff --git a/unittests/ADT/APIntTest.cpp b/unittests/ADT/APIntTest.cpp index 130bc256b2c..c939c3942d7 100644 --- a/unittests/ADT/APIntTest.cpp +++ b/unittests/ADT/APIntTest.cpp @@ -1060,6 +1060,29 @@ TEST(APIntTest, Rotate) { EXPECT_EQ(APInt(8, 1), APInt(8, 16).rotl(4)); EXPECT_EQ(APInt(8, 16), APInt(8, 16).rotl(8)); + EXPECT_EQ(APInt(32, 2), APInt(32, 1).rotl(33)); + EXPECT_EQ(APInt(32, 2), APInt(32, 1).rotl(APInt(32, 33))); + + EXPECT_EQ(APInt(32, 2), APInt(32, 1).rotl(33)); + EXPECT_EQ(APInt(32, 2), APInt(32, 1).rotl(APInt(32, 33))); + EXPECT_EQ(APInt(32, 2), APInt(32, 1).rotl(APInt(33, 33))); + EXPECT_EQ(APInt(32, (1 << 8)), APInt(32, 1).rotl(APInt(32, 40))); + EXPECT_EQ(APInt(32, (1 << 30)), APInt(32, 1).rotl(APInt(31, 30))); + EXPECT_EQ(APInt(32, (1 << 31)), APInt(32, 1).rotl(APInt(31, 31))); + + EXPECT_EQ(APInt(32, 1), APInt(32, 1).rotl(APInt(1, 0))); + EXPECT_EQ(APInt(32, 2), APInt(32, 1).rotl(APInt(1, 1))); + + EXPECT_EQ(APInt(32, 16), APInt(32, 1).rotl(APInt(3, 4))); + + EXPECT_EQ(APInt(32, 1), APInt(32, 1).rotl(APInt(64, 64))); + EXPECT_EQ(APInt(32, 2), APInt(32, 1).rotl(APInt(64, 65))); + + EXPECT_EQ(APInt(7, 24), APInt(7, 3).rotl(APInt(7, 3))); + EXPECT_EQ(APInt(7, 24), APInt(7, 3).rotl(APInt(7, 10))); + EXPECT_EQ(APInt(7, 24), APInt(7, 3).rotl(APInt(5, 10))); + EXPECT_EQ(APInt(7, 6), APInt(7, 3).rotl(APInt(12, 120))); + EXPECT_EQ(APInt(8, 16), APInt(8, 16).rotr(0)); EXPECT_EQ(APInt(8, 8), APInt(8, 16).rotr(1)); EXPECT_EQ(APInt(8, 4), APInt(8, 16).rotr(2)); @@ -1072,9 +1095,36 @@ TEST(APIntTest, Rotate) { EXPECT_EQ(APInt(8, 16), APInt(8, 1).rotr(4)); EXPECT_EQ(APInt(8, 1), APInt(8, 1).rotr(8)); - APInt Big(256, "00004000800000000000000000003fff8000000000000000", 16); - APInt Rot(256, "3fff80000000000000000000000000000000000040008000", 16); + EXPECT_EQ(APInt(32, (1 << 31)), APInt(32, 1).rotr(33)); + EXPECT_EQ(APInt(32, (1 << 31)), APInt(32, 1).rotr(APInt(32, 33))); + + EXPECT_EQ(APInt(32, (1 << 31)), APInt(32, 1).rotr(33)); + EXPECT_EQ(APInt(32, (1 << 31)), APInt(32, 1).rotr(APInt(32, 33))); + EXPECT_EQ(APInt(32, (1 << 31)), APInt(32, 1).rotr(APInt(33, 33))); + EXPECT_EQ(APInt(32, (1 << 24)), APInt(32, 1).rotr(APInt(32, 40))); + + EXPECT_EQ(APInt(32, (1 << 2)), APInt(32, 1).rotr(APInt(31, 30))); + EXPECT_EQ(APInt(32, (1 << 1)), APInt(32, 1).rotr(APInt(31, 31))); + + EXPECT_EQ(APInt(32, 1), APInt(32, 1).rotr(APInt(1, 0))); + EXPECT_EQ(APInt(32, (1 << 31)), APInt(32, 1).rotr(APInt(1, 1))); + + EXPECT_EQ(APInt(32, (1 << 28)), APInt(32, 1).rotr(APInt(3, 4))); + + EXPECT_EQ(APInt(32, 1), APInt(32, 1).rotr(APInt(64, 64))); + EXPECT_EQ(APInt(32, (1 << 31)), APInt(32, 1).rotr(APInt(64, 65))); + + EXPECT_EQ(APInt(7, 48), APInt(7, 3).rotr(APInt(7, 3))); + EXPECT_EQ(APInt(7, 48), APInt(7, 3).rotr(APInt(7, 10))); + EXPECT_EQ(APInt(7, 48), APInt(7, 3).rotr(APInt(5, 10))); + EXPECT_EQ(APInt(7, 65), APInt(7, 3).rotr(APInt(12, 120))); + + APInt Big(256, "00004000800000000000000000003fff8000000000000003", 16); + APInt Rot(256, "3fff80000000000000030000000000000000000040008000", 16); EXPECT_EQ(Rot, Big.rotr(144)); + + EXPECT_EQ(APInt(32, 8), APInt(32, 1).rotl(Big)); + EXPECT_EQ(APInt(32, (1 << 29)), APInt(32, 1).rotr(Big)); } TEST(APIntTest, Splat) {