From: Alexey Bataev Date: Wed, 22 Nov 2017 21:12:03 +0000 (+0000) Subject: [OPENMP] Add support for cancel constructs in `target teams distribute X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=13e60aabcadd6758402f321fb1957f1bba8012ee;p=clang [OPENMP] Add support for cancel constructs in `target teams distribute parallel for`. Add support for cancel/cancellation point directives inside `target teams distribute parallel for` directives. git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@318881 91177308-0d34-0410-b5e6-96231b3b80d8 --- diff --git a/include/clang/AST/StmtOpenMP.h b/include/clang/AST/StmtOpenMP.h index d26c4bdb30..fc03c424ed 100644 --- a/include/clang/AST/StmtOpenMP.h +++ b/include/clang/AST/StmtOpenMP.h @@ -957,8 +957,13 @@ public: T->getStmtClass() == OMPTargetSimdDirectiveClass || T->getStmtClass() == OMPTeamsDistributeDirectiveClass || T->getStmtClass() == OMPTeamsDistributeSimdDirectiveClass || - T->getStmtClass() == OMPTeamsDistributeParallelForSimdDirectiveClass || - T->getStmtClass() == OMPTeamsDistributeParallelForDirectiveClass; + T->getStmtClass() == + OMPTeamsDistributeParallelForSimdDirectiveClass || + T->getStmtClass() == OMPTeamsDistributeParallelForDirectiveClass || + T->getStmtClass() == + OMPTargetTeamsDistributeParallelForDirectiveClass || + T->getStmtClass() == + OMPTargetTeamsDistributeParallelForSimdDirectiveClass; } }; @@ -3799,6 +3804,8 @@ public: class OMPTargetTeamsDistributeParallelForDirective final : public OMPLoopDirective { friend class ASTStmtReader; + /// true if the construct has inner cancel directive. + bool HasCancel = false; /// Build directive with the given start and end location. /// @@ -3814,7 +3821,8 @@ class OMPTargetTeamsDistributeParallelForDirective final : OMPLoopDirective(this, OMPTargetTeamsDistributeParallelForDirectiveClass, OMPD_target_teams_distribute_parallel_for, StartLoc, - EndLoc, CollapsedNum, NumClauses) {} + EndLoc, CollapsedNum, NumClauses), + HasCancel(false) {} /// Build an empty directive. /// @@ -3826,7 +3834,11 @@ class OMPTargetTeamsDistributeParallelForDirective final : OMPLoopDirective( this, OMPTargetTeamsDistributeParallelForDirectiveClass, OMPD_target_teams_distribute_parallel_for, SourceLocation(), - SourceLocation(), CollapsedNum, NumClauses) {} + SourceLocation(), CollapsedNum, NumClauses), + HasCancel(false) {} + + /// Set cancel state. + void setHasCancel(bool Has) { HasCancel = Has; } public: /// Creates directive with a list of \a Clauses. @@ -3838,11 +3850,12 @@ public: /// \param Clauses List of clauses. /// \param AssociatedStmt Statement, associated with the directive. /// \param Exprs Helper expressions for CodeGen. + /// \param HasCancel true if this directive has inner cancel directive. /// static OMPTargetTeamsDistributeParallelForDirective * Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, unsigned CollapsedNum, ArrayRef Clauses, - Stmt *AssociatedStmt, const HelperExprs &Exprs); + Stmt *AssociatedStmt, const HelperExprs &Exprs, bool HasCancel); /// Creates an empty directive with the place for \a NumClauses clauses. /// @@ -3854,6 +3867,9 @@ public: CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned CollapsedNum, EmptyShell); + /// Return true if current directive has inner cancel directive. + bool hasCancel() const { return HasCancel; } + static bool classof(const Stmt *T) { return T->getStmtClass() == OMPTargetTeamsDistributeParallelForDirectiveClass; diff --git a/lib/AST/StmtOpenMP.cpp b/lib/AST/StmtOpenMP.cpp index c0e9e72cfb..87bf5aaaa5 100644 --- a/lib/AST/StmtOpenMP.cpp +++ b/lib/AST/StmtOpenMP.cpp @@ -1624,7 +1624,7 @@ OMPTargetTeamsDistributeParallelForDirective * OMPTargetTeamsDistributeParallelForDirective::Create( const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, unsigned CollapsedNum, ArrayRef Clauses, Stmt *AssociatedStmt, - const HelperExprs &Exprs) { + const HelperExprs &Exprs, bool HasCancel) { auto Size = llvm::alignTo(sizeof(OMPTargetTeamsDistributeParallelForDirective), alignof(OMPClause *)); @@ -1670,6 +1670,7 @@ OMPTargetTeamsDistributeParallelForDirective::Create( Dir->setCombinedCond(Exprs.DistCombinedFields.Cond); Dir->setCombinedNextLowerBound(Exprs.DistCombinedFields.NLB); Dir->setCombinedNextUpperBound(Exprs.DistCombinedFields.NUB); + Dir->HasCancel = HasCancel; return Dir; } diff --git a/lib/CodeGen/CGStmtOpenMP.cpp b/lib/CodeGen/CGStmtOpenMP.cpp index 7942189b87..af9f7e2eff 100644 --- a/lib/CodeGen/CGStmtOpenMP.cpp +++ b/lib/CodeGen/CGStmtOpenMP.cpp @@ -2014,6 +2014,9 @@ emitInnerParallelForWhenCombined(CodeGenFunction &CGF, HasCancel = D->hasCancel(); else if (const auto *D = dyn_cast(&S)) HasCancel = D->hasCancel(); + else if (const auto *D = + dyn_cast(&S)) + HasCancel = D->hasCancel(); } CodeGenFunction::OMPCancelStackRAII CancelRegion(CGF, S.getDirectiveKind(), HasCancel); @@ -3949,7 +3952,8 @@ CodeGenFunction::getOMPCancelDestination(OpenMPDirectiveKind Kind) { Kind == OMPD_parallel_sections || Kind == OMPD_parallel_for || Kind == OMPD_distribute_parallel_for || Kind == OMPD_target_parallel_for || - Kind == OMPD_teams_distribute_parallel_for); + Kind == OMPD_teams_distribute_parallel_for || + Kind == OMPD_target_teams_distribute_parallel_for); return OMPCancelStack.getExitBlock(); } diff --git a/lib/Sema/SemaOpenMP.cpp b/lib/Sema/SemaOpenMP.cpp index 939d75bddc..985cc62df0 100644 --- a/lib/Sema/SemaOpenMP.cpp +++ b/lib/Sema/SemaOpenMP.cpp @@ -2593,7 +2593,8 @@ static bool checkNestingOfRegions(Sema &SemaRef, DSAStackTy *Stack, (ParentRegion == OMPD_for || ParentRegion == OMPD_parallel_for || ParentRegion == OMPD_target_parallel_for || ParentRegion == OMPD_distribute_parallel_for || - ParentRegion == OMPD_teams_distribute_parallel_for)) || + ParentRegion == OMPD_teams_distribute_parallel_for || + ParentRegion == OMPD_target_teams_distribute_parallel_for)) || (CancelRegion == OMPD_taskgroup && ParentRegion == OMPD_task) || (CancelRegion == OMPD_sections && (ParentRegion == OMPD_section || ParentRegion == OMPD_sections || @@ -7324,7 +7325,8 @@ StmtResult Sema::ActOnOpenMPTargetTeamsDistributeParallelForDirective( getCurFunction()->setHasBranchProtectedScope(); return OMPTargetTeamsDistributeParallelForDirective::Create( - Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); + Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B, + DSAStack->isCancelRegion()); } StmtResult Sema::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective( diff --git a/lib/Serialization/ASTReaderStmt.cpp b/lib/Serialization/ASTReaderStmt.cpp index a94b2e78e6..8ef1491eb2 100644 --- a/lib/Serialization/ASTReaderStmt.cpp +++ b/lib/Serialization/ASTReaderStmt.cpp @@ -2978,6 +2978,7 @@ void ASTStmtReader::VisitOMPTargetTeamsDistributeDirective( void ASTStmtReader::VisitOMPTargetTeamsDistributeParallelForDirective( OMPTargetTeamsDistributeParallelForDirective *D) { VisitOMPLoopDirective(D); + D->setHasCancel(Record.readInt()); } void ASTStmtReader::VisitOMPTargetTeamsDistributeParallelForSimdDirective( diff --git a/lib/Serialization/ASTWriterStmt.cpp b/lib/Serialization/ASTWriterStmt.cpp index a695dbd15e..c5f4495d2f 100644 --- a/lib/Serialization/ASTWriterStmt.cpp +++ b/lib/Serialization/ASTWriterStmt.cpp @@ -2636,6 +2636,7 @@ void ASTStmtWriter::VisitOMPTargetTeamsDistributeDirective( void ASTStmtWriter::VisitOMPTargetTeamsDistributeParallelForDirective( OMPTargetTeamsDistributeParallelForDirective *D) { VisitOMPLoopDirective(D); + Record.push_back(D->hasCancel() ? 1 : 0); Code = serialization::STMT_OMP_TARGET_TEAMS_DISTRIBUTE_PARALLEL_FOR_DIRECTIVE; } diff --git a/test/OpenMP/target_teams_distribute_parallel_for_ast_print.cpp b/test/OpenMP/target_teams_distribute_parallel_for_ast_print.cpp index 8df3c972dc..de8d630d49 100644 --- a/test/OpenMP/target_teams_distribute_parallel_for_ast_print.cpp +++ b/test/OpenMP/target_teams_distribute_parallel_for_ast_print.cpp @@ -24,8 +24,10 @@ protected: public: S7(typename T::type v) : a(v) { #pragma omp target teams distribute parallel for private(a) private(this->a) private(T::a) - for (int k = 0; k < a.a; ++k) + for (int k = 0; k < a.a; ++k) { ++this->a.a; +#pragma omp cancel for + } } S7 &operator=(S7 &s) { #pragma omp target teams distribute parallel for private(a) private(this->a) @@ -43,6 +45,7 @@ public: } }; // CHECK: #pragma omp target teams distribute parallel for private(this->a) private(this->a) private(T::a) +// CHECK: #pragma omp cancel for // CHECK: #pragma omp target teams distribute parallel for private(this->a) private(this->a) // CHECK: #pragma omp target teams distribute parallel for default(none) private(b) firstprivate(argv) shared(d) reduction(+: c) reduction(max: e) num_teams(f) thread_limit(d)