]> granicus.if.org Git - llvm/commitdiff
[DAG] Relax type restriction for store merge
authorNirav Dave <niravd@google.com>
Thu, 10 Aug 2017 19:52:45 +0000 (19:52 +0000)
committerNirav Dave <niravd@google.com>
Thu, 10 Aug 2017 19:52:45 +0000 (19:52 +0000)
Summary: Allow stores of bitcastable types to be merged by peeking through BITCAST nodes and recasting stored values constant and vector extract nodes as necessary.

Reviewers: jyknight, hfinkel, efriedma, RKSimon, spatel

Reviewed By: RKSimon

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D34569

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@310655 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/SelectionDAG/DAGCombiner.cpp
test/CodeGen/X86/MergeConsecutiveStores.ll

index 1668642736bb491c9f97c5ef3d09c2afebc16db8..158350dc6d280e3441320a36147835c8cbf6064b 100644 (file)
@@ -466,7 +466,8 @@ namespace {
     /// This is a helper function for MergeConsecutiveStores. When the
     /// source elements of the consecutive stores are all constants or
     /// all extracted vector elements, try to merge them into one
-    /// larger store.  \return True if a merged store was created.
+    /// larger store introducing bitcasts if necessary.  \return True
+    /// if a merged store was created.
     bool MergeStoresOfConstantsOrVecElts(SmallVectorImpl<MemOpLink> &StoreNodes,
                                          EVT MemVT, unsigned NumStores,
                                          bool IsConstantSrc, bool UseVector,
@@ -12474,22 +12475,59 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
       for (unsigned I = 0; I != NumStores; ++I) {
         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[I].MemNode);
         SDValue Val = St->getValue();
-        if (MemVT.getScalarType().isInteger())
-          if (auto *CFP = dyn_cast<ConstantFPSDNode>(Val))
-            Val = DAG.getConstant(
-                (uint32_t)CFP->getValueAPF().bitcastToAPInt().getZExtValue(),
-                SDLoc(CFP), MemVT);
+        // If constant is of the wrong type, convert it now.
+        if (MemVT != Val.getValueType()) {
+          Val = peekThroughBitcast(Val);
+          // Deal with constants of wrong size.
+          if (ElementSizeBytes * 8 != Val.getValueSizeInBits()) {
+            EVT IntMemVT =
+                EVT::getIntegerVT(*DAG.getContext(), MemVT.getSizeInBits());
+            if (auto *CFP = dyn_cast<ConstantFPSDNode>(Val))
+              Val = DAG.getConstant(
+                  CFP->getValueAPF().bitcastToAPInt().zextOrTrunc(
+                      8 * ElementSizeBytes),
+                  SDLoc(CFP), IntMemVT);
+            else if (auto *C = dyn_cast<ConstantSDNode>(Val))
+              Val = DAG.getConstant(
+                  C->getAPIntValue().zextOrTrunc(8 * ElementSizeBytes),
+                  SDLoc(C), IntMemVT);
+          }
+          // Make sure correctly size type is the correct type.
+          Val = DAG.getBitcast(MemVT, Val);
+        }
         BuildVector.push_back(Val);
       }
-      StoredVal = DAG.getBuildVector(StoreTy, DL, BuildVector);
+      StoredVal = DAG.getNode(MemVT.isVector() ? ISD::CONCAT_VECTORS
+                                               : ISD::BUILD_VECTOR,
+                              DL, StoreTy, BuildVector);
     } else {
       SmallVector<SDValue, 8> Ops;
       for (unsigned i = 0; i < NumStores; ++i) {
         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
-        SDValue Val = St->getValue();
-        // All operands of BUILD_VECTOR / CONCAT_VECTOR must have the same type.
-        if (Val.getValueType() != MemVT)
-          return false;
+        SDValue Val = peekThroughBitcast(St->getValue());
+        // All operands of BUILD_VECTOR / CONCAT_VECTOR must be of
+        // type MemVT. If the underlying value is not the correct
+        // type, but it is an extraction of an appropriate vector we
+        // can recast Val to be of the correct type. This may require
+        // converting between EXTRACT_VECTOR_ELT and
+        // EXTRACT_SUBVECTOR.
+        if ((MemVT != Val.getValueType()) &&
+            (Val.getOpcode() == ISD::EXTRACT_VECTOR_ELT ||
+             Val.getOpcode() == ISD::EXTRACT_SUBVECTOR)) {
+          SDValue Vec = Val.getOperand(0);
+          EVT MemVTScalarTy = MemVT.getScalarType();
+          // We may need to add a bitcast here to get types to line up.
+          if (MemVTScalarTy != Vec.getValueType()) {
+            unsigned Elts = Vec.getValueType().getSizeInBits() /
+                            MemVTScalarTy.getSizeInBits();
+            EVT NewVecTy =
+                EVT::getVectorVT(*DAG.getContext(), MemVTScalarTy, Elts);
+            Vec = DAG.getBitcast(NewVecTy, Vec);
+          }
+          auto OpC = (MemVT.isVector()) ? ISD::EXTRACT_SUBVECTOR
+                                        : ISD::EXTRACT_VECTOR_ELT;
+          Val = DAG.getNode(OpC, SDLoc(Val), MemVT, Vec, Val.getOperand(1));
+        }
         Ops.push_back(Val);
       }
 
@@ -12532,7 +12570,7 @@ bool DAGCombiner::MergeStoresOfConstantsOrVecElts(
 
   // make sure we use trunc store if it's necessary to be legal.
   SDValue NewStore;
-  if (UseVector || !UseTrunc) {
+  if (!UseTrunc) {
     NewStore = DAG.getStore(NewChain, DL, StoredVal, FirstInChain->getBasePtr(),
                             FirstInChain->getPointerInfo(),
                             FirstInChain->getAlignment());
@@ -12573,7 +12611,7 @@ void DAGCombiner::getStoreMergeCandidates(
   BaseIndexOffset BasePtr = BaseIndexOffset::match(St->getBasePtr(), DAG);
   EVT MemVT = St->getMemoryVT();
 
-  SDValue Val = St->getValue();
+  SDValue Val = peekThroughBitcast(St->getValue());
   // We must have a base and an offset.
   if (!BasePtr.getBase().getNode())
     return;
@@ -12601,10 +12639,12 @@ void DAGCombiner::getStoreMergeCandidates(
                             int64_t &Offset) -> bool {
     if (Other->isVolatile() || Other->isIndexed())
       return false;
-    SDValue Val = Other->getValue();
+    SDValue Val = peekThroughBitcast(Other->getValue());
+    // Allow merging constants of different types as integers.
+    bool NoTypeMatch = (MemVT.isInteger()) ? !MemVT.bitsEq(Other->getMemoryVT())
+                                           : Other->getMemoryVT() != MemVT;
     if (IsLoadSrc) {
-      // Loads must match type.
-      if (Other->getMemoryVT() != MemVT)
+      if (NoTypeMatch)
         return false;
       // The Load's Base Ptr must also match
       if (LoadSDNode *OtherLd = dyn_cast<LoadSDNode>(Val)) {
@@ -12617,16 +12657,16 @@ void DAGCombiner::getStoreMergeCandidates(
         return false;
     }
     if (IsConstantSrc) {
-      // Allow merging constants of different types as integers.
-      if (MemVT.isInteger() ? !MemVT.bitsEq(Other->getMemoryVT())
-                            : Other->getMemoryVT() != MemVT)
+      if (NoTypeMatch)
         return false;
       if (!(isa<ConstantSDNode>(Val) || isa<ConstantFPSDNode>(Val)))
         return false;
     }
     if (IsExtractVecSrc) {
-      // Must match type.
-      if (Other->getMemoryVT() != MemVT)
+      // Do not merge truncated stores here.
+      if (Other->isTruncatingStore())
+        return false;
+      if (!MemVT.bitsEq(Val.getValueType()))
         return false;
       if (Val.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
           Val.getOpcode() != ISD::EXTRACT_SUBVECTOR)
@@ -12723,7 +12763,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) {
 
   // Perform an early exit check. Do not bother looking at stored values that
   // are not constants, loads, or extracted vector elements.
-  SDValue StoredVal = St->getValue();
+  SDValue StoredVal = peekThroughBitcast(St->getValue());
   bool IsLoadSrc = isa<LoadSDNode>(StoredVal);
   bool IsConstantSrc = isa<ConstantSDNode>(StoredVal) ||
                        isa<ConstantFPSDNode>(StoredVal);
@@ -12911,7 +12951,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) {
       unsigned NumStoresToMerge = 1;
       for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
         StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
-        SDValue StVal = St->getValue();
+        SDValue StVal = peekThroughBitcast(St->getValue());
         // This restriction could be loosened.
         // Bail out if any stored values are not elements extracted from a
         // vector. It should be possible to handle mixed sources, but load
@@ -12977,7 +13017,7 @@ bool DAGCombiner::MergeConsecutiveStores(StoreSDNode *St) {
     BaseIndexOffset LdBasePtr;
     for (unsigned i = 0; i < NumConsecutiveStores; ++i) {
       StoreSDNode *St = cast<StoreSDNode>(StoreNodes[i].MemNode);
-      SDValue Val = St->getValue();
+      SDValue Val = peekThroughBitcast(St->getValue());
       LoadSDNode *Ld = dyn_cast<LoadSDNode>(Val);
       if (!Ld)
         break;
index 40ef993ce04747fa010f634fe5b0e66587cad2e4..7b0520fe53019f186e69d9fc21a8bf81365fcd80 100644 (file)
@@ -522,7 +522,7 @@ define void @merge_vec_extract_stores(<8 x float> %v1, <8 x float> %v2, <4 x flo
 ; CHECK-NEXT: retq
 }
 
-; Merging vector stores when sourced from vector loads is not currently handled.
+; Merging vector stores when sourced from vector loads.
 define void @merge_vec_stores_from_loads(<4 x float>* %v, <4 x float>* %ptr) {
   %load_idx0 = getelementptr inbounds <4 x float>, <4 x float>* %v, i64 0
   %load_idx1 = getelementptr inbounds <4 x float>, <4 x float>* %v, i64 1
@@ -621,9 +621,6 @@ define void @merge_bitcast(<4 x i32> %v, float* %ptr) {
   ret void
 
 ; CHECK-LABEL: merge_bitcast
-; CHECK:      vmovd    %xmm0, (%rdi)
-; CHECK-NEXT: vpextrd  $1, %xmm0, 4(%rdi)
-; CHECK-NEXT: vpextrd  $2, %xmm0, 8(%rdi)
-; CHECK-NEXT: vpextrd  $3, %xmm0, 12(%rdi)
+; CHECK:      vmovups  %xmm0, (%rdi)
 ; CHECK-NEXT: retq
 }