]> granicus.if.org Git - clang/commitdiff
Extend __builtin_shufflevector to expose the full power of the llvm shufflevector...
authorNate Begeman <natebegeman@mac.com>
Tue, 8 Jun 2010 00:16:34 +0000 (00:16 +0000)
committerNate Begeman <natebegeman@mac.com>
Tue, 8 Jun 2010 00:16:34 +0000 (00:16 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@105589 91177308-0d34-0410-b5e6-96231b3b80d8

lib/CodeGen/CGExprScalar.cpp
lib/Sema/SemaChecking.cpp

index 2108414c5ae66969316e39583573d5dc6c180184..08374353169ef5343def39e5372ea0712c589305 100644 (file)
@@ -578,12 +578,104 @@ Value *ScalarExprEmitter::VisitExpr(Expr *E) {
 }
 
 Value *ScalarExprEmitter::VisitShuffleVectorExpr(ShuffleVectorExpr *E) {
-  llvm::SmallVector<llvm::Constant*, 32> indices;
-  for (unsigned i = 2; i < E->getNumSubExprs(); i++) {
-    indices.push_back(cast<llvm::Constant>(CGF.EmitScalarExpr(E->getExpr(i))));
+  // Vector Mask Case
+  if (E->getNumSubExprs() == 2 || 
+      E->getNumSubExprs() == 3 && E->getExpr(2)->getType()->isVectorType()) {
+    Value* LHS = CGF.EmitScalarExpr(E->getExpr(0));
+    Value* RHS = CGF.EmitScalarExpr(E->getExpr(1));
+    Value* Mask;
+    
+    const llvm::Type *I32Ty = llvm::Type::getInt32Ty(CGF.getLLVMContext());
+    const llvm::VectorType *LTy = cast<llvm::VectorType>(LHS->getType());
+    unsigned LHSElts = LTy->getNumElements();
+
+    if (E->getNumSubExprs() == 3) {
+      Mask = CGF.EmitScalarExpr(E->getExpr(2));
+      
+      // Shuffle LHS & RHS into one input vector.
+      llvm::SmallVector<llvm::Constant*, 32> concat;
+      for (unsigned i = 0; i != LHSElts; ++i) {
+        concat.push_back(llvm::ConstantInt::get(I32Ty, 2*i));
+        concat.push_back(llvm::ConstantInt::get(I32Ty, 2*i+1));
+      }
+      
+      Value* CV = llvm::ConstantVector::get(concat.begin(), concat.size());
+      LHS = Builder.CreateShuffleVector(LHS, RHS, CV, "concat");
+      LHSElts *= 2;
+    } else {
+      Mask = RHS;
+    }
+    
+    const llvm::VectorType *MTy = cast<llvm::VectorType>(Mask->getType());
+    llvm::Constant* EltMask;
+    
+    // Treat vec3 like vec4.
+    if ((LHSElts == 6) && (E->getNumSubExprs() == 3))
+      EltMask = llvm::ConstantInt::get(MTy->getElementType(),
+                                       (1 << llvm::Log2_32(LHSElts+2))-1);
+    else if ((LHSElts == 3) && (E->getNumSubExprs() == 2))
+      EltMask = llvm::ConstantInt::get(MTy->getElementType(),
+                                       (1 << llvm::Log2_32(LHSElts+1))-1);
+    else
+      EltMask = llvm::ConstantInt::get(MTy->getElementType(),
+                                       (1 << llvm::Log2_32(LHSElts))-1);
+             
+    // Mask off the high bits of each shuffle index.
+    llvm::SmallVector<llvm::Constant *, 32> MaskV;
+    for (unsigned i = 0, e = MTy->getNumElements(); i != e; ++i)
+      MaskV.push_back(EltMask);
+    
+    Value* MaskBits = llvm::ConstantVector::get(MaskV.begin(), MaskV.size());
+    Mask = Builder.CreateAnd(Mask, MaskBits, "mask");
+    
+    // newv = undef
+    // mask = mask & maskbits
+    // for each elt
+    //   n = extract mask i
+    //   x = extract val n
+    //   newv = insert newv, x, i
+    const llvm::VectorType *RTy = llvm::VectorType::get(LTy->getElementType(),
+                                                        MTy->getNumElements());
+    Value* NewV = llvm::UndefValue::get(RTy);
+    for (unsigned i = 0, e = MTy->getNumElements(); i != e; ++i) {
+      Value *Indx = llvm::ConstantInt::get(I32Ty, i);
+      Indx = Builder.CreateExtractElement(Mask, Indx, "shuf_idx");
+      Indx = Builder.CreateZExt(Indx, I32Ty, "idx_zext");
+      
+      // Handle vec3 special since the index will be off by one for the RHS.
+      if ((LHSElts == 6) && (E->getNumSubExprs() == 3)) {
+        Value *cmpIndx, *newIndx;
+        cmpIndx = Builder.CreateICmpUGT(Indx, llvm::ConstantInt::get(I32Ty, 3),
+                                        "cmp_shuf_idx");
+        newIndx = Builder.CreateSub(Indx, llvm::ConstantInt::get(I32Ty, 1),
+                                    "shuf_idx_adj");
+        Indx = Builder.CreateSelect(cmpIndx, newIndx, Indx, "sel_shuf_idx");
+      }
+      Value *VExt = Builder.CreateExtractElement(LHS, Indx, "shuf_elt");
+      NewV = Builder.CreateInsertElement(NewV, VExt, Indx, "shuf_ins");
+    }
+    return NewV;
   }
+  
   Value* V1 = CGF.EmitScalarExpr(E->getExpr(0));
   Value* V2 = CGF.EmitScalarExpr(E->getExpr(1));
+  
+  // Handle vec3 special since the index will be off by one for the RHS.
+  llvm::SmallVector<llvm::Constant*, 32> indices;
+  for (unsigned i = 2; i < E->getNumSubExprs(); i++) {
+    llvm::Constant *C = cast<llvm::Constant>(CGF.EmitScalarExpr(E->getExpr(i)));
+    const llvm::VectorType *VTy = cast<llvm::VectorType>(V1->getType());
+    if (VTy->getNumElements() == 3) {
+      if (llvm::ConstantInt *CI = dyn_cast<llvm::ConstantInt>(C)) {
+        uint64_t cVal = CI->getZExtValue();
+        if (cVal > 3) {
+          C = llvm::ConstantInt::get(C->getType(), cVal-1);
+        }
+      }
+    }
+    indices.push_back(C);
+  }
+
   Value* SV = llvm::ConstantVector::get(indices.begin(), indices.size());
   return Builder.CreateShuffleVector(V1, V2, SV, "shuffle");
 }
index d47f2ce8edc0fe69c871d20a65efd623d05240e8..7fa5762f06a42e936c5689387b0c697f28f866f3 100644 (file)
@@ -633,45 +633,53 @@ bool Sema::SemaBuiltinFPClassification(CallExpr *TheCall, unsigned NumArgs) {
 /// SemaBuiltinShuffleVector - Handle __builtin_shufflevector.
 // This is declared to take (...), so we have to check everything.
 Action::OwningExprResult Sema::SemaBuiltinShuffleVector(CallExpr *TheCall) {
-  if (TheCall->getNumArgs() < 3)
+  if (TheCall->getNumArgs() < 2)
     return ExprError(Diag(TheCall->getLocEnd(),
                           diag::err_typecheck_call_too_few_args_at_least)
-      << 0 /*function call*/ << 3 << TheCall->getNumArgs()
+      << 0 /*function call*/ << 2 << TheCall->getNumArgs()
       << TheCall->getSourceRange());
 
-  unsigned numElements = std::numeric_limits<unsigned>::max();
+  // Determine which of the following types of shufflevector we're checking:
+  // 1) unary, vector mask: (lhs, mask)
+  // 2) binary, vector mask: (lhs, rhs, mask)
+  // 3) binary, scalar mask: (lhs, rhs, index, ..., index)
+  QualType resType = TheCall->getArg(0)->getType();
+  unsigned numElements = 0;
+  
   if (!TheCall->getArg(0)->isTypeDependent() &&
       !TheCall->getArg(1)->isTypeDependent()) {
-    QualType FAType = TheCall->getArg(0)->getType();
-    QualType SAType = TheCall->getArg(1)->getType();
-
-    if (!FAType->isVectorType() || !SAType->isVectorType()) {
+    QualType LHSType = TheCall->getArg(0)->getType();
+    QualType RHSType = TheCall->getArg(1)->getType();
+    
+    if (!LHSType->isVectorType() || !RHSType->isVectorType()) {
       Diag(TheCall->getLocStart(), diag::err_shufflevector_non_vector)
         << SourceRange(TheCall->getArg(0)->getLocStart(),
                        TheCall->getArg(1)->getLocEnd());
       return ExprError();
     }
-
-    if (!Context.hasSameUnqualifiedType(FAType, SAType)) {
+    
+    numElements = LHSType->getAs<VectorType>()->getNumElements();
+    unsigned numResElements = TheCall->getNumArgs() - 2;
+
+    // Check to see if we have a call with 2 vector arguments, the unary shuffle
+    // with mask.  If so, verify that RHS is an integer vector type with the
+    // same number of elts as lhs.
+    if (TheCall->getNumArgs() == 2) {
+      if (!RHSType->isIntegerType() || 
+          RHSType->getAs<VectorType>()->getNumElements() != numElements)
+        Diag(TheCall->getLocStart(), diag::err_shufflevector_incompatible_vector)
+          << SourceRange(TheCall->getArg(1)->getLocStart(),
+                         TheCall->getArg(1)->getLocEnd());
+      numResElements = numElements;
+    }
+    else if (!Context.hasSameUnqualifiedType(LHSType, RHSType)) {
       Diag(TheCall->getLocStart(), diag::err_shufflevector_incompatible_vector)
         << SourceRange(TheCall->getArg(0)->getLocStart(),
                        TheCall->getArg(1)->getLocEnd());
       return ExprError();
-    }
-
-    numElements = FAType->getAs<VectorType>()->getNumElements();
-    if (TheCall->getNumArgs() != numElements+2) {
-      if (TheCall->getNumArgs() < numElements+2)
-        return ExprError(Diag(TheCall->getLocEnd(),
-                              diag::err_typecheck_call_too_few_args)
-                 << 0 /*function call*/ 
-                 << numElements+2 << TheCall->getNumArgs()
-                 << TheCall->getSourceRange());
-      return ExprError(Diag(TheCall->getLocEnd(),
-                            diag::err_typecheck_call_too_many_args)
-                 << 0 /*function call*/ 
-                 << numElements+2 << TheCall->getNumArgs()
-                 << TheCall->getSourceRange());
+    } else if (numElements != numResElements) {
+      QualType eltType = LHSType->getAs<VectorType>()->getElementType();
+      resType = Context.getVectorType(eltType, numResElements, false, false);
     }
   }
 
@@ -680,9 +688,11 @@ Action::OwningExprResult Sema::SemaBuiltinShuffleVector(CallExpr *TheCall) {
         TheCall->getArg(i)->isValueDependent())
       continue;
 
-    llvm::APSInt Result;
-    if (SemaBuiltinConstantArg(TheCall, i, Result))
-      return ExprError();
+    llvm::APSInt Result(32);
+    if (!TheCall->getArg(i)->isIntegerConstantExpr(Result, Context))
+      return ExprError(Diag(TheCall->getLocStart(),
+                  diag::err_shufflevector_nonconstant_argument)
+                << TheCall->getArg(i)->getSourceRange());
 
     if (Result.getActiveBits() > 64 || Result.getZExtValue() >= numElements*2)
       return ExprError(Diag(TheCall->getLocStart(),
@@ -698,7 +708,7 @@ Action::OwningExprResult Sema::SemaBuiltinShuffleVector(CallExpr *TheCall) {
   }
 
   return Owned(new (Context) ShuffleVectorExpr(Context, exprs.begin(),
-                                            exprs.size(), exprs[0]->getType(),
+                                            exprs.size(), resType,
                                             TheCall->getCallee()->getLocStart(),
                                             TheCall->getRParenLoc()));
 }