From bd6b181280fd87f99e3b096264683713465ccac7 Mon Sep 17 00:00:00 2001 From: Alexey Bataev Date: Tue, 15 Oct 2019 19:37:05 +0000 Subject: [PATCH] [OPENMP]Allow final clause in combined task-based directives. The condition of the final clause must be captured in the combined task-based directives, like 'parallel master taskloop' directive. git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@374942 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/clang/AST/OpenMPClause.h | 27 ++++++++++++------- include/clang/AST/RecursiveASTVisitor.h | 1 + lib/AST/OpenMPClause.cpp | 9 ++++++- lib/AST/StmtProfile.cpp | 1 + lib/Sema/SemaOpenMP.cpp | 23 +++++++++++++--- lib/Serialization/ASTReader.cpp | 1 + lib/Serialization/ASTWriter.cpp | 1 + .../parallel_master_taskloop_ast_print.cpp | 4 +-- .../parallel_master_taskloop_codegen.cpp | 15 +++++++---- 9 files changed, 61 insertions(+), 21 deletions(-) diff --git a/include/clang/AST/OpenMPClause.h b/include/clang/AST/OpenMPClause.h index 911c1cfb77..346b3e4a29 100644 --- a/include/clang/AST/OpenMPClause.h +++ b/include/clang/AST/OpenMPClause.h @@ -519,7 +519,7 @@ public: /// \endcode /// In this example directive '#pragma omp task' has simple 'final' /// clause with condition 'a > 5'. -class OMPFinalClause : public OMPClause { +class OMPFinalClause : public OMPClause, public OMPClauseWithPreInit { friend class OMPClauseReader; /// Location of '('. @@ -534,18 +534,26 @@ class OMPFinalClause : public OMPClause { public: /// Build 'final' clause with condition \a Cond. /// + /// \param Cond final condition. + /// \param HelperCond Helper condition for the construct. + /// \param CaptureRegion Innermost OpenMP region where expressions in this + /// clause must be captured. /// \param StartLoc Starting location of the clause. /// \param LParenLoc Location of '('. /// \param Cond Condition of the clause. /// \param EndLoc Ending location of the clause. - OMPFinalClause(Expr *Cond, SourceLocation StartLoc, SourceLocation LParenLoc, - SourceLocation EndLoc) - : OMPClause(OMPC_final, StartLoc, EndLoc), LParenLoc(LParenLoc), - Condition(Cond) {} + OMPFinalClause(Expr *Cond, Stmt *HelperSize, + OpenMPDirectiveKind CaptureRegion, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc) + : OMPClause(OMPC_final, StartLoc, EndLoc), OMPClauseWithPreInit(this), + LParenLoc(LParenLoc), Condition(Cond) { + setPreInitStmt(HelperSize, CaptureRegion); + } /// Build an empty clause. OMPFinalClause() - : OMPClause(OMPC_final, SourceLocation(), SourceLocation()) {} + : OMPClause(OMPC_final, SourceLocation(), SourceLocation()), + OMPClauseWithPreInit(this) {} /// Sets the location of '('. void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } @@ -562,11 +570,10 @@ public: return const_child_range(&Condition, &Condition + 1); } - child_range used_children() { - return child_range(child_iterator(), child_iterator()); - } + child_range used_children(); const_child_range used_children() const { - return const_child_range(const_child_iterator(), const_child_iterator()); + auto Children = const_cast(this)->used_children(); + return const_child_range(Children.begin(), Children.end()); } static bool classof(const OMPClause *T) { diff --git a/include/clang/AST/RecursiveASTVisitor.h b/include/clang/AST/RecursiveASTVisitor.h index 3a21034057..1de458dccb 100644 --- a/include/clang/AST/RecursiveASTVisitor.h +++ b/include/clang/AST/RecursiveASTVisitor.h @@ -2906,6 +2906,7 @@ bool RecursiveASTVisitor::VisitOMPIfClause(OMPIfClause *C) { template bool RecursiveASTVisitor::VisitOMPFinalClause(OMPFinalClause *C) { + TRY_TO(VisitOMPClauseWithPreInit(C)); TRY_TO(TraverseStmt(C->getCondition())); return true; } diff --git a/lib/AST/OpenMPClause.cpp b/lib/AST/OpenMPClause.cpp index c1aeaf6a6b..ccb3cd3436 100644 --- a/lib/AST/OpenMPClause.cpp +++ b/lib/AST/OpenMPClause.cpp @@ -88,9 +88,10 @@ const OMPClauseWithPreInit *OMPClauseWithPreInit::get(const OMPClause *C) { return static_cast(C); case OMPC_num_tasks: return static_cast(C); + case OMPC_final: + return static_cast(C); case OMPC_default: case OMPC_proc_bind: - case OMPC_final: case OMPC_safelen: case OMPC_simdlen: case OMPC_allocator: @@ -248,6 +249,12 @@ OMPClause::child_range OMPNumTasksClause::used_children() { return child_range(&NumTasks, &NumTasks + 1); } +OMPClause::child_range OMPFinalClause::used_children() { + if (Stmt **C = getAddrOfExprAsWritten(getPreInitStmt())) + return child_range(C, C + 1); + return child_range(&Condition, &Condition + 1); +} + OMPOrderedClause *OMPOrderedClause::Create(const ASTContext &C, Expr *Num, unsigned NumLoops, SourceLocation StartLoc, diff --git a/lib/AST/StmtProfile.cpp b/lib/AST/StmtProfile.cpp index 82bb4b86d9..e536abfe35 100644 --- a/lib/AST/StmtProfile.cpp +++ b/lib/AST/StmtProfile.cpp @@ -440,6 +440,7 @@ void OMPClauseProfiler::VisitOMPIfClause(const OMPIfClause *C) { } void OMPClauseProfiler::VisitOMPFinalClause(const OMPFinalClause *C) { + VistOMPClauseWithPreInit(C); if (C->getCondition()) Profiler->VisitStmt(C->getCondition()); } diff --git a/lib/Sema/SemaOpenMP.cpp b/lib/Sema/SemaOpenMP.cpp index 4d6ff009c5..b3f711bc7f 100644 --- a/lib/Sema/SemaOpenMP.cpp +++ b/lib/Sema/SemaOpenMP.cpp @@ -4600,6 +4600,11 @@ StmtResult Sema::ActOnOpenMPExecutableDirective( if (isOpenMPParallelDirective(DSAStack->getCurrentDirective())) break; continue; + case OMPC_final: + // Do not analyze if no parent parallel directive. + if (isOpenMPParallelDirective(DSAStack->getCurrentDirective())) + break; + continue; case OMPC_ordered: case OMPC_device: case OMPC_num_teams: @@ -4609,7 +4614,6 @@ StmtResult Sema::ActOnOpenMPExecutableDirective( case OMPC_collapse: case OMPC_safelen: case OMPC_simdlen: - case OMPC_final: case OMPC_default: case OMPC_proc_bind: case OMPC_private: @@ -10783,6 +10787,7 @@ static OpenMPDirectiveKind getOpenMPCaptureRegionForClause( break; case OMPC_grainsize: case OMPC_num_tasks: + case OMPC_final: switch (DKind) { case OMPD_task: case OMPD_taskloop: @@ -10858,7 +10863,6 @@ static OpenMPDirectiveKind getOpenMPCaptureRegionForClause( case OMPC_linear: case OMPC_default: case OMPC_proc_bind: - case OMPC_final: case OMPC_safelen: case OMPC_simdlen: case OMPC_allocator: @@ -10945,6 +10949,8 @@ OMPClause *Sema::ActOnOpenMPFinalClause(Expr *Condition, SourceLocation LParenLoc, SourceLocation EndLoc) { Expr *ValExpr = Condition; + Stmt *HelperValStmt = nullptr; + OpenMPDirectiveKind CaptureRegion = OMPD_unknown; if (!Condition->isValueDependent() && !Condition->isTypeDependent() && !Condition->isInstantiationDependent() && !Condition->containsUnexpandedParameterPack()) { @@ -10953,10 +10959,21 @@ OMPClause *Sema::ActOnOpenMPFinalClause(Expr *Condition, return nullptr; ValExpr = MakeFullExpr(Val.get()).get(); + + OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective(); + CaptureRegion = getOpenMPCaptureRegionForClause(DKind, OMPC_final); + if (CaptureRegion != OMPD_unknown && !CurContext->isDependentContext()) { + ValExpr = MakeFullExpr(ValExpr).get(); + llvm::MapVector Captures; + ValExpr = tryBuildCapture(*this, ValExpr, Captures).get(); + HelperValStmt = buildPreInits(Context, Captures); + } } - return new (Context) OMPFinalClause(ValExpr, StartLoc, LParenLoc, EndLoc); + return new (Context) OMPFinalClause(ValExpr, HelperValStmt, CaptureRegion, + StartLoc, LParenLoc, EndLoc); } + ExprResult Sema::PerformOpenMPImplicitIntegerConversion(SourceLocation Loc, Expr *Op) { if (!Op) diff --git a/lib/Serialization/ASTReader.cpp b/lib/Serialization/ASTReader.cpp index d8790768cb..dc653039f5 100644 --- a/lib/Serialization/ASTReader.cpp +++ b/lib/Serialization/ASTReader.cpp @@ -12499,6 +12499,7 @@ void OMPClauseReader::VisitOMPIfClause(OMPIfClause *C) { } void OMPClauseReader::VisitOMPFinalClause(OMPFinalClause *C) { + VisitOMPClauseWithPreInit(C); C->setCondition(Record.readSubExpr()); C->setLParenLoc(Record.readSourceLocation()); } diff --git a/lib/Serialization/ASTWriter.cpp b/lib/Serialization/ASTWriter.cpp index aef3523116..1e76c328a1 100644 --- a/lib/Serialization/ASTWriter.cpp +++ b/lib/Serialization/ASTWriter.cpp @@ -6639,6 +6639,7 @@ void OMPClauseWriter::VisitOMPIfClause(OMPIfClause *C) { } void OMPClauseWriter::VisitOMPFinalClause(OMPFinalClause *C) { + VisitOMPClauseWithPreInit(C); Record.AddStmt(C->getCondition()); Record.AddSourceLocation(C->getLParenLoc()); } diff --git a/test/OpenMP/parallel_master_taskloop_ast_print.cpp b/test/OpenMP/parallel_master_taskloop_ast_print.cpp index 23cf67c0a5..b151380935 100644 --- a/test/OpenMP/parallel_master_taskloop_ast_print.cpp +++ b/test/OpenMP/parallel_master_taskloop_ast_print.cpp @@ -60,9 +60,9 @@ int main(int argc, char **argv) { static int a; // CHECK: static int a; #pragma omp taskgroup task_reduction(+: d) -#pragma omp parallel master taskloop if(parallel: a) default(none) shared(a, argc) final(b) priority(5) num_tasks(argc) reduction(*: g) +#pragma omp parallel master taskloop if(parallel: a) default(none) shared(a, b, argc) final(b) priority(5) num_tasks(argc) reduction(*: g) // CHECK-NEXT: #pragma omp taskgroup task_reduction(+: d) - // CHECK-NEXT: #pragma omp parallel master taskloop if(parallel: a) default(none) shared(a,argc) final(b) priority(5) num_tasks(argc) reduction(*: g) + // CHECK-NEXT: #pragma omp parallel master taskloop if(parallel: a) default(none) shared(a,b,argc) final(b) priority(5) num_tasks(argc) reduction(*: g) for (int i = 0; i < 2; ++i) a = 2; // CHECK-NEXT: for (int i = 0; i < 2; ++i) diff --git a/test/OpenMP/parallel_master_taskloop_codegen.cpp b/test/OpenMP/parallel_master_taskloop_codegen.cpp index 289687cff2..70ecfcd5f8 100644 --- a/test/OpenMP/parallel_master_taskloop_codegen.cpp +++ b/test/OpenMP/parallel_master_taskloop_codegen.cpp @@ -187,10 +187,15 @@ int main(int argc, char **argv) { struct S { int a; S(int c) { -// CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEFLOC]], i32 2, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, %struct.S*, i32*)* [[OMP_OUTLINED4:@.+]] to void (i32*, i32*, ...)*), %struct.S* %{{.+}}, i32* %{{.+}}) - -// CHECK: define internal void [[OMP_OUTLINED4]](i32* noalias %{{.+}}, i32* noalias %{{.+}}, %struct.S* %{{.+}}, i32* dereferenceable(4) %{{.+}}) -// CHECK: [[TASKV:%.+]] = call i8* @__kmpc_omp_task_alloc(%struct.ident_t* [[DEFLOC]], i32 [[GTID:%.+]], i32 1, i64 80, i64 16, i32 (i32, i8*)* bitcast (i32 (i32, [[TDP_TY:%.+]]*)* [[TASK4:@.+]] to i32 (i32, i8*)*)) +// CHECK: call void (%struct.ident_t*, i32, void (i32*, i32*, ...)*, ...) @__kmpc_fork_call(%struct.ident_t* [[DEFLOC]], i32 3, void (i32*, i32*, ...)* bitcast (void (i32*, i32*, %struct.S*, i32*, i64)* [[OMP_OUTLINED4:@.+]] to void (i32*, i32*, ...)*), %struct.S* %{{.+}}, i32* %{{.+}}, i64 %{{.+}}) + +// CHECK: define internal void [[OMP_OUTLINED4]](i32* noalias %{{.+}}, i32* noalias %{{.+}}, %struct.S* %{{.+}}, i32* dereferenceable(4) %{{.+}}, i64 %{{.+}}) +// CHECK: [[CONV:%.+]] = bitcast i64* %{{.+}} to i8* +// CHECK: [[CONDI8:%.+]] = load i8, i8* [[CONV]], +// CHECK: [[COND:%.+]] = trunc i8 [[CONDI8]] to i1 +// CHECK: [[IS_FINAL:%.+]] = select i1 [[COND:%.+]], i32 2, i32 0 +// CHECK: [[FLAGS:%.+]] = or i32 [[IS_FINAL]], 1 +// CHECK: [[TASKV:%.+]] = call i8* @__kmpc_omp_task_alloc(%struct.ident_t* [[DEFLOC]], i32 [[GTID:%.+]], i32 [[FLAGS]], i64 80, i64 16, i32 (i32, i8*)* bitcast (i32 (i32, [[TDP_TY:%.+]]*)* [[TASK4:@.+]] to i32 (i32, i8*)*)) // CHECK: [[TASK:%.+]] = bitcast i8* [[TASKV]] to [[TDP_TY]]* // CHECK: [[TASK_DATA:%.+]] = getelementptr inbounds [[TDP_TY]], [[TDP_TY]]* [[TASK]], i32 0, i32 0 // CHECK: [[DOWN:%.+]] = getelementptr inbounds [[TD_TY:%.+]], [[TD_TY]]* [[TASK_DATA]], i32 0, i32 5 @@ -201,7 +206,7 @@ struct S { // CHECK: store i64 1, i64* [[ST]], // CHECK: [[ST_VAL:%.+]] = load i64, i64* [[ST]], // CHECK: call void @__kmpc_taskloop(%struct.ident_t* [[DEFLOC]], i32 [[GTID]], i8* [[TASKV]], i32 1, i64* [[DOWN]], i64* [[UP]], i64 [[ST_VAL]], i32 1, i32 2, i64 4, i8* null) -#pragma omp parallel master taskloop shared(c) num_tasks(4) +#pragma omp parallel master taskloop shared(c) num_tasks(4) final(c) for (a = 0; a < c; ++a) ; } -- 2.40.0