From 8678b086003c9edd0ffd6e68c5e99ee3909bfce4 Mon Sep 17 00:00:00 2001 From: Simon Pilgrim Date: Fri, 29 Mar 2019 15:28:25 +0000 Subject: [PATCH] [SLP] Add support for commutative icmp/fcmp predicates For the cases where the icmp/fcmp predicate is commutative, use reorderInputsAccordingToOpcode to collect and commute the operands. This requires a helper to recognise commutativity in both general Instruction and CmpInstr types - the CmpInst::isCommutative doesn't overload the Instruction::isCommutative method for reasons I'm not clear on (maybe because its based on predicate not opcode?!?). Differential Revision: https://reviews.llvm.org/D59992 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@357266 91177308-0d34-0410-b5e6-96231b3b80d8 --- lib/Transforms/Vectorize/SLPVectorizer.cpp | 42 +++++--- .../SLPVectorizer/X86/cmp_commute.ll | 96 ++++--------------- 2 files changed, 44 insertions(+), 94 deletions(-) diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index e56c9f48f7d..9bed86b16bd 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -206,6 +206,13 @@ static bool isSplat(ArrayRef VL) { return true; } +/// \returns True if \p I is commutative, handles CmpInst as well as Instruction. +static bool isCommutative(Instruction *I) { + if (auto *IC = dyn_cast(I)) + return IC->isCommutative(); + return I->isCommutative(); +} + /// Checks if the vector of instructions can be represented as a shuffle, like: /// %x0 = extractelement <4 x i8> %x, i32 0 /// %x3 = extractelement <4 x i8> %x, i32 3 @@ -1854,16 +1861,23 @@ void BoUpSLP::buildTree_rec(ArrayRef VL, unsigned Depth, newTreeEntry(VL, true, UserTreeIdx, ReuseShuffleIndicies); LLVM_DEBUG(dbgs() << "SLP: added a vector of compares.\n"); - // Collect operands - commute if it uses the swapped predicate. ValueList Left, Right; - for (Value *V : VL) { - auto *Cmp = cast(V); - Value *LHS = Cmp->getOperand(0); - Value *RHS = Cmp->getOperand(1); - if (Cmp->getPredicate() != P0) - std::swap(LHS, RHS); - Left.push_back(LHS); - Right.push_back(RHS); + if (cast(VL0)->isCommutative()) { + // Commutative predicate - collect + sort operands of the instructions + // so that each side is more likely to have the same opcode. + assert(P0 == SwapP0 && "Commutative Predicate mismatch"); + reorderInputsAccordingToOpcode(S, VL, Left, Right); + } else { + // Collect operands - commute if it uses the swapped predicate. + for (Value *V : VL) { + auto *Cmp = cast(V); + Value *LHS = Cmp->getOperand(0); + Value *RHS = Cmp->getOperand(1); + if (Cmp->getPredicate() != P0) + std::swap(LHS, RHS); + Left.push_back(LHS); + Right.push_back(RHS); + } } UserTreeIdx.EdgeIdx = 0; @@ -2884,7 +2898,7 @@ void BoUpSLP::reorderInputsAccordingToOpcode(const InstructionsState &S, Instruction *I = cast(VL[i]); // Commute to favor either a splat or maximizing having the same opcodes on // one side. - if (I->isCommutative() && + if (isCommutative(I) && shouldReorderOperands(i, Left, Right, AllSameOpcodeLeft, AllSameOpcodeRight, SplatLeft, SplatRight)) std::swap(Left[i], Right[i]); @@ -2925,11 +2939,11 @@ void BoUpSLP::reorderInputsAccordingToOpcode(const InstructionsState &S, if (isConsecutiveAccess(L, L1, *DL, *SE)) { auto *VL1 = cast(VL[j]); auto *VL2 = cast(VL[j + 1]); - if (VL2->isCommutative()) { + if (isCommutative(VL2)) { std::swap(Left[j + 1], Right[j + 1]); continue; } - if (VL1->isCommutative()) { + if (isCommutative(VL1)) { std::swap(Left[j], Right[j]); continue; } @@ -2941,11 +2955,11 @@ void BoUpSLP::reorderInputsAccordingToOpcode(const InstructionsState &S, if (isConsecutiveAccess(L, L1, *DL, *SE)) { auto *VL1 = cast(VL[j]); auto *VL2 = cast(VL[j + 1]); - if (VL2->isCommutative()) { + if (isCommutative(VL2)) { std::swap(Left[j + 1], Right[j + 1]); continue; } - if (VL1->isCommutative()) { + if (isCommutative(VL1)) { std::swap(Left[j], Right[j]); continue; } diff --git a/test/Transforms/SLPVectorizer/X86/cmp_commute.ll b/test/Transforms/SLPVectorizer/X86/cmp_commute.ll index 46345a18164..03db2af9239 100644 --- a/test/Transforms/SLPVectorizer/X86/cmp_commute.ll +++ b/test/Transforms/SLPVectorizer/X86/cmp_commute.ll @@ -8,26 +8,10 @@ define <4 x i32> @icmp_eq_v4i32(<4 x i32> %a, i32* %b) { ; CHECK-LABEL: @icmp_eq_v4i32( -; CHECK-NEXT: [[A0:%.*]] = extractelement <4 x i32> [[A:%.*]], i32 0 -; CHECK-NEXT: [[A1:%.*]] = extractelement <4 x i32> [[A]], i32 1 -; CHECK-NEXT: [[A2:%.*]] = extractelement <4 x i32> [[A]], i32 2 -; CHECK-NEXT: [[A3:%.*]] = extractelement <4 x i32> [[A]], i32 3 -; CHECK-NEXT: [[P1:%.*]] = getelementptr inbounds i32, i32* [[B:%.*]], i64 1 -; CHECK-NEXT: [[P2:%.*]] = getelementptr inbounds i32, i32* [[B]], i64 2 -; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds i32, i32* [[B]], i64 3 -; CHECK-NEXT: [[B0:%.*]] = load i32, i32* [[B]], align 4 -; CHECK-NEXT: [[B1:%.*]] = load i32, i32* [[P1]], align 4 -; CHECK-NEXT: [[B2:%.*]] = load i32, i32* [[P2]], align 4 -; CHECK-NEXT: [[B3:%.*]] = load i32, i32* [[P3]], align 4 -; CHECK-NEXT: [[C0:%.*]] = icmp eq i32 [[A0]], [[B0]] -; CHECK-NEXT: [[C1:%.*]] = icmp eq i32 [[B1]], [[A1]] -; CHECK-NEXT: [[C2:%.*]] = icmp eq i32 [[B2]], [[A2]] -; CHECK-NEXT: [[C3:%.*]] = icmp eq i32 [[A3]], [[B3]] -; CHECK-NEXT: [[D0:%.*]] = insertelement <4 x i1> undef, i1 [[C0]], i32 0 -; CHECK-NEXT: [[D1:%.*]] = insertelement <4 x i1> [[D0]], i1 [[C1]], i32 1 -; CHECK-NEXT: [[D2:%.*]] = insertelement <4 x i1> [[D1]], i1 [[C2]], i32 2 -; CHECK-NEXT: [[D3:%.*]] = insertelement <4 x i1> [[D2]], i1 [[C3]], i32 3 -; CHECK-NEXT: [[R:%.*]] = sext <4 x i1> [[D3]] to <4 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i32* [[B:%.*]] to <4 x i32>* +; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i32>, <4 x i32>* [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = icmp eq <4 x i32> [[TMP2]], [[A:%.*]] +; CHECK-NEXT: [[R:%.*]] = sext <4 x i1> [[TMP3]] to <4 x i32> ; CHECK-NEXT: ret <4 x i32> [[R]] ; %a0 = extractelement <4 x i32> %a, i32 0 @@ -56,26 +40,10 @@ define <4 x i32> @icmp_eq_v4i32(<4 x i32> %a, i32* %b) { define <4 x i32> @icmp_ne_v4i32(<4 x i32> %a, i32* %b) { ; CHECK-LABEL: @icmp_ne_v4i32( -; CHECK-NEXT: [[A0:%.*]] = extractelement <4 x i32> [[A:%.*]], i32 0 -; CHECK-NEXT: [[A1:%.*]] = extractelement <4 x i32> [[A]], i32 1 -; CHECK-NEXT: [[A2:%.*]] = extractelement <4 x i32> [[A]], i32 2 -; CHECK-NEXT: [[A3:%.*]] = extractelement <4 x i32> [[A]], i32 3 -; CHECK-NEXT: [[P1:%.*]] = getelementptr inbounds i32, i32* [[B:%.*]], i64 1 -; CHECK-NEXT: [[P2:%.*]] = getelementptr inbounds i32, i32* [[B]], i64 2 -; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds i32, i32* [[B]], i64 3 -; CHECK-NEXT: [[B0:%.*]] = load i32, i32* [[B]], align 4 -; CHECK-NEXT: [[B1:%.*]] = load i32, i32* [[P1]], align 4 -; CHECK-NEXT: [[B2:%.*]] = load i32, i32* [[P2]], align 4 -; CHECK-NEXT: [[B3:%.*]] = load i32, i32* [[P3]], align 4 -; CHECK-NEXT: [[C0:%.*]] = icmp ne i32 [[A0]], [[B0]] -; CHECK-NEXT: [[C1:%.*]] = icmp ne i32 [[B1]], [[A1]] -; CHECK-NEXT: [[C2:%.*]] = icmp ne i32 [[B2]], [[A2]] -; CHECK-NEXT: [[C3:%.*]] = icmp ne i32 [[A3]], [[B3]] -; CHECK-NEXT: [[D0:%.*]] = insertelement <4 x i1> undef, i1 [[C0]], i32 0 -; CHECK-NEXT: [[D1:%.*]] = insertelement <4 x i1> [[D0]], i1 [[C1]], i32 1 -; CHECK-NEXT: [[D2:%.*]] = insertelement <4 x i1> [[D1]], i1 [[C2]], i32 2 -; CHECK-NEXT: [[D3:%.*]] = insertelement <4 x i1> [[D2]], i1 [[C3]], i32 3 -; CHECK-NEXT: [[R:%.*]] = sext <4 x i1> [[D3]] to <4 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i32* [[B:%.*]] to <4 x i32>* +; CHECK-NEXT: [[TMP2:%.*]] = load <4 x i32>, <4 x i32>* [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = icmp ne <4 x i32> [[TMP2]], [[A:%.*]] +; CHECK-NEXT: [[R:%.*]] = sext <4 x i1> [[TMP3]] to <4 x i32> ; CHECK-NEXT: ret <4 x i32> [[R]] ; %a0 = extractelement <4 x i32> %a, i32 0 @@ -104,26 +72,10 @@ define <4 x i32> @icmp_ne_v4i32(<4 x i32> %a, i32* %b) { define <4 x i32> @fcmp_oeq_v4i32(<4 x float> %a, float* %b) { ; CHECK-LABEL: @fcmp_oeq_v4i32( -; CHECK-NEXT: [[A0:%.*]] = extractelement <4 x float> [[A:%.*]], i32 0 -; CHECK-NEXT: [[A1:%.*]] = extractelement <4 x float> [[A]], i32 1 -; CHECK-NEXT: [[A2:%.*]] = extractelement <4 x float> [[A]], i32 2 -; CHECK-NEXT: [[A3:%.*]] = extractelement <4 x float> [[A]], i32 3 -; CHECK-NEXT: [[P1:%.*]] = getelementptr inbounds float, float* [[B:%.*]], i64 1 -; CHECK-NEXT: [[P2:%.*]] = getelementptr inbounds float, float* [[B]], i64 2 -; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[B]], i64 3 -; CHECK-NEXT: [[B0:%.*]] = load float, float* [[B]], align 4 -; CHECK-NEXT: [[B1:%.*]] = load float, float* [[P1]], align 4 -; CHECK-NEXT: [[B2:%.*]] = load float, float* [[P2]], align 4 -; CHECK-NEXT: [[B3:%.*]] = load float, float* [[P3]], align 4 -; CHECK-NEXT: [[C0:%.*]] = fcmp oeq float [[A0]], [[B0]] -; CHECK-NEXT: [[C1:%.*]] = fcmp oeq float [[B1]], [[A1]] -; CHECK-NEXT: [[C2:%.*]] = fcmp oeq float [[B2]], [[A2]] -; CHECK-NEXT: [[C3:%.*]] = fcmp oeq float [[A3]], [[B3]] -; CHECK-NEXT: [[D0:%.*]] = insertelement <4 x i1> undef, i1 [[C0]], i32 0 -; CHECK-NEXT: [[D1:%.*]] = insertelement <4 x i1> [[D0]], i1 [[C1]], i32 1 -; CHECK-NEXT: [[D2:%.*]] = insertelement <4 x i1> [[D1]], i1 [[C2]], i32 2 -; CHECK-NEXT: [[D3:%.*]] = insertelement <4 x i1> [[D2]], i1 [[C3]], i32 3 -; CHECK-NEXT: [[R:%.*]] = sext <4 x i1> [[D3]] to <4 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[B:%.*]] to <4 x float>* +; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = fcmp oeq <4 x float> [[TMP2]], [[A:%.*]] +; CHECK-NEXT: [[R:%.*]] = sext <4 x i1> [[TMP3]] to <4 x i32> ; CHECK-NEXT: ret <4 x i32> [[R]] ; %a0 = extractelement <4 x float> %a, i32 0 @@ -152,26 +104,10 @@ define <4 x i32> @fcmp_oeq_v4i32(<4 x float> %a, float* %b) { define <4 x i32> @fcmp_uno_v4i32(<4 x float> %a, float* %b) { ; CHECK-LABEL: @fcmp_uno_v4i32( -; CHECK-NEXT: [[A0:%.*]] = extractelement <4 x float> [[A:%.*]], i32 0 -; CHECK-NEXT: [[A1:%.*]] = extractelement <4 x float> [[A]], i32 1 -; CHECK-NEXT: [[A2:%.*]] = extractelement <4 x float> [[A]], i32 2 -; CHECK-NEXT: [[A3:%.*]] = extractelement <4 x float> [[A]], i32 3 -; CHECK-NEXT: [[P1:%.*]] = getelementptr inbounds float, float* [[B:%.*]], i64 1 -; CHECK-NEXT: [[P2:%.*]] = getelementptr inbounds float, float* [[B]], i64 2 -; CHECK-NEXT: [[P3:%.*]] = getelementptr inbounds float, float* [[B]], i64 3 -; CHECK-NEXT: [[B0:%.*]] = load float, float* [[B]], align 4 -; CHECK-NEXT: [[B1:%.*]] = load float, float* [[P1]], align 4 -; CHECK-NEXT: [[B2:%.*]] = load float, float* [[P2]], align 4 -; CHECK-NEXT: [[B3:%.*]] = load float, float* [[P3]], align 4 -; CHECK-NEXT: [[C0:%.*]] = fcmp uno float [[A0]], [[B0]] -; CHECK-NEXT: [[C1:%.*]] = fcmp uno float [[B1]], [[A1]] -; CHECK-NEXT: [[C2:%.*]] = fcmp uno float [[B2]], [[A2]] -; CHECK-NEXT: [[C3:%.*]] = fcmp uno float [[A3]], [[B3]] -; CHECK-NEXT: [[D0:%.*]] = insertelement <4 x i1> undef, i1 [[C0]], i32 0 -; CHECK-NEXT: [[D1:%.*]] = insertelement <4 x i1> [[D0]], i1 [[C1]], i32 1 -; CHECK-NEXT: [[D2:%.*]] = insertelement <4 x i1> [[D1]], i1 [[C2]], i32 2 -; CHECK-NEXT: [[D3:%.*]] = insertelement <4 x i1> [[D2]], i1 [[C3]], i32 3 -; CHECK-NEXT: [[R:%.*]] = sext <4 x i1> [[D3]] to <4 x i32> +; CHECK-NEXT: [[TMP1:%.*]] = bitcast float* [[B:%.*]] to <4 x float>* +; CHECK-NEXT: [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4 +; CHECK-NEXT: [[TMP3:%.*]] = fcmp uno <4 x float> [[TMP2]], [[A:%.*]] +; CHECK-NEXT: [[R:%.*]] = sext <4 x i1> [[TMP3]] to <4 x i32> ; CHECK-NEXT: ret <4 x i32> [[R]] ; %a0 = extractelement <4 x float> %a, i32 0 -- 2.50.1