]> granicus.if.org Git - llvm/commitdiff
[X86] When selecting sse_load_f32/f64 pattern, make sure there's only one use of...
authorCraig Topper <craig.topper@intel.com>
Mon, 21 Aug 2017 16:04:04 +0000 (16:04 +0000)
committerCraig Topper <craig.topper@intel.com>
Mon, 21 Aug 2017 16:04:04 +0000 (16:04 +0000)
Summary: With masked operations, its possible for the operation node like fadd, fsub, etc. to be used by multiple different vselects. Since the pattern matching will start at the vselect, we need to make sure the operation node itself is only used once before we can fold a load. Otherwise we'll end up folding the same load into multiple instructions.

Reviewers: RKSimon, spatel, zvi, igorb

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D36938

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@311342 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/X86/X86ISelDAGToDAG.cpp
test/CodeGen/X86/avx512-memfold.ll

index 789d91175d30641a0089cfff81d955078bbc7cea..4bac46b82b42070a866830617aa8413e64e2065f 100644 (file)
@@ -1538,6 +1538,20 @@ bool X86DAGToDAGISel::selectAddr(SDNode *Parent, SDValue N, SDValue &Base,
   return true;
 }
 
+// We can only fold a load if all nodes between it and the root node have a
+// single use. If there are additional uses, we could end up duplicating the
+// load.
+static bool hasSingleUsesFromRoot(SDNode *Root, SDNode *N) {
+  SDNode *User = *N->use_begin();
+  while (User != Root) {
+    if (!User->hasOneUse())
+      return false;
+    User = *User->use_begin();
+  }
+
+  return true;
+}
+
 /// Match a scalar SSE load. In particular, we want to match a load whose top
 /// elements are either undef or zeros. The load flavor is derived from the
 /// type of N, which is either v4f32 or v2f64.
@@ -1554,7 +1568,8 @@ bool X86DAGToDAGISel::selectScalarSSELoad(SDNode *Root,
   if (ISD::isNON_EXTLoad(N.getNode())) {
     PatternNodeWithChain = N;
     if (IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) &&
-        IsLegalToFold(PatternNodeWithChain, *N->use_begin(), Root, OptLevel)) {
+        IsLegalToFold(PatternNodeWithChain, *N->use_begin(), Root, OptLevel) &&
+        hasSingleUsesFromRoot(Root, N.getNode())) {
       LoadSDNode *LD = cast<LoadSDNode>(PatternNodeWithChain);
       return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp,
                         Segment);
@@ -1565,7 +1580,8 @@ bool X86DAGToDAGISel::selectScalarSSELoad(SDNode *Root,
   if (N.getOpcode() == X86ISD::VZEXT_LOAD) {
     PatternNodeWithChain = N;
     if (IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) &&
-        IsLegalToFold(PatternNodeWithChain, *N->use_begin(), Root, OptLevel)) {
+        IsLegalToFold(PatternNodeWithChain, *N->use_begin(), Root, OptLevel) &&
+        hasSingleUsesFromRoot(Root, N.getNode())) {
       auto *MI = cast<MemIntrinsicSDNode>(PatternNodeWithChain);
       return selectAddr(MI, MI->getBasePtr(), Base, Scale, Index, Disp,
                         Segment);
@@ -1579,7 +1595,8 @@ bool X86DAGToDAGISel::selectScalarSSELoad(SDNode *Root,
     PatternNodeWithChain = N.getOperand(0);
     if (ISD::isNON_EXTLoad(PatternNodeWithChain.getNode()) &&
         IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) &&
-        IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel)) {
+        IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel) &&
+        hasSingleUsesFromRoot(Root, N.getNode())) {
       LoadSDNode *LD = cast<LoadSDNode>(PatternNodeWithChain);
       return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp,
                         Segment);
@@ -1595,7 +1612,8 @@ bool X86DAGToDAGISel::selectScalarSSELoad(SDNode *Root,
     PatternNodeWithChain = N.getOperand(0).getOperand(0);
     if (ISD::isNON_EXTLoad(PatternNodeWithChain.getNode()) &&
         IsProfitableToFold(PatternNodeWithChain, N.getNode(), Root) &&
-        IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel)) {
+        IsLegalToFold(PatternNodeWithChain, N.getNode(), Root, OptLevel) &&
+        hasSingleUsesFromRoot(Root, N.getNode())) {
       // Okay, this is a zero extending load.  Fold it.
       LoadSDNode *LD = cast<LoadSDNode>(PatternNodeWithChain);
       return selectAddr(LD, LD->getBasePtr(), Base, Scale, Index, Disp,
index e1bb87af568d58818738a9b51770655881f1f412..7490b99fd54dbde45ed19bfc34ded3ba44fed859 100644 (file)
@@ -72,9 +72,10 @@ define <2 x double> @test_int_x86_avx512_mask_vfmadd_sd(<2 x double> %a, <2 x do
 define <4 x float> @test_mask_add_ss_double_use(<4 x float> %a, float* %b, i8 %mask, <4 x float> %c) {
 ; CHECK-LABEL: test_mask_add_ss_double_use:
 ; CHECK:       ## BB#0:
+; CHECK-NEXT:    vmovss {{.*#+}} xmm2 = mem[0],zero,zero,zero
 ; CHECK-NEXT:    kmovw %esi, %k1
-; CHECK-NEXT:    vaddss (%rdi), %xmm0, %xmm1 {%k1}
-; CHECK-NEXT:    vaddss (%rdi), %xmm0, %xmm0 {%k1} {z}
+; CHECK-NEXT:    vaddss %xmm2, %xmm0, %xmm1 {%k1}
+; CHECK-NEXT:    vaddss %xmm2, %xmm0, %xmm0 {%k1} {z}
 ; CHECK-NEXT:    vmulps %xmm0, %xmm1, %xmm0
 ; CHECK-NEXT:    retq
   %b.val = load float, float* %b