]> granicus.if.org Git - clang/commitdiff
Improve template argument deduction from a call. In particular,
authorDouglas Gregor <dgregor@apple.com>
Tue, 7 Jul 2009 23:09:34 +0000 (23:09 +0000)
committerDouglas Gregor <dgregor@apple.com>
Tue, 7 Jul 2009 23:09:34 +0000 (23:09 +0000)
implement C++ [temp.deduct.call]p3b3, which allows a template-id
parameter to match a derived class of the argument, while deducing
template arguments.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@74965 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Sema/Sema.h
lib/Sema/SemaTemplateDeduction.cpp
test/CXX/temp/temp.fct.spec/temp.deduct/temp.deduct.call/p3.cpp

index 99d3566726b5169d0c339d160d0e4a46aecea87a..11c0c25f26362723babd421fb64101fe75710738 100644 (file)
@@ -2266,7 +2266,7 @@ public:
     /// \brief When performing template argument deduction for a function 
     /// template, there were too many call arguments.
     TDK_TooManyArguments,
-    /// \brief When performing template argument deduction for a class 
+    /// \brief When performing template argument deduction for a function 
     /// template, there were too few call arguments.
     TDK_TooFewArguments,
     /// \brief The explicitly-specified template arguments were not valid
index 5a0578f6bcb70f7c5ef740019a5fb2851033b651..61bddd52e8a9944d97c32a4af633ca01c4a9cc9e 100644 (file)
@@ -180,6 +180,109 @@ DeduceTemplateArguments(ASTContext &Context,
   return Sema::TDK_Success;
 }
 
+/// \brief Deduce the template arguments by comparing the template parameter 
+/// type (which is a template-id) with the template argument type.
+///
+/// \param Context the AST context in which this deduction occurs.
+///
+/// \param TemplateParams the template parameters that we are deducing
+///
+/// \param Param the parameter type
+///
+/// \param Arg the argument type
+///
+/// \param Info information about the template argument deduction itself
+///
+/// \param Deduced the deduced template arguments
+///
+/// \returns the result of template argument deduction so far. Note that a
+/// "success" result means that template argument deduction has not yet failed,
+/// but it may still fail, later, for other reasons.
+static Sema::TemplateDeductionResult
+DeduceTemplateArguments(ASTContext &Context,
+                        TemplateParameterList *TemplateParams,
+                        const TemplateSpecializationType *Param,
+                        QualType Arg,
+                        Sema::TemplateDeductionInfo &Info,
+                        llvm::SmallVectorImpl<TemplateArgument> &Deduced) {
+  assert(Arg->isCanonical() && "Argument type must be canonical");
+  
+  // Check whether the template argument is a dependent template-id.
+  // FIXME: This is untested code; it can be tested when we implement
+  // partial ordering of class template partial specializations.
+  if (const TemplateSpecializationType *SpecArg 
+        = dyn_cast<TemplateSpecializationType>(Arg)) {
+    // Perform template argument deduction for the template name.
+    if (Sema::TemplateDeductionResult Result
+          = DeduceTemplateArguments(Context,
+                                    Param->getTemplateName(),
+                                    SpecArg->getTemplateName(),
+                                    Info, Deduced))
+      return Result;
+    
+    unsigned NumArgs = Param->getNumArgs();
+    
+    // FIXME: When one of the template-names refers to a
+    // declaration with default template arguments, do we need to
+    // fill in those default template arguments here? Most likely,
+    // the answer is "yes", but I don't see any references. This
+    // issue may be resolved elsewhere, because we may want to
+    // instantiate default template arguments when we actually write
+    // the template-id.
+    if (SpecArg->getNumArgs() != NumArgs)
+      return Sema::TDK_NonDeducedMismatch;
+    
+    // Perform template argument deduction on each template
+    // argument.
+    for (unsigned I = 0; I != NumArgs; ++I)
+      if (Sema::TemplateDeductionResult Result
+            = DeduceTemplateArguments(Context, TemplateParams,
+                                      Param->getArg(I),
+                                      SpecArg->getArg(I),
+                                      Info, Deduced))
+        return Result;
+    
+    return Sema::TDK_Success;
+  }
+  
+  // If the argument type is a class template specialization, we
+  // perform template argument deduction using its template
+  // arguments.
+  const RecordType *RecordArg = dyn_cast<RecordType>(Arg);
+  if (!RecordArg)
+    return Sema::TDK_NonDeducedMismatch;
+  
+  ClassTemplateSpecializationDecl *SpecArg 
+    = dyn_cast<ClassTemplateSpecializationDecl>(RecordArg->getDecl());
+  if (!SpecArg)
+    return Sema::TDK_NonDeducedMismatch;
+  
+  // Perform template argument deduction for the template name.
+  if (Sema::TemplateDeductionResult Result
+        = DeduceTemplateArguments(Context, 
+                                  Param->getTemplateName(),
+                               TemplateName(SpecArg->getSpecializedTemplate()),
+                                  Info, Deduced))
+    return Result;
+    
+  // FIXME: Can the # of arguments in the parameter and the argument
+  // differ due to default arguments?
+  unsigned NumArgs = Param->getNumArgs();
+  const TemplateArgumentList &ArgArgs = SpecArg->getTemplateArgs();
+  if (NumArgs != ArgArgs.size())
+    return Sema::TDK_NonDeducedMismatch;
+  
+  for (unsigned I = 0; I != NumArgs; ++I)
+    if (Sema::TemplateDeductionResult Result 
+          = DeduceTemplateArguments(Context, TemplateParams,
+                                    Param->getArg(I),
+                                    ArgArgs.get(I),
+                                    Info, Deduced))
+      return Result;
+    
+  return Sema::TDK_Success;
+}
+
 /// \brief Deduce the template arguments by comparing the parameter type and
 /// the argument type (C++ [temp.deduct.type]).
 ///
@@ -453,84 +556,76 @@ DeduceTemplateArguments(ASTContext &Context,
     case Type::TemplateSpecialization: {
       const TemplateSpecializationType *SpecParam
         = cast<TemplateSpecializationType>(Param);
-
-      // Check whether the template argument is a dependent template-id.
-      // FIXME: This is untested code; it can be tested when we implement
-      // partial ordering of class template partial specializations.
-      if (const TemplateSpecializationType *SpecArg 
-            = dyn_cast<TemplateSpecializationType>(Arg)) {
-        // Perform template argument deduction for the template name.
-        if (Sema::TemplateDeductionResult Result
-              = DeduceTemplateArguments(Context,
-                                        SpecParam->getTemplateName(),
-                                        SpecArg->getTemplateName(),
-                                        Info, Deduced))
-          return Result;
+      
+      // Try to deduce template arguments from the template-id.
+      Sema::TemplateDeductionResult Result
+        = DeduceTemplateArguments(Context, TemplateParams, SpecParam, Arg,  
+                                  Info, Deduced);
+      
+      if (Result && (TDF & TDF_DerivedClass) && 
+          Result != Sema::TDK_Inconsistent) {
+        // C++ [temp.deduct.call]p3b3:
+        //   If P is a class, and P has the form template-id, then A can be a
+        //   derived class of the deduced A. Likewise, if P is a pointer to a
+        //   class of the form template-id, A can be a pointer to a derived 
+        //   class pointed to by the deduced A.
+        //
+        // More importantly:
+        //   These alternatives are considered only if type deduction would 
+        //   otherwise fail.
+        if (const RecordType *RecordT = dyn_cast<RecordType>(Arg)) {
+          // Use data recursion to crawl through the list of base classes.
+          // Visited contains the set of nodes we have already visited, while 
+          // ToVisit is our stack of records that we still need to visit.
+          llvm::SmallPtrSet<const RecordType *, 8> Visited;
+          llvm::SmallVector<const RecordType *, 8> ToVisit;
+          ToVisit.push_back(RecordT);
+          bool Successful = false;
+          while (!ToVisit.empty()) {
+            // Retrieve the next class in the inheritance hierarchy.
+            const RecordType *NextT = ToVisit.back();
+            ToVisit.pop_back();
             
-        unsigned NumArgs = SpecParam->getNumArgs();
-
-        // FIXME: When one of the template-names refers to a
-        // declaration with default template arguments, do we need to
-        // fill in those default template arguments here? Most likely,
-        // the answer is "yes", but I don't see any references. This
-        // issue may be resolved elsewhere, because we may want to
-        // instantiate default template arguments when
-        if (SpecArg->getNumArgs() != NumArgs)
-          return Sema::TDK_NonDeducedMismatch;
-
-        // Perform template argument deduction on each template
-        // argument.
-        for (unsigned I = 0; I != NumArgs; ++I)
-          if (Sema::TemplateDeductionResult Result
-                = DeduceTemplateArguments(Context, TemplateParams,
-                                          SpecParam->getArg(I),
-                                          SpecArg->getArg(I),
-                                          Info, Deduced))
-            return Result;
-
-        return Sema::TDK_Success;
-      } 
-
-      // If the argument type is a class template specialization, we
-      // perform template argument deduction using its template
-      // arguments.
-      const RecordType *RecordArg = dyn_cast<RecordType>(Arg);
-      if (!RecordArg)
-        return Sema::TDK_NonDeducedMismatch;
-
-      // FIXME: Check TDF_DerivedClass here. When this flag is set, we need
-      // to troll through the base classes of the argument and try matching
-      // all of them. Failure to match does not mean that there is a problem,
-      // of course.
-
-      ClassTemplateSpecializationDecl *SpecArg 
-        = dyn_cast<ClassTemplateSpecializationDecl>(RecordArg->getDecl());
-      if (!SpecArg)
-        return Sema::TDK_NonDeducedMismatch;
-
-      // Perform template argument deduction for the template name.
-      if (Sema::TemplateDeductionResult Result
-            = DeduceTemplateArguments(Context, 
-                                      SpecParam->getTemplateName(),
-                              TemplateName(SpecArg->getSpecializedTemplate()),
-                                      Info, Deduced))
-          return Result;
-
-      // FIXME: Can the # of arguments in the parameter and the argument differ?
-      unsigned NumArgs = SpecParam->getNumArgs();
-      const TemplateArgumentList &ArgArgs = SpecArg->getTemplateArgs();
-      if (NumArgs != ArgArgs.size())
-        return Sema::TDK_NonDeducedMismatch;
-
-      for (unsigned I = 0; I != NumArgs; ++I)
-        if (Sema::TemplateDeductionResult Result
-              = DeduceTemplateArguments(Context, TemplateParams,
-                                        SpecParam->getArg(I),
-                                        ArgArgs.get(I),
-                                        Info, Deduced))
-          return Result;
+            // If we have already seen this type, skip it.
+            if (!Visited.insert(NextT))
+              continue;
+           
+            // If this is a base class, try to perform template argument
+            // deduction from it.
+            if (NextT != RecordT) {
+              Sema::TemplateDeductionResult BaseResult
+                = DeduceTemplateArguments(Context, TemplateParams, SpecParam,
+                                          QualType(NextT, 0), Info, Deduced);
+              
+              // If template argument deduction for this base was successful,
+              // note that we had some success.
+              if (BaseResult == Sema::TDK_Success)
+                Successful = true;
+              // If deduction against this base resulted in an inconsistent
+              // set of deduced template arguments, template argument
+              // deduction fails.
+              else if (BaseResult == Sema::TDK_Inconsistent)
+                return BaseResult;
+            }
+            
+            // Visit base classes
+            CXXRecordDecl *Next = cast<CXXRecordDecl>(NextT->getDecl());
+            for (CXXRecordDecl::base_class_iterator Base = Next->bases_begin(),
+                                                 BaseEnd = Next->bases_end();
+               Base != BaseEnd; ++Base) {
+              assert(Base->getType()->isRecordType() && 
+                     "Base class that isn't a record?");
+              ToVisit.push_back(Base->getType()->getAsRecordType());
+            }
+          }
+          
+          if (Successful)
+            return Sema::TDK_Success;
+        }
+        
+      }
       
-      return Sema::TDK_Success;
+      return Result;
     }
 
     //     T type::*
index c014c663598cc4425c2670d8613b40fc396fd1f9..596427adf9efd6db3044fa7f46007385ee87fa16 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: clang-cc -fsyntax-only %s
+// RUN: clang-cc -fsyntax-only -verify %s
 
 template<typename T> struct A { };
 
@@ -57,4 +57,30 @@ void test_f3(int ***ip, volatile int ***vip) {
   A<volatile int> a1 = f3(vip);
 }
                              
-// FIXME: the next bullet requires a lot of effort.
+//   - If P is a class, and P has the form template-id, then A can be a 
+//     derived class of the deduced A. Likewise, if P is a pointer to a class
+//     of the form template-id, A can be a pointer to a derived class pointed 
+//     to by the deduced A.
+template<typename T, int I> struct C { };
+
+struct D : public C<int, 1> { };
+struct E : public D { };
+struct F : A<float> { };
+
+template<typename T, int I>
+  C<T, I> *f4a(const C<T, I>&);
+template<typename T, int I>
+  C<T, I> *f4b(C<T, I>);
+template<typename T, int I>
+  C<T, I> *f4c(C<T, I>*);
+int *f4c(...);
+
+void test_f4(D d, E e, F f) {
+  C<int, 1> *ci1a = f4a(d);
+  C<int, 1> *ci2a = f4a(e);
+  C<int, 1> *ci1b = f4b(d);
+  C<int, 1> *ci2b = f4b(e);
+  C<int, 1> *ci1c = f4c(&d);
+  C<int, 1> *ci2c = f4c(&e);
+  int       *ip1 = f4c(&f);
+}