]> granicus.if.org Git - clang/commitdiff
Allow implicit casts during arithmetic for OCUVector operations
authorNate Begeman <natebegeman@mac.com>
Sun, 30 Dec 2007 02:59:45 +0000 (02:59 +0000)
committerNate Begeman <natebegeman@mac.com>
Sun, 30 Dec 2007 02:59:45 +0000 (02:59 +0000)
Add codegen support and test for said casts.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@45443 91177308-0d34-0410-b5e6-96231b3b80d8

CodeGen/CGExprScalar.cpp
CodeGen/CodeGenFunction.h
Sema/SemaExpr.cpp
test/CodeGen/ocu-vector.c

index 8d5f8143c993cc85efad3458b5689b8beb43a3da..0fc1d0f8e79ab3feb86231215df04e63058d7fd2 100644 (file)
@@ -360,6 +360,15 @@ Value *ScalarExprEmitter::EmitScalarConversion(Value *Src, QualType SrcType,
     return Builder.CreatePtrToInt(Src, DstTy, "conv");
   }
   
+  // A scalar source can be splatted to a vector of the same element type
+  if (isa<llvm::VectorType>(DstTy) && !isa<VectorType>(SrcType)) {
+    const llvm::VectorType *VT = cast<llvm::VectorType>(DstTy);
+    assert((VT->getElementType() == Src->getType()) &&
+           "Vector element type must match scalar type to splat.");
+    return CGF.EmitVector(&Src, DstType->getAsVectorType()->getNumElements(), 
+                          true);
+  }
+
   if (isa<llvm::VectorType>(Src->getType()) ||
       isa<llvm::VectorType>(DstTy)) {
     return Builder.CreateBitCast(Src, DstTy, "conv");
@@ -1049,14 +1058,15 @@ Value *CodeGenFunction::EmitShuffleVector(Value* V1, Value *V2, ...) {
 }
 
 llvm::Value *CodeGenFunction::EmitVector(llvm::Value * const *Vals, 
-                                         unsigned NumVals)
+                                         unsigned NumVals, bool isSplat)
 {
   llvm::Value *Vec
   = llvm::UndefValue::get(llvm::VectorType::get(Vals[0]->getType(), NumVals));
   
   for (unsigned i = 0, e = NumVals ; i != e; ++i) {
+    llvm::Value *Val = isSplat ? Vals[0] : Vals[i];
     llvm::Value *Idx = llvm::ConstantInt::get(llvm::Type::Int32Ty, i);
-    Vec = Builder.CreateInsertElement(Vec, Vals[i], Idx, "tmp");
+    Vec = Builder.CreateInsertElement(Vec, Val, Idx, "tmp");
   }
   
   return Vec;  
index a263ca0405e1f5ef39613e69560203a8274f08b6..216b7cd71a695efacca6d19c9ec427aed141e6f3 100644 (file)
@@ -395,7 +395,8 @@ public:
   llvm::Value *EmitPPCBuiltinExpr(unsigned BuiltinID, const CallExpr *E);
   
   llvm::Value *EmitShuffleVector(llvm::Value* V1, llvm::Value *V2, ...);
-  llvm::Value *EmitVector(llvm::Value * const *Vals, unsigned NumVals);
+  llvm::Value *EmitVector(llvm::Value * const *Vals, unsigned NumVals,
+                          bool isSplat = false);
   
   llvm::Value *EmitObjCStringLiteral(const ObjCStringLiteral *E);
 
index 014e457b9760cd606d5e6ae91b9ac86eaf8d112c..20118c265d12026e536db35add99987673fb8d52 100644 (file)
@@ -1125,15 +1125,15 @@ Sema::CheckAssignmentConstraints(QualType lhsType, QualType rhsType) {
   }
   else if (lhsType->isArithmeticType() && rhsType->isArithmeticType()) {
     if (lhsType->isVectorType() || rhsType->isVectorType()) {
+      // For OCUVector, allow vector splats; float -> <n x float>
+      if (const OCUVectorType *LV = lhsType->getAsOCUVectorType()) {
+        if (LV->getElementType().getTypePtr() == rhsType.getTypePtr())
+          return Compatible;
+      }
       if (!getLangOptions().LaxVectorConversions) {
         if (lhsType.getCanonicalType() != rhsType.getCanonicalType())
           return Incompatible;
       } else {
-        // For OCUVector, allow vector splats; float -> <n x float>
-        if (const OCUVectorType *LV = lhsType->getAsOCUVectorType()) {
-          if (LV->getElementType().getTypePtr() == rhsType.getTypePtr())
-            return Compatible;
-        }
         if (lhsType->isVectorType() && rhsType->isVectorType()) {
           // If LHS and RHS are both integer or both floating point types, and
           // the total vector length is the same, allow the conversion.  This is
@@ -1218,6 +1218,27 @@ inline QualType Sema::CheckVectorOperands(SourceLocation loc, Expr *&lex,
   // make sure the vector types are identical. 
   if (lhsType == rhsType)
     return lhsType;
+
+  // if the lhs is an ocu vector and the rhs is a scalar of the same type,
+  // promote the rhs to the vector type.
+  if (const OCUVectorType *V = lhsType->getAsOCUVectorType()) {
+    if (V->getElementType().getCanonicalType().getTypePtr()
+        == rhsType.getCanonicalType().getTypePtr()) {
+      promoteExprToType(rex, lhsType);
+      return lhsType;
+    }
+  }
+
+  // if the rhs is an ocu vector and the lhs is a scalar of the same type,
+  // promote the lhs to the vector type.
+  if (const OCUVectorType *V = rhsType->getAsOCUVectorType()) {
+    if (V->getElementType().getCanonicalType().getTypePtr()
+        == lhsType.getCanonicalType().getTypePtr()) {
+      promoteExprToType(lex, rhsType);
+      return rhsType;
+    }
+  }
+
   // You cannot convert between vector values of different size.
   Diag(loc, diag::err_typecheck_vector_not_convertable, 
        lex->getType().getAsString(), rex->getType().getAsString(),
index 9e904f69ff46a22d3542d1b87cdd9d28b4f36725..ee6e737a5853eca52a0b4ed840331c848bfab0d6 100644 (file)
@@ -1,7 +1,6 @@
 // RUN: clang -emit-llvm %s
 
 typedef __attribute__(( ocu_vector_type(4) )) float float4;
-//typedef __attribute__(( ocu_vector_type(3) )) float float3;
 typedef __attribute__(( ocu_vector_type(2) )) float float2;
 
 
@@ -33,3 +32,16 @@ static void test4(float4 *out) {
   float d = 4.0f;
   *out = ((float4) {a,b,c,d});
 }
+
+static void test5(float4 *out) {
+  float a;
+  float4 b;
+  
+  a = 1.0f;
+  b = a;
+  b = b * 5.0f;
+  b = 5.0f * b;
+  b *= a;
+  
+  *out = b;
+}