From c874b4b3b131a32aeab924ae30b4568dd453346a Mon Sep 17 00:00:00 2001 From: Alexey Bataev Date: Tue, 18 Jul 2017 20:17:46 +0000 Subject: [PATCH] [OPENMP] Initial support for 'task_reduction' clause. Parsing/sema analysis of the 'task_reduction' clause. git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@308352 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/clang/AST/OpenMPClause.h | 211 ++++++++++++++ include/clang/AST/RecursiveASTVisitor.h | 22 ++ include/clang/AST/StmtOpenMP.h | 34 ++- include/clang/Basic/DiagnosticSemaKinds.td | 6 +- include/clang/Basic/OpenMPKinds.def | 8 + include/clang/Sema/Sema.h | 10 +- lib/AST/OpenMPClause.cpp | 57 ++++ lib/AST/StmtOpenMP.cpp | 21 +- lib/AST/StmtPrinter.cpp | 25 +- lib/AST/StmtProfile.cpp | 24 ++ lib/Basic/OpenMPKinds.cpp | 16 +- lib/CodeGen/CGStmtOpenMP.cpp | 1 + lib/Parse/ParseOpenMP.cpp | 23 +- lib/Sema/SemaOpenMP.cpp | 72 +++-- lib/Sema/TreeTransform.h | 60 ++++ lib/Serialization/ASTReaderStmt.cpp | 42 ++- lib/Serialization/ASTWriterStmt.cpp | 20 ++ test/OpenMP/taskgroup_ast_print.cpp | 60 +++- test/OpenMP/taskgroup_messages.cpp | 2 + .../taskgroup_task_reduction_messages.cpp | 258 ++++++++++++++++++ tools/libclang/CIndex.cpp | 17 ++ 21 files changed, 929 insertions(+), 60 deletions(-) create mode 100644 test/OpenMP/taskgroup_task_reduction_messages.cpp diff --git a/include/clang/AST/OpenMPClause.h b/include/clang/AST/OpenMPClause.h index 14e73819f5..a1cae8e18f 100644 --- a/include/clang/AST/OpenMPClause.h +++ b/include/clang/AST/OpenMPClause.h @@ -1890,6 +1890,217 @@ public: } }; +/// This represents clause 'task_reduction' in the '#pragma omp taskgroup' +/// directives. +/// +/// \code +/// #pragma omp taskgroup task_reduction(+:a,b) +/// \endcode +/// In this example directive '#pragma omp taskgroup' has clause +/// 'task_reduction' with operator '+' and the variables 'a' and 'b'. +/// +class OMPTaskReductionClause final + : public OMPVarListClause, + public OMPClauseWithPostUpdate, + private llvm::TrailingObjects { + friend TrailingObjects; + friend OMPVarListClause; + friend class OMPClauseReader; + /// Location of ':'. + SourceLocation ColonLoc; + /// Nested name specifier for C++. + NestedNameSpecifierLoc QualifierLoc; + /// Name of custom operator. + DeclarationNameInfo NameInfo; + + /// Build clause with number of variables \a N. + /// + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param EndLoc Ending location of the clause. + /// \param ColonLoc Location of ':'. + /// \param N Number of the variables in the clause. + /// \param QualifierLoc The nested-name qualifier with location information + /// \param NameInfo The full name info for reduction identifier. + /// + OMPTaskReductionClause(SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation ColonLoc, SourceLocation EndLoc, + unsigned N, NestedNameSpecifierLoc QualifierLoc, + const DeclarationNameInfo &NameInfo) + : OMPVarListClause(OMPC_task_reduction, StartLoc, + LParenLoc, EndLoc, N), + OMPClauseWithPostUpdate(this), ColonLoc(ColonLoc), + QualifierLoc(QualifierLoc), NameInfo(NameInfo) {} + + /// Build an empty clause. + /// + /// \param N Number of variables. + /// + explicit OMPTaskReductionClause(unsigned N) + : OMPVarListClause( + OMPC_task_reduction, SourceLocation(), SourceLocation(), + SourceLocation(), N), + OMPClauseWithPostUpdate(this), ColonLoc(), QualifierLoc(), NameInfo() {} + + /// Sets location of ':' symbol in clause. + void setColonLoc(SourceLocation CL) { ColonLoc = CL; } + /// Sets the name info for specified reduction identifier. + void setNameInfo(DeclarationNameInfo DNI) { NameInfo = DNI; } + /// Sets the nested name specifier. + void setQualifierLoc(NestedNameSpecifierLoc NSL) { QualifierLoc = NSL; } + + /// Set list of helper expressions, required for proper codegen of the clause. + /// These expressions represent private copy of the reduction variable. + void setPrivates(ArrayRef Privates); + + /// Get the list of helper privates. + MutableArrayRef getPrivates() { + return MutableArrayRef(varlist_end(), varlist_size()); + } + ArrayRef getPrivates() const { + return llvm::makeArrayRef(varlist_end(), varlist_size()); + } + + /// Set list of helper expressions, required for proper codegen of the clause. + /// These expressions represent LHS expression in the final reduction + /// expression performed by the reduction clause. + void setLHSExprs(ArrayRef LHSExprs); + + /// Get the list of helper LHS expressions. + MutableArrayRef getLHSExprs() { + return MutableArrayRef(getPrivates().end(), varlist_size()); + } + ArrayRef getLHSExprs() const { + return llvm::makeArrayRef(getPrivates().end(), varlist_size()); + } + + /// Set list of helper expressions, required for proper codegen of the clause. + /// These expressions represent RHS expression in the final reduction + /// expression performed by the reduction clause. Also, variables in these + /// expressions are used for proper initialization of reduction copies. + void setRHSExprs(ArrayRef RHSExprs); + + /// Get the list of helper destination expressions. + MutableArrayRef getRHSExprs() { + return MutableArrayRef(getLHSExprs().end(), varlist_size()); + } + ArrayRef getRHSExprs() const { + return llvm::makeArrayRef(getLHSExprs().end(), varlist_size()); + } + + /// Set list of helper reduction expressions, required for proper + /// codegen of the clause. These expressions are binary expressions or + /// operator/custom reduction call that calculates new value from source + /// helper expressions to destination helper expressions. + void setReductionOps(ArrayRef ReductionOps); + + /// Get the list of helper reduction expressions. + MutableArrayRef getReductionOps() { + return MutableArrayRef(getRHSExprs().end(), varlist_size()); + } + ArrayRef getReductionOps() const { + return llvm::makeArrayRef(getRHSExprs().end(), varlist_size()); + } + +public: + /// Creates clause with a list of variables \a VL. + /// + /// \param StartLoc Starting location of the clause. + /// \param LParenLoc Location of '('. + /// \param ColonLoc Location of ':'. + /// \param EndLoc Ending location of the clause. + /// \param VL The variables in the clause. + /// \param QualifierLoc The nested-name qualifier with location information + /// \param NameInfo The full name info for reduction identifier. + /// \param Privates List of helper expressions for proper generation of + /// private copies. + /// \param LHSExprs List of helper expressions for proper generation of + /// assignment operation required for copyprivate clause. This list represents + /// LHSs of the reduction expressions. + /// \param RHSExprs List of helper expressions for proper generation of + /// assignment operation required for copyprivate clause. This list represents + /// RHSs of the reduction expressions. + /// Also, variables in these expressions are used for proper initialization of + /// reduction copies. + /// \param ReductionOps List of helper expressions that represents reduction + /// expressions: + /// \code + /// LHSExprs binop RHSExprs; + /// operator binop(LHSExpr, RHSExpr); + /// (LHSExpr, RHSExpr); + /// \endcode + /// Required for proper codegen of final reduction operation performed by the + /// reduction clause. + /// \param PreInit Statement that must be executed before entering the OpenMP + /// region with this clause. + /// \param PostUpdate Expression that must be executed after exit from the + /// OpenMP region with this clause. + /// + static OMPTaskReductionClause * + Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation ColonLoc, SourceLocation EndLoc, ArrayRef VL, + NestedNameSpecifierLoc QualifierLoc, + const DeclarationNameInfo &NameInfo, ArrayRef Privates, + ArrayRef LHSExprs, ArrayRef RHSExprs, + ArrayRef ReductionOps, Stmt *PreInit, Expr *PostUpdate); + + /// Creates an empty clause with the place for \a N variables. + /// + /// \param C AST context. + /// \param N The number of variables. + /// + static OMPTaskReductionClause *CreateEmpty(const ASTContext &C, unsigned N); + + /// Gets location of ':' symbol in clause. + SourceLocation getColonLoc() const { return ColonLoc; } + /// Gets the name info for specified reduction identifier. + const DeclarationNameInfo &getNameInfo() const { return NameInfo; } + /// Gets the nested name specifier. + NestedNameSpecifierLoc getQualifierLoc() const { return QualifierLoc; } + + typedef MutableArrayRef::iterator helper_expr_iterator; + typedef ArrayRef::iterator helper_expr_const_iterator; + typedef llvm::iterator_range helper_expr_range; + typedef llvm::iterator_range + helper_expr_const_range; + + helper_expr_const_range privates() const { + return helper_expr_const_range(getPrivates().begin(), getPrivates().end()); + } + helper_expr_range privates() { + return helper_expr_range(getPrivates().begin(), getPrivates().end()); + } + helper_expr_const_range lhs_exprs() const { + return helper_expr_const_range(getLHSExprs().begin(), getLHSExprs().end()); + } + helper_expr_range lhs_exprs() { + return helper_expr_range(getLHSExprs().begin(), getLHSExprs().end()); + } + helper_expr_const_range rhs_exprs() const { + return helper_expr_const_range(getRHSExprs().begin(), getRHSExprs().end()); + } + helper_expr_range rhs_exprs() { + return helper_expr_range(getRHSExprs().begin(), getRHSExprs().end()); + } + helper_expr_const_range reduction_ops() const { + return helper_expr_const_range(getReductionOps().begin(), + getReductionOps().end()); + } + helper_expr_range reduction_ops() { + return helper_expr_range(getReductionOps().begin(), + getReductionOps().end()); + } + + child_range children() { + return child_range(reinterpret_cast(varlist_begin()), + reinterpret_cast(varlist_end())); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == OMPC_task_reduction; + } +}; + /// \brief This represents clause 'linear' in the '#pragma omp ...' /// directives. /// diff --git a/include/clang/AST/RecursiveASTVisitor.h b/include/clang/AST/RecursiveASTVisitor.h index 917b240428..e7f271cc08 100644 --- a/include/clang/AST/RecursiveASTVisitor.h +++ b/include/clang/AST/RecursiveASTVisitor.h @@ -3016,6 +3016,28 @@ RecursiveASTVisitor::VisitOMPReductionClause(OMPReductionClause *C) { return true; } +template +bool RecursiveASTVisitor::VisitOMPTaskReductionClause( + OMPTaskReductionClause *C) { + TRY_TO(TraverseNestedNameSpecifierLoc(C->getQualifierLoc())); + TRY_TO(TraverseDeclarationNameInfo(C->getNameInfo())); + TRY_TO(VisitOMPClauseList(C)); + TRY_TO(VisitOMPClauseWithPostUpdate(C)); + for (auto *E : C->privates()) { + TRY_TO(TraverseStmt(E)); + } + for (auto *E : C->lhs_exprs()) { + TRY_TO(TraverseStmt(E)); + } + for (auto *E : C->rhs_exprs()) { + TRY_TO(TraverseStmt(E)); + } + for (auto *E : C->reduction_ops()) { + TRY_TO(TraverseStmt(E)); + } + return true; +} + template bool RecursiveASTVisitor::VisitOMPFlushClause(OMPFlushClause *C) { TRY_TO(VisitOMPClauseList(C)); diff --git a/include/clang/AST/StmtOpenMP.h b/include/clang/AST/StmtOpenMP.h index 463af06fdd..09dd87fdc8 100644 --- a/include/clang/AST/StmtOpenMP.h +++ b/include/clang/AST/StmtOpenMP.h @@ -1895,7 +1895,7 @@ public: } }; -/// \brief This represents '#pragma omp taskgroup' directive. +/// This represents '#pragma omp taskgroup' directive. /// /// \code /// #pragma omp taskgroup @@ -1903,39 +1903,45 @@ public: /// class OMPTaskgroupDirective : public OMPExecutableDirective { friend class ASTStmtReader; - /// \brief Build directive with the given start and end location. + /// Build directive with the given start and end location. /// /// \param StartLoc Starting location of the directive kind. /// \param EndLoc Ending location of the directive. + /// \param NumClauses Number of clauses. /// - OMPTaskgroupDirective(SourceLocation StartLoc, SourceLocation EndLoc) + OMPTaskgroupDirective(SourceLocation StartLoc, SourceLocation EndLoc, + unsigned NumClauses) : OMPExecutableDirective(this, OMPTaskgroupDirectiveClass, OMPD_taskgroup, - StartLoc, EndLoc, 0, 1) {} + StartLoc, EndLoc, NumClauses, 1) {} - /// \brief Build an empty directive. + /// Build an empty directive. + /// \param NumClauses Number of clauses. /// - explicit OMPTaskgroupDirective() + explicit OMPTaskgroupDirective(unsigned NumClauses) : OMPExecutableDirective(this, OMPTaskgroupDirectiveClass, OMPD_taskgroup, - SourceLocation(), SourceLocation(), 0, 1) {} + SourceLocation(), SourceLocation(), NumClauses, + 1) {} public: - /// \brief Creates directive. + /// Creates directive. /// /// \param C AST context. /// \param StartLoc Starting location of the directive kind. /// \param EndLoc Ending Location of the directive. + /// \param Clauses List of clauses. /// \param AssociatedStmt Statement, associated with the directive. /// - static OMPTaskgroupDirective *Create(const ASTContext &C, - SourceLocation StartLoc, - SourceLocation EndLoc, - Stmt *AssociatedStmt); + static OMPTaskgroupDirective * + Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef Clauses, Stmt *AssociatedStmt); - /// \brief Creates an empty directive. + /// Creates an empty directive. /// /// \param C AST context. + /// \param NumClauses Number of clauses. /// - static OMPTaskgroupDirective *CreateEmpty(const ASTContext &C, EmptyShell); + static OMPTaskgroupDirective *CreateEmpty(const ASTContext &C, + unsigned NumClauses, EmptyShell); static bool classof(const Stmt *T) { return T->getStmtClass() == OMPTaskgroupDirectiveClass; diff --git a/include/clang/Basic/DiagnosticSemaKinds.td b/include/clang/Basic/DiagnosticSemaKinds.td index 09ad1c5b2a..af14638e1d 100644 --- a/include/clang/Basic/DiagnosticSemaKinds.td +++ b/include/clang/Basic/DiagnosticSemaKinds.td @@ -8655,11 +8655,11 @@ def err_omp_unknown_reduction_identifier : Error< def err_omp_not_resolved_reduction_identifier : Error< "unable to resolve declare reduction construct for type %0">; def err_omp_reduction_ref_type_arg : Error< - "argument of OpenMP clause 'reduction' must reference the same object in all threads">; + "argument of OpenMP clause '%0' must reference the same object in all threads">; def err_omp_clause_not_arithmetic_type_arg : Error< - "arguments of OpenMP clause 'reduction' for 'min' or 'max' must be of %select{scalar|arithmetic}0 type">; + "arguments of OpenMP clause '%0' for 'min' or 'max' must be of %select{scalar|arithmetic}1 type">; def err_omp_clause_floating_type_arg : Error< - "arguments of OpenMP clause 'reduction' with bitwise operators cannot be of floating type">; + "arguments of OpenMP clause '%0' with bitwise operators cannot be of floating type">; def err_omp_once_referenced : Error< "variable can appear only once in OpenMP '%0' clause">; def err_omp_once_referenced_in_target_update : Error< diff --git a/include/clang/Basic/OpenMPKinds.def b/include/clang/Basic/OpenMPKinds.def index aae1c3a9b8..645ed52b59 100644 --- a/include/clang/Basic/OpenMPKinds.def +++ b/include/clang/Basic/OpenMPKinds.def @@ -168,6 +168,9 @@ #ifndef OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE #define OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE(Name) #endif +#ifndef OPENMP_TASKGROUP_CLAUSE +#define OPENMP_TASKGROUP_CLAUSE(Name) +#endif // OpenMP directives. OPENMP_DIRECTIVE(threadprivate) @@ -270,6 +273,7 @@ OPENMP_CLAUSE(to, OMPToClause) OPENMP_CLAUSE(from, OMPFromClause) OPENMP_CLAUSE(use_device_ptr, OMPUseDevicePtrClause) OPENMP_CLAUSE(is_device_ptr, OMPIsDevicePtrClause) +OPENMP_CLAUSE(task_reduction, OMPTaskReductionClause) // Clauses allowed for OpenMP directive 'parallel'. OPENMP_PARALLEL_CLAUSE(if) @@ -848,6 +852,10 @@ OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE(aligned) OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE(safelen) OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE(simdlen) +// Clauses allowed for OpenMP directive 'taskgroup'. +OPENMP_TASKGROUP_CLAUSE(task_reduction) + +#undef OPENMP_TASKGROUP_CLAUSE #undef OPENMP_TASKLOOP_SIMD_CLAUSE #undef OPENMP_TASKLOOP_CLAUSE #undef OPENMP_LINEAR_KIND diff --git a/include/clang/Sema/Sema.h b/include/clang/Sema/Sema.h index 49d626755e..1fa3f831d9 100644 --- a/include/clang/Sema/Sema.h +++ b/include/clang/Sema/Sema.h @@ -8682,7 +8682,8 @@ public: StmtResult ActOnOpenMPTaskwaitDirective(SourceLocation StartLoc, SourceLocation EndLoc); /// \brief Called on well-formed '\#pragma omp taskgroup'. - StmtResult ActOnOpenMPTaskgroupDirective(Stmt *AStmt, SourceLocation StartLoc, + StmtResult ActOnOpenMPTaskgroupDirective(ArrayRef Clauses, + Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc); /// \brief Called on well-formed '\#pragma omp flush'. StmtResult ActOnOpenMPFlushDirective(ArrayRef Clauses, @@ -9024,6 +9025,13 @@ public: CXXScopeSpec &ReductionIdScopeSpec, const DeclarationNameInfo &ReductionId, ArrayRef UnresolvedReductions = llvm::None); + /// Called on well-formed 'task_reduction' clause. + OMPClause *ActOnOpenMPTaskReductionClause( + ArrayRef VarList, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation ColonLoc, SourceLocation EndLoc, + CXXScopeSpec &ReductionIdScopeSpec, + const DeclarationNameInfo &ReductionId, + ArrayRef UnresolvedReductions = llvm::None); /// \brief Called on well-formed 'linear' clause. OMPClause * ActOnOpenMPLinearClause(ArrayRef VarList, Expr *Step, diff --git a/lib/AST/OpenMPClause.cpp b/lib/AST/OpenMPClause.cpp index 77470a9b76..2c4d159a1b 100644 --- a/lib/AST/OpenMPClause.cpp +++ b/lib/AST/OpenMPClause.cpp @@ -46,6 +46,8 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) { return static_cast(C); case OMPC_reduction: return static_cast(C); + case OMPC_task_reduction: + return static_cast(C); case OMPC_linear: return static_cast(C); case OMPC_if: @@ -112,6 +114,8 @@ const OMPClauseWithPostUpdate *OMPClauseWithPostUpdate::get(const OMPClause *C) return static_cast(C); case OMPC_reduction: return static_cast(C); + case OMPC_task_reduction: + return static_cast(C); case OMPC_linear: return static_cast(C); case OMPC_schedule: @@ -505,6 +509,59 @@ OMPReductionClause *OMPReductionClause::CreateEmpty(const ASTContext &C, return new (Mem) OMPReductionClause(N); } +void OMPTaskReductionClause::setPrivates(ArrayRef Privates) { + assert(Privates.size() == varlist_size() && + "Number of private copies is not the same as the preallocated buffer"); + std::copy(Privates.begin(), Privates.end(), varlist_end()); +} + +void OMPTaskReductionClause::setLHSExprs(ArrayRef LHSExprs) { + assert( + LHSExprs.size() == varlist_size() && + "Number of LHS expressions is not the same as the preallocated buffer"); + std::copy(LHSExprs.begin(), LHSExprs.end(), getPrivates().end()); +} + +void OMPTaskReductionClause::setRHSExprs(ArrayRef RHSExprs) { + assert( + RHSExprs.size() == varlist_size() && + "Number of RHS expressions is not the same as the preallocated buffer"); + std::copy(RHSExprs.begin(), RHSExprs.end(), getLHSExprs().end()); +} + +void OMPTaskReductionClause::setReductionOps(ArrayRef ReductionOps) { + assert(ReductionOps.size() == varlist_size() && "Number of task reduction " + "expressions is not the same " + "as the preallocated buffer"); + std::copy(ReductionOps.begin(), ReductionOps.end(), getRHSExprs().end()); +} + +OMPTaskReductionClause *OMPTaskReductionClause::Create( + const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc, SourceLocation ColonLoc, ArrayRef VL, + NestedNameSpecifierLoc QualifierLoc, const DeclarationNameInfo &NameInfo, + ArrayRef Privates, ArrayRef LHSExprs, + ArrayRef RHSExprs, ArrayRef ReductionOps, Stmt *PreInit, + Expr *PostUpdate) { + void *Mem = C.Allocate(totalSizeToAlloc(5 * VL.size())); + OMPTaskReductionClause *Clause = new (Mem) OMPTaskReductionClause( + StartLoc, LParenLoc, EndLoc, ColonLoc, VL.size(), QualifierLoc, NameInfo); + Clause->setVarRefs(VL); + Clause->setPrivates(Privates); + Clause->setLHSExprs(LHSExprs); + Clause->setRHSExprs(RHSExprs); + Clause->setReductionOps(ReductionOps); + Clause->setPreInitStmt(PreInit); + Clause->setPostUpdateExpr(PostUpdate); + return Clause; +} + +OMPTaskReductionClause *OMPTaskReductionClause::CreateEmpty(const ASTContext &C, + unsigned N) { + void *Mem = C.Allocate(totalSizeToAlloc(5 * N)); + return new (Mem) OMPTaskReductionClause(N); +} + OMPFlushClause *OMPFlushClause::Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc, diff --git a/lib/AST/StmtOpenMP.cpp b/lib/AST/StmtOpenMP.cpp index cccb2f075b..1dcb4fd519 100644 --- a/lib/AST/StmtOpenMP.cpp +++ b/lib/AST/StmtOpenMP.cpp @@ -522,23 +522,28 @@ OMPTaskwaitDirective *OMPTaskwaitDirective::CreateEmpty(const ASTContext &C, return new (Mem) OMPTaskwaitDirective(); } -OMPTaskgroupDirective *OMPTaskgroupDirective::Create(const ASTContext &C, - SourceLocation StartLoc, - SourceLocation EndLoc, - Stmt *AssociatedStmt) { - unsigned Size = llvm::alignTo(sizeof(OMPTaskgroupDirective), alignof(Stmt *)); +OMPTaskgroupDirective *OMPTaskgroupDirective::Create( + const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef Clauses, Stmt *AssociatedStmt) { + unsigned Size = llvm::alignTo(sizeof(OMPTaskgroupDirective) + + sizeof(OMPClause *) * Clauses.size(), + alignof(Stmt *)); void *Mem = C.Allocate(Size + sizeof(Stmt *)); OMPTaskgroupDirective *Dir = - new (Mem) OMPTaskgroupDirective(StartLoc, EndLoc); + new (Mem) OMPTaskgroupDirective(StartLoc, EndLoc, Clauses.size()); Dir->setAssociatedStmt(AssociatedStmt); + Dir->setClauses(Clauses); return Dir; } OMPTaskgroupDirective *OMPTaskgroupDirective::CreateEmpty(const ASTContext &C, + unsigned NumClauses, EmptyShell) { - unsigned Size = llvm::alignTo(sizeof(OMPTaskgroupDirective), alignof(Stmt *)); + unsigned Size = llvm::alignTo(sizeof(OMPTaskgroupDirective) + + sizeof(OMPClause *) * NumClauses, + alignof(Stmt *)); void *Mem = C.Allocate(Size + sizeof(Stmt *)); - return new (Mem) OMPTaskgroupDirective(); + return new (Mem) OMPTaskgroupDirective(NumClauses); } OMPCancellationPointDirective *OMPCancellationPointDirective::Create( diff --git a/lib/AST/StmtPrinter.cpp b/lib/AST/StmtPrinter.cpp index 21f5259c3c..5ebaa32b49 100644 --- a/lib/AST/StmtPrinter.cpp +++ b/lib/AST/StmtPrinter.cpp @@ -836,6 +836,29 @@ void OMPClausePrinter::VisitOMPReductionClause(OMPReductionClause *Node) { } } +void OMPClausePrinter::VisitOMPTaskReductionClause( + OMPTaskReductionClause *Node) { + if (!Node->varlist_empty()) { + OS << "task_reduction("; + NestedNameSpecifier *QualifierLoc = + Node->getQualifierLoc().getNestedNameSpecifier(); + OverloadedOperatorKind OOK = + Node->getNameInfo().getName().getCXXOverloadedOperator(); + if (QualifierLoc == nullptr && OOK != OO_None) { + // Print reduction identifier in C format + OS << getOperatorSpelling(OOK); + } else { + // Use C++ format + if (QualifierLoc != nullptr) + QualifierLoc->print(OS, Policy); + OS << Node->getNameInfo(); + } + OS << ":"; + VisitOMPClauseList(Node, ' '); + OS << ")"; + } +} + void OMPClausePrinter::VisitOMPLinearClause(OMPLinearClause *Node) { if (!Node->varlist_empty()) { OS << "linear"; @@ -1081,7 +1104,7 @@ void StmtPrinter::VisitOMPTaskwaitDirective(OMPTaskwaitDirective *Node) { } void StmtPrinter::VisitOMPTaskgroupDirective(OMPTaskgroupDirective *Node) { - Indent() << "#pragma omp taskgroup"; + Indent() << "#pragma omp taskgroup "; PrintOMPExecutableDirective(Node); } diff --git a/lib/AST/StmtProfile.cpp b/lib/AST/StmtProfile.cpp index f1fbe2806b..e06386018d 100644 --- a/lib/AST/StmtProfile.cpp +++ b/lib/AST/StmtProfile.cpp @@ -549,6 +549,30 @@ void OMPClauseProfiler::VisitOMPReductionClause( Profiler->VisitStmt(E); } } +void OMPClauseProfiler::VisitOMPTaskReductionClause( + const OMPTaskReductionClause *C) { + Profiler->VisitNestedNameSpecifier( + C->getQualifierLoc().getNestedNameSpecifier()); + Profiler->VisitName(C->getNameInfo().getName()); + VisitOMPClauseList(C); + VistOMPClauseWithPostUpdate(C); + for (auto *E : C->privates()) { + if (E) + Profiler->VisitStmt(E); + } + for (auto *E : C->lhs_exprs()) { + if (E) + Profiler->VisitStmt(E); + } + for (auto *E : C->rhs_exprs()) { + if (E) + Profiler->VisitStmt(E); + } + for (auto *E : C->reduction_ops()) { + if (E) + Profiler->VisitStmt(E); + } +} void OMPClauseProfiler::VisitOMPLinearClause(const OMPLinearClause *C) { VisitOMPClauseList(C); VistOMPClauseWithPostUpdate(C); diff --git a/lib/Basic/OpenMPKinds.cpp b/lib/Basic/OpenMPKinds.cpp index 76a0e18c2d..050c0cc466 100644 --- a/lib/Basic/OpenMPKinds.cpp +++ b/lib/Basic/OpenMPKinds.cpp @@ -138,6 +138,7 @@ unsigned clang::getOpenMPSimpleClauseType(OpenMPClauseKind Kind, case OMPC_lastprivate: case OMPC_shared: case OMPC_reduction: + case OMPC_task_reduction: case OMPC_aligned: case OMPC_copyin: case OMPC_copyprivate: @@ -277,6 +278,7 @@ const char *clang::getOpenMPSimpleClauseTypeName(OpenMPClauseKind Kind, case OMPC_lastprivate: case OMPC_shared: case OMPC_reduction: + case OMPC_task_reduction: case OMPC_aligned: case OMPC_copyin: case OMPC_copyprivate: @@ -705,6 +707,16 @@ bool clang::isAllowedClauseForDirective(OpenMPDirectiveKind DKind, #define OPENMP_TARGET_TEAMS_DISTRIBUTE_SIMD_CLAUSE(Name) \ case OMPC_##Name: \ return true; +#include "clang/Basic/OpenMPKinds.def" + default: + break; + } + break; + case OMPD_taskgroup: + switch (CKind) { +#define OPENMP_TASKGROUP_CLAUSE(Name) \ + case OMPC_##Name: \ + return true; #include "clang/Basic/OpenMPKinds.def" default: break; @@ -719,7 +731,6 @@ bool clang::isAllowedClauseForDirective(OpenMPDirectiveKind DKind, case OMPD_taskyield: case OMPD_barrier: case OMPD_taskwait: - case OMPD_taskgroup: case OMPD_cancellation_point: case OMPD_declare_reduction: break; @@ -840,7 +851,8 @@ bool clang::isOpenMPDistributeDirective(OpenMPDirectiveKind Kind) { bool clang::isOpenMPPrivate(OpenMPClauseKind Kind) { return Kind == OMPC_private || Kind == OMPC_firstprivate || Kind == OMPC_lastprivate || Kind == OMPC_linear || - Kind == OMPC_reduction; // TODO add next clauses like 'reduction'. + Kind == OMPC_reduction || + Kind == OMPC_task_reduction; // TODO add next clauses like 'reduction'. } bool clang::isOpenMPThreadPrivate(OpenMPClauseKind Kind) { diff --git a/lib/CodeGen/CGStmtOpenMP.cpp b/lib/CodeGen/CGStmtOpenMP.cpp index ad4b6e5e11..6135cf31d1 100644 --- a/lib/CodeGen/CGStmtOpenMP.cpp +++ b/lib/CodeGen/CGStmtOpenMP.cpp @@ -3466,6 +3466,7 @@ static void EmitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind, case OMPC_firstprivate: case OMPC_lastprivate: case OMPC_reduction: + case OMPC_task_reduction: case OMPC_safelen: case OMPC_simdlen: case OMPC_collapse: diff --git a/lib/Parse/ParseOpenMP.cpp b/lib/Parse/ParseOpenMP.cpp index 2e5e36242e..d9a088595a 100644 --- a/lib/Parse/ParseOpenMP.cpp +++ b/lib/Parse/ParseOpenMP.cpp @@ -1102,7 +1102,7 @@ bool Parser::ParseOpenMPSimpleVarList( /// simdlen-clause | threads-clause | simd-clause | num_teams-clause | /// thread_limit-clause | priority-clause | grainsize-clause | /// nogroup-clause | num_tasks-clause | hint-clause | to-clause | -/// from-clause | is_device_ptr-clause +/// from-clause | is_device_ptr-clause | task_reduction-clause /// OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind, OpenMPClauseKind CKind, bool FirstClause) { @@ -1220,6 +1220,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind, case OMPC_lastprivate: case OMPC_shared: case OMPC_reduction: + case OMPC_task_reduction: case OMPC_linear: case OMPC_aligned: case OMPC_copyin: @@ -1585,7 +1586,7 @@ bool Parser::ParseOpenMPVarList(OpenMPDirectiveKind DKind, BalancedDelimiterTracker LinearT(*this, tok::l_paren, tok::annot_pragma_openmp_end); // Handle reduction-identifier for reduction clause. - if (Kind == OMPC_reduction) { + if (Kind == OMPC_reduction || Kind == OMPC_task_reduction) { ColonProtectionRAIIObject ColonRAII(*this); if (getLangOpts().CPlusPlus) ParseOptionalCXXScopeSpecifier(Data.ReductionIdScopeSpec, @@ -1733,13 +1734,13 @@ bool Parser::ParseOpenMPVarList(OpenMPDirectiveKind DKind, Diag(Tok, diag::warn_pragma_expected_colon) << "map type"; } - bool IsComma = - (Kind != OMPC_reduction && Kind != OMPC_depend && Kind != OMPC_map) || - (Kind == OMPC_reduction && !InvalidReductionId) || - (Kind == OMPC_map && Data.MapType != OMPC_MAP_unknown && - (!MapTypeModifierSpecified || - Data.MapTypeModifier == OMPC_MAP_always)) || - (Kind == OMPC_depend && Data.DepKind != OMPC_DEPEND_unknown); + bool IsComma = (Kind != OMPC_reduction && Kind != OMPC_task_reduction && + Kind != OMPC_depend && Kind != OMPC_map) || + (Kind == OMPC_reduction && !InvalidReductionId) || + (Kind == OMPC_map && Data.MapType != OMPC_MAP_unknown && + (!MapTypeModifierSpecified || + Data.MapTypeModifier == OMPC_MAP_always)) || + (Kind == OMPC_depend && Data.DepKind != OMPC_DEPEND_unknown); const bool MayHaveTail = (Kind == OMPC_linear || Kind == OMPC_aligned); while (IsComma || (Tok.isNot(tok::r_paren) && Tok.isNot(tok::colon) && Tok.isNot(tok::annot_pragma_openmp_end))) { @@ -1795,7 +1796,7 @@ bool Parser::ParseOpenMPVarList(OpenMPDirectiveKind DKind, } /// \brief Parsing of OpenMP clause 'private', 'firstprivate', 'lastprivate', -/// 'shared', 'copyin', 'copyprivate', 'flush' or 'reduction'. +/// 'shared', 'copyin', 'copyprivate', 'flush', 'reduction' or 'task_reduction'. /// /// private-clause: /// 'private' '(' list ')' @@ -1811,6 +1812,8 @@ bool Parser::ParseOpenMPVarList(OpenMPDirectiveKind DKind, /// 'aligned' '(' list [ ':' alignment ] ')' /// reduction-clause: /// 'reduction' '(' reduction-identifier ':' list ')' +/// task_reduction-clause: +/// 'task_reduction' '(' reduction-identifier ':' list ')' /// copyprivate-clause: /// 'copyprivate' '(' list ')' /// flush-clause: diff --git a/lib/Sema/SemaOpenMP.cpp b/lib/Sema/SemaOpenMP.cpp index fda63f4c2e..01f574b6ae 100644 --- a/lib/Sema/SemaOpenMP.cpp +++ b/lib/Sema/SemaOpenMP.cpp @@ -2500,9 +2500,8 @@ StmtResult Sema::ActOnOpenMPExecutableDirective( Res = ActOnOpenMPTaskwaitDirective(StartLoc, EndLoc); break; case OMPD_taskgroup: - assert(ClausesWithImplicit.empty() && - "No clauses are allowed for 'omp taskgroup' directive"); - Res = ActOnOpenMPTaskgroupDirective(AStmt, StartLoc, EndLoc); + Res = ActOnOpenMPTaskgroupDirective(ClausesWithImplicit, AStmt, StartLoc, + EndLoc); break; case OMPD_flush: assert(AStmt == nullptr && @@ -5069,7 +5068,8 @@ StmtResult Sema::ActOnOpenMPTaskwaitDirective(SourceLocation StartLoc, return OMPTaskwaitDirective::Create(Context, StartLoc, EndLoc); } -StmtResult Sema::ActOnOpenMPTaskgroupDirective(Stmt *AStmt, +StmtResult Sema::ActOnOpenMPTaskgroupDirective(ArrayRef Clauses, + Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc) { if (!AStmt) @@ -5079,7 +5079,8 @@ StmtResult Sema::ActOnOpenMPTaskgroupDirective(Stmt *AStmt, getCurFunction()->setHasBranchProtectedScope(); - return OMPTaskgroupDirective::Create(Context, StartLoc, EndLoc, AStmt); + return OMPTaskgroupDirective::Create(Context, StartLoc, EndLoc, Clauses, + AStmt); } StmtResult Sema::ActOnOpenMPFlushDirective(ArrayRef Clauses, @@ -6851,6 +6852,7 @@ OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr, case OMPC_lastprivate: case OMPC_shared: case OMPC_reduction: + case OMPC_task_reduction: case OMPC_linear: case OMPC_aligned: case OMPC_copyin: @@ -7154,6 +7156,7 @@ static OpenMPDirectiveKind getOpenMPCaptureRegionForClause( case OMPC_firstprivate: case OMPC_lastprivate: case OMPC_reduction: + case OMPC_task_reduction: case OMPC_linear: case OMPC_default: case OMPC_proc_bind: @@ -7469,6 +7472,7 @@ OMPClause *Sema::ActOnOpenMPSimpleClause( case OMPC_lastprivate: case OMPC_shared: case OMPC_reduction: + case OMPC_task_reduction: case OMPC_linear: case OMPC_aligned: case OMPC_copyin: @@ -7626,6 +7630,7 @@ OMPClause *Sema::ActOnOpenMPSingleExprWithArgClause( case OMPC_lastprivate: case OMPC_shared: case OMPC_reduction: + case OMPC_task_reduction: case OMPC_linear: case OMPC_aligned: case OMPC_copyin: @@ -7823,6 +7828,7 @@ OMPClause *Sema::ActOnOpenMPClause(OpenMPClauseKind Kind, case OMPC_lastprivate: case OMPC_shared: case OMPC_reduction: + case OMPC_task_reduction: case OMPC_linear: case OMPC_aligned: case OMPC_copyin: @@ -7935,6 +7941,11 @@ OMPClause *Sema::ActOnOpenMPVarListClause( Res = ActOnOpenMPReductionClause(VarList, StartLoc, LParenLoc, ColonLoc, EndLoc, ReductionIdScopeSpec, ReductionId); break; + case OMPC_task_reduction: + Res = ActOnOpenMPTaskReductionClause(VarList, StartLoc, LParenLoc, ColonLoc, + EndLoc, ReductionIdScopeSpec, + ReductionId); + break; case OMPC_linear: Res = ActOnOpenMPLinearClause(VarList, TailExpr, StartLoc, LParenLoc, LinKind, DepLinMapLoc, ColonLoc, EndLoc); @@ -8953,10 +8964,10 @@ struct ReductionData { } // namespace static bool ActOnOMPReductionKindClause( - Sema &S, DSAStackTy *Stack, ArrayRef VarList, - SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation ColonLoc, - SourceLocation EndLoc, CXXScopeSpec &ReductionIdScopeSpec, - const DeclarationNameInfo &ReductionId, + Sema &S, DSAStackTy *Stack, OpenMPClauseKind ClauseKind, + ArrayRef VarList, SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation ColonLoc, SourceLocation EndLoc, + CXXScopeSpec &ReductionIdScopeSpec, const DeclarationNameInfo &ReductionId, ArrayRef UnresolvedReductions, ReductionData &RD) { auto DN = ReductionId.getName(); auto OOK = DN.getCXXOverloadedOperator(); @@ -9111,8 +9122,7 @@ static bool ActOnOMPReductionKindClause( // A list item that appears in a reduction clause must not be // const-qualified. if (Type.getNonReferenceType().isConstant(Context)) { - S.Diag(ELoc, diag::err_omp_const_reduction_list_item) - << getOpenMPClauseName(OMPC_reduction) << Type << ERange; + S.Diag(ELoc, diag::err_omp_const_reduction_list_item) << ERange; if (!ASE && !OASE) { bool IsDecl = !VD || VD->isThisDeclarationADefinition(Context) == VarDecl::DeclarationOnly; @@ -9130,7 +9140,8 @@ static bool ActOnOMPReductionKindClause( if (VD->getType()->isReferenceType() && VDDef && VDDef->hasInit()) { DSARefChecker Check(Stack); if (Check.Visit(VDDef->getInit())) { - S.Diag(ELoc, diag::err_omp_reduction_ref_type_arg) << ERange; + S.Diag(ELoc, diag::err_omp_reduction_ref_type_arg) + << getOpenMPClauseName(ClauseKind) << ERange; S.Diag(VDDef->getLocation(), diag::note_defined_here) << VDDef; continue; } @@ -9152,7 +9163,7 @@ static bool ActOnOMPReductionKindClause( DVar = Stack->getTopDSA(D, false); if (DVar.CKind == OMPC_reduction) { S.Diag(ELoc, diag::err_omp_once_referenced) - << getOpenMPClauseName(OMPC_reduction); + << getOpenMPClauseName(ClauseKind); if (DVar.RefExpr) S.Diag(DVar.RefExpr->getExprLoc(), diag::note_omp_referenced); } else if (DVar.CKind != OMPC_unknown) { @@ -9216,7 +9227,7 @@ static bool ActOnOMPReductionKindClause( !(Type->isScalarType() || (S.getLangOpts().CPlusPlus && Type->isArithmeticType()))) { S.Diag(ELoc, diag::err_omp_clause_not_arithmetic_type_arg) - << S.getLangOpts().CPlusPlus; + << getOpenMPClauseName(ClauseKind) << S.getLangOpts().CPlusPlus; if (!ASE && !OASE) { bool IsDecl = !VD || VD->isThisDeclarationADefinition(Context) == VarDecl::DeclarationOnly; @@ -9228,7 +9239,8 @@ static bool ActOnOMPReductionKindClause( } if ((BOK == BO_OrAssign || BOK == BO_AndAssign || BOK == BO_XorAssign) && !S.getLangOpts().CPlusPlus && Type->isFloatingType()) { - S.Diag(ELoc, diag::err_omp_clause_floating_type_arg); + S.Diag(ELoc, diag::err_omp_clause_floating_type_arg) + << getOpenMPClauseName(ClauseKind); if (!ASE && !OASE) { bool IsDecl = !VD || VD->isThisDeclarationADefinition(Context) == VarDecl::DeclarationOnly; @@ -9481,6 +9493,8 @@ static bool ActOnOMPReductionKindClause( } } } + // All reduction items are still marked as reduction (to do not increase + // code base size). Stack->addDSA(D, RefExpr->IgnoreParens(), OMPC_reduction, Ref); RD.push(VarsExpr, PrivateDRE, LHSDRE, RHSDRE, ReductionOp.get()); } @@ -9494,9 +9508,10 @@ OMPClause *Sema::ActOnOpenMPReductionClause( ArrayRef UnresolvedReductions) { ReductionData RD(VarList.size()); - if (ActOnOMPReductionKindClause(*this, DSAStack, VarList, StartLoc, LParenLoc, - ColonLoc, EndLoc, ReductionIdScopeSpec, - ReductionId, UnresolvedReductions, RD)) + if (ActOnOMPReductionKindClause(*this, DSAStack, OMPC_reduction, VarList, + StartLoc, LParenLoc, ColonLoc, EndLoc, + ReductionIdScopeSpec, ReductionId, + UnresolvedReductions, RD)) return nullptr; return OMPReductionClause::Create( @@ -9507,6 +9522,27 @@ OMPClause *Sema::ActOnOpenMPReductionClause( buildPostUpdate(*this, RD.ExprPostUpdates)); } +OMPClause *Sema::ActOnOpenMPTaskReductionClause( + ArrayRef VarList, SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation ColonLoc, SourceLocation EndLoc, + CXXScopeSpec &ReductionIdScopeSpec, const DeclarationNameInfo &ReductionId, + ArrayRef UnresolvedReductions) { + ReductionData RD(VarList.size()); + + if (ActOnOMPReductionKindClause(*this, DSAStack, OMPC_task_reduction, + VarList, StartLoc, LParenLoc, ColonLoc, + EndLoc, ReductionIdScopeSpec, ReductionId, + UnresolvedReductions, RD)) + return nullptr; + + return OMPTaskReductionClause::Create( + Context, StartLoc, LParenLoc, ColonLoc, EndLoc, RD.Vars, + ReductionIdScopeSpec.getWithLocInContext(Context), ReductionId, + RD.Privates, RD.LHSs, RD.RHSs, RD.ReductionOps, + buildPreInits(Context, RD.ExprCaptures), + buildPostUpdate(*this, RD.ExprPostUpdates)); +} + bool Sema::CheckOpenMPLinearModifier(OpenMPLinearClauseKind LinKind, SourceLocation LinLoc) { if ((!LangOpts.CPlusPlus && LinKind != OMPC_LINEAR_val) || diff --git a/lib/Sema/TreeTransform.h b/lib/Sema/TreeTransform.h index 7aa8f64d50..91da9f88c5 100644 --- a/lib/Sema/TreeTransform.h +++ b/lib/Sema/TreeTransform.h @@ -1651,6 +1651,21 @@ public: ReductionId, UnresolvedReductions); } + /// Build a new OpenMP 'task_reduction' clause. + /// + /// By default, performs semantic analysis to build the new statement. + /// Subclasses may override this routine to provide different behavior. + OMPClause *RebuildOMPTaskReductionClause( + ArrayRef VarList, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation ColonLoc, SourceLocation EndLoc, + CXXScopeSpec &ReductionIdScopeSpec, + const DeclarationNameInfo &ReductionId, + ArrayRef UnresolvedReductions) { + return getSema().ActOnOpenMPTaskReductionClause( + VarList, StartLoc, LParenLoc, ColonLoc, EndLoc, ReductionIdScopeSpec, + ReductionId, UnresolvedReductions); + } + /// \brief Build a new OpenMP 'linear' clause. /// /// By default, performs semantic analysis to build the new OpenMP clause. @@ -8399,6 +8414,51 @@ TreeTransform::TransformOMPReductionClause(OMPReductionClause *C) { C->getLocEnd(), ReductionIdScopeSpec, NameInfo, UnresolvedReductions); } +template +OMPClause *TreeTransform::TransformOMPTaskReductionClause( + OMPTaskReductionClause *C) { + llvm::SmallVector Vars; + Vars.reserve(C->varlist_size()); + for (auto *VE : C->varlists()) { + ExprResult EVar = getDerived().TransformExpr(cast(VE)); + if (EVar.isInvalid()) + return nullptr; + Vars.push_back(EVar.get()); + } + CXXScopeSpec ReductionIdScopeSpec; + ReductionIdScopeSpec.Adopt(C->getQualifierLoc()); + + DeclarationNameInfo NameInfo = C->getNameInfo(); + if (NameInfo.getName()) { + NameInfo = getDerived().TransformDeclarationNameInfo(NameInfo); + if (!NameInfo.getName()) + return nullptr; + } + // Build a list of all UDR decls with the same names ranged by the Scopes. + // The Scope boundary is a duplication of the previous decl. + llvm::SmallVector UnresolvedReductions; + for (auto *E : C->reduction_ops()) { + // Transform all the decls. + if (E) { + auto *ULE = cast(E); + UnresolvedSet<8> Decls; + for (auto *D : ULE->decls()) { + NamedDecl *InstD = + cast(getDerived().TransformDecl(E->getExprLoc(), D)); + Decls.addDecl(InstD, InstD->getAccess()); + } + UnresolvedReductions.push_back(UnresolvedLookupExpr::Create( + SemaRef.Context, /*NamingClass=*/nullptr, + ReductionIdScopeSpec.getWithLocInContext(SemaRef.Context), NameInfo, + /*ADL=*/true, ULE->isOverloaded(), Decls.begin(), Decls.end())); + } else + UnresolvedReductions.push_back(nullptr); + } + return getDerived().RebuildOMPTaskReductionClause( + Vars, C->getLocStart(), C->getLParenLoc(), C->getColonLoc(), + C->getLocEnd(), ReductionIdScopeSpec, NameInfo, UnresolvedReductions); +} + template OMPClause * TreeTransform::TransformOMPLinearClause(OMPLinearClause *C) { diff --git a/lib/Serialization/ASTReaderStmt.cpp b/lib/Serialization/ASTReaderStmt.cpp index afee50ffa3..21adcddd3a 100644 --- a/lib/Serialization/ASTReaderStmt.cpp +++ b/lib/Serialization/ASTReaderStmt.cpp @@ -1834,6 +1834,9 @@ OMPClause *OMPClauseReader::readClause() { case OMPC_reduction: C = OMPReductionClause::CreateEmpty(Context, Reader->Record.readInt()); break; + case OMPC_task_reduction: + C = OMPTaskReductionClause::CreateEmpty(Context, Reader->Record.readInt()); + break; case OMPC_linear: C = OMPLinearClause::CreateEmpty(Context, Reader->Record.readInt()); break; @@ -2138,6 +2141,40 @@ void OMPClauseReader::VisitOMPReductionClause(OMPReductionClause *C) { C->setReductionOps(Vars); } +void OMPClauseReader::VisitOMPTaskReductionClause(OMPTaskReductionClause *C) { + VisitOMPClauseWithPostUpdate(C); + C->setLParenLoc(Reader->ReadSourceLocation()); + C->setColonLoc(Reader->ReadSourceLocation()); + NestedNameSpecifierLoc NNSL = Reader->Record.readNestedNameSpecifierLoc(); + DeclarationNameInfo DNI; + Reader->ReadDeclarationNameInfo(DNI); + C->setQualifierLoc(NNSL); + C->setNameInfo(DNI); + + unsigned NumVars = C->varlist_size(); + SmallVector Vars; + Vars.reserve(NumVars); + for (unsigned I = 0; I != NumVars; ++I) + Vars.push_back(Reader->Record.readSubExpr()); + C->setVarRefs(Vars); + Vars.clear(); + for (unsigned I = 0; I != NumVars; ++I) + Vars.push_back(Reader->Record.readSubExpr()); + C->setPrivates(Vars); + Vars.clear(); + for (unsigned I = 0; I != NumVars; ++I) + Vars.push_back(Reader->Record.readSubExpr()); + C->setLHSExprs(Vars); + Vars.clear(); + for (unsigned I = 0; I != NumVars; ++I) + Vars.push_back(Reader->Record.readSubExpr()); + C->setRHSExprs(Vars); + Vars.clear(); + for (unsigned I = 0; I != NumVars; ++I) + Vars.push_back(Reader->Record.readSubExpr()); + C->setReductionOps(Vars); +} + void OMPClauseReader::VisitOMPLinearClause(OMPLinearClause *C) { VisitOMPClauseWithPostUpdate(C); C->setLParenLoc(Reader->ReadSourceLocation()); @@ -2709,6 +2746,8 @@ void ASTStmtReader::VisitOMPTaskwaitDirective(OMPTaskwaitDirective *D) { void ASTStmtReader::VisitOMPTaskgroupDirective(OMPTaskgroupDirective *D) { VisitStmt(D); + // The NumClauses field was read in ReadStmtFromStream. + Record.skipInts(1); VisitOMPExecutableDirective(D); } @@ -3479,7 +3518,8 @@ Stmt *ASTReader::ReadStmtFromStream(ModuleFile &F) { break; case STMT_OMP_TASKGROUP_DIRECTIVE: - S = OMPTaskgroupDirective::CreateEmpty(Context, Empty); + S = OMPTaskgroupDirective::CreateEmpty( + Context, Record[ASTStmtReader::NumStmtFields], Empty); break; case STMT_OMP_FLUSH_DIRECTIVE: diff --git a/lib/Serialization/ASTWriterStmt.cpp b/lib/Serialization/ASTWriterStmt.cpp index 90a732e575..ae2e0b88c3 100644 --- a/lib/Serialization/ASTWriterStmt.cpp +++ b/lib/Serialization/ASTWriterStmt.cpp @@ -1963,6 +1963,25 @@ void OMPClauseWriter::VisitOMPReductionClause(OMPReductionClause *C) { Record.AddStmt(E); } +void OMPClauseWriter::VisitOMPTaskReductionClause(OMPTaskReductionClause *C) { + Record.push_back(C->varlist_size()); + VisitOMPClauseWithPostUpdate(C); + Record.AddSourceLocation(C->getLParenLoc()); + Record.AddSourceLocation(C->getColonLoc()); + Record.AddNestedNameSpecifierLoc(C->getQualifierLoc()); + Record.AddDeclarationNameInfo(C->getNameInfo()); + for (auto *VE : C->varlists()) + Record.AddStmt(VE); + for (auto *VE : C->privates()) + Record.AddStmt(VE); + for (auto *E : C->lhs_exprs()) + Record.AddStmt(E); + for (auto *E : C->rhs_exprs()) + Record.AddStmt(E); + for (auto *E : C->reduction_ops()) + Record.AddStmt(E); +} + void OMPClauseWriter::VisitOMPLinearClause(OMPLinearClause *C) { Record.push_back(C->varlist_size()); VisitOMPClauseWithPostUpdate(C); @@ -2440,6 +2459,7 @@ void ASTStmtWriter::VisitOMPTaskwaitDirective(OMPTaskwaitDirective *D) { void ASTStmtWriter::VisitOMPTaskgroupDirective(OMPTaskgroupDirective *D) { VisitStmt(D); + Record.push_back(D->getNumClauses()); VisitOMPExecutableDirective(D); Code = serialization::STMT_OMP_TASKGROUP_DIRECTIVE; } diff --git a/test/OpenMP/taskgroup_ast_print.cpp b/test/OpenMP/taskgroup_ast_print.cpp index 683141e686..4292eedfb5 100644 --- a/test/OpenMP/taskgroup_ast_print.cpp +++ b/test/OpenMP/taskgroup_ast_print.cpp @@ -8,6 +8,62 @@ void foo() {} +struct S1 { + S1(): a(0) {} + S1(int v) : a(v) {} + int a; + typedef int type; + S1& operator +(const S1&); + S1& operator *(const S1&); + S1& operator &&(const S1&); + S1& operator ^(const S1&); +}; + +template +class S7 : public T { +protected: + T a; + T b[100]; + S7() : a(0) {} + +public: + S7(typename T::type v) : a(v) { +#pragma omp taskgroup task_reduction(+ : a) task_reduction(*: b[:]) + for (int k = 0; k < a.a; ++k) + ++this->a.a; + } + S7 &operator=(S7 &s) { +#pragma omp taskgroup task_reduction(&& : this->a) task_reduction(^: b[s.a.a]) + for (int k = 0; k < s.a.a; ++k) + ++s.a.a; + return *this; + } +}; + +// CHECK: #pragma omp taskgroup task_reduction(+: this->a) task_reduction(*: this->b[:]) +// CHECK: #pragma omp taskgroup task_reduction(&&: this->a) task_reduction(^: this->b[s.a.a]) +// CHECK: #pragma omp taskgroup task_reduction(+: this->a) task_reduction(*: this->b[:]) + +class S8 : public S7 { + S8() {} + +public: + S8(int v) : S7(v){ +#pragma omp taskgroup task_reduction(^ : S7 < S1 > ::a) task_reduction(+ : S7 < S1 > ::b[ : S7 < S1 > ::a.a]) + for (int k = 0; k < a.a; ++k) + ++this->a.a; + } + S8 &operator=(S8 &s) { +#pragma omp taskgroup task_reduction(* : this->a) task_reduction(&&:this->b[a.a:]) + for (int k = 0; k < s.a.a; ++k) + ++s.a.a; + return *this; + } +}; + +// CHECK: #pragma omp taskgroup task_reduction(^: this->S7::a) task_reduction(+: this->S7::b[:this->S7::a.a]) +// CHECK: #pragma omp taskgroup task_reduction(*: this->a) task_reduction(&&: this->b[this->a.a:]) + int main (int argc, char **argv) { int b = argc, c, d, e, f, g; static int a; @@ -18,9 +74,9 @@ int main (int argc, char **argv) { // CHECK-NEXT: a = 2; // CHECK-NEXT: ++a; ++a; -#pragma omp taskgroup +#pragma omp taskgroup task_reduction(min: a) foo(); -// CHECK-NEXT: #pragma omp taskgroup +// CHECK-NEXT: #pragma omp taskgroup task_reduction(min: a) // CHECK-NEXT: foo(); // CHECK-NEXT: return 0; return 0; diff --git a/test/OpenMP/taskgroup_messages.cpp b/test/OpenMP/taskgroup_messages.cpp index e893da4be3..23820734a5 100644 --- a/test/OpenMP/taskgroup_messages.cpp +++ b/test/OpenMP/taskgroup_messages.cpp @@ -63,5 +63,7 @@ int foo() { foo(); } +#pragma omp taskgroup init // expected-warning {{extra tokens at the end of '#pragma omp taskgroup' are ignored}} + ; return 0; } diff --git a/test/OpenMP/taskgroup_task_reduction_messages.cpp b/test/OpenMP/taskgroup_task_reduction_messages.cpp new file mode 100644 index 0000000000..9c8177b0bd --- /dev/null +++ b/test/OpenMP/taskgroup_task_reduction_messages.cpp @@ -0,0 +1,258 @@ +// RUN: %clang_cc1 -verify -fopenmp -ferror-limit 150 -o - %s +// RUN: %clang_cc1 -verify -fopenmp -std=c++98 -ferror-limit 150 -o - %s +// RUN: %clang_cc1 -verify -fopenmp -std=c++11 -ferror-limit 150 -o - %s + +void foo() { +} + +bool foobool(int argc) { + return argc; +} + +void foobar(int &ref) { +#pragma omp taskgroup task_reduction(+:ref) + foo(); +} + +struct S1; // expected-note {{declared here}} expected-note 4 {{forward declaration of 'S1'}} +extern S1 a; +class S2 { + mutable int a; + S2 &operator+(const S2 &arg) { return (*this); } // expected-note 3 {{implicitly declared private here}} + +public: + S2() : a(0) {} + S2(S2 &s2) : a(s2.a) {} + static float S2s; // expected-note 2 {{static data member is predetermined as shared}} + static const float S2sc; +}; +const float S2::S2sc = 0; // expected-note 2 {{'S2sc' defined here}} +S2 b; // expected-note 3 {{'b' defined here}} +const S2 ba[5]; // expected-note 2 {{'ba' defined here}} +class S3 { + int a; + +public: + int b; + S3() : a(0) {} + S3(const S3 &s3) : a(s3.a) {} + S3 operator+(const S3 &arg1) { return arg1; } +}; +int operator+(const S3 &arg1, const S3 &arg2) { return 5; } +S3 c; // expected-note 3 {{'c' defined here}} +const S3 ca[5]; // expected-note 2 {{'ca' defined here}} +extern const int f; // expected-note 4 {{'f' declared here}} +class S4 { + int a; + S4(); // expected-note {{implicitly declared private here}} + S4(const S4 &s4); + S4 &operator+(const S4 &arg) { return (*this); } + +public: + S4(int v) : a(v) {} +}; +S4 &operator&=(S4 &arg1, S4 &arg2) { return arg1; } +class S5 { + int a; + S5() : a(0) {} // expected-note {{implicitly declared private here}} + S5(const S5 &s5) : a(s5.a) {} + S5 &operator+(const S5 &arg); + +public: + S5(int v) : a(v) {} +}; +class S6 { // expected-note 3 {{candidate function (the implicit copy assignment operator) not viable: no known conversion from 'int' to 'const S6' for 1st argument}} +#if __cplusplus >= 201103L // C++11 or later +// expected-note@-2 3 {{candidate function (the implicit move assignment operator) not viable}} +#endif + int a; + +public: + S6() : a(6) {} + operator int() { return 6; } +} o; + +S3 h, k; +#pragma omp threadprivate(h) // expected-note 2 {{defined as threadprivate or thread local}} + +template // expected-note {{declared here}} +T tmain(T argc) { + const T d = T(); // expected-note 4 {{'d' defined here}} + const T da[5] = {T()}; // expected-note 2 {{'da' defined here}} + T qa[5] = {T()}; + T i; + T &j = i; // expected-note 2 {{'j' defined here}} + S3 &p = k; // expected-note 2 {{'p' defined here}} + const T &r = da[(int)i]; // expected-note 2 {{'r' defined here}} + T &q = qa[(int)i]; + T fl; +#pragma omp taskgroup task_reduction // expected-error {{expected '(' after 'task_reduction'}} + foo(); +#pragma omp taskgroup task_reduction + // expected-error {{expected '(' after 'task_reduction'}} expected-warning {{extra tokens at the end of '#pragma omp taskgroup' are ignored}} + foo(); +#pragma omp taskgroup task_reduction( // expected-error {{expected unqualified-id}} expected-warning {{missing ':' after reduction identifier - ignoring}} expected-error {{expected ')'}} expected-note {{to match this '('}} + foo(); +#pragma omp taskgroup task_reduction(- // expected-warning {{missing ':' after reduction identifier - ignoring}} expected-error {{expected ')'}} expected-note {{to match this '('}} + foo(); +#pragma omp taskgroup task_reduction() // expected-error {{expected unqualified-id}} expected-warning {{missing ':' after reduction identifier - ignoring}} + foo(); +#pragma omp taskgroup task_reduction(*) // expected-warning {{missing ':' after reduction identifier - ignoring}} + foo(); +#pragma omp taskgroup task_reduction(\) // expected-error {{expected unqualified-id}} expected-warning {{missing ':' after reduction identifier - ignoring}} + foo(); +#pragma omp taskgroup task_reduction(& : argc // expected-error {{expected ')'}} expected-note {{to match this '('}} expected-error {{invalid operands to binary expression ('float' and 'float')}} + foo(); +#pragma omp taskgroup task_reduction(| : argc, // expected-error {{expected expression}} expected-error {{expected ')'}} expected-note {{to match this '('}} expected-error {{invalid operands to binary expression ('float' and 'float')}} + foo(); +#pragma omp taskgroup task_reduction(|| : argc ? i : argc) // expected-error 2 {{expected variable name, array element or array section}} + foo(); +#pragma omp taskgroup task_reduction(foo : argc) //expected-error {{incorrect reduction identifier, expected one of '+', '-', '*', '&', '|', '^', '&&', '||', 'min' or 'max' or declare reduction for type 'float'}} expected-error {{incorrect reduction identifier, expected one of '+', '-', '*', '&', '|', '^', '&&', '||', 'min' or 'max' or declare reduction for type 'int'}} + foo(); +#pragma omp taskgroup task_reduction(&& : argc) + foo(); +#pragma omp taskgroup task_reduction(^ : T) // expected-error {{'T' does not refer to a value}} + foo(); +#pragma omp taskgroup task_reduction(+ : a, b, c, d, f) // expected-error {{a reduction list item with incomplete type 'S1'}} expected-error 3 {{const-qualified list item cannot be reduction}} expected-error 2 {{'operator+' is a private member of 'S2'}} + foo(); +#pragma omp taskgroup task_reduction(min : a, b, c, d, f) // expected-error {{a reduction list item with incomplete type 'S1'}} expected-error 4 {{arguments of OpenMP clause 'task_reduction' for 'min' or 'max' must be of arithmetic type}} expected-error 3 {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(max : h.b) // expected-error {{expected variable name, array element or array section}} + foo(); +#pragma omp taskgroup task_reduction(+ : ba) // expected-error {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(* : ca) // expected-error {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(- : da) // expected-error {{const-qualified list item cannot be reduction}} expected-error {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(^ : fl) // expected-error {{invalid operands to binary expression ('float' and 'float')}} + foo(); +#pragma omp taskgroup task_reduction(&& : S2::S2s) // expected-error {{shared variable cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(&& : S2::S2sc) // expected-error {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(+ : h, k) // expected-error {{threadprivate or thread local variable cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(+ : o) // expected-error 2 {{no viable overloaded '='}} + foo(); +#pragma omp parallel private(k) +#pragma omp taskgroup task_reduction(+ : p), task_reduction(+ : p) // expected-error 2 {{argument of OpenMP clause 'task_reduction' must reference the same object in all threads}} + foo(); +#pragma omp taskgroup task_reduction(+ : p), task_reduction(+ : p) // expected-error 2 {{variable can appear only once in OpenMP 'task_reduction' clause}} expected-note 2 {{previously referenced here}} + foo(); +#pragma omp taskgroup task_reduction(+ : r) // expected-error 2 {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp parallel shared(i) +#pragma omp parallel reduction(min : i) +#pragma omp taskgroup task_reduction(max : j) // expected-error 2 {{argument of OpenMP clause 'task_reduction' must reference the same object in all threads}} + foo(); +#pragma omp parallel +#pragma omp for private(fl) + for (int i = 0; i < 10; ++i) +#pragma omp taskgroup task_reduction(+ : fl) + foo(); +#pragma omp parallel +#pragma omp for reduction(- : fl) + for (int i = 0; i < 10; ++i) +#pragma omp taskgroup task_reduction(+ : fl) + foo(); + + return T(); +} + +namespace A { +double x; +#pragma omp threadprivate(x) // expected-note {{defined as threadprivate or thread local}} +} +namespace B { +using A::x; +} + +int main(int argc, char **argv) { + const int d = 5; // expected-note 2 {{'d' defined here}} + const int da[5] = {0}; // expected-note {{'da' defined here}} + int qa[5] = {0}; + S4 e(4); + S5 g(5); + int i; + int &j = i; // expected-note {{'j' defined here}} + S3 &p = k; // expected-note 2 {{'p' defined here}} + const int &r = da[i]; // expected-note {{'r' defined here}} + int &q = qa[i]; + float fl; +#pragma omp taskgroup task_reduction // expected-error {{expected '(' after 'task_reduction'}} + foo(); +#pragma omp taskgroup task_reduction + // expected-error {{expected '(' after 'task_reduction'}} expected-warning {{extra tokens at the end of '#pragma omp taskgroup' are ignored}} + foo(); +#pragma omp taskgroup task_reduction( // expected-error {{expected unqualified-id}} expected-warning {{missing ':' after reduction identifier - ignoring}} expected-error {{expected ')'}} expected-note {{to match this '('}} + foo(); +#pragma omp taskgroup task_reduction(- // expected-warning {{missing ':' after reduction identifier - ignoring}} expected-error {{expected ')'}} expected-note {{to match this '('}} + foo(); +#pragma omp taskgroup task_reduction() // expected-error {{expected unqualified-id}} expected-warning {{missing ':' after reduction identifier - ignoring}} + foo(); +#pragma omp taskgroup task_reduction(*) // expected-warning {{missing ':' after reduction identifier - ignoring}} + foo(); +#pragma omp taskgroup task_reduction(\) // expected-error {{expected unqualified-id}} expected-warning {{missing ':' after reduction identifier - ignoring}} + foo(); +#pragma omp taskgroup task_reduction(foo : argc // expected-error {{expected ')'}} expected-note {{to match this '('}} expected-error {{incorrect reduction identifier, expected one of '+', '-', '*', '&', '|', '^', '&&', '||', 'min' or 'max'}} + foo(); +#pragma omp taskgroup task_reduction(| : argc, // expected-error {{expected expression}} expected-error {{expected ')'}} expected-note {{to match this '('}} + foo(); +#pragma omp taskgroup task_reduction(|| : argc > 0 ? argv[1] : argv[2]) // expected-error {{expected variable name, array element or array section}} + foo(); +#pragma omp taskgroup task_reduction(~ : argc) // expected-error {{expected unqualified-id}} + foo(); +#pragma omp taskgroup task_reduction(&& : argc) + foo(); +#pragma omp taskgroup task_reduction(^ : S1) // expected-error {{'S1' does not refer to a value}} + foo(); +#pragma omp taskgroup task_reduction(+ : a, b, c, d, f) // expected-error {{a reduction list item with incomplete type 'S1'}} expected-error 2 {{const-qualified list item cannot be reduction}} expected-error {{'operator+' is a private member of 'S2'}} + foo(); +#pragma omp taskgroup task_reduction(min : a, b, c, d, f) // expected-error {{a reduction list item with incomplete type 'S1'}} expected-error 2 {{arguments of OpenMP clause 'task_reduction' for 'min' or 'max' must be of arithmetic type}} expected-error 2 {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(max : h.b) // expected-error {{expected variable name, array element or array section}} + foo(); +#pragma omp taskgroup task_reduction(+ : ba) // expected-error {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(* : ca) // expected-error {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(- : da) // expected-error {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(^ : fl) // expected-error {{invalid operands to binary expression ('float' and 'float')}} + foo(); +#pragma omp taskgroup task_reduction(&& : S2::S2s) // expected-error {{shared variable cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(&& : S2::S2sc) // expected-error {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(& : e, g) // expected-error {{calling a private constructor of class 'S4'}} expected-error {{nvalid operands to binary expression ('S4' and 'S4')}} expected-error {{calling a private constructor of class 'S5'}} expected-error {{invalid operands to binary expression ('S5' and 'S5')}} + foo(); +#pragma omp taskgroup task_reduction(+ : h, k, B::x) // expected-error 2 {{threadprivate or thread local variable cannot be reduction}} + foo(); +#pragma omp taskgroup task_reduction(+ : o) // expected-error {{no viable overloaded '='}} + foo(); +#pragma omp parallel private(k) +#pragma omp taskgroup task_reduction(+ : p), task_reduction(+ : p) // expected-error 2 {{argument of OpenMP clause 'task_reduction' must reference the same object in all threads}} + foo(); +#pragma omp taskgroup task_reduction(+ : p), task_reduction(+ : p) // expected-error {{variable can appear only once in OpenMP 'task_reduction' clause}} expected-note {{previously referenced here}} + foo(); +#pragma omp taskgroup task_reduction(+ : r) // expected-error {{const-qualified list item cannot be reduction}} + foo(); +#pragma omp parallel shared(i) +#pragma omp parallel reduction(min : i) +#pragma omp taskgroup task_reduction(max : j) // expected-error {{argument of OpenMP clause 'task_reduction' must reference the same object in all threads}} + foo(); +#pragma omp parallel +#pragma omp for private(fl) + for (int i = 0; i < 10; ++i) +#pragma omp taskgroup task_reduction(+ : fl) + foo(); +#pragma omp parallel +#pragma omp for reduction(- : fl) + for (int i = 0; i < 10; ++i) +#pragma omp taskgroup task_reduction(+ : fl) + foo(); + static int m; +#pragma omp taskgroup task_reduction(+ : m) // OK + m++; + + return tmain(argc) + tmain(fl); // expected-note {{in instantiation of function template specialization 'tmain' requested here}} expected-note {{in instantiation of function template specialization 'tmain' requested here}} +} diff --git a/tools/libclang/CIndex.cpp b/tools/libclang/CIndex.cpp index 869aeac099..d527535a17 100644 --- a/tools/libclang/CIndex.cpp +++ b/tools/libclang/CIndex.cpp @@ -2264,6 +2264,23 @@ void OMPClauseEnqueue::VisitOMPReductionClause(const OMPReductionClause *C) { Visitor->AddStmt(E); } } +void OMPClauseEnqueue::VisitOMPTaskReductionClause( + const OMPTaskReductionClause *C) { + VisitOMPClauseList(C); + VisitOMPClauseWithPostUpdate(C); + for (auto *E : C->privates()) { + Visitor->AddStmt(E); + } + for (auto *E : C->lhs_exprs()) { + Visitor->AddStmt(E); + } + for (auto *E : C->rhs_exprs()) { + Visitor->AddStmt(E); + } + for (auto *E : C->reduction_ops()) { + Visitor->AddStmt(E); + } +} void OMPClauseEnqueue::VisitOMPLinearClause(const OMPLinearClause *C) { VisitOMPClauseList(C); VisitOMPClauseWithPostUpdate(C); -- 2.40.0