Cleanup some aggregate code in the executor
authorDavid Rowley <drowley@postgresql.org>
Sun, 4 Jul 2021 06:47:31 +0000 (18:47 +1200)
committerDavid Rowley <drowley@postgresql.org>
Sun, 4 Jul 2021 06:47:31 +0000 (18:47 +1200)
Here we alter the code that calls build_pertrans_for_aggref() so that the
function no longer needs to special-case whether it's dealing with an
aggtransfn or an aggcombinefn.  This allows us to reuse the
build_aggregate_transfn_expr() function and just get rid of the
build_aggregate_combinefn_expr() completely.

All of the special case code that was in build_pertrans_for_aggref() has
been moved up to the calling functions.

This saves about a dozen lines of code in nodeAgg.c and a few dozen more
in parse_agg.c

Also, rename a few variables in nodeAgg.c to try to make it more clear
that we're working with either a aggtransfn or an aggcombinefn.  Some of
the old names would have you believe that we were always working with an
aggtransfn.

Discussion: https://postgr.es/m/CAApHDvptMQ9FmF0D67zC_w88yVnoNVR2+kkOQGUrCmdxWxLULQ@mail.gmail.com

src/backend/executor/nodeAgg.c
src/backend/parser/parse_agg.c
src/include/parser/parse_agg.h

index 8440a76fbdc9e60e83b6e60125ac7df510645ad0..914b02ceee48e240f50ff854980fad07a3833ee2 100644 (file)
@@ -461,10 +461,11 @@ static void hashagg_tapeinfo_release(HashTapeInfo *tapeinfo, int tapenum);
 static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
 static void build_pertrans_for_aggref(AggStatePerTrans pertrans,
                                                                          AggState *aggstate, EState *estate,
-                                                                         Aggref *aggref, Oid aggtransfn, Oid aggtranstype,
-                                                                         Oid aggserialfn, Oid aggdeserialfn,
-                                                                         Datum initValue, bool initValueIsNull,
-                                                                         Oid *inputTypes, int numArguments);
+                                                                         Aggref *aggref, Oid transfn_oid,
+                                                                         Oid aggtranstype, Oid aggserialfn,
+                                                                         Oid aggdeserialfn, Datum initValue,
+                                                                         bool initValueIsNull, Oid *inputTypes,
+                                                                         int numArguments);
 
 
 /*
@@ -3724,8 +3725,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
                Aggref     *aggref = lfirst(l);
                AggStatePerAgg peragg;
                AggStatePerTrans pertrans;
-               Oid                     inputTypes[FUNC_MAX_ARGS];
-               int                     numArguments;
+               Oid                     aggTransFnInputTypes[FUNC_MAX_ARGS];
+               int                     numAggTransFnArgs;
                int                     numDirectArgs;
                HeapTuple       aggTuple;
                Form_pg_aggregate aggform;
@@ -3859,14 +3860,15 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
                 * could be different from the agg's declared input types, when the
                 * agg accepts ANY or a polymorphic type.
                 */
-               numArguments = get_aggregate_argtypes(aggref, inputTypes);
+               numAggTransFnArgs = get_aggregate_argtypes(aggref,
+                                                                                                  aggTransFnInputTypes);
 
                /* Count the "direct" arguments, if any */
                numDirectArgs = list_length(aggref->aggdirectargs);
 
                /* Detect how many arguments to pass to the finalfn */
                if (aggform->aggfinalextra)
-                       peragg->numFinalArgs = numArguments + 1;
+                       peragg->numFinalArgs = numAggTransFnArgs + 1;
                else
                        peragg->numFinalArgs = numDirectArgs + 1;
 
@@ -3880,7 +3882,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
                 */
                if (OidIsValid(finalfn_oid))
                {
-                       build_aggregate_finalfn_expr(inputTypes,
+                       build_aggregate_finalfn_expr(aggTransFnInputTypes,
                                                                                 peragg->numFinalArgs,
                                                                                 aggtranstype,
                                                                                 aggref->aggtype,
@@ -3911,7 +3913,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
                        /*
                         * If this aggregation is performing state combines, then instead
                         * of using the transition function, we'll use the combine
-                        * function
+                        * function.
                         */
                        if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
                        {
@@ -3924,8 +3926,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
                        else
                                transfn_oid = aggform->aggtransfn;
 
-                       aclresult = pg_proc_aclcheck(transfn_oid, aggOwner,
-                                                                                ACL_EXECUTE);
+                       aclresult = pg_proc_aclcheck(transfn_oid, aggOwner, ACL_EXECUTE);
                        if (aclresult != ACLCHECK_OK)
                                aclcheck_error(aclresult, OBJECT_FUNCTION,
                                                           get_func_name(transfn_oid));
@@ -3943,11 +3944,72 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
                        else
                                initValue = GetAggInitVal(textInitVal, aggtranstype);
 
-                       build_pertrans_for_aggref(pertrans, aggstate, estate,
-                                                                         aggref, transfn_oid, aggtranstype,
-                                                                         serialfn_oid, deserialfn_oid,
-                                                                         initValue, initValueIsNull,
-                                                                         inputTypes, numArguments);
+                       if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
+                       {
+                               Oid                     combineFnInputTypes[] = {aggtranstype,
+                               aggtranstype};
+
+                               /*
+                                * When combining there's only one input, the to-be-combined
+                                * transition value.  The transition value is not counted
+                                * here.
+                                */
+                               pertrans->numTransInputs = 1;
+
+                               /* aggcombinefn always has two arguments of aggtranstype */
+                               build_pertrans_for_aggref(pertrans, aggstate, estate,
+                                                                                 aggref, transfn_oid, aggtranstype,
+                                                                                 serialfn_oid, deserialfn_oid,
+                                                                                 initValue, initValueIsNull,
+                                                                                 combineFnInputTypes, 2);
+
+                               /*
+                                * Ensure that a combine function to combine INTERNAL states
+                                * is not strict. This should have been checked during CREATE
+                                * AGGREGATE, but the strict property could have been changed
+                                * since then.
+                                */
+                               if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID)
+                                       ereport(ERROR,
+                                                       (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                                        errmsg("combine function with transition type %s must not be declared STRICT",
+                                                                       format_type_be(aggtranstype))));
+                       }
+                       else
+                       {
+                               /* Detect how many arguments to pass to the transfn */
+                               if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
+                                       pertrans->numTransInputs = list_length(aggref->args);
+                               else
+                                       pertrans->numTransInputs = numAggTransFnArgs;
+
+                               build_pertrans_for_aggref(pertrans, aggstate, estate,
+                                                                                 aggref, transfn_oid, aggtranstype,
+                                                                                 serialfn_oid, deserialfn_oid,
+                                                                                 initValue, initValueIsNull,
+                                                                                 aggTransFnInputTypes,
+                                                                                 numAggTransFnArgs);
+
+                               /*
+                                * If the transfn is strict and the initval is NULL, make sure
+                                * input type and transtype are the same (or at least
+                                * binary-compatible), so that it's OK to use the first
+                                * aggregated input value as the initial transValue.  This
+                                * should have been checked at agg definition time, but we
+                                * must check again in case the transfn's strictness property
+                                * has been changed.
+                                */
+                               if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
+                               {
+                                       if (numAggTransFnArgs <= numDirectArgs ||
+                                               !IsBinaryCoercible(aggTransFnInputTypes[numDirectArgs],
+                                                                                  aggtranstype))
+                                               ereport(ERROR,
+                                                               (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+                                                                errmsg("aggregate %u needs to have compatible input type and transition type",
+                                                                               aggref->aggfnoid)));
+                               }
+                       }
                }
                else
                        pertrans->aggshared = true;
@@ -4039,20 +4101,24 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
  * Build the state needed to calculate a state value for an aggregate.
  *
  * This initializes all the fields in 'pertrans'. 'aggref' is the aggregate
- * to initialize the state for. 'aggtransfn', 'aggtranstype', and the rest
+ * to initialize the state for. 'transfn_oid', 'aggtranstype', and the rest
  * of the arguments could be calculated from 'aggref', but the caller has
  * calculated them already, so might as well pass them.
+ *
+ * 'transfn_oid' may be either the Oid of the aggtransfn or the aggcombinefn.
  */
 static void
 build_pertrans_for_aggref(AggStatePerTrans pertrans,
                                                  AggState *aggstate, EState *estate,
                                                  Aggref *aggref,
-                                                 Oid aggtransfn, Oid aggtranstype,
+                                                 Oid transfn_oid, Oid aggtranstype,
                                                  Oid aggserialfn, Oid aggdeserialfn,
                                                  Datum initValue, bool initValueIsNull,
                                                  Oid *inputTypes, int numArguments)
 {
        int                     numGroupingSets = Max(aggstate->maxsets, 1);
+       Expr       *transfnexpr;
+       int                     numTransArgs;
        Expr       *serialfnexpr = NULL;
        Expr       *deserialfnexpr = NULL;
        ListCell   *lc;
@@ -4067,7 +4133,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
        pertrans->aggref = aggref;
        pertrans->aggshared = false;
        pertrans->aggCollation = aggref->inputcollid;
-       pertrans->transfn_oid = aggtransfn;
+       pertrans->transfn_oid = transfn_oid;
        pertrans->serialfn_oid = aggserialfn;
        pertrans->deserialfn_oid = aggdeserialfn;
        pertrans->initValue = initValue;
@@ -4081,111 +4147,34 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 
        pertrans->aggtranstype = aggtranstype;
 
+       /* account for the current transition state */
+       numTransArgs = pertrans->numTransInputs + 1;
+
        /*
-        * When combining states, we have no use at all for the aggregate
-        * function's transfn. Instead we use the combinefn.  In this case, the
-        * transfn and transfn_oid fields of pertrans refer to the combine
-        * function rather than the transition function.
+        * Set up infrastructure for calling the transfn.  Note that invtrans is
+        * not needed here.
         */
-       if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
-       {
-               Expr       *combinefnexpr;
-               size_t          numTransArgs;
-
-               /*
-                * When combining there's only one input, the to-be-combined added
-                * transition value from below (this node's transition value is
-                * counted separately).
-                */
-               pertrans->numTransInputs = 1;
-
-               /* account for the current transition state */
-               numTransArgs = pertrans->numTransInputs + 1;
-
-               build_aggregate_combinefn_expr(aggtranstype,
-                                                                          aggref->inputcollid,
-                                                                          aggtransfn,
-                                                                          &combinefnexpr);
-               fmgr_info(aggtransfn, &pertrans->transfn);
-               fmgr_info_set_expr((Node *) combinefnexpr, &pertrans->transfn);
-
-               pertrans->transfn_fcinfo =
-                       (FunctionCallInfo) palloc(SizeForFunctionCallInfo(2));
-               InitFunctionCallInfoData(*pertrans->transfn_fcinfo,
-                                                                &pertrans->transfn,
-                                                                numTransArgs,
-                                                                pertrans->aggCollation,
-                                                                (void *) aggstate, NULL);
-
-               /*
-                * Ensure that a combine function to combine INTERNAL states is not
-                * strict. This should have been checked during CREATE AGGREGATE, but
-                * the strict property could have been changed since then.
-                */
-               if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID)
-                       ereport(ERROR,
-                                       (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
-                                        errmsg("combine function with transition type %s must not be declared STRICT",
-                                                       format_type_be(aggtranstype))));
-       }
-       else
-       {
-               Expr       *transfnexpr;
-               size_t          numTransArgs;
-
-               /* Detect how many arguments to pass to the transfn */
-               if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
-                       pertrans->numTransInputs = numInputs;
-               else
-                       pertrans->numTransInputs = numArguments;
+       build_aggregate_transfn_expr(inputTypes,
+                                                                numArguments,
+                                                                numDirectArgs,
+                                                                aggref->aggvariadic,
+                                                                aggtranstype,
+                                                                aggref->inputcollid,
+                                                                transfn_oid,
+                                                                InvalidOid,
+                                                                &transfnexpr,
+                                                                NULL);
 
-               /* account for the current transition state */
-               numTransArgs = pertrans->numTransInputs + 1;
+       fmgr_info(transfn_oid, &pertrans->transfn);
+       fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
 
-               /*
-                * Set up infrastructure for calling the transfn.  Note that
-                * invtransfn is not needed here.
-                */
-               build_aggregate_transfn_expr(inputTypes,
-                                                                        numArguments,
-                                                                        numDirectArgs,
-                                                                        aggref->aggvariadic,
-                                                                        aggtranstype,
-                                                                        aggref->inputcollid,
-                                                                        aggtransfn,
-                                                                        InvalidOid,
-                                                                        &transfnexpr,
-                                                                        NULL);
-               fmgr_info(aggtransfn, &pertrans->transfn);
-               fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
-
-               pertrans->transfn_fcinfo =
-                       (FunctionCallInfo) palloc(SizeForFunctionCallInfo(numTransArgs));
-               InitFunctionCallInfoData(*pertrans->transfn_fcinfo,
-                                                                &pertrans->transfn,
-                                                                numTransArgs,
-                                                                pertrans->aggCollation,
-                                                                (void *) aggstate, NULL);
-
-               /*
-                * If the transfn is strict and the initval is NULL, make sure input
-                * type and transtype are the same (or at least binary-compatible), so
-                * that it's OK to use the first aggregated input value as the initial
-                * transValue.  This should have been checked at agg definition time,
-                * but we must check again in case the transfn's strictness property
-                * has been changed.
-                */
-               if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
-               {
-                       if (numArguments <= numDirectArgs ||
-                               !IsBinaryCoercible(inputTypes[numDirectArgs],
-                                                                  aggtranstype))
-                               ereport(ERROR,
-                                               (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
-                                                errmsg("aggregate %u needs to have compatible input type and transition type",
-                                                               aggref->aggfnoid)));
-               }
-       }
+       pertrans->transfn_fcinfo =
+               (FunctionCallInfo) palloc(SizeForFunctionCallInfo(numTransArgs));
+       InitFunctionCallInfoData(*pertrans->transfn_fcinfo,
+                                                        &pertrans->transfn,
+                                                        numTransArgs,
+                                                        pertrans->aggCollation,
+                                                        (void *) aggstate, NULL);
 
        /* get info about the state value's datatype */
        get_typlenbyval(aggtranstype,
@@ -4276,6 +4265,9 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
                 */
                Assert(aggstate->aggstrategy != AGG_HASHED && aggstate->aggstrategy != AGG_MIXED);
 
+               /* ORDER BY aggregates are not supported with partial aggregation */
+               Assert(!DO_AGGSPLIT_COMBINE(aggstate->aggsplit));
+
                /* If we have only one input, we need its len/byval info. */
                if (numInputs == 1)
                {
index a25f8d5b98991d15059a0df46d6fdb43a6dc56be..24268eb5024df5413013377e487018c5df2b0b2e 100644 (file)
@@ -1959,6 +1959,11 @@ resolve_aggregate_transtype(Oid aggfuncid,
  * latter may be InvalidOid, however if invtransfn_oid is set then
  * transfn_oid must also be set.
  *
+ * transfn_oid may also be passed as the aggcombinefn when the *transfnexpr is
+ * to be used for a combine aggregate phase.  We expect invtransfn_oid to be
+ * InvalidOid in this case since there is no such thing as an inverse
+ * combinefn.
+ *
  * Pointers to the constructed trees are returned into *transfnexpr,
  * *invtransfnexpr. If there is no invtransfn, the respective pointer is set
  * to NULL.  Since use of the invtransfn is optional, NULL may be passed for
@@ -2021,35 +2026,6 @@ build_aggregate_transfn_expr(Oid *agg_input_types,
        }
 }
 
-/*
- * Like build_aggregate_transfn_expr, but creates an expression tree for the
- * combine function of an aggregate, rather than the transition function.
- */
-void
-build_aggregate_combinefn_expr(Oid agg_state_type,
-                                                          Oid agg_input_collation,
-                                                          Oid combinefn_oid,
-                                                          Expr **combinefnexpr)
-{
-       Node       *argp;
-       List       *args;
-       FuncExpr   *fexpr;
-
-       /* combinefn takes two arguments of the aggregate state type */
-       argp = make_agg_arg(agg_state_type, agg_input_collation);
-
-       args = list_make2(argp, argp);
-
-       fexpr = makeFuncExpr(combinefn_oid,
-                                                agg_state_type,
-                                                args,
-                                                InvalidOid,
-                                                agg_input_collation,
-                                                COERCE_EXPLICIT_CALL);
-       /* combinefn is currently never treated as variadic */
-       *combinefnexpr = (Expr *) fexpr;
-}
-
 /*
  * Like build_aggregate_transfn_expr, but creates an expression tree for the
  * serialization function of an aggregate.
index 4dea01752af3a5283039fa1cb141e50c81fccbe8..bffbb82df66cd9ebaf96687fc52f6772e0b1cda2 100644 (file)
@@ -46,11 +46,6 @@ extern void build_aggregate_transfn_expr(Oid *agg_input_types,
                                                                                 Expr **transfnexpr,
                                                                                 Expr **invtransfnexpr);
 
-extern void build_aggregate_combinefn_expr(Oid agg_state_type,
-                                                                                  Oid agg_input_collation,
-                                                                                  Oid combinefn_oid,
-                                                                                  Expr **combinefnexpr);
-
 extern void build_aggregate_serialfn_expr(Oid serialfn_oid,
                                                                                  Expr **serialfnexpr);