diff --git a/contrib/pg_stat_statements/pg_stat_statements.c b/contrib/pg_stat_statements/pg_stat_statements.c index a6ceaf4..ea930af 100644 --- a/contrib/pg_stat_statements/pg_stat_statements.c +++ b/contrib/pg_stat_statements/pg_stat_statements.c @@ -1546,6 +1546,7 @@ JumbleExpr(pgssJumbleState *jstate, Node *node) JumbleExpr(jstate, (Node *) expr->args); JumbleExpr(jstate, (Node *) expr->aggorder); JumbleExpr(jstate, (Node *) expr->aggdistinct); + JumbleExpr(jstate, (Node *) expr->aggfilter); } break; case T_WindowFunc: @@ -1555,6 +1556,7 @@ JumbleExpr(pgssJumbleState *jstate, Node *node) APP_JUMB(expr->winfnoid); APP_JUMB(expr->winref); JumbleExpr(jstate, (Node *) expr->args); + JumbleExpr(jstate, (Node *) expr->aggfilter); } break; case T_ArrayRef: diff --git a/doc/src/sgml/keywords.sgml b/doc/src/sgml/keywords.sgml index 059c8e4..ecfde99 100644 --- a/doc/src/sgml/keywords.sgml +++ b/doc/src/sgml/keywords.sgml @@ -1786,7 +1786,7 @@ FILTER - + non-reserved reserved reserved diff --git a/doc/src/sgml/ref/select.sgml b/doc/src/sgml/ref/select.sgml index 68309ba..b0cec14 100644 --- a/doc/src/sgml/ref/select.sgml +++ b/doc/src/sgml/ref/select.sgml @@ -598,6 +598,11 @@ GROUP BY expression [, ...] making up each group, producing a separate value for each group (whereas without GROUP BY, an aggregate produces a single value computed across all the selected rows). + The set of rows fed to the aggregate function can be further filtered by + attaching a FILTER clause to the aggregate function + call; see for more information. When + a FILTER clause is present, only those rows matching it + are included. When GROUP BY is present, it is not valid for the SELECT list expressions to refer to ungrouped columns except within aggregate functions or if the diff --git a/doc/src/sgml/syntax.sgml b/doc/src/sgml/syntax.sgml index b139212..803ed85 100644 --- a/doc/src/sgml/syntax.sgml +++ b/doc/src/sgml/syntax.sgml @@ -1554,6 +1554,10 @@ sqrt(2) invocation + + filter + + An aggregate expression represents the application of an aggregate function across the rows selected by a @@ -1562,19 +1566,19 @@ sqrt(2) syntax of an aggregate expression is one of the following: -aggregate_name (expression [ , ... ] [ order_by_clause ] ) -aggregate_name (ALL expression [ , ... ] [ order_by_clause ] ) -aggregate_name (DISTINCT expression [ , ... ] [ order_by_clause ] ) -aggregate_name ( * ) +aggregate_name (expression [ , ... ] [ order_by_clause ] ) [ FILTER ( WHERE filter_clause ) ] +aggregate_name (ALL expression [ , ... ] [ order_by_clause ] ) [ FILTER ( WHERE filter_clause ) ] +aggregate_name (DISTINCT expression [ , ... ] [ order_by_clause ] ) [ FILTER ( WHERE filter_clause ) ] +aggregate_name ( * ) [ FILTER ( WHERE filter_clause ) ] where aggregate_name is a previously - defined aggregate (possibly qualified with a schema name), + defined aggregate (possibly qualified with a schema name) and expression is any value expression that does not itself contain an aggregate - expression or a window function call, and - order_by_clause is a optional - ORDER BY clause as described below. + expression or a window function call. The optional + order_by_clause and + filter_clause are described below. @@ -1607,6 +1611,23 @@ sqrt(2) + If FILTER is specified, then only the input + rows for which the filter_clause + evaluates to true are fed to the aggregate function; other rows + are discarded. For example: + +SELECT + count(*) AS unfiltered, + count(*) FILTER (WHERE i < 5) AS filtered +FROM generate_series(1,10) AS s(i); + unfiltered | filtered +------------+---------- + 10 | 4 +(1 row) + + + + Ordinarily, the input rows are fed to the aggregate function in an unspecified order. In many cases this does not matter; for example, min produces the same result no matter what order it @@ -1709,10 +1730,10 @@ SELECT string_agg(a ORDER BY a, ',') FROM table; -- incorrect The syntax of a window function call is one of the following: -function_name (expression , expression ... ) OVER ( window_definition ) -function_name (expression , expression ... ) OVER window_name -function_name ( * ) OVER ( window_definition ) -function_name ( * ) OVER window_name +function_name (expression , expression ... ) [ FILTER ( WHERE filter_clause ) ] OVER ( window_definition ) +function_name (expression , expression ... ) [ FILTER ( WHERE filter_clause ) ] OVER window_name +function_name ( * ) [ FILTER ( WHERE filter_clause ) ] OVER ( window_definition ) +function_name ( * ) [ FILTER ( WHERE filter_clause ) ] OVER window_name where window_definition has the syntax @@ -1836,7 +1857,8 @@ UNBOUNDED FOLLOWING The built-in window functions are described in . Other window functions can be added by the user. Also, any built-in or user-defined aggregate function can be - used as a window function. + used as a window function. Only aggregate window functions accept + a FILTER clause. diff --git a/src/backend/executor/execQual.c b/src/backend/executor/execQual.c index 1388183..90c2753 100644 --- a/src/backend/executor/execQual.c +++ b/src/backend/executor/execQual.c @@ -4410,6 +4410,8 @@ ExecInitExpr(Expr *node, PlanState *parent) astate->args = (List *) ExecInitExpr((Expr *) aggref->args, parent); + astate->aggfilter = ExecInitExpr(aggref->aggfilter, + parent); /* * Complain if the aggregate's arguments contain any @@ -4448,6 +4450,8 @@ ExecInitExpr(Expr *node, PlanState *parent) wfstate->args = (List *) ExecInitExpr((Expr *) wfunc->args, parent); + wfstate->aggfilter = ExecInitExpr(wfunc->aggfilter, + parent); /* * Complain if the windowfunc's arguments contain any diff --git a/src/backend/executor/execUtils.c b/src/backend/executor/execUtils.c index cf7fb72..b449e0a 100644 --- a/src/backend/executor/execUtils.c +++ b/src/backend/executor/execUtils.c @@ -649,9 +649,9 @@ get_last_attnums(Node *node, ProjectionInfo *projInfo) } /* - * Don't examine the arguments of Aggrefs or WindowFuncs, because those do - * not represent expressions to be evaluated within the overall - * targetlist's econtext. + * Don't examine the arguments or filters of Aggrefs or WindowFuncs, + * because those do not represent expressions to be evaluated within the + * overall targetlist's econtext. */ if (IsA(node, Aggref)) return false; diff --git a/src/backend/executor/functions.c b/src/backend/executor/functions.c index 12e1b8e..ff6a123 100644 --- a/src/backend/executor/functions.c +++ b/src/backend/executor/functions.c @@ -380,7 +380,7 @@ sql_fn_post_column_ref(ParseState *pstate, ColumnRef *cref, Node *var) param = ParseFuncOrColumn(pstate, list_make1(subfield), list_make1(param), - NIL, false, false, false, + NIL, NULL, false, false, false, NULL, true, cref->location); } diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c index c741131..7a0c254 100644 --- a/src/backend/executor/nodeAgg.c +++ b/src/backend/executor/nodeAgg.c @@ -484,10 +484,23 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) { AggStatePerAgg peraggstate = &aggstate->peragg[aggno]; AggStatePerGroup pergroupstate = &pergroup[aggno]; + ExprState *filter = peraggstate->aggrefstate->aggfilter; int nargs = peraggstate->numArguments; int i; TupleTableSlot *slot; + /* Skip anything FILTERed out */ + if (filter) + { + bool isnull; + Datum res; + + res = ExecEvalExprSwitchContext(filter, aggstate->tmpcontext, + &isnull, NULL); + if (isnull || !DatumGetBool(res)) + continue; + } + /* Evaluate the current input expressions for this aggregate */ slot = ExecProject(peraggstate->evalproj, NULL); diff --git a/src/backend/executor/nodeWindowAgg.c b/src/backend/executor/nodeWindowAgg.c index d9f0e79..bbc5336 100644 --- a/src/backend/executor/nodeWindowAgg.c +++ b/src/backend/executor/nodeWindowAgg.c @@ -227,9 +227,23 @@ advance_windowaggregate(WindowAggState *winstate, int i; MemoryContext oldContext; ExprContext *econtext = winstate->tmpcontext; + ExprState *filter = wfuncstate->aggfilter; oldContext = MemoryContextSwitchTo(econtext->ecxt_per_tuple_memory); + /* Skip anything FILTERed out */ + if (filter) + { + bool isnull; + Datum res = ExecEvalExpr(filter, econtext, &isnull, NULL); + + if (isnull || !DatumGetBool(res)) + { + MemoryContextSwitchTo(oldContext); + return; + } + } + /* We start from 1, since the 0th arg will be the transition value */ i = 1; foreach(arg, wfuncstate->args) diff --git a/src/backend/nodes/copyfuncs.c b/src/backend/nodes/copyfuncs.c index b5b8d63..55bdba2 100644 --- a/src/backend/nodes/copyfuncs.c +++ b/src/backend/nodes/copyfuncs.c @@ -1137,6 +1137,7 @@ _copyAggref(const Aggref *from) COPY_NODE_FIELD(args); COPY_NODE_FIELD(aggorder); COPY_NODE_FIELD(aggdistinct); + COPY_NODE_FIELD(aggfilter); COPY_SCALAR_FIELD(aggstar); COPY_SCALAR_FIELD(agglevelsup); COPY_LOCATION_FIELD(location); @@ -1157,6 +1158,7 @@ _copyWindowFunc(const WindowFunc *from) COPY_SCALAR_FIELD(wincollid); COPY_SCALAR_FIELD(inputcollid); COPY_NODE_FIELD(args); + COPY_NODE_FIELD(aggfilter); COPY_SCALAR_FIELD(winref); COPY_SCALAR_FIELD(winstar); COPY_SCALAR_FIELD(winagg); @@ -2152,6 +2154,7 @@ _copyFuncCall(const FuncCall *from) COPY_NODE_FIELD(funcname); COPY_NODE_FIELD(args); COPY_NODE_FIELD(agg_order); + COPY_NODE_FIELD(agg_filter); COPY_SCALAR_FIELD(agg_star); COPY_SCALAR_FIELD(agg_distinct); COPY_SCALAR_FIELD(func_variadic); diff --git a/src/backend/nodes/equalfuncs.c b/src/backend/nodes/equalfuncs.c index 3f96595..79384e7 100644 --- a/src/backend/nodes/equalfuncs.c +++ b/src/backend/nodes/equalfuncs.c @@ -196,6 +196,7 @@ _equalAggref(const Aggref *a, const Aggref *b) COMPARE_NODE_FIELD(args); COMPARE_NODE_FIELD(aggorder); COMPARE_NODE_FIELD(aggdistinct); + COMPARE_NODE_FIELD(aggfilter); COMPARE_SCALAR_FIELD(aggstar); COMPARE_SCALAR_FIELD(agglevelsup); COMPARE_LOCATION_FIELD(location); @@ -211,6 +212,7 @@ _equalWindowFunc(const WindowFunc *a, const WindowFunc *b) COMPARE_SCALAR_FIELD(wincollid); COMPARE_SCALAR_FIELD(inputcollid); COMPARE_NODE_FIELD(args); + COMPARE_NODE_FIELD(aggfilter); COMPARE_SCALAR_FIELD(winref); COMPARE_SCALAR_FIELD(winstar); COMPARE_SCALAR_FIELD(winagg); @@ -1992,6 +1994,7 @@ _equalFuncCall(const FuncCall *a, const FuncCall *b) COMPARE_NODE_FIELD(funcname); COMPARE_NODE_FIELD(args); COMPARE_NODE_FIELD(agg_order); + COMPARE_NODE_FIELD(agg_filter); COMPARE_SCALAR_FIELD(agg_star); COMPARE_SCALAR_FIELD(agg_distinct); COMPARE_SCALAR_FIELD(func_variadic); diff --git a/src/backend/nodes/makefuncs.c b/src/backend/nodes/makefuncs.c index 245aef2..0f8a282 100644 --- a/src/backend/nodes/makefuncs.c +++ b/src/backend/nodes/makefuncs.c @@ -526,6 +526,7 @@ makeFuncCall(List *name, List *args, int location) n->args = args; n->location = location; n->agg_order = NIL; + n->agg_filter = NULL; n->agg_star = FALSE; n->agg_distinct = FALSE; n->func_variadic = FALSE; diff --git a/src/backend/nodes/nodeFuncs.c b/src/backend/nodes/nodeFuncs.c index 42d6621..310400e 100644 --- a/src/backend/nodes/nodeFuncs.c +++ b/src/backend/nodes/nodeFuncs.c @@ -1570,6 +1570,8 @@ expression_tree_walker(Node *node, if (expression_tree_walker((Node *) expr->aggdistinct, walker, context)) return true; + if (walker((Node *) expr->aggfilter, context)) + return true; } break; case T_WindowFunc: @@ -1580,6 +1582,8 @@ expression_tree_walker(Node *node, if (expression_tree_walker((Node *) expr->args, walker, context)) return true; + if (walker((Node *) expr->aggfilter, context)) + return true; } break; case T_ArrayRef: @@ -2079,6 +2083,7 @@ expression_tree_mutator(Node *node, MUTATE(newnode->args, aggref->args, List *); MUTATE(newnode->aggorder, aggref->aggorder, List *); MUTATE(newnode->aggdistinct, aggref->aggdistinct, List *); + MUTATE(newnode->aggfilter, aggref->aggfilter, Expr *); return (Node *) newnode; } break; @@ -2089,6 +2094,7 @@ expression_tree_mutator(Node *node, FLATCOPY(newnode, wfunc, WindowFunc); MUTATE(newnode->args, wfunc->args, List *); + MUTATE(newnode->aggfilter, wfunc->aggfilter, Expr *); return (Node *) newnode; } break; @@ -2951,6 +2957,8 @@ raw_expression_tree_walker(Node *node, return true; if (walker(fcall->agg_order, context)) return true; + if (walker(fcall->agg_filter, context)) + return true; if (walker(fcall->over, context)) return true; /* function name is deemed uninteresting */ diff --git a/src/backend/nodes/outfuncs.c b/src/backend/nodes/outfuncs.c index b2183f4..2475f8d 100644 --- a/src/backend/nodes/outfuncs.c +++ b/src/backend/nodes/outfuncs.c @@ -958,6 +958,7 @@ _outAggref(StringInfo str, const Aggref *node) WRITE_NODE_FIELD(args); WRITE_NODE_FIELD(aggorder); WRITE_NODE_FIELD(aggdistinct); + WRITE_NODE_FIELD(aggfilter); WRITE_BOOL_FIELD(aggstar); WRITE_UINT_FIELD(agglevelsup); WRITE_LOCATION_FIELD(location); @@ -973,6 +974,7 @@ _outWindowFunc(StringInfo str, const WindowFunc *node) WRITE_OID_FIELD(wincollid); WRITE_OID_FIELD(inputcollid); WRITE_NODE_FIELD(args); + WRITE_NODE_FIELD(aggfilter); WRITE_UINT_FIELD(winref); WRITE_BOOL_FIELD(winstar); WRITE_BOOL_FIELD(winagg); @@ -2080,6 +2082,7 @@ _outFuncCall(StringInfo str, const FuncCall *node) WRITE_NODE_FIELD(funcname); WRITE_NODE_FIELD(args); WRITE_NODE_FIELD(agg_order); + WRITE_NODE_FIELD(agg_filter); WRITE_BOOL_FIELD(agg_star); WRITE_BOOL_FIELD(agg_distinct); WRITE_BOOL_FIELD(func_variadic); diff --git a/src/backend/nodes/readfuncs.c b/src/backend/nodes/readfuncs.c index 3a16e9d..30c5150 100644 --- a/src/backend/nodes/readfuncs.c +++ b/src/backend/nodes/readfuncs.c @@ -479,6 +479,7 @@ _readAggref(void) READ_NODE_FIELD(args); READ_NODE_FIELD(aggorder); READ_NODE_FIELD(aggdistinct); + READ_NODE_FIELD(aggfilter); READ_BOOL_FIELD(aggstar); READ_UINT_FIELD(agglevelsup); READ_LOCATION_FIELD(location); @@ -499,6 +500,7 @@ _readWindowFunc(void) READ_OID_FIELD(wincollid); READ_OID_FIELD(inputcollid); READ_NODE_FIELD(args); + READ_NODE_FIELD(aggfilter); READ_UINT_FIELD(winref); READ_BOOL_FIELD(winstar); READ_BOOL_FIELD(winagg); diff --git a/src/backend/optimizer/path/costsize.c b/src/backend/optimizer/path/costsize.c index 3507f18..1732d71 100644 --- a/src/backend/optimizer/path/costsize.c +++ b/src/backend/optimizer/path/costsize.c @@ -1590,6 +1590,14 @@ cost_windowagg(Path *path, PlannerInfo *root, startup_cost += argcosts.startup; wfunccost += argcosts.per_tuple; + /* + * Add the filter's cost to per-input-row costs. XXX We should reduce + * input expression costs according to filter selectivity. + */ + cost_qual_eval_node(&argcosts, (Node *) wfunc->aggfilter, root); + startup_cost += argcosts.startup; + wfunccost += argcosts.per_tuple; + total_cost += wfunccost * input_tuples; } diff --git a/src/backend/optimizer/plan/planagg.c b/src/backend/optimizer/plan/planagg.c index 243aeb3..2f6fc4a 100644 --- a/src/backend/optimizer/plan/planagg.c +++ b/src/backend/optimizer/plan/planagg.c @@ -318,6 +318,13 @@ find_minmax_aggs_walker(Node *node, List **context) return true; /* it couldn't be MIN/MAX */ /* + * We might implement the optimization when a FILTER clause is present + * by adding the filter to the quals of the generated subquery. + */ + if (aggref->aggfilter != NULL) + return true; + + /* * Ignore ORDER BY and DISTINCT, which are valid but pointless on * MIN/MAX. They do not change its result. */ diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c index 6d5b204..7ec6b0b 100644 --- a/src/backend/optimizer/util/clauses.c +++ b/src/backend/optimizer/util/clauses.c @@ -495,6 +495,15 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context) costs->transCost.startup += argcosts.startup; costs->transCost.per_tuple += argcosts.per_tuple; + /* + * Add the filter's cost to per-input-row costs. XXX We should reduce + * input expression costs according to filter selectivity. + */ + cost_qual_eval_node(&argcosts, (Node *) aggref->aggfilter, + context->root); + costs->transCost.startup += argcosts.startup; + costs->transCost.per_tuple += argcosts.per_tuple; + /* extract argument types (ignoring any ORDER BY expressions) */ inputTypes = (Oid *) palloc(sizeof(Oid) * list_length(aggref->args)); numArguments = 0; @@ -565,7 +574,8 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context) /* * Complain if the aggregate's arguments contain any aggregates; - * nested agg functions are semantically nonsensical. + * nested agg functions are semantically nonsensical. Aggregates in + * the FILTER clause are detected in transformAggregateCall(). */ if (contain_agg_clause((Node *) aggref->args)) ereport(ERROR, @@ -639,7 +649,8 @@ find_window_functions_walker(Node *node, WindowFuncLists *lists) /* * Complain if the window function's arguments contain window - * functions + * functions. Window functions in the FILTER clause are detected in + * transformAggregateCall(). */ if (contain_window_function((Node *) wfunc->args)) ereport(ERROR, diff --git a/src/backend/parser/gram.y b/src/backend/parser/gram.y index f67ef0c..5ff3e70 100644 --- a/src/backend/parser/gram.y +++ b/src/backend/parser/gram.y @@ -492,6 +492,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query); opt_frame_clause frame_extent frame_bound %type opt_existing_window_name %type opt_if_not_exists +%type filter_clause /* * Non-keyword token types. These are hard-wired into the "flex" lexer. @@ -538,8 +539,8 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query); EXCLUDE EXCLUDING EXCLUSIVE EXECUTE EXISTS EXPLAIN EXTENSION EXTERNAL EXTRACT - FALSE_P FAMILY FETCH FIRST_P FLOAT_P FOLLOWING FOR FORCE FOREIGN FORWARD - FREEZE FROM FULL FUNCTION FUNCTIONS + FALSE_P FAMILY FETCH FILTER FIRST_P FLOAT_P FOLLOWING FOR + FORCE FOREIGN FORWARD FREEZE FROM FULL FUNCTION FUNCTIONS GLOBAL GRANT GRANTED GREATEST GROUP_P @@ -11111,10 +11112,11 @@ func_application: func_name '(' ')' * (Note that many of the special SQL functions wouldn't actually make any * sense as functional index entries, but we ignore that consideration here.) */ -func_expr: func_application over_clause +func_expr: func_application filter_clause over_clause { FuncCall *n = (FuncCall*)$1; - n->over = $2; + n->agg_filter = $2; + n->over = $3; $$ = (Node*)n; } | func_expr_common_subexpr @@ -11525,6 +11527,11 @@ window_definition: } ; +filter_clause: + FILTER '(' WHERE a_expr ')' { $$ = $4; } + | /*EMPTY*/ { $$ = NULL; } + ; + over_clause: OVER window_specification { $$ = $2; } | OVER ColId @@ -12499,6 +12506,7 @@ unreserved_keyword: | EXTENSION | EXTERNAL | FAMILY + | FILTER | FIRST_P | FOLLOWING | FORCE diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c index 7380618..4e4e1cd 100644 --- a/src/backend/parser/parse_agg.c +++ b/src/backend/parser/parse_agg.c @@ -44,7 +44,7 @@ typedef struct int sublevels_up; } check_ungrouped_columns_context; -static int check_agg_arguments(ParseState *pstate, List *args); +static int check_agg_arguments(ParseState *pstate, List *args, Expr *filter); static bool check_agg_arguments_walker(Node *node, check_agg_arguments_context *context); static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry, @@ -160,7 +160,7 @@ transformAggregateCall(ParseState *pstate, Aggref *agg, * Check the arguments to compute the aggregate's level and detect * improper nesting. */ - min_varlevel = check_agg_arguments(pstate, agg->args); + min_varlevel = check_agg_arguments(pstate, agg->args, agg->aggfilter); agg->agglevelsup = min_varlevel; /* Mark the correct pstate level as having aggregates */ @@ -207,6 +207,9 @@ transformAggregateCall(ParseState *pstate, Aggref *agg, case EXPR_KIND_HAVING: /* okay */ break; + case EXPR_KIND_FILTER: + errkind = true; + break; case EXPR_KIND_WINDOW_PARTITION: /* okay */ break; @@ -299,8 +302,8 @@ transformAggregateCall(ParseState *pstate, Aggref *agg, * one is its parent, etc). * * The aggregate's level is the same as the level of the lowest-level variable - * or aggregate in its arguments; or if it contains no variables at all, we - * presume it to be local. + * or aggregate in its arguments or filter expression; or if it contains no + * variables at all, we presume it to be local. * * We also take this opportunity to detect any aggregates or window functions * nested within the arguments. We can throw error immediately if we find @@ -309,7 +312,7 @@ transformAggregateCall(ParseState *pstate, Aggref *agg, * which we can't know until we finish scanning the arguments. */ static int -check_agg_arguments(ParseState *pstate, List *args) +check_agg_arguments(ParseState *pstate, List *args, Expr *filter) { int agglevel; check_agg_arguments_context context; @@ -323,6 +326,10 @@ check_agg_arguments(ParseState *pstate, List *args) check_agg_arguments_walker, (void *) &context); + (void) expression_tree_walker((Node *) filter, + check_agg_arguments_walker, + (void *) &context); + /* * If we found no vars nor aggs at all, it's a level-zero aggregate; * otherwise, its level is the minimum of vars or aggs. @@ -481,6 +488,9 @@ transformWindowFuncCall(ParseState *pstate, WindowFunc *wfunc, case EXPR_KIND_HAVING: errkind = true; break; + case EXPR_KIND_FILTER: + errkind = true; + break; case EXPR_KIND_WINDOW_PARTITION: case EXPR_KIND_WINDOW_ORDER: case EXPR_KIND_WINDOW_FRAME_RANGE: @@ -807,11 +817,10 @@ check_ungrouped_columns_walker(Node *node, /* * If we find an aggregate call of the original level, do not recurse into - * its arguments; ungrouped vars in the arguments are not an error. We can - * also skip looking at the arguments of aggregates of higher levels, - * since they could not possibly contain Vars that are of concern to us - * (see transformAggregateCall). We do need to look into the arguments of - * aggregates of lower levels, however. + * its arguments or filter; ungrouped vars there are not an error. We can + * also skip looking at aggregates of higher levels, since they could not + * possibly contain Vars of concern to us (see transformAggregateCall). + * We do need to look at aggregates of lower levels, however. */ if (IsA(node, Aggref) && (int) ((Aggref *) node)->agglevelsup >= context->sublevels_up) diff --git a/src/backend/parser/parse_collate.c b/src/backend/parser/parse_collate.c index 80f6ac7..fe57c59 100644 --- a/src/backend/parser/parse_collate.c +++ b/src/backend/parser/parse_collate.c @@ -575,6 +575,10 @@ assign_collations_walker(Node *node, assign_collations_context *context) * the case above for T_TargetEntry will apply * appropriate checks to agg ORDER BY items. * + * Likewise, we assign collations for the (bool) + * expression in aggfilter, independently of any + * other args. + * * We need not recurse into the aggorder or * aggdistinct lists, because those contain only * SortGroupClause nodes which we need not @@ -595,6 +599,24 @@ assign_collations_walker(Node *node, assign_collations_context *context) (void) assign_collations_walker((Node *) tle, &loccontext); } + + assign_expr_collations(context->pstate, + (Node *) aggref->aggfilter); + } + break; + case T_WindowFunc: + { + /* + * WindowFunc requires special processing only for + * its aggfilter clause, as for aggregates. + */ + WindowFunc *wfunc = (WindowFunc *) node; + + (void) assign_collations_walker((Node *) wfunc->args, + &loccontext); + + assign_expr_collations(context->pstate, + (Node *) wfunc->aggfilter); } break; case T_CaseExpr: diff --git a/src/backend/parser/parse_expr.c b/src/backend/parser/parse_expr.c index 06f6512..68b711d 100644 --- a/src/backend/parser/parse_expr.c +++ b/src/backend/parser/parse_expr.c @@ -22,6 +22,7 @@ #include "nodes/nodeFuncs.h" #include "optimizer/var.h" #include "parser/analyze.h" +#include "parser/parse_clause.h" #include "parser/parse_coerce.h" #include "parser/parse_collate.h" #include "parser/parse_expr.h" @@ -462,7 +463,7 @@ transformIndirection(ParseState *pstate, Node *basenode, List *indirection) newresult = ParseFuncOrColumn(pstate, list_make1(n), list_make1(result), - NIL, false, false, false, + NIL, NULL, false, false, false, NULL, true, location); if (newresult == NULL) unknown_attribute(pstate, result, strVal(n), location); @@ -630,7 +631,7 @@ transformColumnRef(ParseState *pstate, ColumnRef *cref) node = ParseFuncOrColumn(pstate, list_make1(makeString(colname)), list_make1(node), - NIL, false, false, false, + NIL, NULL, false, false, false, NULL, true, cref->location); } break; @@ -675,7 +676,7 @@ transformColumnRef(ParseState *pstate, ColumnRef *cref) node = ParseFuncOrColumn(pstate, list_make1(makeString(colname)), list_make1(node), - NIL, false, false, false, + NIL, NULL, false, false, false, NULL, true, cref->location); } break; @@ -733,7 +734,7 @@ transformColumnRef(ParseState *pstate, ColumnRef *cref) node = ParseFuncOrColumn(pstate, list_make1(makeString(colname)), list_make1(node), - NIL, false, false, false, + NIL, NULL, false, false, false, NULL, true, cref->location); } break; @@ -1241,6 +1242,7 @@ transformFuncCall(ParseState *pstate, FuncCall *fn) { List *targs; ListCell *args; + Expr *tagg_filter; /* Transform the list of arguments ... */ targs = NIL; @@ -1250,11 +1252,22 @@ transformFuncCall(ParseState *pstate, FuncCall *fn) (Node *) lfirst(args))); } + /* + * Transform the aggregate filter using transformWhereClause(), to which + * FILTER is virtually identical... + */ + tagg_filter = NULL; + if (fn->agg_filter != NULL) + tagg_filter = (Expr *) + transformWhereClause(pstate, (Node *) fn->agg_filter, + EXPR_KIND_FILTER, "FILTER"); + /* ... and hand off to ParseFuncOrColumn */ return ParseFuncOrColumn(pstate, fn->funcname, targs, fn->agg_order, + tagg_filter, fn->agg_star, fn->agg_distinct, fn->func_variadic, @@ -1430,6 +1443,7 @@ transformSubLink(ParseState *pstate, SubLink *sublink) case EXPR_KIND_FROM_FUNCTION: case EXPR_KIND_WHERE: case EXPR_KIND_HAVING: + case EXPR_KIND_FILTER: case EXPR_KIND_WINDOW_PARTITION: case EXPR_KIND_WINDOW_ORDER: case EXPR_KIND_WINDOW_FRAME_RANGE: @@ -2579,6 +2593,8 @@ ParseExprKindName(ParseExprKind exprKind) return "WHERE"; case EXPR_KIND_HAVING: return "HAVING"; + case EXPR_KIND_FILTER: + return "FILTER"; case EXPR_KIND_WINDOW_PARTITION: return "window PARTITION BY"; case EXPR_KIND_WINDOW_ORDER: diff --git a/src/backend/parser/parse_func.c b/src/backend/parser/parse_func.c index ae7d195..e54922f 100644 --- a/src/backend/parser/parse_func.c +++ b/src/backend/parser/parse_func.c @@ -56,13 +56,13 @@ static Node *ParseComplexProjection(ParseState *pstate, char *funcname, * Also, when is_column is true, we return NULL on failure rather than * reporting a no-such-function error. * - * The argument expressions (in fargs) must have been transformed already. - * But the agg_order expressions, if any, have not been. + * The argument expressions (in fargs) and filter must have been transformed + * already. But the agg_order expressions, if any, have not been. */ Node * ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, - List *agg_order, bool agg_star, bool agg_distinct, - bool func_variadic, + List *agg_order, Expr *agg_filter, + bool agg_star, bool agg_distinct, bool func_variadic, WindowDef *over, bool is_column, int location) { Oid rettype; @@ -174,8 +174,8 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, * the "function call" could be a projection. We also check that there * wasn't any aggregate or variadic decoration, nor an argument name. */ - if (nargs == 1 && agg_order == NIL && !agg_star && !agg_distinct && - over == NULL && !func_variadic && argnames == NIL && + if (nargs == 1 && agg_order == NIL && agg_filter == NULL && !agg_star && + !agg_distinct && over == NULL && !func_variadic && argnames == NIL && list_length(funcname) == 1) { Oid argtype = actual_arg_types[0]; @@ -251,6 +251,12 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, errmsg("ORDER BY specified, but %s is not an aggregate function", NameListToString(funcname)), parser_errposition(pstate, location))); + if (agg_filter) + ereport(ERROR, + (errcode(ERRCODE_WRONG_OBJECT_TYPE), + errmsg("FILTER specified, but %s is not an aggregate function", + NameListToString(funcname)), + parser_errposition(pstate, location))); if (over) ereport(ERROR, (errcode(ERRCODE_WRONG_OBJECT_TYPE), @@ -402,6 +408,7 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, /* aggcollid and inputcollid will be set by parse_collate.c */ /* args, aggorder, aggdistinct will be set by transformAggregateCall */ aggref->aggstar = agg_star; + aggref->aggfilter = agg_filter; /* agglevelsup will be set by transformAggregateCall */ aggref->location = location; @@ -460,6 +467,7 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, /* winref will be set by transformWindowFuncCall */ wfunc->winstar = agg_star; wfunc->winagg = (fdresult == FUNCDETAIL_AGGREGATE); + wfunc->aggfilter = agg_filter; wfunc->location = location; /* @@ -483,6 +491,16 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, parser_errposition(pstate, location))); /* + * Reject window functions which are not aggregates in the case of + * FILTER. + */ + if (!wfunc->winagg && agg_filter) + ereport(ERROR, + (errcode(ERRCODE_WRONG_OBJECT_TYPE), + errmsg("FILTER is not implemented in non-aggregate window functions"), + parser_errposition(pstate, location))); + + /* * ordered aggs not allowed in windows yet */ if (agg_order != NIL) diff --git a/src/backend/utils/adt/ruleutils.c b/src/backend/utils/adt/ruleutils.c index cf9ce3f..976bc98 100644 --- a/src/backend/utils/adt/ruleutils.c +++ b/src/backend/utils/adt/ruleutils.c @@ -7424,6 +7424,13 @@ get_agg_expr(Aggref *aggref, deparse_context *context) appendStringInfoString(buf, " ORDER BY "); get_rule_orderby(aggref->aggorder, aggref->args, false, context); } + + if (aggref->aggfilter != NULL) + { + appendStringInfoString(buf, ") FILTER (WHERE "); + get_rule_expr((Node *) aggref->aggfilter, context, false); + } + appendStringInfoChar(buf, ')'); } @@ -7461,6 +7468,13 @@ get_windowfunc_expr(WindowFunc *wfunc, deparse_context *context) appendStringInfoChar(buf, '*'); else get_rule_expr((Node *) wfunc->args, context, true); + + if (wfunc->aggfilter != NULL) + { + appendStringInfoString(buf, ") FILTER (WHERE "); + get_rule_expr((Node *) wfunc->aggfilter, context, false); + } + appendStringInfoString(buf, ") OVER "); foreach(l, context->windowClause) diff --git a/src/include/nodes/execnodes.h b/src/include/nodes/execnodes.h index 4f77016..5de5db7 100644 --- a/src/include/nodes/execnodes.h +++ b/src/include/nodes/execnodes.h @@ -584,6 +584,7 @@ typedef struct AggrefExprState { ExprState xprstate; List *args; /* states of argument expressions */ + ExprState *aggfilter; /* FILTER expression */ int aggno; /* ID number for agg within its plan node */ } AggrefExprState; @@ -595,6 +596,7 @@ typedef struct WindowFuncExprState { ExprState xprstate; List *args; /* states of argument expressions */ + ExprState *aggfilter; /* FILTER expression */ int wfuncno; /* ID number for wfunc within its plan node */ } WindowFuncExprState; diff --git a/src/include/nodes/parsenodes.h b/src/include/nodes/parsenodes.h index de22dff..fd6cb5a 100644 --- a/src/include/nodes/parsenodes.h +++ b/src/include/nodes/parsenodes.h @@ -283,8 +283,8 @@ typedef struct CollateClause * agg_star indicates we saw a 'foo(*)' construct, while agg_distinct * indicates we saw 'foo(DISTINCT ...)'. In any of these cases, the * construct *must* be an aggregate call. Otherwise, it might be either an - * aggregate or some other kind of function. However, if OVER is present - * it had better be an aggregate or window function. + * aggregate or some other kind of function. However, if FILTER or OVER is + * present it had better be an aggregate or window function. * * Normally, you'd initialize this via makeFuncCall() and then only * change the parts of the struct its defaults don't match afterwards @@ -297,6 +297,7 @@ typedef struct FuncCall List *funcname; /* qualified name of function */ List *args; /* the arguments (list of exprs) */ List *agg_order; /* ORDER BY (list of SortBy) */ + Node *agg_filter; /* FILTER clause, if any */ bool agg_star; /* argument was really '*' */ bool agg_distinct; /* arguments were labeled DISTINCT */ bool func_variadic; /* last argument was labeled VARIADIC */ diff --git a/src/include/nodes/primnodes.h b/src/include/nodes/primnodes.h index 75b716a..a778951 100644 --- a/src/include/nodes/primnodes.h +++ b/src/include/nodes/primnodes.h @@ -247,6 +247,7 @@ typedef struct Aggref List *args; /* arguments and sort expressions */ List *aggorder; /* ORDER BY (list of SortGroupClause) */ List *aggdistinct; /* DISTINCT (list of SortGroupClause) */ + Expr *aggfilter; /* FILTER expression */ bool aggstar; /* TRUE if argument list was really '*' */ Index agglevelsup; /* > 0 if agg belongs to outer query */ int location; /* token location, or -1 if unknown */ @@ -263,6 +264,7 @@ typedef struct WindowFunc Oid wincollid; /* OID of collation of result */ Oid inputcollid; /* OID of collation that function should use */ List *args; /* arguments to the window function */ + Expr *aggfilter; /* FILTER expression */ Index winref; /* index of associated WindowClause */ bool winstar; /* TRUE if argument list was really '*' */ bool winagg; /* is function a simple aggregate? */ diff --git a/src/include/parser/kwlist.h b/src/include/parser/kwlist.h index b3d72a9..287f78e 100644 --- a/src/include/parser/kwlist.h +++ b/src/include/parser/kwlist.h @@ -155,6 +155,7 @@ PG_KEYWORD("extract", EXTRACT, COL_NAME_KEYWORD) PG_KEYWORD("false", FALSE_P, RESERVED_KEYWORD) PG_KEYWORD("family", FAMILY, UNRESERVED_KEYWORD) PG_KEYWORD("fetch", FETCH, RESERVED_KEYWORD) +PG_KEYWORD("filter", FILTER, UNRESERVED_KEYWORD) PG_KEYWORD("first", FIRST_P, UNRESERVED_KEYWORD) PG_KEYWORD("float", FLOAT_P, COL_NAME_KEYWORD) PG_KEYWORD("following", FOLLOWING, UNRESERVED_KEYWORD) diff --git a/src/include/parser/parse_func.h b/src/include/parser/parse_func.h index 6e09dc4..d63cb95 100644 --- a/src/include/parser/parse_func.h +++ b/src/include/parser/parse_func.h @@ -42,10 +42,9 @@ typedef enum } FuncDetailCode; -extern Node *ParseFuncOrColumn(ParseState *pstate, - List *funcname, List *fargs, - List *agg_order, bool agg_star, bool agg_distinct, - bool func_variadic, +extern Node *ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, + List *agg_order, Expr *agg_filter, + bool agg_star, bool agg_distinct, bool func_variadic, WindowDef *over, bool is_column, int location); extern FuncDetailCode func_get_detail(List *funcname, diff --git a/src/include/parser/parse_node.h b/src/include/parser/parse_node.h index 49ca764..bea3b07 100644 --- a/src/include/parser/parse_node.h +++ b/src/include/parser/parse_node.h @@ -39,6 +39,7 @@ typedef enum ParseExprKind EXPR_KIND_FROM_FUNCTION, /* function in FROM clause */ EXPR_KIND_WHERE, /* WHERE */ EXPR_KIND_HAVING, /* HAVING */ + EXPR_KIND_FILTER, /* FILTER */ EXPR_KIND_WINDOW_PARTITION, /* window definition PARTITION BY */ EXPR_KIND_WINDOW_ORDER, /* window definition ORDER BY */ EXPR_KIND_WINDOW_FRAME_RANGE, /* window frame clause with RANGE */ diff --git a/src/test/regress/expected/aggregates.out b/src/test/regress/expected/aggregates.out index d379c0d..7fa9005 100644 --- a/src/test/regress/expected/aggregates.out +++ b/src/test/regress/expected/aggregates.out @@ -1154,3 +1154,98 @@ select string_agg(v, decode('ee', 'hex')) from bytea_test_table; (1 row) drop table bytea_test_table; +-- FILTER tests +select min(unique1) filter (where unique1 > 100) from tenk1; + min +----- + 101 +(1 row) + +select ten, sum(distinct four) filter (where four::text ~ '123') from onek a +group by ten; + ten | sum +-----+----- + 0 | + 1 | + 2 | + 3 | + 4 | + 5 | + 6 | + 7 | + 8 | + 9 | +(10 rows) + +select ten, sum(distinct four) filter (where four > 10) from onek a +group by ten +having exists (select 1 from onek b where sum(distinct a.four) = b.four); + ten | sum +-----+----- + 0 | + 2 | + 4 | + 6 | + 8 | +(5 rows) + +select max(foo COLLATE "C") filter (where (bar collate "POSIX") > '0') +from (values ('a', 'b')) AS v(foo,bar); + max +----- + a +(1 row) + +-- outer reference in FILTER (PostgreSQL extension) +select (select count(*) + from (values (1)) t0(inner_c)) +from (values (2),(3)) t1(outer_c); -- inner query is aggregation query + count +------- + 1 + 1 +(2 rows) + +select (select count(*) filter (where outer_c <> 0) + from (values (1)) t0(inner_c)) +from (values (2),(3)) t1(outer_c); -- outer query is aggregation query + count +------- + 2 +(1 row) + +select (select count(inner_c) filter (where outer_c <> 0) + from (values (1)) t0(inner_c)) +from (values (2),(3)) t1(outer_c); -- inner query is aggregation query + count +------- + 1 + 1 +(2 rows) + +select + (select max((select i.unique2 from tenk1 i where i.unique1 = o.unique1)) + filter (where o.unique1 < 10)) +from tenk1 o; -- outer query is aggregation query + max +------ + 9998 +(1 row) + +-- subquery in FILTER clause (PostgreSQL extension) +select sum(unique1) FILTER (WHERE + unique1 IN (SELECT unique1 FROM onek where unique1 < 100)) FROM tenk1; + sum +------ + 4950 +(1 row) + +-- exercise lots of aggregate parts with FILTER +select aggfns(distinct a,b,c order by a,c using ~<~,b) filter (where a > 1) + from (values (1,3,'foo'),(0,null,null),(2,2,'bar'),(3,1,'baz')) v(a,b,c), + generate_series(1,2) i; + aggfns +--------------------------- + {"(2,2,bar)","(3,1,baz)"} +(1 row) + diff --git a/src/test/regress/expected/window.out b/src/test/regress/expected/window.out index ecc1c2c..7b31d13 100644 --- a/src/test/regress/expected/window.out +++ b/src/test/regress/expected/window.out @@ -1020,5 +1020,18 @@ SELECT ntile(0) OVER (ORDER BY ten), ten, four FROM tenk1; ERROR: argument of ntile must be greater than zero SELECT nth_value(four, 0) OVER (ORDER BY ten), ten, four FROM tenk1; ERROR: argument of nth_value must be greater than zero +-- filter +SELECT sum(salary), row_number() OVER (ORDER BY depname), sum( + sum(salary) FILTER (WHERE enroll_date > '2007-01-01') +) FILTER (WHERE depname <> 'sales') OVER (ORDER BY depname DESC) AS "filtered_sum", + depname +FROM empsalary GROUP BY depname; + sum | row_number | filtered_sum | depname +-------+------------+--------------+----------- + 14600 | 3 | | sales + 7400 | 2 | 3500 | personnel + 25100 | 1 | 22600 | develop +(3 rows) + -- cleanup DROP TABLE empsalary; diff --git a/src/test/regress/sql/aggregates.sql b/src/test/regress/sql/aggregates.sql index 38d4757..5c0196f 100644 --- a/src/test/regress/sql/aggregates.sql +++ b/src/test/regress/sql/aggregates.sql @@ -442,3 +442,41 @@ select string_agg(v, NULL) from bytea_test_table; select string_agg(v, decode('ee', 'hex')) from bytea_test_table; drop table bytea_test_table; + +-- FILTER tests + +select min(unique1) filter (where unique1 > 100) from tenk1; + +select ten, sum(distinct four) filter (where four::text ~ '123') from onek a +group by ten; + +select ten, sum(distinct four) filter (where four > 10) from onek a +group by ten +having exists (select 1 from onek b where sum(distinct a.four) = b.four); + +select max(foo COLLATE "C") filter (where (bar collate "POSIX") > '0') +from (values ('a', 'b')) AS v(foo,bar); + +-- outer reference in FILTER (PostgreSQL extension) +select (select count(*) + from (values (1)) t0(inner_c)) +from (values (2),(3)) t1(outer_c); -- inner query is aggregation query +select (select count(*) filter (where outer_c <> 0) + from (values (1)) t0(inner_c)) +from (values (2),(3)) t1(outer_c); -- outer query is aggregation query +select (select count(inner_c) filter (where outer_c <> 0) + from (values (1)) t0(inner_c)) +from (values (2),(3)) t1(outer_c); -- inner query is aggregation query +select + (select max((select i.unique2 from tenk1 i where i.unique1 = o.unique1)) + filter (where o.unique1 < 10)) +from tenk1 o; -- outer query is aggregation query + +-- subquery in FILTER clause (PostgreSQL extension) +select sum(unique1) FILTER (WHERE + unique1 IN (SELECT unique1 FROM onek where unique1 < 100)) FROM tenk1; + +-- exercise lots of aggregate parts with FILTER +select aggfns(distinct a,b,c order by a,c using ~<~,b) filter (where a > 1) + from (values (1,3,'foo'),(0,null,null),(2,2,'bar'),(3,1,'baz')) v(a,b,c), + generate_series(1,2) i; diff --git a/src/test/regress/sql/window.sql b/src/test/regress/sql/window.sql index 769be0f..6ee3696 100644 --- a/src/test/regress/sql/window.sql +++ b/src/test/regress/sql/window.sql @@ -264,5 +264,13 @@ SELECT ntile(0) OVER (ORDER BY ten), ten, four FROM tenk1; SELECT nth_value(four, 0) OVER (ORDER BY ten), ten, four FROM tenk1; +-- filter + +SELECT sum(salary), row_number() OVER (ORDER BY depname), sum( + sum(salary) FILTER (WHERE enroll_date > '2007-01-01') +) FILTER (WHERE depname <> 'sales') OVER (ORDER BY depname DESC) AS "filtered_sum", + depname +FROM empsalary GROUP BY depname; + -- cleanup DROP TABLE empsalary;