]> granicus.if.org Git - postgresql/commitdiff
Implement OR REPLACE option for CREATE AGGREGATE.
authorAndrew Gierth <rhodiumtoad@postgresql.org>
Tue, 19 Mar 2019 01:16:50 +0000 (01:16 +0000)
committerAndrew Gierth <rhodiumtoad@postgresql.org>
Tue, 19 Mar 2019 01:16:50 +0000 (01:16 +0000)
Aggregates have acquired a dozen or so optional attributes in recent
years for things like parallel query and moving-aggregate mode; the
lack of an OR REPLACE option to add or change these for an existing
agg makes extension upgrades gratuitously hard. Rectify.

13 files changed:
doc/src/sgml/ref/create_aggregate.sgml
src/backend/catalog/pg_aggregate.c
src/backend/catalog/pg_proc.c
src/backend/commands/aggregatecmds.c
src/backend/nodes/copyfuncs.c
src/backend/nodes/equalfuncs.c
src/backend/parser/gram.y
src/backend/tcop/utility.c
src/include/catalog/pg_aggregate.h
src/include/commands/defrem.h
src/include/nodes/parsenodes.h
src/test/regress/expected/create_aggregate.out
src/test/regress/sql/create_aggregate.sql

index b8cd2e7af90489912913399fa1c5f15e2e0d364c..ca0e9db8b139d8578c2fae39cb4819bcbfbaae60 100644 (file)
@@ -21,7 +21,7 @@ PostgreSQL documentation
 
  <refsynopsisdiv>
 <synopsis>
-CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ <replaceable class="parameter">argmode</replaceable> ] [ <replaceable class="parameter">argname</replaceable> ] <replaceable class="parameter">arg_data_type</replaceable> [ , ... ] ) (
+CREATE [ OR REPLACE ] AGGREGATE <replaceable class="parameter">name</replaceable> ( [ <replaceable class="parameter">argmode</replaceable> ] [ <replaceable class="parameter">argname</replaceable> ] <replaceable class="parameter">arg_data_type</replaceable> [ , ... ] ) (
     SFUNC = <replaceable class="parameter">sfunc</replaceable>,
     STYPE = <replaceable class="parameter">state_data_type</replaceable>
     [ , SSPACE = <replaceable class="parameter">state_data_size</replaceable> ]
@@ -44,7 +44,7 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ <replacea
     [ , PARALLEL = { SAFE | RESTRICTED | UNSAFE } ]
 )
 
-CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replaceable class="parameter">argmode</replaceable> ] [ <replaceable class="parameter">argname</replaceable> ] <replaceable class="parameter">arg_data_type</replaceable> [ , ... ] ]
+CREATE [ OR REPLACE ] AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replaceable class="parameter">argmode</replaceable> ] [ <replaceable class="parameter">argname</replaceable> ] <replaceable class="parameter">arg_data_type</replaceable> [ , ... ] ]
                         ORDER BY [ <replaceable class="parameter">argmode</replaceable> ] [ <replaceable class="parameter">argname</replaceable> ] <replaceable class="parameter">arg_data_type</replaceable> [ , ... ] ) (
     SFUNC = <replaceable class="parameter">sfunc</replaceable>,
     STYPE = <replaceable class="parameter">state_data_type</replaceable>
@@ -59,7 +59,7 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replac
 
 <phrase>or the old syntax</phrase>
 
-CREATE AGGREGATE <replaceable class="parameter">name</replaceable> (
+CREATE [ OR REPLACE ] AGGREGATE <replaceable class="parameter">name</replaceable> (
     BASETYPE = <replaceable class="parameter">base_type</replaceable>,
     SFUNC = <replaceable class="parameter">sfunc</replaceable>,
     STYPE = <replaceable class="parameter">state_data_type</replaceable>
@@ -88,12 +88,21 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> (
   <title>Description</title>
 
   <para>
-   <command>CREATE AGGREGATE</command> defines a new aggregate
-   function. Some basic and commonly-used aggregate functions are
-   included with the distribution; they are documented in <xref
-   linkend="functions-aggregate"/>. If one defines new types or needs
-   an aggregate function not already provided, then <command>CREATE
-   AGGREGATE</command> can be used to provide the desired features.
+   <command>CREATE AGGREGATE</command> defines a new aggregate function.
+   <command>CREATE OR REPLACE AGGREGATE</command> will either define a new
+   aggregate function or replace an existing definition. Some basic and
+   commonly-used aggregate functions are included with the distribution; they
+   are documented in <xref linkend="functions-aggregate"/>. If one defines new
+   types or needs an aggregate function not already provided, then
+   <command>CREATE AGGREGATE</command> can be used to provide the desired
+   features.
+  </para>
+
+  <para>
+   When replacing an existing definition, the argument types, result type,
+   and number of direct arguments may not be changed. Also, the new definition
+   must be of the same kind (ordinary aggregate, ordered-set aggregate, or
+   hypothetical-set aggregate) as the old one.
   </para>
 
   <para>
index 19e3171bf7df4ce257a688f2cf4fafa35e82c56a..cdc8d9453d9c9779280d0586c8e592916c4ed2d3 100644 (file)
@@ -45,6 +45,7 @@ static Oid lookup_agg_function(List *fnName, int nargs, Oid *input_types,
 ObjectAddress
 AggregateCreate(const char *aggName,
                                Oid aggNamespace,
+                               bool replace,
                                char aggKind,
                                int numArgs,
                                int numDirectArgs,
@@ -77,8 +78,10 @@ AggregateCreate(const char *aggName,
 {
        Relation        aggdesc;
        HeapTuple       tup;
+       HeapTuple       oldtup;
        bool            nulls[Natts_pg_aggregate];
        Datum           values[Natts_pg_aggregate];
+       bool            replaces[Natts_pg_aggregate];
        Form_pg_proc proc;
        Oid                     transfn;
        Oid                     finalfn = InvalidOid;   /* can be omitted */
@@ -609,7 +612,7 @@ AggregateCreate(const char *aggName,
 
        myself = ProcedureCreate(aggName,
                                                         aggNamespace,
-                                                        false, /* no replacement */
+                                                        replace, /* maybe replacement */
                                                         false, /* doesn't return a set */
                                                         finaltype, /* returnType */
                                                         GetUserId(),   /* proowner */
@@ -648,6 +651,7 @@ AggregateCreate(const char *aggName,
        {
                nulls[i] = false;
                values[i] = (Datum) NULL;
+               replaces[i] = true;
        }
        values[Anum_pg_aggregate_aggfnoid - 1] = ObjectIdGetDatum(procOid);
        values[Anum_pg_aggregate_aggkind - 1] = CharGetDatum(aggKind);
@@ -678,8 +682,51 @@ AggregateCreate(const char *aggName,
        else
                nulls[Anum_pg_aggregate_aggminitval - 1] = true;
 
-       tup = heap_form_tuple(tupDesc, values, nulls);
-       CatalogTupleInsert(aggdesc, tup);
+       if (replace)
+               oldtup = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(procOid));
+       else
+               oldtup = NULL;
+
+       if (HeapTupleIsValid(oldtup))
+       {
+               Form_pg_aggregate oldagg = (Form_pg_aggregate) GETSTRUCT(oldtup);
+
+               /*
+                * If we're replacing an existing entry, we need to validate that
+                * we're not changing anything that would break callers.
+                * Specifically we must not change aggkind or aggnumdirectargs,
+                * which affect how an aggregate call is treated in parse
+                * analysis.
+                */
+               if (aggKind != oldagg->aggkind)
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_WRONG_OBJECT_TYPE),
+                                        errmsg("cannot change routine kind"),
+                                        (oldagg->aggkind == AGGKIND_NORMAL ?
+                                         errdetail("\"%s\" is an ordinary aggregate function.", aggName) :
+                                         oldagg->aggkind == AGGKIND_ORDERED_SET ?
+                                         errdetail("\"%s\" is an ordered-set aggregate.", aggName) :
+                                         oldagg->aggkind == AGGKIND_HYPOTHETICAL ?
+                                         errdetail("\"%s\" is a hypothetical-set aggregate.", aggName) :
+                                         0)));
+               if (numDirectArgs != oldagg->aggnumdirectargs)
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                        errmsg("cannot change number of direct args of an aggregate function")));
+
+               replaces[Anum_pg_aggregate_aggfnoid - 1] = false;
+               replaces[Anum_pg_aggregate_aggkind - 1] = false;
+               replaces[Anum_pg_aggregate_aggnumdirectargs - 1] = false;
+
+               tup = heap_modify_tuple(oldtup, tupDesc, values, nulls, replaces);
+               CatalogTupleUpdate(aggdesc, &tup->t_self, tup);
+               ReleaseSysCache(oldtup);
+       }
+       else
+       {
+               tup = heap_form_tuple(tupDesc, values, nulls);
+               CatalogTupleInsert(aggdesc, tup);
+       }
 
        table_close(aggdesc, RowExclusiveLock);
 
@@ -688,6 +735,10 @@ AggregateCreate(const char *aggName,
         * made by ProcedureCreate).  Note: we don't need an explicit dependency
         * on aggTransType since we depend on it indirectly through transfn.
         * Likewise for aggmTransType using the mtransfunc, if it exists.
+        *
+        * If we're replacing an existing definition, ProcedureCreate deleted all
+        * our existing dependencies, so we have to do the same things here either
+        * way.
         */
 
        /* Depends on transition function */
index 557e0ea1f1446e43685db99763a2a4f5509d5db1..fb22035a2a6d9f071656f0ae37cbb884a78791e6 100644 (file)
@@ -404,7 +404,9 @@ ProcedureCreate(const char *procedureName,
                                          errdetail("\"%s\" is a window function.", procedureName) :
                                          0)));
 
-               dropcmd = (prokind == PROKIND_PROCEDURE ? "DROP PROCEDURE" : "DROP FUNCTION");
+               dropcmd = (prokind == PROKIND_PROCEDURE ? "DROP PROCEDURE" :
+                                  prokind == PROKIND_AGGREGATE ? "DROP AGGREGATE" :
+                                  "DROP FUNCTION");
 
                /*
                 * Not okay to change the return type of the existing proc, since
@@ -421,7 +423,7 @@ ProcedureCreate(const char *procedureName,
                                         prokind == PROKIND_PROCEDURE
                                         ? errmsg("cannot change whether a procedure has output parameters")
                                         : errmsg("cannot change return type of existing function"),
-                                        /* translator: first %s is DROP FUNCTION or DROP PROCEDURE */
+                                        /* translator: first %s is DROP FUNCTION, DROP PROCEDURE or DROP AGGREGATE */
                                         errhint("Use %s %s first.",
                                                         dropcmd,
                                                         format_procedure(oldproc->oid))));
index d00765fbc744f0a512782693e04ed172846ae5ae..d569067dc4d04f5eb2bb1a510c87f57c2f1a300b 100644 (file)
@@ -54,7 +54,12 @@ static char extractModify(DefElem *defel);
  * "parameters" is a list of DefElem representing the agg's definition clauses.
  */
 ObjectAddress
-DefineAggregate(ParseState *pstate, List *name, List *args, bool oldstyle, List *parameters)
+DefineAggregate(ParseState *pstate,
+                               List *name,
+                               List *args,
+                               bool oldstyle,
+                               List *parameters,
+                               bool replace)
 {
        char       *aggName;
        Oid                     aggNamespace;
@@ -436,6 +441,7 @@ DefineAggregate(ParseState *pstate, List *name, List *args, bool oldstyle, List
         */
        return AggregateCreate(aggName, /* aggregate name */
                                                   aggNamespace,        /* namespace */
+                                                  replace,
                                                   aggKind,
                                                   numArgs,
                                                   numDirectArgs,
index a8a735c2476d8e2390a8d3fd51e7a7fa7cdb6df8..6f3565ad205d424140738735d3465aa5cf2ea117 100644 (file)
@@ -3372,6 +3372,7 @@ _copyDefineStmt(const DefineStmt *from)
        COPY_NODE_FIELD(args);
        COPY_NODE_FIELD(definition);
        COPY_SCALAR_FIELD(if_not_exists);
+       COPY_SCALAR_FIELD(replace);
 
        return newnode;
 }
index 3cab90e9f883bfae48307f51b9783b3b6b5c671a..813606ce0e780297ee1032625b85328351eb31eb 100644 (file)
@@ -1265,6 +1265,7 @@ _equalDefineStmt(const DefineStmt *a, const DefineStmt *b)
        COMPARE_NODE_FIELD(args);
        COMPARE_NODE_FIELD(definition);
        COMPARE_SCALAR_FIELD(if_not_exists);
+       COMPARE_SCALAR_FIELD(replace);
 
        return true;
 }
index e814939a254ff48d83ffaf3b23f851f44cad7254..502e51bb0e1ab54cfa53dcfbd832531e07f60648 100644 (file)
@@ -5618,25 +5618,27 @@ CreateAssertionStmt:
  *****************************************************************************/
 
 DefineStmt:
-                       CREATE AGGREGATE func_name aggr_args definition
+                       CREATE opt_or_replace AGGREGATE func_name aggr_args definition
                                {
                                        DefineStmt *n = makeNode(DefineStmt);
                                        n->kind = OBJECT_AGGREGATE;
                                        n->oldstyle = false;
-                                       n->defnames = $3;
-                                       n->args = $4;
-                                       n->definition = $5;
+                                       n->replace = $2;
+                                       n->defnames = $4;
+                                       n->args = $5;
+                                       n->definition = $6;
                                        $$ = (Node *)n;
                                }
-                       | CREATE AGGREGATE func_name old_aggr_definition
+                       | CREATE opt_or_replace AGGREGATE func_name old_aggr_definition
                                {
                                        /* old-style (pre-8.2) syntax for CREATE AGGREGATE */
                                        DefineStmt *n = makeNode(DefineStmt);
                                        n->kind = OBJECT_AGGREGATE;
                                        n->oldstyle = true;
-                                       n->defnames = $3;
+                                       n->replace = $2;
+                                       n->defnames = $4;
                                        n->args = NIL;
-                                       n->definition = $4;
+                                       n->definition = $5;
                                        $$ = (Node *)n;
                                }
                        | CREATE OPERATOR any_operator definition
index bdfaa506e7e25be40975114d3b2f11fb83cc374d..5053ef05effd3c2709cfe4900f16f0bca383c7ae 100644 (file)
@@ -1237,7 +1237,8 @@ ProcessUtilitySlow(ParseState *pstate,
                                                        address =
                                                                DefineAggregate(pstate, stmt->defnames, stmt->args,
                                                                                                stmt->oldstyle,
-                                                                                               stmt->definition);
+                                                                                               stmt->definition,
+                                                                                               stmt->replace);
                                                        break;
                                                case OBJECT_OPERATOR:
                                                        Assert(stmt->args == NIL);
index 832b7c2145a4023d2ee91b1454c5fd686cf49bf2..0b111b128328da3f892528f5e920b3475b789092 100644 (file)
@@ -142,6 +142,7 @@ typedef FormData_pg_aggregate *Form_pg_aggregate;
 
 extern ObjectAddress AggregateCreate(const char *aggName,
                                Oid aggNamespace,
+                               bool replace,
                                char aggKind,
                                int numArgs,
                                int numDirectArgs,
index e592a914a482f0ecbe46aa18f47590239e621b21..3bc2e8eb16c4e4362e5d3854a1da8ee204afc953 100644 (file)
@@ -94,7 +94,7 @@ extern void UpdateStatisticsForTypeChange(Oid statsOid,
 
 /* commands/aggregatecmds.c */
 extern ObjectAddress DefineAggregate(ParseState *pstate, List *name, List *args, bool oldstyle,
-                               List *parameters);
+                                                                        List *parameters, bool replace);
 
 /* commands/opclasscmds.c */
 extern ObjectAddress DefineOpClass(CreateOpClassStmt *stmt);
index fcfba4be4c84a0194dfa6eb83a1e6d323aad53cf..81278e40197609df9fe5a753266b2f454f37a5f5 100644 (file)
@@ -2532,6 +2532,7 @@ typedef struct DefineStmt
        List       *args;                       /* a list of TypeName (if needed) */
        List       *definition;         /* a list of DefElem */
        bool            if_not_exists;  /* just do nothing if it already exists? */
+       bool            replace;                /* replace if already exists? */
 } DefineStmt;
 
 /* ----------------------
index 3d92084e1304158c254e6dba5bc24fb093fcca3b..a2eb9996e1579f83609b9f91078937bb8495a942 100644 (file)
@@ -160,6 +160,77 @@ WHERE aggfnoid = 'myavg'::REGPROC;
  myavg    | numeric_avg_accum | numeric_avg_combine | internal     | numeric_avg_serialize | numeric_avg_deserialize | s
 (1 row)
 
+DROP AGGREGATE myavg (numeric);
+-- create or replace aggregate
+CREATE AGGREGATE myavg (numeric)
+(
+       stype = internal,
+       sfunc = numeric_avg_accum,
+       finalfunc = numeric_avg
+);
+CREATE OR REPLACE AGGREGATE myavg (numeric)
+(
+       stype = internal,
+       sfunc = numeric_avg_accum,
+       finalfunc = numeric_avg,
+       serialfunc = numeric_avg_serialize,
+       deserialfunc = numeric_avg_deserialize,
+       combinefunc = numeric_avg_combine,
+       finalfunc_modify = shareable  -- just to test a non-default setting
+);
+-- Ensure all these functions made it into the catalog again
+SELECT aggfnoid, aggtransfn, aggcombinefn, aggtranstype::regtype,
+       aggserialfn, aggdeserialfn, aggfinalmodify
+FROM pg_aggregate
+WHERE aggfnoid = 'myavg'::REGPROC;
+ aggfnoid |    aggtransfn     |    aggcombinefn     | aggtranstype |      aggserialfn      |      aggdeserialfn      | aggfinalmodify 
+----------+-------------------+---------------------+--------------+-----------------------+-------------------------+----------------
+ myavg    | numeric_avg_accum | numeric_avg_combine | internal     | numeric_avg_serialize | numeric_avg_deserialize | s
+(1 row)
+
+-- can change stype:
+CREATE OR REPLACE AGGREGATE myavg (numeric)
+(
+       stype = numeric,
+       sfunc = numeric_add
+);
+SELECT aggfnoid, aggtransfn, aggcombinefn, aggtranstype::regtype,
+       aggserialfn, aggdeserialfn, aggfinalmodify
+FROM pg_aggregate
+WHERE aggfnoid = 'myavg'::REGPROC;
+ aggfnoid | aggtransfn  | aggcombinefn | aggtranstype | aggserialfn | aggdeserialfn | aggfinalmodify 
+----------+-------------+--------------+--------------+-------------+---------------+----------------
+ myavg    | numeric_add | -            | numeric      | -           | -             | r
+(1 row)
+
+-- can't change return type:
+CREATE OR REPLACE AGGREGATE myavg (numeric)
+(
+       stype = numeric,
+       sfunc = numeric_add,
+       finalfunc = numeric_out
+);
+ERROR:  cannot change return type of existing function
+HINT:  Use DROP AGGREGATE myavg(numeric) first.
+-- can't change to a different kind:
+CREATE OR REPLACE AGGREGATE myavg (order by numeric)
+(
+       stype = numeric,
+       sfunc = numeric_add
+);
+ERROR:  cannot change routine kind
+DETAIL:  "myavg" is an ordinary aggregate function.
+-- can't change plain function to aggregate:
+create function sum4(int8,int8,int8,int8) returns int8 as
+'select $1 + $2 + $3 + $4' language sql strict immutable;
+CREATE OR REPLACE AGGREGATE sum3 (int8,int8,int8)
+(
+       stype = int8,
+       sfunc = sum4
+);
+ERROR:  cannot change routine kind
+DETAIL:  "sum3" is a function.
+drop function sum4(int8,int8,int8,int8);
 DROP AGGREGATE myavg (numeric);
 -- invalid: bad parallel-safety marking
 CREATE AGGREGATE mysum (int)
index cb6552e2d68047461fd460b0c9829f10f98bd111..fd7cd400c192ba09ea481af055f144f92be0dffe 100644 (file)
@@ -174,6 +174,71 @@ WHERE aggfnoid = 'myavg'::REGPROC;
 
 DROP AGGREGATE myavg (numeric);
 
+-- create or replace aggregate
+CREATE AGGREGATE myavg (numeric)
+(
+       stype = internal,
+       sfunc = numeric_avg_accum,
+       finalfunc = numeric_avg
+);
+
+CREATE OR REPLACE AGGREGATE myavg (numeric)
+(
+       stype = internal,
+       sfunc = numeric_avg_accum,
+       finalfunc = numeric_avg,
+       serialfunc = numeric_avg_serialize,
+       deserialfunc = numeric_avg_deserialize,
+       combinefunc = numeric_avg_combine,
+       finalfunc_modify = shareable  -- just to test a non-default setting
+);
+
+-- Ensure all these functions made it into the catalog again
+SELECT aggfnoid, aggtransfn, aggcombinefn, aggtranstype::regtype,
+       aggserialfn, aggdeserialfn, aggfinalmodify
+FROM pg_aggregate
+WHERE aggfnoid = 'myavg'::REGPROC;
+
+-- can change stype:
+CREATE OR REPLACE AGGREGATE myavg (numeric)
+(
+       stype = numeric,
+       sfunc = numeric_add
+);
+SELECT aggfnoid, aggtransfn, aggcombinefn, aggtranstype::regtype,
+       aggserialfn, aggdeserialfn, aggfinalmodify
+FROM pg_aggregate
+WHERE aggfnoid = 'myavg'::REGPROC;
+
+-- can't change return type:
+CREATE OR REPLACE AGGREGATE myavg (numeric)
+(
+       stype = numeric,
+       sfunc = numeric_add,
+       finalfunc = numeric_out
+);
+
+-- can't change to a different kind:
+CREATE OR REPLACE AGGREGATE myavg (order by numeric)
+(
+       stype = numeric,
+       sfunc = numeric_add
+);
+
+-- can't change plain function to aggregate:
+create function sum4(int8,int8,int8,int8) returns int8 as
+'select $1 + $2 + $3 + $4' language sql strict immutable;
+
+CREATE OR REPLACE AGGREGATE sum3 (int8,int8,int8)
+(
+       stype = int8,
+       sfunc = sum4
+);
+
+drop function sum4(int8,int8,int8,int8);
+
+DROP AGGREGATE myavg (numeric);
+
 -- invalid: bad parallel-safety marking
 CREATE AGGREGATE mysum (int)
 (