]> granicus.if.org Git - clang/commitdiff
[OPENMP] Add support for cancel constructs in `target teams distribute
authorAlexey Bataev <a.bataev@hotmail.com>
Wed, 22 Nov 2017 21:12:03 +0000 (21:12 +0000)
committerAlexey Bataev <a.bataev@hotmail.com>
Wed, 22 Nov 2017 21:12:03 +0000 (21:12 +0000)
parallel for`.

Add support for cancel/cancellation point directives inside `target
teams distribute parallel for` directives.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@318881 91177308-0d34-0410-b5e6-96231b3b80d8

include/clang/AST/StmtOpenMP.h
lib/AST/StmtOpenMP.cpp
lib/CodeGen/CGStmtOpenMP.cpp
lib/Sema/SemaOpenMP.cpp
lib/Serialization/ASTReaderStmt.cpp
lib/Serialization/ASTWriterStmt.cpp
test/OpenMP/target_teams_distribute_parallel_for_ast_print.cpp

index d26c4bdb30f9ae93219c1d9bf1f94e8c6ef4b724..fc03c424ed60a48c4884e6b9c68e155ae238829d 100644 (file)
@@ -957,8 +957,13 @@ public:
            T->getStmtClass() == OMPTargetSimdDirectiveClass ||
            T->getStmtClass() == OMPTeamsDistributeDirectiveClass ||
            T->getStmtClass() == OMPTeamsDistributeSimdDirectiveClass ||
-           T->getStmtClass() == OMPTeamsDistributeParallelForSimdDirectiveClass ||
-           T->getStmtClass() == OMPTeamsDistributeParallelForDirectiveClass;
+           T->getStmtClass() ==
+               OMPTeamsDistributeParallelForSimdDirectiveClass ||
+           T->getStmtClass() == OMPTeamsDistributeParallelForDirectiveClass ||
+           T->getStmtClass() ==
+               OMPTargetTeamsDistributeParallelForDirectiveClass ||
+           T->getStmtClass() ==
+               OMPTargetTeamsDistributeParallelForSimdDirectiveClass;
   }
 };
 
@@ -3799,6 +3804,8 @@ public:
 class OMPTargetTeamsDistributeParallelForDirective final
     : public OMPLoopDirective {
   friend class ASTStmtReader;
+  /// true if the construct has inner cancel directive.
+  bool HasCancel = false;
 
   /// Build directive with the given start and end location.
   ///
@@ -3814,7 +3821,8 @@ class OMPTargetTeamsDistributeParallelForDirective final
       : OMPLoopDirective(this,
                          OMPTargetTeamsDistributeParallelForDirectiveClass,
                          OMPD_target_teams_distribute_parallel_for, StartLoc,
-                         EndLoc, CollapsedNum, NumClauses) {}
+                         EndLoc, CollapsedNum, NumClauses),
+        HasCancel(false) {}
 
   /// Build an empty directive.
   ///
@@ -3826,7 +3834,11 @@ class OMPTargetTeamsDistributeParallelForDirective final
       : OMPLoopDirective(
             this, OMPTargetTeamsDistributeParallelForDirectiveClass,
             OMPD_target_teams_distribute_parallel_for, SourceLocation(),
-            SourceLocation(), CollapsedNum, NumClauses) {}
+            SourceLocation(), CollapsedNum, NumClauses),
+        HasCancel(false) {}
+
+  /// Set cancel state.
+  void setHasCancel(bool Has) { HasCancel = Has; }
 
 public:
   /// Creates directive with a list of \a Clauses.
@@ -3838,11 +3850,12 @@ public:
   /// \param Clauses List of clauses.
   /// \param AssociatedStmt Statement, associated with the directive.
   /// \param Exprs Helper expressions for CodeGen.
+  /// \param HasCancel true if this directive has inner cancel directive.
   ///
   static OMPTargetTeamsDistributeParallelForDirective *
   Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
          unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses,
-         Stmt *AssociatedStmt, const HelperExprs &Exprs);
+         Stmt *AssociatedStmt, const HelperExprs &Exprs, bool HasCancel);
 
   /// Creates an empty directive with the place for \a NumClauses clauses.
   ///
@@ -3854,6 +3867,9 @@ public:
   CreateEmpty(const ASTContext &C, unsigned NumClauses, unsigned CollapsedNum,
               EmptyShell);
 
+  /// Return true if current directive has inner cancel directive.
+  bool hasCancel() const { return HasCancel; }
+
   static bool classof(const Stmt *T) {
     return T->getStmtClass() ==
            OMPTargetTeamsDistributeParallelForDirectiveClass;
index c0e9e72cfbe2df26a0f6202e4413705564ad533f..87bf5aaaa585dafa3c54d227b62ddb9777e0ff82 100644 (file)
@@ -1624,7 +1624,7 @@ OMPTargetTeamsDistributeParallelForDirective *
 OMPTargetTeamsDistributeParallelForDirective::Create(
     const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
     unsigned CollapsedNum, ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
-    const HelperExprs &Exprs) {
+    const HelperExprs &Exprs, bool HasCancel) {
   auto Size =
       llvm::alignTo(sizeof(OMPTargetTeamsDistributeParallelForDirective),
                     alignof(OMPClause *));
@@ -1670,6 +1670,7 @@ OMPTargetTeamsDistributeParallelForDirective::Create(
   Dir->setCombinedCond(Exprs.DistCombinedFields.Cond);
   Dir->setCombinedNextLowerBound(Exprs.DistCombinedFields.NLB);
   Dir->setCombinedNextUpperBound(Exprs.DistCombinedFields.NUB);
+  Dir->HasCancel = HasCancel;
   return Dir;
 }
 
index 7942189b87e809bd899f13f5f673f92b3a1ede96..af9f7e2eff51e51ad906ba34a76be0edc78d25af 100644 (file)
@@ -2014,6 +2014,9 @@ emitInnerParallelForWhenCombined(CodeGenFunction &CGF,
         HasCancel = D->hasCancel();
       else if (const auto *D = dyn_cast<OMPDistributeParallelForDirective>(&S))
         HasCancel = D->hasCancel();
+      else if (const auto *D =
+                   dyn_cast<OMPTargetTeamsDistributeParallelForDirective>(&S))
+        HasCancel = D->hasCancel();
     }
     CodeGenFunction::OMPCancelStackRAII CancelRegion(CGF, S.getDirectiveKind(),
                                                      HasCancel);
@@ -3949,7 +3952,8 @@ CodeGenFunction::getOMPCancelDestination(OpenMPDirectiveKind Kind) {
          Kind == OMPD_parallel_sections || Kind == OMPD_parallel_for ||
          Kind == OMPD_distribute_parallel_for ||
          Kind == OMPD_target_parallel_for ||
-         Kind == OMPD_teams_distribute_parallel_for);
+         Kind == OMPD_teams_distribute_parallel_for ||
+         Kind == OMPD_target_teams_distribute_parallel_for);
   return OMPCancelStack.getExitBlock();
 }
 
index 939d75bddc6a750b382e2b1808609982cc339608..985cc62df09bd9f4a1dc76a7bb7b701b29d9323b 100644 (file)
@@ -2593,7 +2593,8 @@ static bool checkNestingOfRegions(Sema &SemaRef, DSAStackTy *Stack,
              (ParentRegion == OMPD_for || ParentRegion == OMPD_parallel_for ||
               ParentRegion == OMPD_target_parallel_for ||
               ParentRegion == OMPD_distribute_parallel_for ||
-              ParentRegion == OMPD_teams_distribute_parallel_for)) ||
+              ParentRegion == OMPD_teams_distribute_parallel_for ||
+              ParentRegion == OMPD_target_teams_distribute_parallel_for)) ||
             (CancelRegion == OMPD_taskgroup && ParentRegion == OMPD_task) ||
             (CancelRegion == OMPD_sections &&
              (ParentRegion == OMPD_section || ParentRegion == OMPD_sections ||
@@ -7324,7 +7325,8 @@ StmtResult Sema::ActOnOpenMPTargetTeamsDistributeParallelForDirective(
 
   getCurFunction()->setHasBranchProtectedScope();
   return OMPTargetTeamsDistributeParallelForDirective::Create(
-      Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B);
+      Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B,
+      DSAStack->isCancelRegion());
 }
 
 StmtResult Sema::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective(
index a94b2e78e6d609f41349cabc3bd13bb14ed2e80a..8ef1491eb2da7307650e231ec798b6283fda2028 100644 (file)
@@ -2978,6 +2978,7 @@ void ASTStmtReader::VisitOMPTargetTeamsDistributeDirective(
 void ASTStmtReader::VisitOMPTargetTeamsDistributeParallelForDirective(
     OMPTargetTeamsDistributeParallelForDirective *D) {
   VisitOMPLoopDirective(D);
+  D->setHasCancel(Record.readInt());
 }
 
 void ASTStmtReader::VisitOMPTargetTeamsDistributeParallelForSimdDirective(
index a695dbd15ebe3a320ccaaf795444712d0ab37ff4..c5f4495d2f01341ac4ca63cfaa3f32168084de19 100644 (file)
@@ -2636,6 +2636,7 @@ void ASTStmtWriter::VisitOMPTargetTeamsDistributeDirective(
 void ASTStmtWriter::VisitOMPTargetTeamsDistributeParallelForDirective(
     OMPTargetTeamsDistributeParallelForDirective *D) {
   VisitOMPLoopDirective(D);
+  Record.push_back(D->hasCancel() ? 1 : 0);
   Code = serialization::STMT_OMP_TARGET_TEAMS_DISTRIBUTE_PARALLEL_FOR_DIRECTIVE;
 }
 
index 8df3c972dce061cfd2b45cf22fcb1b6f77ea6da9..de8d630d49a89ecef6de030d54a116a1d183df65 100644 (file)
@@ -24,8 +24,10 @@ protected:
 public:
   S7(typename T::type v) : a(v) {
 #pragma omp target teams distribute parallel for private(a) private(this->a) private(T::a)
-    for (int k = 0; k < a.a; ++k)
+    for (int k = 0; k < a.a; ++k) {
       ++this->a.a;
+#pragma omp cancel for
+    }
   }
   S7 &operator=(S7 &s) {
 #pragma omp target teams distribute parallel for private(a) private(this->a)
@@ -43,6 +45,7 @@ public:
   }
 };
 // CHECK: #pragma omp target teams distribute parallel for private(this->a) private(this->a) private(T::a)
+// CHECK: #pragma omp cancel for
 // CHECK: #pragma omp target teams distribute parallel for private(this->a) private(this->a)
 // CHECK: #pragma omp target teams distribute parallel for default(none) private(b) firstprivate(argv) shared(d) reduction(+: c) reduction(max: e) num_teams(f) thread_limit(d)