From 284b3cbfd1cd5b6585437fbb8b95fe136f273efb Mon Sep 17 00:00:00 2001 From: Richard Smith Date: Sun, 12 May 2013 17:32:42 +0000 Subject: [PATCH] C++1y: support for 'switch' statements in constexpr functions. This is somewhat inefficient; we perform a linear scan of switch labels to find the one matching the condition, and then walk the body looking for that label. Both parts should be straightforward to optimize. git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@181671 91177308-0d34-0410-b5e6-96231b3b80d8 --- include/clang/AST/Stmt.h | 3 - lib/AST/ExprConstant.cpp | 151 +++++++++++++++++++-- test/SemaCXX/constant-expression-cxx1y.cpp | 114 ++++++++++++++++ 3 files changed, 256 insertions(+), 12 deletions(-) diff --git a/include/clang/AST/Stmt.h b/include/clang/AST/Stmt.h index 74c9ec2053..b2ab672627 100644 --- a/include/clang/AST/Stmt.h +++ b/include/clang/AST/Stmt.h @@ -967,9 +967,6 @@ public: SwitchCase *getSwitchCaseList() { return FirstCase; } /// \brief Set the case list for this switch statement. - /// - /// The caller is responsible for incrementing the retain counts on - /// all of the SwitchCase statements in this list. void setSwitchCaseList(SwitchCase *SC) { FirstCase = SC; } SourceLocation getSwitchLoc() const { return SwitchLoc; } diff --git a/lib/AST/ExprConstant.cpp b/lib/AST/ExprConstant.cpp index bf967961aa..339e78bbe9 100644 --- a/lib/AST/ExprConstant.cpp +++ b/lib/AST/ExprConstant.cpp @@ -2769,7 +2769,9 @@ enum EvalStmtResult { /// Hit a 'continue' statement. ESR_Continue, /// Hit a 'break' statement. - ESR_Break + ESR_Break, + /// Still scanning for 'case' or 'default' statement. + ESR_CaseNotFound }; } @@ -2803,12 +2805,13 @@ static bool EvaluateCond(EvalInfo &Info, const VarDecl *CondDecl, } static EvalStmtResult EvaluateStmt(APValue &Result, EvalInfo &Info, - const Stmt *S); + const Stmt *S, const SwitchCase *SC = 0); /// Evaluate the body of a loop, and translate the result as appropriate. static EvalStmtResult EvaluateLoopBody(APValue &Result, EvalInfo &Info, - const Stmt *Body) { - switch (EvalStmtResult ESR = EvaluateStmt(Result, Info, Body)) { + const Stmt *Body, + const SwitchCase *Case = 0) { + switch (EvalStmtResult ESR = EvaluateStmt(Result, Info, Body, Case)) { case ESR_Break: return ESR_Succeeded; case ESR_Succeeded: @@ -2816,17 +2819,128 @@ static EvalStmtResult EvaluateLoopBody(APValue &Result, EvalInfo &Info, return ESR_Continue; case ESR_Failed: case ESR_Returned: + case ESR_CaseNotFound: return ESR; } llvm_unreachable("Invalid EvalStmtResult!"); } +/// Evaluate a switch statement. +static EvalStmtResult EvaluateSwitch(APValue &Result, EvalInfo &Info, + const SwitchStmt *SS) { + // Evaluate the switch condition. + if (SS->getConditionVariable() && + !EvaluateDecl(Info, SS->getConditionVariable())) + return ESR_Failed; + APSInt Value; + if (!EvaluateInteger(SS->getCond(), Value, Info)) + return ESR_Failed; + + // Find the switch case corresponding to the value of the condition. + // FIXME: Cache this lookup. + const SwitchCase *Found = 0; + for (const SwitchCase *SC = SS->getSwitchCaseList(); SC; + SC = SC->getNextSwitchCase()) { + if (isa(SC)) { + Found = SC; + continue; + } + + const CaseStmt *CS = cast(SC); + APSInt LHS = CS->getLHS()->EvaluateKnownConstInt(Info.Ctx); + APSInt RHS = CS->getRHS() ? CS->getRHS()->EvaluateKnownConstInt(Info.Ctx) + : LHS; + if (LHS <= Value && Value <= RHS) { + Found = SC; + break; + } + } + + if (!Found) + return ESR_Succeeded; + + // Search the switch body for the switch case and evaluate it from there. + switch (EvalStmtResult ESR = EvaluateStmt(Result, Info, SS->getBody(), Found)) { + case ESR_Break: + return ESR_Succeeded; + case ESR_Succeeded: + case ESR_Continue: + case ESR_Failed: + case ESR_Returned: + return ESR; + case ESR_CaseNotFound: + Found->dump(); + SS->getBody()->dump(); + llvm_unreachable("couldn't find switch case"); + } +} + // Evaluate a statement. static EvalStmtResult EvaluateStmt(APValue &Result, EvalInfo &Info, - const Stmt *S) { + const Stmt *S, const SwitchCase *Case) { if (!Info.nextStep(S)) return ESR_Failed; + // If we're hunting down a 'case' or 'default' label, recurse through + // substatements until we hit the label. + if (Case) { + // FIXME: We don't start the lifetime of objects whose initialization we + // jump over. However, such objects must be of class type with a trivial + // default constructor that initialize all subobjects, so must be empty, + // so this almost never matters. + switch (S->getStmtClass()) { + case Stmt::CompoundStmtClass: + // FIXME: Precompute which substatement of a compound statement we + // would jump to, and go straight there rather than performing a + // linear scan each time. + case Stmt::LabelStmtClass: + case Stmt::AttributedStmtClass: + case Stmt::DoStmtClass: + break; + + case Stmt::CaseStmtClass: + case Stmt::DefaultStmtClass: + if (Case == S) + Case = 0; + break; + + case Stmt::IfStmtClass: { + // FIXME: Precompute which side of an 'if' we would jump to, and go + // straight there rather than scanning both sides. + const IfStmt *IS = cast(S); + EvalStmtResult ESR = EvaluateStmt(Result, Info, IS->getThen(), Case); + if (ESR != ESR_CaseNotFound || !IS->getElse()) + return ESR; + return EvaluateStmt(Result, Info, IS->getElse(), Case); + } + + case Stmt::WhileStmtClass: { + EvalStmtResult ESR = + EvaluateLoopBody(Result, Info, cast(S)->getBody(), Case); + if (ESR != ESR_Continue) + return ESR; + break; + } + + case Stmt::ForStmtClass: { + const ForStmt *FS = cast(S); + EvalStmtResult ESR = + EvaluateLoopBody(Result, Info, FS->getBody(), Case); + if (ESR != ESR_Continue) + return ESR; + if (FS->getInc() && !EvaluateIgnoredValue(Info, FS->getInc())) + return ESR_Failed; + break; + } + + case Stmt::DeclStmtClass: + // FIXME: If the variable has initialization that can't be jumped over, + // bail out of any immediately-surrounding compound-statement too. + default: + return ESR_CaseNotFound; + } + } + // FIXME: Mark all temporaries in the current frame as destroyed at // the end of each full-expression. switch (S->getStmtClass()) { @@ -2865,11 +2979,13 @@ static EvalStmtResult EvaluateStmt(APValue &Result, EvalInfo &Info, const CompoundStmt *CS = cast(S); for (CompoundStmt::const_body_iterator BI = CS->body_begin(), BE = CS->body_end(); BI != BE; ++BI) { - EvalStmtResult ESR = EvaluateStmt(Result, Info, *BI); - if (ESR != ESR_Succeeded) + EvalStmtResult ESR = EvaluateStmt(Result, Info, *BI, Case); + if (ESR == ESR_Succeeded) + Case = 0; + else if (ESR != ESR_CaseNotFound) return ESR; } - return ESR_Succeeded; + return Case ? ESR_CaseNotFound : ESR_Succeeded; } case Stmt::IfStmtClass: { @@ -2909,9 +3025,10 @@ static EvalStmtResult EvaluateStmt(APValue &Result, EvalInfo &Info, const DoStmt *DS = cast(S); bool Continue; do { - EvalStmtResult ESR = EvaluateLoopBody(Result, Info, DS->getBody()); + EvalStmtResult ESR = EvaluateLoopBody(Result, Info, DS->getBody(), Case); if (ESR != ESR_Continue) return ESR; + Case = 0; if (!EvaluateAsBooleanCondition(DS->getCond(), Continue, Info)) return ESR_Failed; @@ -2983,11 +3100,27 @@ static EvalStmtResult EvaluateStmt(APValue &Result, EvalInfo &Info, return ESR_Succeeded; } + case Stmt::SwitchStmtClass: + return EvaluateSwitch(Result, Info, cast(S)); + case Stmt::ContinueStmtClass: return ESR_Continue; case Stmt::BreakStmtClass: return ESR_Break; + + case Stmt::LabelStmtClass: + return EvaluateStmt(Result, Info, cast(S)->getSubStmt(), Case); + + case Stmt::AttributedStmtClass: + // As a general principle, C++11 attributes can be ignored without + // any semantic impact. + return EvaluateStmt(Result, Info, cast(S)->getSubStmt(), + Case); + + case Stmt::CaseStmtClass: + case Stmt::DefaultStmtClass: + return EvaluateStmt(Result, Info, cast(S)->getSubStmt(), Case); } } diff --git a/test/SemaCXX/constant-expression-cxx1y.cpp b/test/SemaCXX/constant-expression-cxx1y.cpp index 824a0e8b91..59a972a4fd 100644 --- a/test/SemaCXX/constant-expression-cxx1y.cpp +++ b/test/SemaCXX/constant-expression-cxx1y.cpp @@ -593,3 +593,117 @@ namespace assignment_op { } static_assert(testC(), ""); } + +namespace switch_stmt { + constexpr int f(char k) { + bool b = false; + int z = 6; + switch (k) { + return -1; + case 0: + if (false) { + case 1: + z = 1; + for (; b;) { + return 5; + while (0) + case 2: return 2; + case 7: z = 7; + do case 6: { + return z; + if (false) + case 3: return 3; + case 4: z = 4; + } while (1); + case 5: b = true; + case 9: z = 9; + } + return z; + } else if (false) case 8: z = 8; + else if (false) { + case 10: + z = -10; + break; + } + else z = 0; + return z; + default: + return -1; + } + return -z; + } + static_assert(f(0) == 0, ""); + static_assert(f(1) == 1, ""); + static_assert(f(2) == 2, ""); + static_assert(f(3) == 3, ""); + static_assert(f(4) == 4, ""); + static_assert(f(5) == 5, ""); + static_assert(f(6) == 6, ""); + static_assert(f(7) == 7, ""); + static_assert(f(8) == 8, ""); + static_assert(f(9) == 9, ""); + static_assert(f(10) == 10, ""); + + // Check that we can continue an outer loop from within a switch. + constexpr bool contin() { + for (int n = 0; n != 10; ++n) { + switch (n) { + case 0: + ++n; + continue; + case 1: + return false; + case 2: + return true; + } + } + return false; + } + static_assert(contin(), ""); + + constexpr bool switch_into_for() { + int n = 0; + switch (n) { + for (; n == 1; ++n) { + return n == 1; + case 0: ; + } + } + return false; + } + static_assert(switch_into_for(), ""); + + constexpr void duff_copy(char *a, const char *b, int n) { + switch ((n - 1) % 8 + 1) { + for ( ; n; n = (n - 1) & ~7) { + case 8: a[n-8] = b[n-8]; + case 7: a[n-7] = b[n-7]; + case 6: a[n-6] = b[n-6]; + case 5: a[n-5] = b[n-5]; + case 4: a[n-4] = b[n-4]; + case 3: a[n-3] = b[n-3]; + case 2: a[n-2] = b[n-2]; + case 1: a[n-1] = b[n-1]; + } + case 0: ; + } + } + + constexpr bool test_copy(const char *str, int n) { + char buffer[16] = {}; + duff_copy(buffer, str, n); + for (int i = 0; i != sizeof(buffer); ++i) + if (buffer[i] != (i < n ? str[i] : 0)) + return false; + return true; + } + static_assert(test_copy("foo", 0), ""); + static_assert(test_copy("foo", 1), ""); + static_assert(test_copy("foo", 2), ""); + static_assert(test_copy("hello world", 0), ""); + static_assert(test_copy("hello world", 7), ""); + static_assert(test_copy("hello world", 8), ""); + static_assert(test_copy("hello world", 9), ""); + static_assert(test_copy("hello world", 10), ""); + static_assert(test_copy("hello world", 10), ""); +} -- 2.40.0