]> granicus.if.org Git - postgresql/commitdiff
Simplify representation of aggregate transition values a bit.
authorAndres Freund <andres@anarazel.de>
Wed, 3 Jan 2018 02:02:37 +0000 (18:02 -0800)
committerAndres Freund <andres@anarazel.de>
Wed, 3 Jan 2018 02:23:37 +0000 (18:23 -0800)
Previously aggregate transition values for hash and other forms of
aggregation (i.e. sort and no group by) were represented
differently. Hash based aggregation used a grouping set indexed array
pointing to an array of transition values, whereas other forms of
aggregation used one flattened array with the index being computed out
of grouping set and transition offsets.

That made upcoming changes hard, so represent both as grouping set
indexed array of per-group data.

As a nice side-effect this also makes aggregation slightly faster,
because computing offsets with `transno + (setno * numTrans)` turns
out not to be that cheap (too big for x86 lea for example).

Author: Andres Freund
Discussion: https://postgr.es/m/20171128003121.nmxbm2ounxzb6n2t@alap3.anarazel.de

src/backend/executor/nodeAgg.c
src/include/nodes/execnodes.h

index da6ef1a94c419704fe7ef07963d0fd70eebd821a..a3454e52f6d1144510e4e844d9a7ed300b0084bf 100644 (file)
@@ -532,13 +532,14 @@ static void select_current_set(AggState *aggstate, int setno, bool is_hash);
 static void initialize_phase(AggState *aggstate, int newphase);
 static TupleTableSlot *fetch_input_tuple(AggState *aggstate);
 static void initialize_aggregates(AggState *aggstate,
-                                         AggStatePerGroup pergroup,
+                                         AggStatePerGroup *pergroups,
                                          int numReset);
 static void advance_transition_function(AggState *aggstate,
                                                        AggStatePerTrans pertrans,
                                                        AggStatePerGroup pergroupstate);
-static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup,
-                                  AggStatePerGroup *pergroups);
+static void advance_aggregates(AggState *aggstate,
+                                  AggStatePerGroup *sort_pergroups,
+                                  AggStatePerGroup *hash_pergroups);
 static void advance_combine_function(AggState *aggstate,
                                                 AggStatePerTrans pertrans,
                                                 AggStatePerGroup pergroupstate);
@@ -793,14 +794,16 @@ initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans,
  * If there are multiple grouping sets, we initialize only the first numReset
  * of them (the grouping sets are ordered so that the most specific one, which
  * is reset most often, is first). As a convenience, if numReset is 0, we
- * reinitialize all sets. numReset is -1 to initialize a hashtable entry, in
- * which case the caller must have used select_current_set appropriately.
+ * reinitialize all sets.
+ *
+ * NB: This cannot be used for hash aggregates, as for those the grouping set
+ * number has to be specified from further up.
  *
  * When called, CurrentMemoryContext should be the per-query context.
  */
 static void
 initialize_aggregates(AggState *aggstate,
-                                         AggStatePerGroup pergroup,
+                                         AggStatePerGroup *pergroups,
                                          int numReset)
 {
        int                     transno;
@@ -812,30 +815,18 @@ initialize_aggregates(AggState *aggstate,
        if (numReset == 0)
                numReset = numGroupingSets;
 
-       for (transno = 0; transno < numTrans; transno++)
+       for (setno = 0; setno < numReset; setno++)
        {
-               AggStatePerTrans pertrans = &transstates[transno];
-
-               if (numReset < 0)
-               {
-                       AggStatePerGroup pergroupstate;
+               AggStatePerGroup pergroup = pergroups[setno];
 
-                       pergroupstate = &pergroup[transno];
+               select_current_set(aggstate, setno, false);
 
-                       initialize_aggregate(aggstate, pertrans, pergroupstate);
-               }
-               else
+               for (transno = 0; transno < numTrans; transno++)
                {
-                       for (setno = 0; setno < numReset; setno++)
-                       {
-                               AggStatePerGroup pergroupstate;
-
-                               pergroupstate = &pergroup[transno + (setno * numTrans)];
+                       AggStatePerTrans pertrans = &transstates[transno];
+                       AggStatePerGroup pergroupstate = &pergroup[transno];
 
-                               select_current_set(aggstate, setno, false);
-
-                               initialize_aggregate(aggstate, pertrans, pergroupstate);
-                       }
+                       initialize_aggregate(aggstate, pertrans, pergroupstate);
                }
        }
 }
@@ -976,7 +967,9 @@ advance_transition_function(AggState *aggstate,
  * When called, CurrentMemoryContext should be the per-query context.
  */
 static void
-advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGroup *pergroups)
+advance_aggregates(AggState *aggstate,
+                                  AggStatePerGroup *sort_pergroups,
+                                  AggStatePerGroup *hash_pergroups)
 {
        int                     transno;
        int                     setno = 0;
@@ -1019,7 +1012,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
                {
                        /* DISTINCT and/or ORDER BY case */
                        Assert(slot->tts_nvalid >= (pertrans->numInputs + inputoff));
-                       Assert(!pergroups);
+                       Assert(!hash_pergroups);
 
                        /*
                         * If the transfn is strict, we want to check for nullity before
@@ -1090,7 +1083,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
                                fcinfo->argnull[i + 1] = slot->tts_isnull[i + inputoff];
                        }
 
-                       if (pergroup)
+                       if (sort_pergroups)
                        {
                                /* advance transition states for ordered grouping */
 
@@ -1100,13 +1093,13 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
 
                                        select_current_set(aggstate, setno, false);
 
-                                       pergroupstate = &pergroup[transno + (setno * numTrans)];
+                                       pergroupstate = &sort_pergroups[setno][transno];
 
                                        advance_transition_function(aggstate, pertrans, pergroupstate);
                                }
                        }
 
-                       if (pergroups)
+                       if (hash_pergroups)
                        {
                                /* advance transition states for hashed grouping */
 
@@ -1116,7 +1109,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
 
                                        select_current_set(aggstate, setno, true);
 
-                                       pergroupstate = &pergroups[setno][transno];
+                                       pergroupstate = &hash_pergroups[setno][transno];
 
                                        advance_transition_function(aggstate, pertrans, pergroupstate);
                                }
@@ -2095,12 +2088,25 @@ lookup_hash_entry(AggState *aggstate)
 
        if (isnew)
        {
-               entry->additional = (AggStatePerGroup)
+               AggStatePerGroup pergroup;
+               int                     transno;
+
+               pergroup = (AggStatePerGroup)
                        MemoryContextAlloc(perhash->hashtable->tablecxt,
                                                           sizeof(AggStatePerGroupData) * aggstate->numtrans);
-               /* initialize aggregates for new tuple group */
-               initialize_aggregates(aggstate, (AggStatePerGroup) entry->additional,
-                                                         -1);
+               entry->additional = pergroup;
+
+               /*
+                * Initialize aggregates for new tuple group, lookup_hash_entries()
+                * already has selected the relevant grouping set.
+                */
+               for (transno = 0; transno < aggstate->numtrans; transno++)
+               {
+                       AggStatePerTrans pertrans = &aggstate->pertrans[transno];
+                       AggStatePerGroup pergroupstate = &pergroup[transno];
+
+                       initialize_aggregate(aggstate, pertrans, pergroupstate);
+               }
        }
 
        return entry;
@@ -2184,7 +2190,7 @@ agg_retrieve_direct(AggState *aggstate)
        ExprContext *econtext;
        ExprContext *tmpcontext;
        AggStatePerAgg peragg;
-       AggStatePerGroup pergroup;
+       AggStatePerGroup *pergroups;
        AggStatePerGroup *hash_pergroups = NULL;
        TupleTableSlot *outerslot;
        TupleTableSlot *firstSlot;
@@ -2207,7 +2213,7 @@ agg_retrieve_direct(AggState *aggstate)
        tmpcontext = aggstate->tmpcontext;
 
        peragg = aggstate->peragg;
-       pergroup = aggstate->pergroup;
+       pergroups = aggstate->pergroups;
        firstSlot = aggstate->ss.ss_ScanTupleSlot;
 
        /*
@@ -2409,7 +2415,7 @@ agg_retrieve_direct(AggState *aggstate)
                        /*
                         * Initialize working state for a new input tuple group.
                         */
-                       initialize_aggregates(aggstate, pergroup, numReset);
+                       initialize_aggregates(aggstate, pergroups, numReset);
 
                        if (aggstate->grp_firstTuple != NULL)
                        {
@@ -2446,9 +2452,9 @@ agg_retrieve_direct(AggState *aggstate)
                                                hash_pergroups = NULL;
 
                                        if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
-                                               combine_aggregates(aggstate, pergroup);
+                                               combine_aggregates(aggstate, pergroups[0]);
                                        else
-                                               advance_aggregates(aggstate, pergroup, hash_pergroups);
+                                               advance_aggregates(aggstate, pergroups, hash_pergroups);
 
                                        /* Reset per-input-tuple context after each tuple */
                                        ResetExprContext(tmpcontext);
@@ -2512,7 +2518,7 @@ agg_retrieve_direct(AggState *aggstate)
 
                finalize_aggregates(aggstate,
                                                        peragg,
-                                                       pergroup + (currentSet * aggstate->numtrans));
+                                                       pergroups[currentSet]);
 
                /*
                 * If there's no row to project right now, we must continue rather
@@ -2756,7 +2762,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
        aggstate->curpertrans = NULL;
        aggstate->input_done = false;
        aggstate->agg_done = false;
-       aggstate->pergroup = NULL;
+       aggstate->pergroups = NULL;
        aggstate->grp_firstTuple = NULL;
        aggstate->sort_in = NULL;
        aggstate->sort_out = NULL;
@@ -3052,13 +3058,16 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 
        if (node->aggstrategy != AGG_HASHED)
        {
-               AggStatePerGroup pergroup;
+               AggStatePerGroup *pergroups;
+
+               pergroups = (AggStatePerGroup *) palloc0(sizeof(AggStatePerGroup) *
+                                                                                                numGroupingSets);
 
-               pergroup = (AggStatePerGroup) palloc0(sizeof(AggStatePerGroupData)
-                                                                                         * numaggs
-                                                                                         * numGroupingSets);
+               for (i = 0; i < numGroupingSets; i++)
+                       pergroups[i] = (AggStatePerGroup) palloc0(sizeof(AggStatePerGroupData)
+                                                                                                         * numaggs);
 
-               aggstate->pergroup = pergroup;
+               aggstate->pergroups = pergroups;
        }
 
        /*
@@ -4086,8 +4095,11 @@ ExecReScanAgg(AggState *node)
                /*
                 * Reset the per-group state (in particular, mark transvalues null)
                 */
-               MemSet(node->pergroup, 0,
-                          sizeof(AggStatePerGroupData) * node->numaggs * numGroupingSets);
+               for (setno = 0; setno < numGroupingSets; setno++)
+               {
+                       MemSet(node->pergroups[setno], 0,
+                                  sizeof(AggStatePerGroupData) * node->numaggs);
+               }
 
                /* reset to phase 1 */
                initialize_phase(node, 1);
index 94351eafaddd7cfa19d6a689a5b5e371b5267669..bbc3ec3f3f65c8e30ce8eeb83fc2b680193d9294 100644 (file)
@@ -1852,13 +1852,15 @@ typedef struct AggState
        Tuplesortstate *sort_out;       /* input is copied here for next phase */
        TupleTableSlot *sort_slot;      /* slot for sort results */
        /* these fields are used in AGG_PLAIN and AGG_SORTED modes: */
-       AggStatePerGroup pergroup;      /* per-Aggref-per-group working state */
+       AggStatePerGroup *pergroups;    /* grouping set indexed array of per-group
+                                                                        * pointers */
        HeapTuple       grp_firstTuple; /* copy of first tuple of current group */
        /* these fields are used in AGG_HASHED and AGG_MIXED modes: */
        bool            table_filled;   /* hash table filled yet? */
        int                     num_hashes;
        AggStatePerHash perhash;
-       AggStatePerGroup *hash_pergroup;        /* array of per-group pointers */
+       AggStatePerGroup *hash_pergroup;        /* grouping set indexed array of
+                                                                                * per-group pointers */
        /* support for evaluation of agg input expressions: */
        ProjectionInfo *combinedproj;   /* projection machinery */
 } AggState;