Fix parallel-safety code for parallel aggregation.
authorRobert Haas <rhaas@postgresql.org>
Tue, 5 Apr 2016 20:06:15 +0000 (16:06 -0400)
committerRobert Haas <rhaas@postgresql.org>
Tue, 5 Apr 2016 20:06:15 +0000 (16:06 -0400)
has_parallel_hazard() was ignoring the proparallel markings for
aggregates, which is no good.  Fix that.  There was no way to mark
an aggregate as actually being parallel-safe, either, so add a
PARALLEL option to CREATE AGGREGATE.

Patch by me, reviewed by David Rowley.

doc/src/sgml/ref/create_aggregate.sgml
src/backend/catalog/pg_aggregate.c
src/backend/commands/aggregatecmds.c
src/backend/commands/functioncmds.c
src/backend/optimizer/util/clauses.c
src/include/catalog/pg_aggregate.h
src/test/regress/expected/create_aggregate.out
src/test/regress/sql/create_aggregate.sql

index 7a6f8a97fdac11ec6f123523ce79af8bcf6a4430..3df330393dec93b65e1036545edab4b4beb42532 100644 (file)
@@ -40,6 +40,7 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ <replacea
     [ , MFINALFUNC_EXTRA ]
     [ , MINITCOND = <replaceable class="PARAMETER">minitial_condition</replaceable> ]
     [ , SORTOP = <replaceable class="PARAMETER">sort_operator</replaceable> ]
+    [ , PARALLEL = { SAFE | RESTRICTED | UNSAFE } ]
 )
 
 CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replaceable class="parameter">argmode</replaceable> ] [ <replaceable class="parameter">argname</replaceable> ] <replaceable class="parameter">arg_data_type</replaceable> [ , ... ] ]
@@ -55,6 +56,8 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replac
     [ , SERIALTYPE = <replaceable class="PARAMETER">serialtype</replaceable> ]
     [ , INITCOND = <replaceable class="PARAMETER">initial_condition</replaceable> ]
     [ , HYPOTHETICAL ]
+    [ , PARALLEL = { SAFE | RESTRICTED | UNSAFE } ]
+
 )
 
 <phrase>or the old syntax</phrase>
@@ -684,6 +687,12 @@ SELECT col FROM tab ORDER BY col USING sortop LIMIT 1;
     Currently, ordered-set aggregates do not need to support
     moving-aggregate mode, since they cannot be used as window functions.
    </para>
+
+   <para>
+    The meaning of <literal>PARALLEL SAFE</>, <literal>PARALLEL RESTRICTED</>,
+    and <literal>PARALLEL UNSAFE</> is the same as for
+    <xref linkend="sql-createfunction">.
+   </para>
  </refsect1>
 
  <refsect1>
index b420349835b05a7740ea688f2d8b03eb0a02e7b1..bcc941104f5c44452b05aa0e00249742005c3d25 100644 (file)
@@ -72,7 +72,8 @@ AggregateCreate(const char *aggName,
                                Oid aggmTransType,
                                int32 aggmTransSpace,
                                const char *agginitval,
-                               const char *aggminitval)
+                               const char *aggminitval,
+                               char proparallel)
 {
        Relation        aggdesc;
        HeapTuple       tup;
@@ -622,7 +623,7 @@ AggregateCreate(const char *aggName,
                                                         false,         /* isStrict (not needed for agg) */
                                                         PROVOLATILE_IMMUTABLE,         /* volatility (not
                                                                                                                 * needed for agg) */
-                                                        PROPARALLEL_UNSAFE,
+                                                        proparallel,
                                                         parameterTypes,        /* paramTypes */
                                                         allParameterTypes, /* allParamTypes */
                                                         parameterModes,        /* parameterModes */
index 3424f842b9c0b4e30098cc334772d273af35b175..5c4d576b8660f37f29ba53937f113e8931a0b464 100644 (file)
@@ -78,6 +78,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
        int32           mtransSpace = 0;
        char       *initval = NULL;
        char       *minitval = NULL;
+       char       *parallel = NULL;
        int                     numArgs;
        int                     numDirectArgs = 0;
        oidvector  *parameterTypes;
@@ -91,6 +92,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
        Oid                     mtransTypeId = InvalidOid;
        char            transTypeType;
        char            mtransTypeType = 0;
+       char            proparallel = PROPARALLEL_UNSAFE;
        ListCell   *pl;
 
        /* Convert list of names to a name and namespace */
@@ -178,6 +180,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
                        initval = defGetString(defel);
                else if (pg_strcasecmp(defel->defname, "minitcond") == 0)
                        minitval = defGetString(defel);
+               else if (pg_strcasecmp(defel->defname, "parallel") == 0)
+                       parallel = defGetString(defel);
                else
                        ereport(WARNING,
                                        (errcode(ERRCODE_SYNTAX_ERROR),
@@ -449,6 +453,20 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
                (void) OidInputFunctionCall(typinput, minitval, typioparam, -1);
        }
 
+       if (parallel)
+       {
+               if (pg_strcasecmp(parallel, "safe") == 0)
+                       proparallel = PROPARALLEL_SAFE;
+               else if (pg_strcasecmp(parallel, "restricted") == 0)
+                       proparallel = PROPARALLEL_RESTRICTED;
+               else if (pg_strcasecmp(parallel, "unsafe") == 0)
+                       proparallel = PROPARALLEL_UNSAFE;
+               else
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_SYNTAX_ERROR),
+                                        errmsg("parameter \"parallel\" must be SAFE, RESTRICTED, or UNSAFE")));
+       }
+
        /*
         * Most of the argument-checking is done inside of AggregateCreate
         */
@@ -480,5 +498,6 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
                                                   mtransTypeId,                /* transition data type */
                                                   mtransSpace, /* transition space */
                                                   initval,             /* initial condition */
-                                                  minitval);   /* initial condition */
+                                                  minitval,    /* initial condition */
+                                                  proparallel);                /* parallel safe? */
 }
index a745d73c7a517e0a163ff850149fbfd531778a4c..748c8f75d4824b1dbd2e4a81ceb7bc26a9a84be9 100644 (file)
@@ -566,9 +566,8 @@ interpret_func_parallel(DefElem *defel)
        else
        {
                ereport(ERROR,
-                               (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
-                                errmsg("parallel option \"%s\" not recognized",
-                                               str)));
+                               (errcode(ERRCODE_SYNTAX_ERROR),
+                                errmsg("parameter \"parallel\" must be SAFE, RESTRICTED, or UNSAFE")));
                return PROPARALLEL_UNSAFE;              /* keep compiler quiet */
        }
 }
index c615717dea3214a36a475c5dcf9dc22bb8700e18..5674a73dfe0f834e4ab73ac2c06052dd8c7e495c 100644 (file)
@@ -1419,6 +1419,13 @@ has_parallel_hazard_walker(Node *node, has_parallel_hazard_arg *context)
                if (parallel_too_dangerous(func_parallel(expr->funcid), context))
                        return true;
        }
+       else if (IsA(node, Aggref))
+       {
+               Aggref     *aggref = (Aggref *) node;
+
+               if (parallel_too_dangerous(func_parallel(aggref->aggfnoid), context))
+                       return true;
+       }
        else if (IsA(node, OpExpr))
        {
                OpExpr     *expr = (OpExpr *) node;
index 101d073a04a896c9464e557f42fb61009a0f5027..7d5015a1cf312de7c34218ad0743ac444ab5eb77 100644 (file)
@@ -349,6 +349,7 @@ extern ObjectAddress AggregateCreate(const char *aggName,
                                Oid aggmTransType,
                                int32 aggmTransSpace,
                                const char *agginitval,
-                               const char *aggminitval);
+                               const char *aggminitval,
+                               char proparallel);
 
 #endif   /* PG_AGGREGATE_H */
index dac26982bcafca7e192cf9662492d065c7ddd41e..1aba0c626696a6048b13133b7ad9a81ab823bdce 100644 (file)
@@ -20,9 +20,9 @@ CREATE AGGREGATE newsum (
 -- zero-argument aggregate
 CREATE AGGREGATE newcnt (*) (
    sfunc = int8inc, stype = int8,
-   initcond = '0'
+   initcond = '0', parallel = safe
 );
--- old-style spelling of same
+-- old-style spelling of same (except without parallel-safe; that's too new)
 CREATE AGGREGATE oldcnt (
    sfunc = int8inc, basetype = 'ANY', stype = int8,
    initcond = '0'
@@ -188,6 +188,14 @@ WHERE aggfnoid = 'myavg'::REGPROC;
 (1 row)
 
 DROP AGGREGATE myavg (numeric);
+-- invalid: bad parallel-safety marking
+CREATE AGGREGATE mysum (int)
+(
+       stype = int,
+       sfunc = int4pl,
+       parallel = pear
+);
+ERROR:  parameter "parallel" must be SAFE, RESTRICTED, or UNSAFE
 -- invalid: nonstrict inverse with strict forward function
 CREATE FUNCTION float8mi_n(float8, float8) RETURNS float8 AS
 $$ SELECT $1 - $2; $$
index a7da31e59433ddad65d890e988f9bc3d14a73965..c98c154a8296baa6e345576c716a482cdd23880d 100644 (file)
@@ -23,10 +23,10 @@ CREATE AGGREGATE newsum (
 -- zero-argument aggregate
 CREATE AGGREGATE newcnt (*) (
    sfunc = int8inc, stype = int8,
-   initcond = '0'
+   initcond = '0', parallel = safe
 );
 
--- old-style spelling of same
+-- old-style spelling of same (except without parallel-safe; that's too new)
 CREATE AGGREGATE oldcnt (
    sfunc = int8inc, basetype = 'ANY', stype = int8,
    initcond = '0'
@@ -201,6 +201,14 @@ WHERE aggfnoid = 'myavg'::REGPROC;
 
 DROP AGGREGATE myavg (numeric);
 
+-- invalid: bad parallel-safety marking
+CREATE AGGREGATE mysum (int)
+(
+       stype = int,
+       sfunc = int4pl,
+       parallel = pear
+);
+
 -- invalid: nonstrict inverse with strict forward function
 
 CREATE FUNCTION float8mi_n(float8, float8) RETURNS float8 AS