]> granicus.if.org Git - postgresql/commitdiff
Share transition state between different aggregates when possible.
authorHeikki Linnakangas <heikki.linnakangas@iki.fi>
Tue, 4 Aug 2015 14:53:10 +0000 (17:53 +0300)
committerHeikki Linnakangas <heikki.linnakangas@iki.fi>
Tue, 4 Aug 2015 14:53:10 +0000 (17:53 +0300)
If there are two different aggregates in the query with same inputs, and
the aggregates have the same initial condition and transition function,
only calculate the state value once, and only call the final functions
separately. For example, AVG(x) and SUM(x) aggregates have the same
transition function, which accumulates the sum and number of input tuples.
For a query like "SELECT AVG(x), SUM(x) FROM x", we can therefore
accumulate the state function only once, which gives a nice speedup.

David Rowley, reviewed and edited by me.

src/backend/executor/execQual.c
src/backend/executor/nodeAgg.c
src/backend/executor/nodeWindowAgg.c
src/backend/parser/parse_agg.c
src/include/nodes/execnodes.h
src/include/parser/parse_agg.h
src/test/regress/expected/aggregates.out
src/test/regress/sql/aggregates.sql

index 16bc8fa5f6c3534191d08551050e76ff9d765969..29f058ce5cbb1e7f7a8e10f712f84288a8be9ccb 100644 (file)
@@ -4487,35 +4487,15 @@ ExecInitExpr(Expr *node, PlanState *parent)
                        break;
                case T_Aggref:
                        {
-                               Aggref     *aggref = (Aggref *) node;
                                AggrefExprState *astate = makeNode(AggrefExprState);
 
                                astate->xprstate.evalfunc = (ExprStateEvalFunc) ExecEvalAggref;
                                if (parent && IsA(parent, AggState))
                                {
                                        AggState   *aggstate = (AggState *) parent;
-                                       int                     naggs;
 
                                        aggstate->aggs = lcons(astate, aggstate->aggs);
-                                       naggs = ++aggstate->numaggs;
-
-                                       astate->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs,
-                                                                                                                                 parent);
-                                       astate->args = (List *) ExecInitExpr((Expr *) aggref->args,
-                                                                                                                parent);
-                                       astate->aggfilter = ExecInitExpr(aggref->aggfilter,
-                                                                                                        parent);
-
-                                       /*
-                                        * Complain if the aggregate's arguments contain any
-                                        * aggregates; nested agg functions are semantically
-                                        * nonsensical.  (This should have been caught earlier,
-                                        * but we defend against it here anyway.)
-                                        */
-                                       if (naggs != aggstate->numaggs)
-                                               ereport(ERROR,
-                                                               (errcode(ERRCODE_GROUPING_ERROR),
-                                               errmsg("aggregate function calls cannot be nested")));
+                                       aggstate->numaggs++;
                                }
                                else
                                {
index 2bf48c54e3cd19badfb891e8ef7811dbed2a7c3a..2e3685557beab734fb4cbd966d449a288e982225 100644 (file)
 
 
 /*
- * AggStatePerAggData - per-aggregate working state for the Agg scan
+ * AggStatePerTransData - per aggregate state value information
+ *
+ * Working state for updating the aggregate's state value, by calling the
+ * transition function with an input row. This struct does not store the
+ * information needed to produce the final aggregate result from the transition
+ * state, that's stored in AggStatePerAggData instead. This separation allows
+ * multiple aggregate results to be produced from a single state value.
  */
-typedef struct AggStatePerAggData
+typedef struct AggStatePerTransData
 {
        /*
         * These values are set up during ExecInitAgg() and do not change
         * thereafter:
         */
 
-       /* Links to Aggref expr and state nodes this working state is for */
-       AggrefExprState *aggrefstate;
+       /*
+        * Link to an Aggref expr this state value is for.
+        *
+        * There can be multiple Aggref's sharing the same state value, as long as
+        * the inputs and transition function are identical. This points to the
+        * first one of them.
+        */
        Aggref     *aggref;
 
        /*
@@ -186,25 +197,22 @@ typedef struct AggStatePerAggData
         */
        int                     numTransInputs;
 
-       /*
-        * Number of arguments to pass to the finalfn.  This is always at least 1
-        * (the transition state value) plus any ordered-set direct args. If the
-        * finalfn wants extra args then we pass nulls corresponding to the
-        * aggregated input columns.
-        */
-       int                     numFinalArgs;
-
-       /* Oids of transfer functions */
+       /* Oid of the state transition function */
        Oid                     transfn_oid;
-       Oid                     finalfn_oid;    /* may be InvalidOid */
+
+       /* Oid of state value's datatype */
+       Oid                     aggtranstype;
+
+       /* ExprStates of the FILTER and argument expressions. */
+       ExprState  *aggfilter;          /* state of FILTER expression, if any */
+       List       *args;                       /* states of aggregated-argument expressions */
+       List       *aggdirectargs;      /* states of direct-argument expressions */
 
        /*
-        * fmgr lookup data for transfer functions --- only valid when
-        * corresponding oid is not InvalidOid.  Note in particular that fn_strict
-        * flags are kept here.
+        * fmgr lookup data for transition function.  Note in particular that the
+        * fn_strict flag is kept here.
         */
        FmgrInfo        transfn;
-       FmgrInfo        finalfn;
 
        /* Input collation derived for aggregate */
        Oid                     aggCollation;
@@ -236,17 +244,15 @@ typedef struct AggStatePerAggData
        bool            initValueIsNull;
 
        /*
-        * We need the len and byval info for the agg's input, result, and
-        * transition data types in order to know how to copy/delete values.
+        * We need the len and byval info for the agg's input and transition data
+        * types in order to know how to copy/delete values.
         *
         * Note that the info for the input type is used only when handling
         * DISTINCT aggs with just one argument, so there is only one input type.
         */
        int16           inputtypeLen,
-                               resulttypeLen,
                                transtypeLen;
        bool            inputtypeByVal,
-                               resulttypeByVal,
                                transtypeByVal;
 
        /*
@@ -288,6 +294,54 @@ typedef struct AggStatePerAggData
         * worth the extra space consumption.
         */
        FunctionCallInfoData transfn_fcinfo;
+}      AggStatePerTransData;
+
+/*
+ * AggStatePerAggData - per-aggregate information
+ *
+ * This contains the information needed to call the final function, to produce
+ * a final aggregate result from the state value. If there are multiple
+ * identical Aggrefs in the query, they can all share the same per-agg data.
+ *
+ * These values are set up during ExecInitAgg() and do not change thereafter.
+ */
+typedef struct AggStatePerAggData
+{
+       /*
+        * Link to an Aggref expr this state value is for.
+        *
+        * There can be multiple identical Aggref's sharing the same per-agg. This
+        * points to the first one of them.
+        */
+       Aggref     *aggref;
+
+       /* index to the state value which this agg should use */
+       int                     transno;
+
+       /* Optional Oid of final function (may be InvalidOid) */
+       Oid                     finalfn_oid;
+
+       /*
+        * fmgr lookup data for final function --- only valid when finalfn_oid oid
+        * is not InvalidOid.
+        */
+       FmgrInfo        finalfn;
+
+       /*
+        * Number of arguments to pass to the finalfn.  This is always at least 1
+        * (the transition state value) plus any ordered-set direct args. If the
+        * finalfn wants extra args then we pass nulls corresponding to the
+        * aggregated input columns.
+        */
+       int                     numFinalArgs;
+
+       /*
+        * We need the len and byval info for the agg's result data type in order
+        * to know how to copy/delete values.
+        */
+       int16           resulttypeLen;
+       bool            resulttypeByVal;
+
 }      AggStatePerAggData;
 
 /*
@@ -358,25 +412,23 @@ typedef struct AggHashEntryData
        AggStatePerGroupData pergroup[FLEXIBLE_ARRAY_MEMBER];
 }      AggHashEntryData;
 
-
 static void initialize_phase(AggState *aggstate, int newphase);
 static TupleTableSlot *fetch_input_tuple(AggState *aggstate);
 static void initialize_aggregates(AggState *aggstate,
-                                         AggStatePerAgg peragg,
                                          AggStatePerGroup pergroup,
                                          int numReset);
 static void advance_transition_function(AggState *aggstate,
-                                                       AggStatePerAgg peraggstate,
+                                                       AggStatePerTrans pertrans,
                                                        AggStatePerGroup pergroupstate);
 static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup);
 static void process_ordered_aggregate_single(AggState *aggstate,
-                                                                AggStatePerAgg peraggstate,
+                                                                AggStatePerTrans pertrans,
                                                                 AggStatePerGroup pergroupstate);
 static void process_ordered_aggregate_multi(AggState *aggstate,
-                                                               AggStatePerAgg peraggstate,
+                                                               AggStatePerTrans pertrans,
                                                                AggStatePerGroup pergroupstate);
 static void finalize_aggregate(AggState *aggstate,
-                                  AggStatePerAgg peraggstate,
+                                  AggStatePerAgg peragg,
                                   AggStatePerGroup pergroupstate,
                                   Datum *resultVal, bool *resultIsNull);
 static void prepare_projection_slot(AggState *aggstate,
@@ -396,6 +448,17 @@ static TupleTableSlot *agg_retrieve_direct(AggState *aggstate);
 static void agg_fill_hash_table(AggState *aggstate);
 static TupleTableSlot *agg_retrieve_hash_table(AggState *aggstate);
 static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
+static void build_pertrans_for_aggref(AggStatePerTrans pertrans,
+                                                 AggState *aggsate, EState *estate,
+                                                 Aggref *aggref, Oid aggtransfn, Oid aggtranstype,
+                                                 Datum initValue, bool initValueIsNull,
+                                                 Oid *inputTypes, int numArguments);
+static int find_compatible_peragg(Aggref *newagg, AggState *aggstate,
+                                          int lastaggno, List **same_input_transnos);
+static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
+                                                Oid aggtransfn, Oid aggtranstype,
+                                                Datum initValue, bool initValueIsNull,
+                                                List *possible_matches);
 
 
 /*
@@ -498,20 +561,20 @@ fetch_input_tuple(AggState *aggstate)
  * When called, CurrentMemoryContext should be the per-query context.
  */
 static void
-initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
+initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans,
                                         AggStatePerGroup pergroupstate)
 {
        /*
         * Start a fresh sort operation for each DISTINCT/ORDER BY aggregate.
         */
-       if (peraggstate->numSortCols > 0)
+       if (pertrans->numSortCols > 0)
        {
                /*
                 * In case of rescan, maybe there could be an uncompleted sort
                 * operation?  Clean it up if so.
                 */
-               if (peraggstate->sortstates[aggstate->current_set])
-                       tuplesort_end(peraggstate->sortstates[aggstate->current_set]);
+               if (pertrans->sortstates[aggstate->current_set])
+                       tuplesort_end(pertrans->sortstates[aggstate->current_set]);
 
 
                /*
@@ -519,21 +582,21 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
                 * otherwise sort the full tuple.  (See comments for
                 * process_ordered_aggregate_single.)
                 */
-               if (peraggstate->numInputs == 1)
-                       peraggstate->sortstates[aggstate->current_set] =
-                               tuplesort_begin_datum(peraggstate->evaldesc->attrs[0]->atttypid,
-                                                                         peraggstate->sortOperators[0],
-                                                                         peraggstate->sortCollations[0],
-                                                                         peraggstate->sortNullsFirst[0],
+               if (pertrans->numInputs == 1)
+                       pertrans->sortstates[aggstate->current_set] =
+                               tuplesort_begin_datum(pertrans->evaldesc->attrs[0]->atttypid,
+                                                                         pertrans->sortOperators[0],
+                                                                         pertrans->sortCollations[0],
+                                                                         pertrans->sortNullsFirst[0],
                                                                          work_mem, false);
                else
-                       peraggstate->sortstates[aggstate->current_set] =
-                               tuplesort_begin_heap(peraggstate->evaldesc,
-                                                                        peraggstate->numSortCols,
-                                                                        peraggstate->sortColIdx,
-                                                                        peraggstate->sortOperators,
-                                                                        peraggstate->sortCollations,
-                                                                        peraggstate->sortNullsFirst,
+                       pertrans->sortstates[aggstate->current_set] =
+                               tuplesort_begin_heap(pertrans->evaldesc,
+                                                                        pertrans->numSortCols,
+                                                                        pertrans->sortColIdx,
+                                                                        pertrans->sortOperators,
+                                                                        pertrans->sortCollations,
+                                                                        pertrans->sortNullsFirst,
                                                                         work_mem, false);
        }
 
@@ -543,20 +606,20 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
         * Note that when the initial value is pass-by-ref, we must copy it (into
         * the aggcontext) since we will pfree the transValue later.
         */
-       if (peraggstate->initValueIsNull)
-               pergroupstate->transValue = peraggstate->initValue;
+       if (pertrans->initValueIsNull)
+               pergroupstate->transValue = pertrans->initValue;
        else
        {
                MemoryContext oldContext;
 
                oldContext = MemoryContextSwitchTo(
                aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory);
-               pergroupstate->transValue = datumCopy(peraggstate->initValue,
-                                                                                         peraggstate->transtypeByVal,
-                                                                                         peraggstate->transtypeLen);
+               pergroupstate->transValue = datumCopy(pertrans->initValue,
+                                                                                         pertrans->transtypeByVal,
+                                                                                         pertrans->transtypeLen);
                MemoryContextSwitchTo(oldContext);
        }
-       pergroupstate->transValueIsNull = peraggstate->initValueIsNull;
+       pergroupstate->transValueIsNull = pertrans->initValueIsNull;
 
        /*
         * If the initial value for the transition state doesn't exist in the
@@ -565,11 +628,11 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
         * aggregates like max() and min().) The noTransValue flag signals that we
         * still need to do this.
         */
-       pergroupstate->noTransValue = peraggstate->initValueIsNull;
+       pergroupstate->noTransValue = pertrans->initValueIsNull;
 }
 
 /*
- * Initialize all aggregates for a new group of input values.
+ * Initialize all aggregate transition states for a new group of input values.
  *
  * If there are multiple grouping sets, we initialize only the first numReset
  * of them (the grouping sets are ordered so that the most specific one, which
@@ -580,61 +643,61 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
  */
 static void
 initialize_aggregates(AggState *aggstate,
-                                         AggStatePerAgg peragg,
                                          AggStatePerGroup pergroup,
                                          int numReset)
 {
-       int                     aggno;
+       int                     transno;
        int                     numGroupingSets = Max(aggstate->phase->numsets, 1);
        int                     setno = 0;
+       AggStatePerTrans transstates = aggstate->pertrans;
 
        if (numReset < 1)
                numReset = numGroupingSets;
 
-       for (aggno = 0; aggno < aggstate->numaggs; aggno++)
+       for (transno = 0; transno < aggstate->numtrans; transno++)
        {
-               AggStatePerAgg peraggstate = &peragg[aggno];
+               AggStatePerTrans pertrans = &transstates[transno];
 
                for (setno = 0; setno < numReset; setno++)
                {
                        AggStatePerGroup pergroupstate;
 
-                       pergroupstate = &pergroup[aggno + (setno * (aggstate->numaggs))];
+                       pergroupstate = &pergroup[transno + (setno * (aggstate->numtrans))];
 
                        aggstate->current_set = setno;
 
-                       initialize_aggregate(aggstate, peraggstate, pergroupstate);
+                       initialize_aggregate(aggstate, pertrans, pergroupstate);
                }
        }
 }
 
 /*
  * Given new input value(s), advance the transition function of one aggregate
- * within one grouping set only (already set in aggstate->current_set)
+ * state within one grouping set only (already set in aggstate->current_set)
  *
  * The new values (and null flags) have been preloaded into argument positions
- * 1 and up in peraggstate->transfn_fcinfo, so that we needn't copy them again
- * to pass to the transition function.  We also expect that the static fields
- * of the fcinfo are already initialized; that was done by ExecInitAgg().
+ * 1 and up in pertrans->transfn_fcinfo, so that we needn't copy them again to
+ * pass to the transition function.  We also expect that the static fields of
+ * the fcinfo are already initialized; that was done by ExecInitAgg().
  *
  * It doesn't matter which memory context this is called in.
  */
 static void
 advance_transition_function(AggState *aggstate,
-                                                       AggStatePerAgg peraggstate,
+                                                       AggStatePerTrans pertrans,
                                                        AggStatePerGroup pergroupstate)
 {
-       FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo;
+       FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
        MemoryContext oldContext;
        Datum           newVal;
 
-       if (peraggstate->transfn.fn_strict)
+       if (pertrans->transfn.fn_strict)
        {
                /*
                 * For a strict transfn, nothing happens when there's a NULL input; we
                 * just keep the prior transValue.
                 */
-               int                     numTransInputs = peraggstate->numTransInputs;
+               int                     numTransInputs = pertrans->numTransInputs;
                int                     i;
 
                for (i = 1; i <= numTransInputs; i++)
@@ -656,8 +719,8 @@ advance_transition_function(AggState *aggstate,
                        oldContext = MemoryContextSwitchTo(
                                                                                           aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory);
                        pergroupstate->transValue = datumCopy(fcinfo->arg[1],
-                                                                                                 peraggstate->transtypeByVal,
-                                                                                                 peraggstate->transtypeLen);
+                                                                                                 pertrans->transtypeByVal,
+                                                                                                 pertrans->transtypeLen);
                        pergroupstate->transValueIsNull = false;
                        pergroupstate->noTransValue = false;
                        MemoryContextSwitchTo(oldContext);
@@ -678,8 +741,8 @@ advance_transition_function(AggState *aggstate,
        /* We run the transition functions in per-input-tuple memory context */
        oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory);
 
-       /* set up aggstate->curperagg for AggGetAggref() */
-       aggstate->curperagg = peraggstate;
+       /* set up aggstate->curpertrans for AggGetAggref() */
+       aggstate->curpertrans = pertrans;
 
        /*
         * OK to call the transition function
@@ -690,22 +753,22 @@ advance_transition_function(AggState *aggstate,
 
        newVal = FunctionCallInvoke(fcinfo);
 
-       aggstate->curperagg = NULL;
+       aggstate->curpertrans = NULL;
 
        /*
         * If pass-by-ref datatype, must copy the new value into aggcontext and
         * pfree the prior transValue.  But if transfn returned a pointer to its
         * first input, we don't need to do anything.
         */
-       if (!peraggstate->transtypeByVal &&
+       if (!pertrans->transtypeByVal &&
                DatumGetPointer(newVal) != DatumGetPointer(pergroupstate->transValue))
        {
                if (!fcinfo->isnull)
                {
                        MemoryContextSwitchTo(aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory);
                        newVal = datumCopy(newVal,
-                                                          peraggstate->transtypeByVal,
-                                                          peraggstate->transtypeLen);
+                                                          pertrans->transtypeByVal,
+                                                          pertrans->transtypeLen);
                }
                if (!pergroupstate->transValueIsNull)
                        pfree(DatumGetPointer(pergroupstate->transValue));
@@ -718,26 +781,26 @@ advance_transition_function(AggState *aggstate,
 }
 
 /*
- * Advance all the aggregates for one input tuple.  The input tuple
- * has been stored in tmpcontext->ecxt_outertuple, so that it is accessible
- * to ExecEvalExpr.  pergroup is the array of per-group structs to use
- * (this might be in a hashtable entry).
+ * Advance each aggregate transition state for one input tuple.  The input
+ * tuple has been stored in tmpcontext->ecxt_outertuple, so that it is
+ * accessible to ExecEvalExpr.  pergroup is the array of per-group structs to
+ * use (this might be in a hashtable entry).
  *
  * When called, CurrentMemoryContext should be the per-query context.
  */
 static void
 advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 {
-       int                     aggno;
+       int                     transno;
        int                     setno = 0;
        int                     numGroupingSets = Max(aggstate->phase->numsets, 1);
-       int                     numAggs = aggstate->numaggs;
+       int                     numTrans = aggstate->numtrans;
 
-       for (aggno = 0; aggno < numAggs; aggno++)
+       for (transno = 0; transno < numTrans; transno++)
        {
-               AggStatePerAgg peraggstate = &aggstate->peragg[aggno];
-               ExprState  *filter = peraggstate->aggrefstate->aggfilter;
-               int                     numTransInputs = peraggstate->numTransInputs;
+               AggStatePerTrans pertrans = &aggstate->pertrans[transno];
+               ExprState  *filter = pertrans->aggfilter;
+               int                     numTransInputs = pertrans->numTransInputs;
                int                     i;
                TupleTableSlot *slot;
 
@@ -754,12 +817,12 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
                }
 
                /* Evaluate the current input expressions for this aggregate */
-               slot = ExecProject(peraggstate->evalproj, NULL);
+               slot = ExecProject(pertrans->evalproj, NULL);
 
-               if (peraggstate->numSortCols > 0)
+               if (pertrans->numSortCols > 0)
                {
                        /* DISTINCT and/or ORDER BY case */
-                       Assert(slot->tts_nvalid == peraggstate->numInputs);
+                       Assert(slot->tts_nvalid == pertrans->numInputs);
 
                        /*
                         * If the transfn is strict, we want to check for nullity before
@@ -768,7 +831,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
                         * not numInputs, since nullity in columns used only for sorting
                         * is not relevant here.
                         */
-                       if (peraggstate->transfn.fn_strict)
+                       if (pertrans->transfn.fn_strict)
                        {
                                for (i = 0; i < numTransInputs; i++)
                                {
@@ -782,18 +845,18 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
                        for (setno = 0; setno < numGroupingSets; setno++)
                        {
                                /* OK, put the tuple into the tuplesort object */
-                               if (peraggstate->numInputs == 1)
-                                       tuplesort_putdatum(peraggstate->sortstates[setno],
+                               if (pertrans->numInputs == 1)
+                                       tuplesort_putdatum(pertrans->sortstates[setno],
                                                                           slot->tts_values[0],
                                                                           slot->tts_isnull[0]);
                                else
-                                       tuplesort_puttupleslot(peraggstate->sortstates[setno], slot);
+                                       tuplesort_puttupleslot(pertrans->sortstates[setno], slot);
                        }
                }
                else
                {
                        /* We can apply the transition function immediately */
-                       FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo;
+                       FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
 
                        /* Load values into fcinfo */
                        /* Start from 1, since the 0th arg will be the transition value */
@@ -806,11 +869,11 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
 
                        for (setno = 0; setno < numGroupingSets; setno++)
                        {
-                               AggStatePerGroup pergroupstate = &pergroup[aggno + (setno * numAggs)];
+                               AggStatePerGroup pergroupstate = &pergroup[transno + (setno * numTrans)];
 
                                aggstate->current_set = setno;
 
-                               advance_transition_function(aggstate, peraggstate, pergroupstate);
+                               advance_transition_function(aggstate, pertrans, pergroupstate);
                        }
                }
        }
@@ -841,7 +904,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
  */
 static void
 process_ordered_aggregate_single(AggState *aggstate,
-                                                                AggStatePerAgg peraggstate,
+                                                                AggStatePerTrans pertrans,
                                                                 AggStatePerGroup pergroupstate)
 {
        Datum           oldVal = (Datum) 0;
@@ -849,14 +912,14 @@ process_ordered_aggregate_single(AggState *aggstate,
        bool            haveOldVal = false;
        MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory;
        MemoryContext oldContext;
-       bool            isDistinct = (peraggstate->numDistinctCols > 0);
-       FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo;
+       bool            isDistinct = (pertrans->numDistinctCols > 0);
+       FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
        Datum      *newVal;
        bool       *isNull;
 
-       Assert(peraggstate->numDistinctCols < 2);
+       Assert(pertrans->numDistinctCols < 2);
 
-       tuplesort_performsort(peraggstate->sortstates[aggstate->current_set]);
+       tuplesort_performsort(pertrans->sortstates[aggstate->current_set]);
 
        /* Load the column into argument 1 (arg 0 will be transition value) */
        newVal = fcinfo->arg + 1;
@@ -868,7 +931,7 @@ process_ordered_aggregate_single(AggState *aggstate,
         * pfree them when they are no longer needed.
         */
 
-       while (tuplesort_getdatum(peraggstate->sortstates[aggstate->current_set],
+       while (tuplesort_getdatum(pertrans->sortstates[aggstate->current_set],
                                                          true, newVal, isNull))
        {
                /*
@@ -887,18 +950,18 @@ process_ordered_aggregate_single(AggState *aggstate,
                        haveOldVal &&
                        ((oldIsNull && *isNull) ||
                         (!oldIsNull && !*isNull &&
-                         DatumGetBool(FunctionCall2(&peraggstate->equalfns[0],
+                         DatumGetBool(FunctionCall2(&pertrans->equalfns[0],
                                                                                 oldVal, *newVal)))))
                {
                        /* equal to prior, so forget this one */
-                       if (!peraggstate->inputtypeByVal && !*isNull)
+                       if (!pertrans->inputtypeByVal && !*isNull)
                                pfree(DatumGetPointer(*newVal));
                }
                else
                {
-                       advance_transition_function(aggstate, peraggstate, pergroupstate);
+                       advance_transition_function(aggstate, pertrans, pergroupstate);
                        /* forget the old value, if any */
-                       if (!oldIsNull && !peraggstate->inputtypeByVal)
+                       if (!oldIsNull && !pertrans->inputtypeByVal)
                                pfree(DatumGetPointer(oldVal));
                        /* and remember the new one for subsequent equality checks */
                        oldVal = *newVal;
@@ -909,11 +972,11 @@ process_ordered_aggregate_single(AggState *aggstate,
                MemoryContextSwitchTo(oldContext);
        }
 
-       if (!oldIsNull && !peraggstate->inputtypeByVal)
+       if (!oldIsNull && !pertrans->inputtypeByVal)
                pfree(DatumGetPointer(oldVal));
 
-       tuplesort_end(peraggstate->sortstates[aggstate->current_set]);
-       peraggstate->sortstates[aggstate->current_set] = NULL;
+       tuplesort_end(pertrans->sortstates[aggstate->current_set]);
+       pertrans->sortstates[aggstate->current_set] = NULL;
 }
 
 /*
@@ -930,25 +993,25 @@ process_ordered_aggregate_single(AggState *aggstate,
  */
 static void
 process_ordered_aggregate_multi(AggState *aggstate,
-                                                               AggStatePerAgg peraggstate,
+                                                               AggStatePerTrans pertrans,
                                                                AggStatePerGroup pergroupstate)
 {
        MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory;
-       FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo;
-       TupleTableSlot *slot1 = peraggstate->evalslot;
-       TupleTableSlot *slot2 = peraggstate->uniqslot;
-       int                     numTransInputs = peraggstate->numTransInputs;
-       int                     numDistinctCols = peraggstate->numDistinctCols;
+       FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
+       TupleTableSlot *slot1 = pertrans->evalslot;
+       TupleTableSlot *slot2 = pertrans->uniqslot;
+       int                     numTransInputs = pertrans->numTransInputs;
+       int                     numDistinctCols = pertrans->numDistinctCols;
        bool            haveOldValue = false;
        int                     i;
 
-       tuplesort_performsort(peraggstate->sortstates[aggstate->current_set]);
+       tuplesort_performsort(pertrans->sortstates[aggstate->current_set]);
 
        ExecClearTuple(slot1);
        if (slot2)
                ExecClearTuple(slot2);
 
-       while (tuplesort_gettupleslot(peraggstate->sortstates[aggstate->current_set],
+       while (tuplesort_gettupleslot(pertrans->sortstates[aggstate->current_set],
                                                                  true, slot1))
        {
                /*
@@ -962,8 +1025,8 @@ process_ordered_aggregate_multi(AggState *aggstate,
                        !haveOldValue ||
                        !execTuplesMatch(slot1, slot2,
                                                         numDistinctCols,
-                                                        peraggstate->sortColIdx,
-                                                        peraggstate->equalfns,
+                                                        pertrans->sortColIdx,
+                                                        pertrans->equalfns,
                                                         workcontext))
                {
                        /* Load values into fcinfo */
@@ -974,7 +1037,7 @@ process_ordered_aggregate_multi(AggState *aggstate,
                                fcinfo->argnull[i + 1] = slot1->tts_isnull[i];
                        }
 
-                       advance_transition_function(aggstate, peraggstate, pergroupstate);
+                       advance_transition_function(aggstate, pertrans, pergroupstate);
 
                        if (numDistinctCols > 0)
                        {
@@ -997,8 +1060,8 @@ process_ordered_aggregate_multi(AggState *aggstate,
        if (slot2)
                ExecClearTuple(slot2);
 
-       tuplesort_end(peraggstate->sortstates[aggstate->current_set]);
-       peraggstate->sortstates[aggstate->current_set] = NULL;
+       tuplesort_end(pertrans->sortstates[aggstate->current_set]);
+       pertrans->sortstates[aggstate->current_set] = NULL;
 }
 
 /*
@@ -1009,10 +1072,14 @@ process_ordered_aggregate_multi(AggState *aggstate,
  *
  * The finalfunction will be run, and the result delivered, in the
  * output-tuple context; caller's CurrentMemoryContext does not matter.
+ *
+ * The finalfn uses the state as set in the transno. This also might be
+ * being used by another aggregate function, so it's important that we do
+ * nothing destructive here.
  */
 static void
 finalize_aggregate(AggState *aggstate,
-                                  AggStatePerAgg peraggstate,
+                                  AggStatePerAgg peragg,
                                   AggStatePerGroup pergroupstate,
                                   Datum *resultVal, bool *resultIsNull)
 {
@@ -1021,6 +1088,7 @@ finalize_aggregate(AggState *aggstate,
        MemoryContext oldContext;
        int                     i;
        ListCell   *lc;
+       AggStatePerTrans pertrans = &aggstate->pertrans[peragg->transno];
 
        oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory);
 
@@ -1031,7 +1099,7 @@ finalize_aggregate(AggState *aggstate,
         * for the transition state value.
         */
        i = 1;
-       foreach(lc, peraggstate->aggrefstate->aggdirectargs)
+       foreach(lc, pertrans->aggdirectargs)
        {
                ExprState  *expr = (ExprState *) lfirst(lc);
 
@@ -1046,16 +1114,16 @@ finalize_aggregate(AggState *aggstate,
        /*
         * Apply the agg's finalfn if one is provided, else return transValue.
         */
-       if (OidIsValid(peraggstate->finalfn_oid))
+       if (OidIsValid(peragg->finalfn_oid))
        {
-               int                     numFinalArgs = peraggstate->numFinalArgs;
+               int                     numFinalArgs = peragg->numFinalArgs;
 
-               /* set up aggstate->curperagg for AggGetAggref() */
-               aggstate->curperagg = peraggstate;
+               /* set up aggstate->curpertrans for AggGetAggref() */
+               aggstate->curpertrans = pertrans;
 
-               InitFunctionCallInfoData(fcinfo, &peraggstate->finalfn,
+               InitFunctionCallInfoData(fcinfo, &peragg->finalfn,
                                                                 numFinalArgs,
-                                                                peraggstate->aggCollation,
+                                                                pertrans->aggCollation,
                                                                 (void *) aggstate, NULL);
 
                /* Fill in the transition state value */
@@ -1082,7 +1150,7 @@ finalize_aggregate(AggState *aggstate,
                        *resultVal = FunctionCallInvoke(&fcinfo);
                        *resultIsNull = fcinfo.isnull;
                }
-               aggstate->curperagg = NULL;
+               aggstate->curpertrans = NULL;
        }
        else
        {
@@ -1093,12 +1161,12 @@ finalize_aggregate(AggState *aggstate,
        /*
         * If result is pass-by-ref, make sure it is in the right context.
         */
-       if (!peraggstate->resulttypeByVal && !*resultIsNull &&
+       if (!peragg->resulttypeByVal && !*resultIsNull &&
                !MemoryContextContains(CurrentMemoryContext,
                                                           DatumGetPointer(*resultVal)))
                *resultVal = datumCopy(*resultVal,
-                                                          peraggstate->resulttypeByVal,
-                                                          peraggstate->resulttypeLen);
+                                                          peragg->resulttypeByVal,
+                                                          peragg->resulttypeLen);
 
        MemoryContextSwitchTo(oldContext);
 }
@@ -1173,7 +1241,7 @@ prepare_projection_slot(AggState *aggstate, TupleTableSlot *slot, int currentSet
  */
 static void
 finalize_aggregates(AggState *aggstate,
-                                       AggStatePerAgg peragg,
+                                       AggStatePerAgg peraggs,
                                        AggStatePerGroup pergroup,
                                        int currentSet)
 {
@@ -1189,26 +1257,28 @@ finalize_aggregates(AggState *aggstate,
 
        for (aggno = 0; aggno < aggstate->numaggs; aggno++)
        {
-               AggStatePerAgg peraggstate = &peragg[aggno];
+               AggStatePerAgg peragg = &peraggs[aggno];
+               int                     transno = peragg->transno;
+               AggStatePerTrans pertrans = &aggstate->pertrans[transno];
                AggStatePerGroup pergroupstate;
 
-               pergroupstate = &pergroup[aggno + (currentSet * (aggstate->numaggs))];
+               pergroupstate = &pergroup[transno + (currentSet * (aggstate->numtrans))];
 
-               if (peraggstate->numSortCols > 0)
+               if (pertrans->numSortCols > 0)
                {
                        Assert(((Agg *) aggstate->ss.ps.plan)->aggstrategy != AGG_HASHED);
 
-                       if (peraggstate->numInputs == 1)
+                       if (pertrans->numInputs == 1)
                                process_ordered_aggregate_single(aggstate,
-                                                                                                peraggstate,
+                                                                                                pertrans,
                                                                                                 pergroupstate);
                        else
                                process_ordered_aggregate_multi(aggstate,
-                                                                                               peraggstate,
+                                                                                               pertrans,
                                                                                                pergroupstate);
                }
 
-               finalize_aggregate(aggstate, peraggstate, pergroupstate,
+               finalize_aggregate(aggstate, peragg, pergroupstate,
                                                   &aggvalues[aggno], &aggnulls[aggno]);
        }
 }
@@ -1428,7 +1498,7 @@ lookup_hash_entry(AggState *aggstate, TupleTableSlot *inputslot)
        if (isnew)
        {
                /* initialize aggregates for new tuple group */
-               initialize_aggregates(aggstate, aggstate->peragg, entry->pergroup, 0);
+               initialize_aggregates(aggstate, entry->pergroup, 0);
        }
 
        return entry;
@@ -1716,7 +1786,7 @@ agg_retrieve_direct(AggState *aggstate)
                        /*
                         * Initialize working state for a new input tuple group.
                         */
-                       initialize_aggregates(aggstate, peragg, pergroup, numReset);
+                       initialize_aggregates(aggstate, pergroup, numReset);
 
                        if (aggstate->grp_firstTuple != NULL)
                        {
@@ -1945,17 +2015,18 @@ AggState *
 ExecInitAgg(Agg *node, EState *estate, int eflags)
 {
        AggState   *aggstate;
-       AggStatePerAgg peragg;
+       AggStatePerAgg peraggs;
+       AggStatePerTrans pertransstates;
        Plan       *outerPlan;
        ExprContext *econtext;
        int                     numaggs,
+                               transno,
                                aggno;
        int                     phase;
        ListCell   *l;
        Bitmapset  *all_grouped_cols = NULL;
        int                     numGroupingSets = 1;
        int                     numPhases;
-       int                     currentsortno = 0;
        int                     i = 0;
        int                     j = 0;
 
@@ -1971,12 +2042,14 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 
        aggstate->aggs = NIL;
        aggstate->numaggs = 0;
+       aggstate->numtrans = 0;
        aggstate->maxsets = 0;
        aggstate->hashfunctions = NULL;
        aggstate->projected_set = -1;
        aggstate->current_set = 0;
        aggstate->peragg = NULL;
-       aggstate->curperagg = NULL;
+       aggstate->pertrans = NULL;
+       aggstate->curpertrans = NULL;
        aggstate->agg_done = false;
        aggstate->input_done = false;
        aggstate->pergroup = NULL;
@@ -2209,8 +2282,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
        econtext->ecxt_aggvalues = (Datum *) palloc0(sizeof(Datum) * numaggs);
        econtext->ecxt_aggnulls = (bool *) palloc0(sizeof(bool) * numaggs);
 
-       peragg = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs);
-       aggstate->peragg = peragg;
+       peraggs = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs);
+       pertransstates = (AggStatePerTrans) palloc0(sizeof(AggStatePerTransData) * numaggs);
+
+       aggstate->peragg = peraggs;
+       aggstate->pertrans = pertransstates;
 
        if (node->aggstrategy == AGG_HASHED)
        {
@@ -2230,71 +2306,86 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
                aggstate->pergroup = pergroup;
        }
 
-       /*
+       /* -----------------
         * Perform lookups of aggregate function info, and initialize the
-        * unchanging fields of the per-agg data.  We also detect duplicate
-        * aggregates (for example, "SELECT sum(x) ... HAVING sum(x) > 0"). When
-        * duplicates are detected, we only make an AggStatePerAgg struct for the
-        * first one.  The clones are simply pointed at the same result entry by
-        * giving them duplicate aggno values.
+        * unchanging fields of the per-agg and per-trans data.
+        *
+        * We try to optimize by detecting duplicate aggregate functions so that
+        * their state and final values are re-used, rather than needlessly being
+        * re-calculated independently. We also detect aggregates that are not
+        * the same, but which can share the same transition state.
+        *
+        * Scenarios:
+        *
+        * 1. An aggregate function appears more than once in query:
+        *
+        *    SELECT SUM(x) FROM ... HAVING SUM(x) > 0
+        *
+        *    Since the aggregates are the identical, we only need to calculate
+        *    the calculate it once. Both aggregates will share the same 'aggno'
+        *    value.
+        *
+        * 2. Two different aggregate functions appear in the query, but the
+        *    aggregates have the same transition function and initial value, but
+        *    different final function:
+        *
+        *    SELECT SUM(x), AVG(x) FROM ...
+        *
+        *    In this case we must create a new peragg for the varying aggregate,
+        *    and need to call the final functions separately, but can share the
+        *    same transition state.
+        *
+        * For either of these optimizations to be valid, the aggregate's
+        * arguments must be the same, including any modifiers such as ORDER BY,
+        * DISTINCT and FILTER, and they mustn't contain any volatile functions.
+        * -----------------
         */
        aggno = -1;
+       transno = -1;
        foreach(l, aggstate->aggs)
        {
                AggrefExprState *aggrefstate = (AggrefExprState *) lfirst(l);
                Aggref     *aggref = (Aggref *) aggrefstate->xprstate.expr;
-               AggStatePerAgg peraggstate;
+               AggStatePerAgg peragg;
+               AggStatePerTrans pertrans;
+               int                     existing_aggno;
+               int                     existing_transno;
+               List       *same_input_transnos;
                Oid                     inputTypes[FUNC_MAX_ARGS];
                int                     numArguments;
                int                     numDirectArgs;
-               int                     numInputs;
-               int                     numSortCols;
-               int                     numDistinctCols;
-               List       *sortlist;
                HeapTuple       aggTuple;
                Form_pg_aggregate aggform;
-               Oid                     aggtranstype;
                AclResult       aclresult;
                Oid                     transfn_oid,
                                        finalfn_oid;
-               Expr       *transfnexpr,
-                                  *finalfnexpr;
+               Expr       *finalfnexpr;
+               Oid                     aggtranstype;
                Datum           textInitVal;
-               int                     i;
-               ListCell   *lc;
+               Datum           initValue;
+               bool            initValueIsNull;
 
                /* Planner should have assigned aggregate to correct level */
                Assert(aggref->agglevelsup == 0);
 
-               /* Look for a previous duplicate aggregate */
-               for (i = 0; i <= aggno; i++)
+               /* 1. Check for already processed aggs which can be re-used */
+               existing_aggno = find_compatible_peragg(aggref, aggstate, aggno,
+                                                                                               &same_input_transnos);
+               if (existing_aggno != -1)
                {
-                       if (equal(aggref, peragg[i].aggref) &&
-                               !contain_volatile_functions((Node *) aggref))
-                               break;
-               }
-               if (i <= aggno)
-               {
-                       /* Found a match to an existing entry, so just mark it */
-                       aggrefstate->aggno = i;
+                       /*
+                        * Existing compatible agg found. so just point the Aggref to the
+                        * same per-agg struct.
+                        */
+                       aggrefstate->aggno = existing_aggno;
                        continue;
                }
 
-               /* Nope, so assign a new PerAgg record */
-               peraggstate = &peragg[++aggno];
-
                /* Mark Aggref state node with assigned index in the result array */
+               peragg = &peraggs[++aggno];
+               peragg->aggref = aggref;
                aggrefstate->aggno = aggno;
 
-               /* Begin filling in the peraggstate data */
-               peraggstate->aggrefstate = aggrefstate;
-               peraggstate->aggref = aggref;
-               peraggstate->sortstates = (Tuplesortstate **)
-                       palloc0(sizeof(Tuplesortstate *) * numGroupingSets);
-
-               for (currentsortno = 0; currentsortno < numGroupingSets; currentsortno++)
-                       peraggstate->sortstates[currentsortno] = NULL;
-
                /* Fetch the pg_aggregate row */
                aggTuple = SearchSysCache1(AGGFNOID,
                                                                   ObjectIdGetDatum(aggref->aggfnoid));
@@ -2311,8 +2402,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
                                                   get_func_name(aggref->aggfnoid));
                InvokeFunctionExecuteHook(aggref->aggfnoid);
 
-               peraggstate->transfn_oid = transfn_oid = aggform->aggtransfn;
-               peraggstate->finalfn_oid = finalfn_oid = aggform->aggfinalfn;
+               transfn_oid = aggform->aggtransfn;
+               peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn;
 
                /* Check that aggregate owner has permission to call component fns */
                {
@@ -2350,74 +2441,43 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
                 * agg accepts ANY or a polymorphic type.
                 */
                numArguments = get_aggregate_argtypes(aggref, inputTypes);
-               peraggstate->numArguments = numArguments;
 
                /* Count the "direct" arguments, if any */
                numDirectArgs = list_length(aggref->aggdirectargs);
 
-               /* Count the number of aggregated input columns */
-               numInputs = list_length(aggref->args);
-               peraggstate->numInputs = numInputs;
-
-               /* Detect how many arguments to pass to the transfn */
-               if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
-                       peraggstate->numTransInputs = numInputs;
-               else
-                       peraggstate->numTransInputs = numArguments;
-
-               /* Detect how many arguments to pass to the finalfn */
-               if (aggform->aggfinalextra)
-                       peraggstate->numFinalArgs = numArguments + 1;
-               else
-                       peraggstate->numFinalArgs = numDirectArgs + 1;
-
                /* resolve actual type of transition state, if polymorphic */
                aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid,
                                                                                                   aggform->aggtranstype,
                                                                                                   inputTypes,
                                                                                                   numArguments);
 
-               /* build expression trees using actual argument & result types */
-               build_aggregate_fnexprs(inputTypes,
-                                                               numArguments,
-                                                               numDirectArgs,
-                                                               peraggstate->numFinalArgs,
-                                                               aggref->aggvariadic,
-                                                               aggtranstype,
-                                                               aggref->aggtype,
-                                                               aggref->inputcollid,
-                                                               transfn_oid,
-                                                               InvalidOid,             /* invtrans is not needed here */
-                                                               finalfn_oid,
-                                                               &transfnexpr,
-                                                               NULL,
-                                                               &finalfnexpr);
-
-               /* set up infrastructure for calling the transfn and finalfn */
-               fmgr_info(transfn_oid, &peraggstate->transfn);
-               fmgr_info_set_expr((Node *) transfnexpr, &peraggstate->transfn);
+               /* Detect how many arguments to pass to the finalfn */
+               if (aggform->aggfinalextra)
+                       peragg->numFinalArgs = numArguments + 1;
+               else
+                       peragg->numFinalArgs = numDirectArgs + 1;
 
+               /*
+                * build expression trees using actual argument & result types for the
+                * finalfn, if it exists
+                */
                if (OidIsValid(finalfn_oid))
                {
-                       fmgr_info(finalfn_oid, &peraggstate->finalfn);
-                       fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn);
+                       build_aggregate_finalfn_expr(inputTypes,
+                                                                                peragg->numFinalArgs,
+                                                                                aggtranstype,
+                                                                                aggref->aggtype,
+                                                                                aggref->inputcollid,
+                                                                                finalfn_oid,
+                                                                                &finalfnexpr);
+                       fmgr_info(finalfn_oid, &peragg->finalfn);
+                       fmgr_info_set_expr((Node *) finalfnexpr, &peragg->finalfn);
                }
 
-               peraggstate->aggCollation = aggref->inputcollid;
-
-               InitFunctionCallInfoData(peraggstate->transfn_fcinfo,
-                                                                &peraggstate->transfn,
-                                                                peraggstate->numTransInputs + 1,
-                                                                peraggstate->aggCollation,
-                                                                (void *) aggstate, NULL);
-
-               /* get info about relevant datatypes */
+               /* get info about the result type's datatype */
                get_typlenbyval(aggref->aggtype,
-                                               &peraggstate->resulttypeLen,
-                                               &peraggstate->resulttypeByVal);
-               get_typlenbyval(aggtranstype,
-                                               &peraggstate->transtypeLen,
-                                               &peraggstate->transtypeByVal);
+                                               &peragg->resulttypeLen,
+                                               &peragg->resulttypeByVal);
 
                /*
                 * initval is potentially null, so don't try to access it as a struct
@@ -2425,161 +2485,292 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
                 */
                textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple,
                                                                          Anum_pg_aggregate_agginitval,
-                                                                         &peraggstate->initValueIsNull);
-
-               if (peraggstate->initValueIsNull)
-                       peraggstate->initValue = (Datum) 0;
+                                                                         &initValueIsNull);
+               if (initValueIsNull)
+                       initValue = (Datum) 0;
                else
-                       peraggstate->initValue = GetAggInitVal(textInitVal,
-                                                                                                  aggtranstype);
+                       initValue = GetAggInitVal(textInitVal, aggtranstype);
 
                /*
-                * If the transfn is strict and the initval is NULL, make sure input
-                * type and transtype are the same (or at least binary-compatible), so
-                * that it's OK to use the first aggregated input value as the initial
-                * transValue.  This should have been checked at agg definition time,
-                * but we must check again in case the transfn's strictness property
-                * has been changed.
+                * 2. Build working state for invoking the transition function, or
+                * look up previously initialized working state, if we can share it.
+                *
+                * find_compatible_peragg() already collected a list of per-Trans's
+                * with the same inputs. Check if any of them have the same transition
+                * function and initial value.
                 */
-               if (peraggstate->transfn.fn_strict && peraggstate->initValueIsNull)
+               existing_transno = find_compatible_pertrans(aggstate, aggref,
+                                                                                                       transfn_oid, aggtranstype,
+                                                                                                 initValue, initValueIsNull,
+                                                                                                       same_input_transnos);
+               if (existing_transno != -1)
+               {
+                       /*
+                        * Existing compatible trans found, so just point the 'peragg' to
+                        * the same per-trans struct.
+                        */
+                       pertrans = &pertransstates[existing_transno];
+                       peragg->transno = existing_transno;
+               }
+               else
                {
-                       if (numArguments <= numDirectArgs ||
-                               !IsBinaryCoercible(inputTypes[numDirectArgs], aggtranstype))
-                               ereport(ERROR,
-                                               (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
-                                                errmsg("aggregate %u needs to have compatible input type and transition type",
-                                                               aggref->aggfnoid)));
+                       pertrans = &pertransstates[++transno];
+                       build_pertrans_for_aggref(pertrans, aggstate, estate,
+                                                                         aggref, transfn_oid, aggtranstype,
+                                                                         initValue, initValueIsNull,
+                                                                         inputTypes, numArguments);
+                       peragg->transno = transno;
                }
+               ReleaseSysCache(aggTuple);
+       }
 
-               /*
-                * Get a tupledesc corresponding to the aggregated inputs (including
-                * sort expressions) of the agg.
-                */
-               peraggstate->evaldesc = ExecTypeFromTL(aggref->args, false);
+       /*
+        * Update numaggs to match the number of unique aggregates found. Also set
+        * numstates to the number of unique aggregate states found.
+        */
+       aggstate->numaggs = aggno + 1;
+       aggstate->numtrans = transno + 1;
+
+       return aggstate;
+}
+
+/*
+ * Build the state needed to calculate a state value for an aggregate.
+ *
+ * This initializes all the fields in 'pertrans'. 'aggref' is the aggregate
+ * to initialize the state for. 'aggtransfn', 'aggtranstype', and the rest
+ * of the arguments could be calculated from 'aggref', but the caller has
+ * calculated them already, so might as well pass them.
+ */
+static void
+build_pertrans_for_aggref(AggStatePerTrans pertrans,
+                                                 AggState *aggstate, EState *estate,
+                                                 Aggref *aggref,
+                                                 Oid aggtransfn, Oid aggtranstype,
+                                                 Datum initValue, bool initValueIsNull,
+                                                 Oid *inputTypes, int numArguments)
+{
+       int                     numGroupingSets = Max(aggstate->maxsets, 1);
+       Expr       *transfnexpr;
+       ListCell   *lc;
+       int                     numInputs;
+       int                     numDirectArgs;
+       List       *sortlist;
+       int                     numSortCols;
+       int                     numDistinctCols;
+       int                     naggs;
+       int                     i;
+
+       /* Begin filling in the pertrans data */
+       pertrans->aggref = aggref;
+       pertrans->aggCollation = aggref->inputcollid;
+       pertrans->transfn_oid = aggtransfn;
+       pertrans->initValue = initValue;
+       pertrans->initValueIsNull = initValueIsNull;
+
+       /* Count the "direct" arguments, if any */
+       numDirectArgs = list_length(aggref->aggdirectargs);
+
+       /* Count the number of aggregated input columns */
+       pertrans->numInputs = numInputs = list_length(aggref->args);
+
+       pertrans->aggtranstype = aggtranstype;
+
+       /* Detect how many arguments to pass to the transfn */
+       if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
+               pertrans->numTransInputs = numInputs;
+       else
+               pertrans->numTransInputs = numArguments;
 
-               /* Create slot we're going to do argument evaluation in */
-               peraggstate->evalslot = ExecInitExtraTupleSlot(estate);
-               ExecSetSlotDescriptor(peraggstate->evalslot, peraggstate->evaldesc);
+       /*
+        * Set up infrastructure for calling the transfn
+        */
+       build_aggregate_transfn_expr(inputTypes,
+                                                                numArguments,
+                                                                numDirectArgs,
+                                                                aggref->aggvariadic,
+                                                                aggtranstype,
+                                                                aggref->inputcollid,
+                                                                aggtransfn,
+                                                                InvalidOid,    /* invtrans is not needed here */
+                                                                &transfnexpr,
+                                                                NULL);
+       fmgr_info(aggtransfn, &pertrans->transfn);
+       fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
+
+       InitFunctionCallInfoData(pertrans->transfn_fcinfo,
+                                                        &pertrans->transfn,
+                                                        pertrans->numTransInputs + 1,
+                                                        pertrans->aggCollation,
+                                                        (void *) aggstate, NULL);
+
+       /*
+        * If the transfn is strict and the initval is NULL, make sure input type
+        * and transtype are the same (or at least binary-compatible), so that
+        * it's OK to use the first aggregated input value as the initial
+        * transValue.  This should have been checked at agg definition time, but
+        * we must check again in case the transfn's strictness property has been
+        * changed.
+        */
+       if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
+       {
+               if (numArguments <= numDirectArgs ||
+                       !IsBinaryCoercible(inputTypes[numDirectArgs],
+                                                          aggtranstype))
+                       ereport(ERROR,
+                                       (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                        errmsg("aggregate %u needs to have compatible input type and transition type",
+                                                       aggref->aggfnoid)));
+       }
+
+       /* get info about the state value's datatype */
+       get_typlenbyval(aggtranstype,
+                                       &pertrans->transtypeLen,
+                                       &pertrans->transtypeByVal);
+
+       /*
+        * Get a tupledesc corresponding to the aggregated inputs (including sort
+        * expressions) of the agg.
+        */
+       pertrans->evaldesc = ExecTypeFromTL(aggref->args, false);
+
+       /* Create slot we're going to do argument evaluation in */
+       pertrans->evalslot = ExecInitExtraTupleSlot(estate);
+       ExecSetSlotDescriptor(pertrans->evalslot, pertrans->evaldesc);
+
+       /* Initialize the input and FILTER expressions */
+       naggs = aggstate->numaggs;
+       pertrans->aggfilter = ExecInitExpr(aggref->aggfilter,
+                                                                          (PlanState *) aggstate);
+       pertrans->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs,
+                                                                                                       (PlanState *) aggstate);
+       pertrans->args = (List *) ExecInitExpr((Expr *) aggref->args,
+                                                                                  (PlanState *) aggstate);
+
+       /*
+        * Complain if the aggregate's arguments contain any aggregates; nested
+        * agg functions are semantically nonsensical.  (This should have been
+        * caught earlier, but we defend against it here anyway.)
+        */
+       if (naggs != aggstate->numaggs)
+               ereport(ERROR,
+                               (errcode(ERRCODE_GROUPING_ERROR),
+                                errmsg("aggregate function calls cannot be nested")));
+
+       /* Set up projection info for evaluation */
+       pertrans->evalproj = ExecBuildProjectionInfo(pertrans->args,
+                                                                                                aggstate->tmpcontext,
+                                                                                                pertrans->evalslot,
+                                                                                                NULL);
+
+       /*
+        * If we're doing either DISTINCT or ORDER BY for a plain agg, then we
+        * have a list of SortGroupClause nodes; fish out the data in them and
+        * stick them into arrays.  We ignore ORDER BY for an ordered-set agg,
+        * however; the agg's transfn and finalfn are responsible for that.
+        *
+        * Note that by construction, if there is a DISTINCT clause then the ORDER
+        * BY clause is a prefix of it (see transformDistinctClause).
+        */
+       if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
+       {
+               sortlist = NIL;
+               numSortCols = numDistinctCols = 0;
+       }
+       else if (aggref->aggdistinct)
+       {
+               sortlist = aggref->aggdistinct;
+               numSortCols = numDistinctCols = list_length(sortlist);
+               Assert(numSortCols >= list_length(aggref->aggorder));
+       }
+       else
+       {
+               sortlist = aggref->aggorder;
+               numSortCols = list_length(sortlist);
+               numDistinctCols = 0;
+       }
 
-               /* Set up projection info for evaluation */
-               peraggstate->evalproj = ExecBuildProjectionInfo(aggrefstate->args,
-                                                                                                               aggstate->tmpcontext,
-                                                                                                               peraggstate->evalslot,
-                                                                                                               NULL);
+       pertrans->numSortCols = numSortCols;
+       pertrans->numDistinctCols = numDistinctCols;
 
+       if (numSortCols > 0)
+       {
                /*
-                * If we're doing either DISTINCT or ORDER BY for a plain agg, then we
-                * have a list of SortGroupClause nodes; fish out the data in them and
-                * stick them into arrays.  We ignore ORDER BY for an ordered-set agg,
-                * however; the agg's transfn and finalfn are responsible for that.
-                *
-                * Note that by construction, if there is a DISTINCT clause then the
-                * ORDER BY clause is a prefix of it (see transformDistinctClause).
+                * We don't implement DISTINCT or ORDER BY aggs in the HASHED case
+                * (yet)
                 */
-               if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
+               Assert(((Agg *) aggstate->ss.ps.plan)->aggstrategy != AGG_HASHED);
+
+               /* If we have only one input, we need its len/byval info. */
+               if (numInputs == 1)
                {
-                       sortlist = NIL;
-                       numSortCols = numDistinctCols = 0;
+                       get_typlenbyval(inputTypes[numDirectArgs],
+                                                       &pertrans->inputtypeLen,
+                                                       &pertrans->inputtypeByVal);
                }
-               else if (aggref->aggdistinct)
+               else if (numDistinctCols > 0)
                {
-                       sortlist = aggref->aggdistinct;
-                       numSortCols = numDistinctCols = list_length(sortlist);
-                       Assert(numSortCols >= list_length(aggref->aggorder));
+                       /* we will need an extra slot to store prior values */
+                       pertrans->uniqslot = ExecInitExtraTupleSlot(estate);
+                       ExecSetSlotDescriptor(pertrans->uniqslot,
+                                                                 pertrans->evaldesc);
                }
-               else
-               {
-                       sortlist = aggref->aggorder;
-                       numSortCols = list_length(sortlist);
-                       numDistinctCols = 0;
-               }
-
-               peraggstate->numSortCols = numSortCols;
-               peraggstate->numDistinctCols = numDistinctCols;
 
-               if (numSortCols > 0)
+               /* Extract the sort information for use later */
+               pertrans->sortColIdx =
+                       (AttrNumber *) palloc(numSortCols * sizeof(AttrNumber));
+               pertrans->sortOperators =
+                       (Oid *) palloc(numSortCols * sizeof(Oid));
+               pertrans->sortCollations =
+                       (Oid *) palloc(numSortCols * sizeof(Oid));
+               pertrans->sortNullsFirst =
+                       (bool *) palloc(numSortCols * sizeof(bool));
+
+               i = 0;
+               foreach(lc, sortlist)
                {
-                       /*
-                        * We don't implement DISTINCT or ORDER BY aggs in the HASHED case
-                        * (yet)
-                        */
-                       Assert(node->aggstrategy != AGG_HASHED);
+                       SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
+                       TargetEntry *tle = get_sortgroupclause_tle(sortcl, aggref->args);
 
-                       /* If we have only one input, we need its len/byval info. */
-                       if (numInputs == 1)
-                       {
-                               get_typlenbyval(inputTypes[numDirectArgs],
-                                                               &peraggstate->inputtypeLen,
-                                                               &peraggstate->inputtypeByVal);
-                       }
-                       else if (numDistinctCols > 0)
-                       {
-                               /* we will need an extra slot to store prior values */
-                               peraggstate->uniqslot = ExecInitExtraTupleSlot(estate);
-                               ExecSetSlotDescriptor(peraggstate->uniqslot,
-                                                                         peraggstate->evaldesc);
-                       }
-
-                       /* Extract the sort information for use later */
-                       peraggstate->sortColIdx =
-                               (AttrNumber *) palloc(numSortCols * sizeof(AttrNumber));
-                       peraggstate->sortOperators =
-                               (Oid *) palloc(numSortCols * sizeof(Oid));
-                       peraggstate->sortCollations =
-                               (Oid *) palloc(numSortCols * sizeof(Oid));
-                       peraggstate->sortNullsFirst =
-                               (bool *) palloc(numSortCols * sizeof(bool));
+                       /* the parser should have made sure of this */
+                       Assert(OidIsValid(sortcl->sortop));
 
-                       i = 0;
-                       foreach(lc, sortlist)
-                       {
-                               SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
-                               TargetEntry *tle = get_sortgroupclause_tle(sortcl,
-                                                                                                                  aggref->args);
-
-                               /* the parser should have made sure of this */
-                               Assert(OidIsValid(sortcl->sortop));
-
-                               peraggstate->sortColIdx[i] = tle->resno;
-                               peraggstate->sortOperators[i] = sortcl->sortop;
-                               peraggstate->sortCollations[i] = exprCollation((Node *) tle->expr);
-                               peraggstate->sortNullsFirst[i] = sortcl->nulls_first;
-                               i++;
-                       }
-                       Assert(i == numSortCols);
+                       pertrans->sortColIdx[i] = tle->resno;
+                       pertrans->sortOperators[i] = sortcl->sortop;
+                       pertrans->sortCollations[i] = exprCollation((Node *) tle->expr);
+                       pertrans->sortNullsFirst[i] = sortcl->nulls_first;
+                       i++;
                }
+               Assert(i == numSortCols);
+       }
 
-               if (aggref->aggdistinct)
-               {
-                       Assert(numArguments > 0);
+       if (aggref->aggdistinct)
+       {
+               Assert(numArguments > 0);
 
-                       /*
-                        * We need the equal function for each DISTINCT comparison we will
-                        * make.
-                        */
-                       peraggstate->equalfns =
-                               (FmgrInfo *) palloc(numDistinctCols * sizeof(FmgrInfo));
+               /*
+                * We need the equal function for each DISTINCT comparison we will
+                * make.
+                */
+               pertrans->equalfns =
+                       (FmgrInfo *) palloc(numDistinctCols * sizeof(FmgrInfo));
 
-                       i = 0;
-                       foreach(lc, aggref->aggdistinct)
-                       {
-                               SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
+               i = 0;
+               foreach(lc, aggref->aggdistinct)
+               {
+                       SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
 
-                               fmgr_info(get_opcode(sortcl->eqop), &peraggstate->equalfns[i]);
-                               i++;
-                       }
-                       Assert(i == numDistinctCols);
+                       fmgr_info(get_opcode(sortcl->eqop), &pertrans->equalfns[i]);
+                       i++;
                }
-
-               ReleaseSysCache(aggTuple);
+               Assert(i == numDistinctCols);
        }
 
-       /* Update numaggs to match number of unique aggregates found */
-       aggstate->numaggs = aggno + 1;
-
-       return aggstate;
+       pertrans->sortstates = (Tuplesortstate **)
+               palloc0(sizeof(Tuplesortstate *) * numGroupingSets);
 }
 
+
 static Datum
 GetAggInitVal(Datum textInitVal, Oid transtype)
 {
@@ -2596,11 +2787,130 @@ GetAggInitVal(Datum textInitVal, Oid transtype)
        return initVal;
 }
 
+/*
+ * find_compatible_peragg - search for a previously initialized per-Agg struct
+ *
+ * Searches the previously looked at aggregates to find one which is compatible
+ * with this one, with the same input parameters. If no compatible aggregate
+ * can be found, returns -1.
+ *
+ * As a side-effect, this also collects a list of existing per-Trans structs
+ * with matching inputs. If no identical Aggref is found, the list is passed
+ * later to find_compatible_perstate, to see if we can at least reuse the
+ * state value of another aggregate.
+ */
+static int
+find_compatible_peragg(Aggref *newagg, AggState *aggstate,
+                                          int lastaggno, List **same_input_transnos)
+{
+       int                     aggno;
+       AggStatePerAgg peraggs;
+
+       *same_input_transnos = NIL;
+
+       /* we mustn't reuse the aggref if it contains volatile function calls */
+       if (contain_volatile_functions((Node *) newagg))
+               return -1;
+
+       peraggs = aggstate->peragg;
+
+       /*
+        * Search through the list of already seen aggregates. If we find an
+        * existing aggregate with the same aggregate function and input
+        * parameters as an existing one, then we can re-use that one. While
+        * searching, we'll also collect a list of Aggrefs with the same input
+        * parameters. If no matching Aggref is found, the caller can potentially
+        * still re-use the transition state of one of them.
+        */
+       for (aggno = 0; aggno <= lastaggno; aggno++)
+       {
+               AggStatePerAgg peragg;
+               Aggref     *existingRef;
+
+               peragg = &peraggs[aggno];
+               existingRef = peragg->aggref;
+
+               /* all of the following must be the same or it's no match */
+               if (newagg->inputcollid != existingRef->inputcollid ||
+                       newagg->aggstar != existingRef->aggstar ||
+                       newagg->aggvariadic != existingRef->aggvariadic ||
+                       newagg->aggkind != existingRef->aggkind ||
+                       !equal(newagg->aggdirectargs, existingRef->aggdirectargs) ||
+                       !equal(newagg->args, existingRef->args) ||
+                       !equal(newagg->aggorder, existingRef->aggorder) ||
+                       !equal(newagg->aggdistinct, existingRef->aggdistinct) ||
+                       !equal(newagg->aggfilter, existingRef->aggfilter))
+                       continue;
+
+               /* if it's the same aggregate function then report exact match */
+               if (newagg->aggfnoid == existingRef->aggfnoid &&
+                       newagg->aggtype == existingRef->aggtype &&
+                       newagg->aggcollid == existingRef->aggcollid)
+               {
+                       list_free(*same_input_transnos);
+                       *same_input_transnos = NIL;
+                       return aggno;
+               }
+
+               /*
+                * Not identical, but it had the same inputs. Return it to the caller,
+                * in case we can re-use its per-trans state.
+                */
+               *same_input_transnos = lappend_int(*same_input_transnos,
+                                                                                  peragg->transno);
+       }
+
+       return -1;
+}
+
+/*
+ * find_compatible_pertrans - search for a previously initialized per-Trans
+ * struct
+ *
+ * Searches the list of transnos for a per-Trans struct with the same
+ * transition state and initial condition. (The inputs have already been
+ * verified to match.)
+ */
+static int
+find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
+                                                Oid aggtransfn, Oid aggtranstype,
+                                                Datum initValue, bool initValueIsNull,
+                                                List *transnos)
+{
+       ListCell   *lc;
+
+       foreach(lc, transnos)
+       {
+               int                     transno = lfirst_int(lc);
+               AggStatePerTrans pertrans = &aggstate->pertrans[transno];
+
+               /*
+                * if the transfns or transition state types are not the same then the
+                * state can't be shared.
+                */
+               if (aggtransfn != pertrans->transfn_oid ||
+                       aggtranstype != pertrans->aggtranstype)
+                       continue;
+
+               /* Check that the initial condition matches, too. */
+               if (initValueIsNull && pertrans->initValueIsNull)
+                       return transno;
+
+               if (!initValueIsNull && !pertrans->initValueIsNull &&
+                       datumIsEqual(initValue, pertrans->initValue,
+                                                pertrans->transtypeByVal, pertrans->transtypeLen))
+               {
+                       return transno;
+               }
+       }
+       return -1;
+}
+
 void
 ExecEndAgg(AggState *node)
 {
        PlanState  *outerPlan;
-       int                     aggno;
+       int                     transno;
        int                     numGroupingSets = Max(node->maxsets, 1);
        int                     setno;
 
@@ -2611,14 +2921,14 @@ ExecEndAgg(AggState *node)
        if (node->sort_out)
                tuplesort_end(node->sort_out);
 
-       for (aggno = 0; aggno < node->numaggs; aggno++)
+       for (transno = 0; transno < node->numtrans; transno++)
        {
-               AggStatePerAgg peraggstate = &node->peragg[aggno];
+               AggStatePerTrans pertrans = &node->pertrans[transno];
 
                for (setno = 0; setno < numGroupingSets; setno++)
                {
-                       if (peraggstate->sortstates[setno])
-                               tuplesort_end(peraggstate->sortstates[setno]);
+                       if (pertrans->sortstates[setno])
+                               tuplesort_end(pertrans->sortstates[setno]);
                }
        }
 
@@ -2646,7 +2956,7 @@ ExecReScanAgg(AggState *node)
        ExprContext *econtext = node->ss.ps.ps_ExprContext;
        PlanState  *outerPlan = outerPlanState(node);
        Agg                *aggnode = (Agg *) node->ss.ps.plan;
-       int                     aggno;
+       int                     transno;
        int                     numGroupingSets = Max(node->maxsets, 1);
        int                     setno;
 
@@ -2678,16 +2988,16 @@ ExecReScanAgg(AggState *node)
        }
 
        /* Make sure we have closed any open tuplesorts */
-       for (aggno = 0; aggno < node->numaggs; aggno++)
+       for (transno = 0; transno < node->numtrans; transno++)
        {
                for (setno = 0; setno < numGroupingSets; setno++)
                {
-                       AggStatePerAgg peraggstate = &node->peragg[aggno];
+                       AggStatePerTrans pertrans = &node->pertrans[transno];
 
-                       if (peraggstate->sortstates[setno])
+                       if (pertrans->sortstates[setno])
                        {
-                               tuplesort_end(peraggstate->sortstates[setno]);
-                               peraggstate->sortstates[setno] = NULL;
+                               tuplesort_end(pertrans->sortstates[setno]);
+                               pertrans->sortstates[setno] = NULL;
                        }
                }
        }
@@ -2811,10 +3121,12 @@ AggGetAggref(FunctionCallInfo fcinfo)
 {
        if (fcinfo->context && IsA(fcinfo->context, AggState))
        {
-               AggStatePerAgg curperagg = ((AggState *) fcinfo->context)->curperagg;
+               AggStatePerTrans curpertrans;
+
+               curpertrans = ((AggState *) fcinfo->context)->curpertrans;
 
-               if (curperagg)
-                       return curperagg->aggref;
+               if (curpertrans)
+                       return curpertrans->aggref;
        }
        return NULL;
 }
index ecf96f8c1939c52ddcd06d592a8a2b2e823e44c2..c371d4db14106fa57b0ed3fd54566511a25cc69e 100644 (file)
@@ -2218,20 +2218,16 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc,
                                                                                           numArguments);
 
        /* build expression trees using actual argument & result types */
-       build_aggregate_fnexprs(inputTypes,
-                                                       numArguments,
-                                                       0,      /* no ordered-set window functions yet */
-                                                       peraggstate->numFinalArgs,
-                                                       false,          /* no variadic window functions yet */
-                                                       aggtranstype,
-                                                       wfunc->wintype,
-                                                       wfunc->inputcollid,
-                                                       transfn_oid,
-                                                       invtransfn_oid,
-                                                       finalfn_oid,
-                                                       &transfnexpr,
-                                                       &invtransfnexpr,
-                                                       &finalfnexpr);
+       build_aggregate_transfn_expr(inputTypes,
+                                                                numArguments,
+                                                                0,     /* no ordered-set window functions yet */
+                                                                false,         /* no variadic window functions yet */
+                                                                wfunc->wintype,
+                                                                wfunc->inputcollid,
+                                                                transfn_oid,
+                                                                invtransfn_oid,
+                                                                &transfnexpr,
+                                                                &invtransfnexpr);
 
        /* set up infrastructure for calling the transfn(s) and finalfn */
        fmgr_info(transfn_oid, &peraggstate->transfn);
@@ -2245,6 +2241,13 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc,
 
        if (OidIsValid(finalfn_oid))
        {
+               build_aggregate_finalfn_expr(inputTypes,
+                                                                        peraggstate->numFinalArgs,
+                                                                        aggtranstype,
+                                                                        wfunc->wintype,
+                                                                        wfunc->inputcollid,
+                                                                        finalfn_oid,
+                                                                        &finalfnexpr);
                fmgr_info(finalfn_oid, &peraggstate->finalfn);
                fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn);
        }
index 3846b569d6fa4f23b2236be5bae157ff4644155e..5b0d568478bd61ed3e4f1850b685521904473a21 100644 (file)
@@ -1829,44 +1829,40 @@ resolve_aggregate_transtype(Oid aggfuncid,
 }
 
 /*
- * Create expression trees for the transition and final functions
- * of an aggregate.  These are needed so that polymorphic functions
- * can be used within an aggregate --- without the expression trees,
- * such functions would not know the datatypes they are supposed to use.
- * (The trees will never actually be executed, however, so we can skimp
- * a bit on correctness.)
+ * Create an expression tree for the transition function of an aggregate.
+ * This is needed so that polymorphic functions can be used within an
+ * aggregate --- without the expression tree, such functions would not know
+ * the datatypes they are supposed to use.  (The trees will never actually
+ * be executed, however, so we can skimp a bit on correctness.)
  *
- * agg_input_types, agg_state_type, agg_result_type identify the input,
- * transition, and result types of the aggregate.  These should all be
- * resolved to actual types (ie, none should ever be ANYELEMENT etc).
+ * agg_input_types and agg_state_type identifies the input types of the
+ * aggregate.  These should be resolved to actual types (ie, none should
+ * ever be ANYELEMENT etc).
  * agg_input_collation is the aggregate function's input collation.
  *
  * For an ordered-set aggregate, remember that agg_input_types describes
  * the direct arguments followed by the aggregated arguments.
  *
- * transfn_oid, invtransfn_oid and finalfn_oid identify the funcs to be
- * called; the latter two may be InvalidOid.
+ * transfn_oid and invtransfn_oid identify the funcs to be called; the
+ * latter may be InvalidOid, however if invtransfn_oid is set then
+ * transfn_oid must also be set.
  *
  * Pointers to the constructed trees are returned into *transfnexpr,
- * *invtransfnexpr and *finalfnexpr. If there is no invtransfn or finalfn,
- * the respective pointers are set to NULL.  Since use of the invtransfn is
- * optional, NULL may be passed for invtransfnexpr.
+ * *invtransfnexpr. If there is no invtransfn, the respective pointer is set
+ * to NULL.  Since use of the invtransfn is optional, NULL may be passed for
+ * invtransfnexpr.
  */
 void
-build_aggregate_fnexprs(Oid *agg_input_types,
-                                               int agg_num_inputs,
-                                               int agg_num_direct_inputs,
-                                               int num_finalfn_inputs,
-                                               bool agg_variadic,
-                                               Oid agg_state_type,
-                                               Oid agg_result_type,
-                                               Oid agg_input_collation,
-                                               Oid transfn_oid,
-                                               Oid invtransfn_oid,
-                                               Oid finalfn_oid,
-                                               Expr **transfnexpr,
-                                               Expr **invtransfnexpr,
-                                               Expr **finalfnexpr)
+build_aggregate_transfn_expr(Oid *agg_input_types,
+                                                        int agg_num_inputs,
+                                                        int agg_num_direct_inputs,
+                                                        bool agg_variadic,
+                                                        Oid agg_state_type,
+                                                        Oid agg_input_collation,
+                                                        Oid transfn_oid,
+                                                        Oid invtransfn_oid,
+                                                        Expr **transfnexpr,
+                                                        Expr **invtransfnexpr)
 {
        Param      *argp;
        List       *args;
@@ -1929,13 +1925,24 @@ build_aggregate_fnexprs(Oid *agg_input_types,
                else
                        *invtransfnexpr = NULL;
        }
+}
 
-       /* see if we have a final function */
-       if (!OidIsValid(finalfn_oid))
-       {
-               *finalfnexpr = NULL;
-               return;
-       }
+/*
+ * Like build_aggregate_transfn_expr, but creates an expression tree for the
+ * final function of an aggregate, rather than the transition function.
+ */
+void
+build_aggregate_finalfn_expr(Oid *agg_input_types,
+                                                        int num_finalfn_inputs,
+                                                        Oid agg_state_type,
+                                                        Oid agg_result_type,
+                                                        Oid agg_input_collation,
+                                                        Oid finalfn_oid,
+                                                        Expr **finalfnexpr)
+{
+       Param      *argp;
+       List       *args;
+       int                     i;
 
        /*
         * Build expr tree for final function
index 303fc3c1c77dca2f4c2abd51440c8614b1e88f61..5796de861c464aa27702bc80c7976d2a6feecb16 100644 (file)
@@ -609,9 +609,6 @@ typedef struct WholeRowVarExprState
 typedef struct AggrefExprState
 {
        ExprState       xprstate;
-       List       *aggdirectargs;      /* states of direct-argument expressions */
-       List       *args;                       /* states of aggregated-argument expressions */
-       ExprState  *aggfilter;          /* state of FILTER expression, if any */
        int                     aggno;                  /* ID number for agg within its plan node */
 } AggrefExprState;
 
@@ -1825,6 +1822,7 @@ typedef struct GroupState
  */
 /* these structs are private in nodeAgg.c: */
 typedef struct AggStatePerAggData *AggStatePerAgg;
+typedef struct AggStatePerTransData *AggStatePerTrans;
 typedef struct AggStatePerGroupData *AggStatePerGroup;
 typedef struct AggStatePerPhaseData *AggStatePerPhase;
 
@@ -1833,14 +1831,16 @@ typedef struct AggState
        ScanState       ss;                             /* its first field is NodeTag */
        List       *aggs;                       /* all Aggref nodes in targetlist & quals */
        int                     numaggs;                /* length of list (could be zero!) */
+       int                     numtrans;               /* number of pertrans items */
        AggStatePerPhase phase;         /* pointer to current phase data */
        int                     numphases;              /* number of phases */
        int                     current_phase;  /* current phase number */
        FmgrInfo   *hashfunctions;      /* per-grouping-field hash fns */
        AggStatePerAgg peragg;          /* per-Aggref information */
+       AggStatePerTrans pertrans;      /* per-Trans state information */
        ExprContext **aggcontexts;      /* econtexts for long-lived data (per GS) */
        ExprContext *tmpcontext;        /* econtext for input expressions */
-       AggStatePerAgg curperagg;       /* identifies currently active aggregate */
+       AggStatePerTrans curpertrans;   /* currently active trans state */
        bool            input_done;             /* indicates end of input */
        bool            agg_done;               /* indicates completion of Agg scan */
        int                     projected_set;  /* The last projected grouping set */
index 6a5f9bbdf1556ced6d4d2031791432711e6501c2..e2b3894c282c4c0470fdad7f8e52c28799381506 100644 (file)
@@ -35,19 +35,23 @@ extern Oid resolve_aggregate_transtype(Oid aggfuncid,
                                                        Oid *inputTypes,
                                                        int numArguments);
 
-extern void build_aggregate_fnexprs(Oid *agg_input_types,
+extern void build_aggregate_transfn_expr(Oid *agg_input_types,
                                                int agg_num_inputs,
                                                int agg_num_direct_inputs,
-                                               int num_finalfn_inputs,
                                                bool agg_variadic,
                                                Oid agg_state_type,
-                                               Oid agg_result_type,
                                                Oid agg_input_collation,
                                                Oid transfn_oid,
                                                Oid invtransfn_oid,
-                                               Oid finalfn_oid,
                                                Expr **transfnexpr,
-                                               Expr **invtransfnexpr,
+                                               Expr **invtransfnexpr);
+
+extern void build_aggregate_finalfn_expr(Oid *agg_input_types,
+                                               int num_finalfn_inputs,
+                                               Oid agg_state_type,
+                                               Oid agg_result_type,
+                                               Oid agg_input_collation,
+                                               Oid finalfn_oid,
                                                Expr **finalfnexpr);
 
 #endif   /* PARSE_AGG_H */
index 8852051e9325f4eec89a01def3f0673fbc27a0c6..de826b5e50f8b69f3f80331674757eef6bbc48d6 100644 (file)
@@ -1580,3 +1580,207 @@ select least_agg(variadic array[q1,q2]) from int8_tbl;
  -4567890123456789
 (1 row)
 
+-- test aggregates with common transition functions share the same states
+begin work;
+create type avg_state as (total bigint, count bigint);
+create or replace function avg_transfn(state avg_state, n int) returns avg_state as
+$$
+declare new_state avg_state;
+begin
+       raise notice 'avg_transfn called with %', n;
+       if state is null then
+               if n is not null then
+                       new_state.total := n;
+                       new_state.count := 1;
+                       return new_state;
+               end if;
+               return null;
+       elsif n is not null then
+               state.total := state.total + n;
+               state.count := state.count + 1;
+               return state;
+       end if;
+
+       return null;
+end
+$$ language plpgsql;
+create function avg_finalfn(state avg_state) returns int4 as
+$$
+begin
+       if state is null then
+               return NULL;
+       else
+               return state.total / state.count;
+       end if;
+end
+$$ language plpgsql;
+create function sum_finalfn(state avg_state) returns int4 as
+$$
+begin
+       if state is null then
+               return NULL;
+       else
+               return state.total;
+       end if;
+end
+$$ language plpgsql;
+create aggregate my_avg(int4)
+(
+   stype = avg_state,
+   sfunc = avg_transfn,
+   finalfunc = avg_finalfn
+);
+create aggregate my_sum(int4)
+(
+   stype = avg_state,
+   sfunc = avg_transfn,
+   finalfunc = sum_finalfn
+);
+-- aggregate state should be shared as aggs are the same.
+select my_avg(one),my_avg(one) from (values(1),(3)) t(one);
+NOTICE:  avg_transfn called with 1
+NOTICE:  avg_transfn called with 3
+ my_avg | my_avg 
+--------+--------
+      2 |      2
+(1 row)
+
+-- aggregate state should be shared as transfn is the same for both aggs.
+select my_avg(one),my_sum(one) from (values(1),(3)) t(one);
+NOTICE:  avg_transfn called with 1
+NOTICE:  avg_transfn called with 3
+ my_avg | my_sum 
+--------+--------
+      2 |      4
+(1 row)
+
+-- shouldn't share states due to the distinctness not matching.
+select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one);
+NOTICE:  avg_transfn called with 1
+NOTICE:  avg_transfn called with 3
+NOTICE:  avg_transfn called with 1
+NOTICE:  avg_transfn called with 3
+ my_avg | my_sum 
+--------+--------
+      2 |      4
+(1 row)
+
+-- shouldn't share states due to the filter clause not matching.
+select my_avg(one) filter (where one > 1),my_sum(one) from (values(1),(3)) t(one);
+NOTICE:  avg_transfn called with 1
+NOTICE:  avg_transfn called with 3
+NOTICE:  avg_transfn called with 3
+ my_avg | my_sum 
+--------+--------
+      3 |      4
+(1 row)
+
+-- this should not share the state due to different input columns.
+select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two);
+NOTICE:  avg_transfn called with 2
+NOTICE:  avg_transfn called with 1
+NOTICE:  avg_transfn called with 4
+NOTICE:  avg_transfn called with 3
+ my_avg | my_sum 
+--------+--------
+      2 |      6
+(1 row)
+
+-- test that aggs with the same sfunc and initcond share the same agg state
+create aggregate my_sum_init(int4)
+(
+   stype = avg_state,
+   sfunc = avg_transfn,
+   finalfunc = sum_finalfn,
+   initcond = '(10,0)'
+);
+create aggregate my_avg_init(int4)
+(
+   stype = avg_state,
+   sfunc = avg_transfn,
+   finalfunc = avg_finalfn,
+   initcond = '(10,0)'
+);
+create aggregate my_avg_init2(int4)
+(
+   stype = avg_state,
+   sfunc = avg_transfn,
+   finalfunc = avg_finalfn,
+   initcond = '(4,0)'
+);
+-- state should be shared if INITCONDs are matching
+select my_sum_init(one),my_avg_init(one) from (values(1),(3)) t(one);
+NOTICE:  avg_transfn called with 1
+NOTICE:  avg_transfn called with 3
+ my_sum_init | my_avg_init 
+-------------+-------------
+          14 |           7
+(1 row)
+
+-- Varying INITCONDs should cause the states not to be shared.
+select my_sum_init(one),my_avg_init2(one) from (values(1),(3)) t(one);
+NOTICE:  avg_transfn called with 1
+NOTICE:  avg_transfn called with 1
+NOTICE:  avg_transfn called with 3
+NOTICE:  avg_transfn called with 3
+ my_sum_init | my_avg_init2 
+-------------+--------------
+          14 |            4
+(1 row)
+
+rollback;
+-- test aggregate state sharing to ensure it works if one aggregate has a
+-- finalfn and the other one has none.
+begin work;
+create or replace function sum_transfn(state int4, n int4) returns int4 as
+$$
+declare new_state int4;
+begin
+       raise notice 'sum_transfn called with %', n;
+       if state is null then
+               if n is not null then
+                       new_state := n;
+                       return new_state;
+               end if;
+               return null;
+       elsif n is not null then
+               state := state + n;
+               return state;
+       end if;
+
+       return null;
+end
+$$ language plpgsql;
+create function halfsum_finalfn(state int4) returns int4 as
+$$
+begin
+       if state is null then
+               return NULL;
+       else
+               return state / 2;
+       end if;
+end
+$$ language plpgsql;
+create aggregate my_sum(int4)
+(
+   stype = int4,
+   sfunc = sum_transfn
+);
+create aggregate my_half_sum(int4)
+(
+   stype = int4,
+   sfunc = sum_transfn,
+   finalfunc = halfsum_finalfn
+);
+-- Agg state should be shared even though my_sum has no finalfn
+select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one);
+NOTICE:  sum_transfn called with 1
+NOTICE:  sum_transfn called with 2
+NOTICE:  sum_transfn called with 3
+NOTICE:  sum_transfn called with 4
+ my_sum | my_half_sum 
+--------+-------------
+     10 |           5
+(1 row)
+
+rollback;
index a84327d24ccc1b02948d33599b8d4fc13dd01fbb..8d501dc008d862138ebaf79d4ec44776a0c22292 100644 (file)
@@ -590,3 +590,168 @@ drop view aggordview1;
 -- variadic aggregates
 select least_agg(q1,q2) from int8_tbl;
 select least_agg(variadic array[q1,q2]) from int8_tbl;
+
+
+-- test aggregates with common transition functions share the same states
+begin work;
+
+create type avg_state as (total bigint, count bigint);
+
+create or replace function avg_transfn(state avg_state, n int) returns avg_state as
+$$
+declare new_state avg_state;
+begin
+       raise notice 'avg_transfn called with %', n;
+       if state is null then
+               if n is not null then
+                       new_state.total := n;
+                       new_state.count := 1;
+                       return new_state;
+               end if;
+               return null;
+       elsif n is not null then
+               state.total := state.total + n;
+               state.count := state.count + 1;
+               return state;
+       end if;
+
+       return null;
+end
+$$ language plpgsql;
+
+create function avg_finalfn(state avg_state) returns int4 as
+$$
+begin
+       if state is null then
+               return NULL;
+       else
+               return state.total / state.count;
+       end if;
+end
+$$ language plpgsql;
+
+create function sum_finalfn(state avg_state) returns int4 as
+$$
+begin
+       if state is null then
+               return NULL;
+       else
+               return state.total;
+       end if;
+end
+$$ language plpgsql;
+
+create aggregate my_avg(int4)
+(
+   stype = avg_state,
+   sfunc = avg_transfn,
+   finalfunc = avg_finalfn
+);
+
+create aggregate my_sum(int4)
+(
+   stype = avg_state,
+   sfunc = avg_transfn,
+   finalfunc = sum_finalfn
+);
+
+-- aggregate state should be shared as aggs are the same.
+select my_avg(one),my_avg(one) from (values(1),(3)) t(one);
+
+-- aggregate state should be shared as transfn is the same for both aggs.
+select my_avg(one),my_sum(one) from (values(1),(3)) t(one);
+
+-- shouldn't share states due to the distinctness not matching.
+select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one);
+
+-- shouldn't share states due to the filter clause not matching.
+select my_avg(one) filter (where one > 1),my_sum(one) from (values(1),(3)) t(one);
+
+-- this should not share the state due to different input columns.
+select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two);
+
+-- test that aggs with the same sfunc and initcond share the same agg state
+create aggregate my_sum_init(int4)
+(
+   stype = avg_state,
+   sfunc = avg_transfn,
+   finalfunc = sum_finalfn,
+   initcond = '(10,0)'
+);
+
+create aggregate my_avg_init(int4)
+(
+   stype = avg_state,
+   sfunc = avg_transfn,
+   finalfunc = avg_finalfn,
+   initcond = '(10,0)'
+);
+
+create aggregate my_avg_init2(int4)
+(
+   stype = avg_state,
+   sfunc = avg_transfn,
+   finalfunc = avg_finalfn,
+   initcond = '(4,0)'
+);
+
+-- state should be shared if INITCONDs are matching
+select my_sum_init(one),my_avg_init(one) from (values(1),(3)) t(one);
+
+-- Varying INITCONDs should cause the states not to be shared.
+select my_sum_init(one),my_avg_init2(one) from (values(1),(3)) t(one);
+
+rollback;
+
+-- test aggregate state sharing to ensure it works if one aggregate has a
+-- finalfn and the other one has none.
+begin work;
+
+create or replace function sum_transfn(state int4, n int4) returns int4 as
+$$
+declare new_state int4;
+begin
+       raise notice 'sum_transfn called with %', n;
+       if state is null then
+               if n is not null then
+                       new_state := n;
+                       return new_state;
+               end if;
+               return null;
+       elsif n is not null then
+               state := state + n;
+               return state;
+       end if;
+
+       return null;
+end
+$$ language plpgsql;
+
+create function halfsum_finalfn(state int4) returns int4 as
+$$
+begin
+       if state is null then
+               return NULL;
+       else
+               return state / 2;
+       end if;
+end
+$$ language plpgsql;
+
+create aggregate my_sum(int4)
+(
+   stype = int4,
+   sfunc = sum_transfn
+);
+
+create aggregate my_half_sum(int4)
+(
+   stype = int4,
+   sfunc = sum_transfn,
+   finalfunc = halfsum_finalfn
+);
+
+-- Agg state should be shared even though my_sum has no finalfn
+select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one);
+
+rollback;