From 27dc7e240bfd230ee1315cc00577a6ed72aff94a Mon Sep 17 00:00:00 2001 From: Tom Lane Date: Thu, 24 Mar 2011 20:30:14 -0400 Subject: [PATCH] Fix handling of collation in SQL-language functions. Ensure that parameter symbols receive collation from the function's resolved input collation, and fix inlining to behave properly. BTW, this commit lays about 90% of the infrastructure needed to support use of argument names in SQL functions. Parsing of parameters is now done via the parser-hook infrastructure ... we'd just need to supply a column-ref hook ... --- src/backend/catalog/pg_proc.c | 13 +- src/backend/executor/functions.c | 174 ++++++++++++++---- src/backend/optimizer/util/clauses.c | 106 +++++++---- src/backend/parser/parse_param.c | 5 + src/include/executor/functions.h | 9 + .../regress/expected/collate.linux.utf8.out | 53 ++++++ src/test/regress/sql/collate.linux.utf8.sql | 19 ++ 7 files changed, 303 insertions(+), 76 deletions(-) diff --git a/src/backend/catalog/pg_proc.c b/src/backend/catalog/pg_proc.c index 2523653f37..6138165cc3 100644 --- a/src/backend/catalog/pg_proc.c +++ b/src/backend/catalog/pg_proc.c @@ -845,16 +845,21 @@ fmgr_sql_validator(PG_FUNCTION_ARGS) * OK to do full precheck: analyze and rewrite the queries, * then verify the result type. */ + SQLFunctionParseInfoPtr pinfo; + + /* But first, set up parameter information */ + pinfo = prepare_sql_fn_parse_info(tuple, NULL, InvalidOid); + querytree_list = NIL; foreach(lc, raw_parsetree_list) { Node *parsetree = (Node *) lfirst(lc); List *querytree_sublist; - querytree_sublist = pg_analyze_and_rewrite(parsetree, - prosrc, - proc->proargtypes.values, - proc->pronargs); + querytree_sublist = pg_analyze_and_rewrite_params(parsetree, + prosrc, + (ParserSetupHook) sql_fn_parser_setup, + pinfo); querytree_list = list_concat(querytree_list, querytree_sublist); } diff --git a/src/backend/executor/functions.c b/src/backend/executor/functions.c index 0421be57a4..ce3b77b847 100644 --- a/src/backend/executor/functions.c +++ b/src/backend/executor/functions.c @@ -81,7 +81,8 @@ typedef struct char *fname; /* function name (for error msgs) */ char *src; /* function body text (for error msgs) */ - Oid *argtypes; /* resolved types of arguments */ + SQLFunctionParseInfoPtr pinfo; /* data for parser callback hooks */ + Oid rettype; /* actual return type */ int16 typlen; /* length of the return type */ bool typbyval; /* true if return type is pass by value */ @@ -108,8 +109,21 @@ typedef struct typedef SQLFunctionCache *SQLFunctionCachePtr; +/* + * Data structure needed by the parser callback hooks to resolve parameter + * references during parsing of a SQL function's body. This is separate from + * SQLFunctionCache since we sometimes do parsing separately from execution. + */ +typedef struct SQLFunctionParseInfo +{ + Oid *argtypes; /* resolved types of input arguments */ + int nargs; /* number of input arguments */ + Oid collation; /* function's input collation, if known */ +} SQLFunctionParseInfo; + /* non-export function prototypes */ +static Node *sql_fn_param_ref(ParseState *pstate, ParamRef *pref); static List *init_execution_state(List *queryTree_list, SQLFunctionCachePtr fcache, bool lazyEvalOK); @@ -131,6 +145,112 @@ static void sqlfunction_shutdown(DestReceiver *self); static void sqlfunction_destroy(DestReceiver *self); +/* + * Prepare the SQLFunctionParseInfo struct for parsing a SQL function body + * + * This includes resolving actual types of polymorphic arguments. + * + * call_expr can be passed as NULL, but then we will fail if there are any + * polymorphic arguments. + */ +SQLFunctionParseInfoPtr +prepare_sql_fn_parse_info(HeapTuple procedureTuple, + Node *call_expr, + Oid inputCollation) +{ + SQLFunctionParseInfoPtr pinfo; + Form_pg_proc procedureStruct = (Form_pg_proc) GETSTRUCT(procedureTuple); + int nargs; + + pinfo = (SQLFunctionParseInfoPtr) palloc0(sizeof(SQLFunctionParseInfo)); + + /* Save the function's input collation */ + pinfo->collation = inputCollation; + + /* + * Copy input argument types from the pg_proc entry, then resolve any + * polymorphic types. + */ + pinfo->nargs = nargs = procedureStruct->pronargs; + if (nargs > 0) + { + Oid *argOidVect; + int argnum; + + argOidVect = (Oid *) palloc(nargs * sizeof(Oid)); + memcpy(argOidVect, + procedureStruct->proargtypes.values, + nargs * sizeof(Oid)); + + for (argnum = 0; argnum < nargs; argnum++) + { + Oid argtype = argOidVect[argnum]; + + if (IsPolymorphicType(argtype)) + { + argtype = get_call_expr_argtype(call_expr, argnum); + if (argtype == InvalidOid) + ereport(ERROR, + (errcode(ERRCODE_DATATYPE_MISMATCH), + errmsg("could not determine actual type of argument declared %s", + format_type_be(argOidVect[argnum])))); + argOidVect[argnum] = argtype; + } + } + + pinfo->argtypes = argOidVect; + } + + return pinfo; +} + +/* + * Parser setup hook for parsing a SQL function body. + */ +void +sql_fn_parser_setup(struct ParseState *pstate, SQLFunctionParseInfoPtr pinfo) +{ + /* Later we might use these hooks to support parameter names */ + pstate->p_pre_columnref_hook = NULL; + pstate->p_post_columnref_hook = NULL; + pstate->p_paramref_hook = sql_fn_param_ref; + /* no need to use p_coerce_param_hook */ + pstate->p_ref_hook_state = (void *) pinfo; +} + +/* + * sql_fn_param_ref parser callback for ParamRefs ($n symbols) + */ +static Node * +sql_fn_param_ref(ParseState *pstate, ParamRef *pref) +{ + SQLFunctionParseInfoPtr pinfo = (SQLFunctionParseInfoPtr) pstate->p_ref_hook_state; + int paramno = pref->number; + Param *param; + + /* Check parameter number is valid */ + if (paramno <= 0 || paramno > pinfo->nargs) + return NULL; /* unknown parameter number */ + + param = makeNode(Param); + param->paramkind = PARAM_EXTERN; + param->paramid = paramno; + param->paramtype = pinfo->argtypes[paramno - 1]; + param->paramtypmod = -1; + param->paramcollid = get_typcollation(param->paramtype); + param->location = pref->location; + + /* + * If we have a function input collation, allow it to override the + * type-derived collation for parameter symbols. (XXX perhaps this should + * not happen if the type collation is not default?) + */ + if (OidIsValid(pinfo->collation) && OidIsValid(param->paramcollid)) + param->paramcollid = pinfo->collation; + + return (Node *) param; +} + /* * Set up the per-query execution_state records for a SQL function. * @@ -239,7 +359,9 @@ init_execution_state(List *queryTree_list, return eslist; } -/* Initialize the SQLFunctionCache for a SQL function */ +/* + * Initialize the SQLFunctionCache for a SQL function + */ static void init_sql_fcache(FmgrInfo *finfo, bool lazyEvalOK) { @@ -248,8 +370,6 @@ init_sql_fcache(FmgrInfo *finfo, bool lazyEvalOK) HeapTuple procedureTuple; Form_pg_proc procedureStruct; SQLFunctionCachePtr fcache; - Oid *argOidVect; - int nargs; List *raw_parsetree_list; List *queryTree_list; List *flat_query_list; @@ -302,37 +422,13 @@ init_sql_fcache(FmgrInfo *finfo, bool lazyEvalOK) (procedureStruct->provolatile != PROVOLATILE_VOLATILE); /* - * We need the actual argument types to pass to the parser. + * We need the actual argument types to pass to the parser. Also make + * sure that parameter symbols are considered to have the function's + * resolved input collation. */ - nargs = procedureStruct->pronargs; - if (nargs > 0) - { - int argnum; - - argOidVect = (Oid *) palloc(nargs * sizeof(Oid)); - memcpy(argOidVect, - procedureStruct->proargtypes.values, - nargs * sizeof(Oid)); - /* Resolve any polymorphic argument types */ - for (argnum = 0; argnum < nargs; argnum++) - { - Oid argtype = argOidVect[argnum]; - - if (IsPolymorphicType(argtype)) - { - argtype = get_fn_expr_argtype(finfo, argnum); - if (argtype == InvalidOid) - ereport(ERROR, - (errcode(ERRCODE_DATATYPE_MISMATCH), - errmsg("could not determine actual type of argument declared %s", - format_type_be(argOidVect[argnum])))); - argOidVect[argnum] = argtype; - } - } - } - else - argOidVect = NULL; - fcache->argtypes = argOidVect; + fcache->pinfo = prepare_sql_fn_parse_info(procedureTuple, + finfo->fn_expr, + finfo->fn_collation); /* * And of course we need the function body text. @@ -364,10 +460,10 @@ init_sql_fcache(FmgrInfo *finfo, bool lazyEvalOK) Node *parsetree = (Node *) lfirst(lc); List *queryTree_sublist; - queryTree_sublist = pg_analyze_and_rewrite(parsetree, - fcache->src, - argOidVect, - nargs); + queryTree_sublist = pg_analyze_and_rewrite_params(parsetree, + fcache->src, + (ParserSetupHook) sql_fn_parser_setup, + fcache->pinfo); queryTree_list = lappend(queryTree_list, queryTree_sublist); flat_query_list = list_concat(flat_query_list, list_copy(queryTree_sublist)); @@ -583,7 +679,7 @@ postquel_sub_params(SQLFunctionCachePtr fcache, prm->value = fcinfo->arg[i]; prm->isnull = fcinfo->argnull[i]; prm->pflags = 0; - prm->ptype = fcache->argtypes[i]; + prm->ptype = fcache->pinfo->argtypes[i]; } } else diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c index 7b31b6b4fa..0fddbae60a 100644 --- a/src/backend/optimizer/util/clauses.c +++ b/src/backend/optimizer/util/clauses.c @@ -118,8 +118,8 @@ static Expr *evaluate_function(Oid funcid, Oid result_type, int32 result_typmod, Oid input_collid, List *args, HeapTuple func_tuple, eval_const_expressions_context *context); -static Expr *inline_function(Oid funcid, Oid result_type, List *args, - HeapTuple func_tuple, +static Expr *inline_function(Oid funcid, Oid result_type, Oid input_collid, + List *args, HeapTuple func_tuple, eval_const_expressions_context *context); static Node *substitute_actual_parameters(Node *expr, int nargs, List *args, int *usecounts); @@ -3431,7 +3431,7 @@ simplify_function(Oid funcid, Oid result_type, int32 result_typmod, func_tuple, context); if (!newexpr && allow_inline) - newexpr = inline_function(funcid, result_type, *args, + newexpr = inline_function(funcid, result_type, input_collid, *args, func_tuple, context); ReleaseSysCache(func_tuple); @@ -3798,12 +3798,11 @@ evaluate_function(Oid funcid, Oid result_type, int32 result_typmod, * simplify the function. */ static Expr * -inline_function(Oid funcid, Oid result_type, List *args, +inline_function(Oid funcid, Oid result_type, Oid input_collid, List *args, HeapTuple func_tuple, eval_const_expressions_context *context) { Form_pg_proc funcform = (Form_pg_proc) GETSTRUCT(func_tuple); - Oid *argtypes; char *src; Datum tmp; bool isNull; @@ -3812,6 +3811,9 @@ inline_function(Oid funcid, Oid result_type, List *args, MemoryContext mycxt; inline_error_callback_arg callback_arg; ErrorContextCallback sqlerrcontext; + FuncExpr *fexpr; + SQLFunctionParseInfoPtr pinfo; + ParseState *pstate; List *raw_parsetree_list; Query *querytree; Node *newexpr; @@ -3875,17 +3877,25 @@ inline_function(Oid funcid, Oid result_type, List *args, sqlerrcontext.previous = error_context_stack; error_context_stack = &sqlerrcontext; - /* Check for polymorphic arguments, and substitute actual arg types */ - argtypes = (Oid *) palloc(funcform->pronargs * sizeof(Oid)); - memcpy(argtypes, funcform->proargtypes.values, - funcform->pronargs * sizeof(Oid)); - for (i = 0; i < funcform->pronargs; i++) - { - if (IsPolymorphicType(argtypes[i])) - { - argtypes[i] = exprType((Node *) list_nth(args, i)); - } - } + /* + * Set up to handle parameters while parsing the function body. We need a + * dummy FuncExpr node containing the already-simplified arguments to pass + * to prepare_sql_fn_parse_info. (It is really only needed if there are + * some polymorphic arguments, but for simplicity we always build it.) + */ + fexpr = makeNode(FuncExpr); + fexpr->funcid = funcid; + fexpr->funcresulttype = result_type; + fexpr->funcretset = false; + fexpr->funcformat = COERCE_DONTCARE; /* doesn't matter */ + fexpr->funccollid = InvalidOid; /* doesn't matter */ + fexpr->inputcollid = input_collid; + fexpr->args = args; + fexpr->location = -1; + + pinfo = prepare_sql_fn_parse_info(func_tuple, + (Node *) fexpr, + input_collid); /* * We just do parsing and parse analysis, not rewriting, because rewriting @@ -3897,8 +3907,13 @@ inline_function(Oid funcid, Oid result_type, List *args, if (list_length(raw_parsetree_list) != 1) goto fail; - querytree = parse_analyze(linitial(raw_parsetree_list), src, - argtypes, funcform->pronargs); + pstate = make_parsestate(NULL); + pstate->p_sourcetext = src; + sql_fn_parser_setup(pstate, pinfo); + + querytree = transformStmt(pstate, linitial(raw_parsetree_list)); + + free_parsestate(pstate); /* * The single command must be a simple "SELECT expression". @@ -4030,6 +4045,28 @@ inline_function(Oid funcid, Oid result_type, List *args, MemoryContextDelete(mycxt); + /* + * If the result is of a collatable type, force the result to expose + * the correct collation. In most cases this does not matter, but + * it's possible that the function result is used directly as a sort key + * or in other places where we expect exprCollation() to tell the truth. + */ + if (OidIsValid(input_collid)) + { + Oid exprcoll = exprCollation(newexpr); + + if (OidIsValid(exprcoll) && exprcoll != input_collid) + { + CollateExpr *newnode = makeNode(CollateExpr); + + newnode->arg = (Expr *) newexpr; + newnode->collOid = input_collid; + newnode->location = -1; + + newexpr = (Node *) newnode; + } + } + /* * Since there is now no trace of the function in the plan tree, we must * explicitly record the plan's dependency on the function. @@ -4219,7 +4256,6 @@ inline_set_returning_function(PlannerInfo *root, RangeTblEntry *rte) Oid func_oid; HeapTuple func_tuple; Form_pg_proc funcform; - Oid *argtypes; char *src; Datum tmp; bool isNull; @@ -4229,10 +4265,10 @@ inline_set_returning_function(PlannerInfo *root, RangeTblEntry *rte) List *saveInvalItems; inline_error_callback_arg callback_arg; ErrorContextCallback sqlerrcontext; + SQLFunctionParseInfoPtr pinfo; List *raw_parsetree_list; List *querytree_list; Query *querytree; - int i; Assert(rte->rtekind == RTE_FUNCTION); @@ -4366,17 +4402,14 @@ inline_set_returning_function(PlannerInfo *root, RangeTblEntry *rte) if (list_length(fexpr->args) != funcform->pronargs) goto fail; - /* Check for polymorphic arguments, and substitute actual arg types */ - argtypes = (Oid *) palloc(funcform->pronargs * sizeof(Oid)); - memcpy(argtypes, funcform->proargtypes.values, - funcform->pronargs * sizeof(Oid)); - for (i = 0; i < funcform->pronargs; i++) - { - if (IsPolymorphicType(argtypes[i])) - { - argtypes[i] = exprType((Node *) list_nth(fexpr->args, i)); - } - } + /* + * Set up to handle parameters while parsing the function body. We + * can use the FuncExpr just created as the input for + * prepare_sql_fn_parse_info. + */ + pinfo = prepare_sql_fn_parse_info(func_tuple, + (Node *) fexpr, + fexpr->inputcollid); /* * Parse, analyze, and rewrite (unlike inline_function(), we can't skip @@ -4387,8 +4420,10 @@ inline_set_returning_function(PlannerInfo *root, RangeTblEntry *rte) if (list_length(raw_parsetree_list) != 1) goto fail; - querytree_list = pg_analyze_and_rewrite(linitial(raw_parsetree_list), src, - argtypes, funcform->pronargs); + querytree_list = pg_analyze_and_rewrite_params(linitial(raw_parsetree_list), + src, + (ParserSetupHook) sql_fn_parser_setup, + pinfo); if (list_length(querytree_list) != 1) goto fail; querytree = linitial(querytree_list); @@ -4461,6 +4496,11 @@ inline_set_returning_function(PlannerInfo *root, RangeTblEntry *rte) error_context_stack = sqlerrcontext.previous; ReleaseSysCache(func_tuple); + /* + * We don't have to fix collations here because the upper query is + * already parsed, ie, the collations in the RTE are what count. + */ + /* * Since there is now no trace of the function in the plan tree, we must * explicitly record the plan's dependency on the function. diff --git a/src/backend/parser/parse_param.c b/src/backend/parser/parse_param.c index 1cf255669a..1895f92d7c 100644 --- a/src/backend/parser/parse_param.c +++ b/src/backend/parser/parse_param.c @@ -231,6 +231,11 @@ variable_coerce_param_hook(ParseState *pstate, Param *param, */ param->paramtypmod = -1; + /* + * This module always sets a Param's collation to be the default for + * its datatype. If that's not what you want, you should be using + * the more general parser substitution hooks. + */ param->paramcollid = get_typcollation(param->paramtype); /* Use the leftmost of the param's and coercion's locations */ diff --git a/src/include/executor/functions.h b/src/include/executor/functions.h index e725b24be2..b926e99cbb 100644 --- a/src/include/executor/functions.h +++ b/src/include/executor/functions.h @@ -17,9 +17,18 @@ #include "nodes/execnodes.h" #include "tcop/dest.h" +/* This struct is known only within executor/functions.c */ +typedef struct SQLFunctionParseInfo *SQLFunctionParseInfoPtr; extern Datum fmgr_sql(PG_FUNCTION_ARGS); +extern SQLFunctionParseInfoPtr prepare_sql_fn_parse_info(HeapTuple procedureTuple, + Node *call_expr, + Oid inputCollation); + +extern void sql_fn_parser_setup(struct ParseState *pstate, + SQLFunctionParseInfoPtr pinfo); + extern bool check_sql_fn_retval(Oid func_id, Oid rettype, List *queryTreeList, bool *modifyTargetList, diff --git a/src/test/regress/expected/collate.linux.utf8.out b/src/test/regress/expected/collate.linux.utf8.out index 879f97327d..4680ffd009 100644 --- a/src/test/regress/expected/collate.linux.utf8.out +++ b/src/test/regress/expected/collate.linux.utf8.out @@ -686,6 +686,59 @@ SELECT a, CAST(b AS varchar) FROM collate_test3 ORDER BY 2; 2 | äbc (4 rows) +-- propagation of collation in inlined and non-inlined cases +CREATE FUNCTION mylt (text, text) RETURNS boolean LANGUAGE sql + AS $$ select $1 < $2 $$; +CREATE FUNCTION mylt_noninline (text, text) RETURNS boolean LANGUAGE sql + AS $$ select $1 < $2 limit 1 $$; +SELECT a.b AS a, b.b AS b, a.b < b.b AS lt, + mylt(a.b, b.b), mylt_noninline(a.b, b.b) +FROM collate_test1 a, collate_test1 b +ORDER BY a.b, b.b; + a | b | lt | mylt | mylt_noninline +-----+-----+----+------+---------------- + abc | abc | f | f | f + abc | ABC | t | t | t + abc | äbc | t | t | t + abc | bbc | t | t | t + ABC | abc | f | f | f + ABC | ABC | f | f | f + ABC | äbc | t | t | t + ABC | bbc | t | t | t + äbc | abc | f | f | f + äbc | ABC | f | f | f + äbc | äbc | f | f | f + äbc | bbc | t | t | t + bbc | abc | f | f | f + bbc | ABC | f | f | f + bbc | äbc | f | f | f + bbc | bbc | f | f | f +(16 rows) + +SELECT a.b AS a, b.b AS b, a.b < b.b COLLATE "C" AS lt, + mylt(a.b, b.b COLLATE "C"), mylt_noninline(a.b, b.b COLLATE "C") +FROM collate_test1 a, collate_test1 b +ORDER BY a.b, b.b; + a | b | lt | mylt | mylt_noninline +-----+-----+----+------+---------------- + abc | abc | f | f | f + abc | ABC | f | f | f + abc | äbc | t | t | t + abc | bbc | t | t | t + ABC | abc | t | t | t + ABC | ABC | f | f | f + ABC | äbc | t | t | t + ABC | bbc | t | t | t + äbc | abc | f | f | f + äbc | ABC | f | f | f + äbc | äbc | f | f | f + äbc | bbc | f | f | f + bbc | abc | f | f | f + bbc | ABC | f | f | f + bbc | äbc | t | t | t + bbc | bbc | f | f | f +(16 rows) + -- polymorphism SELECT * FROM unnest((SELECT array_agg(b ORDER BY b) FROM collate_test1)) ORDER BY 1; unnest diff --git a/src/test/regress/sql/collate.linux.utf8.sql b/src/test/regress/sql/collate.linux.utf8.sql index 4aec27d880..2a1f2113b3 100644 --- a/src/test/regress/sql/collate.linux.utf8.sql +++ b/src/test/regress/sql/collate.linux.utf8.sql @@ -212,6 +212,25 @@ SELECT a, CAST(b AS varchar) FROM collate_test2 ORDER BY 2; SELECT a, CAST(b AS varchar) FROM collate_test3 ORDER BY 2; +-- propagation of collation in inlined and non-inlined cases + +CREATE FUNCTION mylt (text, text) RETURNS boolean LANGUAGE sql + AS $$ select $1 < $2 $$; + +CREATE FUNCTION mylt_noninline (text, text) RETURNS boolean LANGUAGE sql + AS $$ select $1 < $2 limit 1 $$; + +SELECT a.b AS a, b.b AS b, a.b < b.b AS lt, + mylt(a.b, b.b), mylt_noninline(a.b, b.b) +FROM collate_test1 a, collate_test1 b +ORDER BY a.b, b.b; + +SELECT a.b AS a, b.b AS b, a.b < b.b COLLATE "C" AS lt, + mylt(a.b, b.b COLLATE "C"), mylt_noninline(a.b, b.b COLLATE "C") +FROM collate_test1 a, collate_test1 b +ORDER BY a.b, b.b; + + -- polymorphism SELECT * FROM unnest((SELECT array_agg(b ORDER BY b) FROM collate_test1)) ORDER BY 1; -- 2.40.0