From d2ea9059262e4119004d4f70cdaed6c530e385fc Mon Sep 17 00:00:00 2001 From: =?utf8?q?Pawe=C5=82=20Chmielowski?= Date: Tue, 23 Apr 2019 17:46:14 +0200 Subject: [PATCH] Fix handling of list arguments on pgsql --- include/ejabberd_sql_pt.hrl | 3 +- src/ejabberd_sql.erl | 28 +++++++++------ src/ejabberd_sql_pt.erl | 71 ++++++++++++++++++++++++++++++++----- 3 files changed, 81 insertions(+), 21 deletions(-) diff --git a/include/ejabberd_sql_pt.hrl b/include/ejabberd_sql_pt.hrl index 50f902fc6..5906a2efb 100644 --- a/include/ejabberd_sql_pt.hrl +++ b/include/ejabberd_sql_pt.hrl @@ -32,5 +32,4 @@ -record(sql_query, {hash, format_query, format_res, args, loc}). --record(sql_escape, {string, integer, boolean}). - +-record(sql_escape, {string, integer, boolean, in_array_string}). diff --git a/src/ejabberd_sql.erl b/src/ejabberd_sql.erl index a02aece4d..f807b0bca 100644 --- a/src/ejabberd_sql.erl +++ b/src/ejabberd_sql.erl @@ -56,7 +56,8 @@ odbcinst_config/0, init_mssql/1, keep_alive/2, - to_list/2]). + to_list/2, + to_array/2]). %% gen_fsm callbacks -export([init/1, handle_event/3, handle_sync_event/4, @@ -264,6 +265,10 @@ to_list(EscapeFun, Val) -> Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)), [<<"(">>, Escaped, <<")">>]. +to_array(EscapeFun, Val) -> + Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)), + [<<"{">>, Escaped, <<"}">>]. + encode_term(Term) -> escape(list_to_binary( erl_prettypr:format(erl_syntax:abstract(Term), @@ -676,10 +681,11 @@ generic_sql_query_format(SQLQuery) -> generic_escape() -> #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end, - integer = fun(X) -> misc:i2l(X) end, - boolean = fun(true) -> <<"1">>; + integer = fun(X) -> misc:i2l(X) end, + boolean = fun(true) -> <<"1">>; (false) -> <<"0">> - end + end, + in_array_string = fun(X) -> <<"'", (escape(X))/binary, "'">> end }. sqlite_sql_query(SQLQuery) -> @@ -693,10 +699,11 @@ sqlite_sql_query_format(SQLQuery) -> sqlite_escape() -> #sql_escape{string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end, - integer = fun(X) -> misc:i2l(X) end, - boolean = fun(true) -> <<"1">>; + integer = fun(X) -> misc:i2l(X) end, + boolean = fun(true) -> <<"1">>; (false) -> <<"0">> - end + end, + in_array_string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end }. standard_escape(S) -> @@ -717,10 +724,11 @@ pgsql_prepare(SQLQuery, State) -> pgsql_execute_escape() -> #sql_escape{string = fun(X) -> X end, - integer = fun(X) -> [misc:i2l(X)] end, - boolean = fun(true) -> "1"; + integer = fun(X) -> [misc:i2l(X)] end, + boolean = fun(true) -> "1"; (false) -> "0" - end + end, + in_array_string = fun(X) -> <<"\"", (escape(X))/binary, "\"">> end }. pgsql_execute_sql_query(SQLQuery, State) -> diff --git a/src/ejabberd_sql_pt.erl b/src/ejabberd_sql_pt.erl index 0ae04c64d..2497c2a74 100644 --- a/src/ejabberd_sql_pt.erl +++ b/src/ejabberd_sql_pt.erl @@ -42,7 +42,8 @@ res_pos = 0, server_host_used = false, used_vars = [], - use_new_schema}). + use_new_schema, + need_array_pass = false}). -define(QUERY_RECORD, "sql_query"). @@ -183,12 +184,24 @@ transform_sql(Arg) -> Pos, no_server_host), [] end, - set_pos( - make_schema_check( - make_sql_query(ParseRes), - make_sql_query(ParseResOld) - ), - Pos). + case ParseRes#state.need_array_pass of + true -> + {PR1, PR2} = perform_array_pass(ParseRes), + {PRO1, PRO2} = perform_array_pass(ParseResOld), + set_pos(make_schema_check( + erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PR2)]), + erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PR1)])]), + erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PRO2)]), + erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PRO1)])])), + Pos); + false -> + set_pos( + make_schema_check( + make_sql_query(ParseRes), + make_sql_query(ParseResOld) + ), + Pos) + end. transform_upsert(Form, TableArg, FieldsArg) -> Table = erl_syntax:string_value(TableArg), @@ -315,8 +328,23 @@ parse1([$%, $( | S], Acc, State) -> erl_syntax:atom(?ESCAPE_RECORD), erl_syntax:atom(InternalType)), erl_syntax:variable(Name)]), - State2#state{'query' = [{var, Var} | State2#state.'query'], - args = [Convert | State2#state.args], + IT2 = case InternalType of + string -> + in_array_string; + _ -> + InternalType + end, + ConvertArr = erl_syntax:application( + erl_syntax:atom(ejabberd_sql), + erl_syntax:atom(to_array), + [erl_syntax:record_access( + erl_syntax:variable(?ESCAPE_VAR), + erl_syntax:atom(?ESCAPE_RECORD), + erl_syntax:atom(IT2)), + erl_syntax:variable(Name)]), + State2#state{'query' = [[{var, Var}] | State2#state.'query'], + need_array_pass = true, + args = [[Convert, ConvertArr] | State2#state.args], params = [Var | State2#state.params], param_pos = State2#state.param_pos + 1, used_vars = [Name | State2#state.used_vars]}; @@ -389,6 +417,31 @@ make_var(V) -> Var = "__V" ++ integer_to_list(V), erl_syntax:variable(Var). +perform_array_pass(State) -> + {NQ, PQ, Rest} = lists:foldl( + fun([{var, _} = Var], {N, P, {str, Str} = Prev}) -> + Str2 = re:replace(Str, "(^|\s+)in\s*$", " = any(", [{return, list}]), + {[Var, Prev | N], [{str, ")"}, Var, {str, Str2} | P], none}; + ([{var, _}], _) -> + throw({error, State#state.loc, ["List variable not following 'in' operator"]}); + (Other, {N, P, none}) -> + {N, P, Other}; + (Other, {N, P, Prev}) -> + {[Prev | N], [Prev | P], Other} + end, {[], [], none}, State#state.query), + {NQ2, PQ2} = case Rest of + none -> + {NQ, PQ}; + _ -> {[Rest | NQ], [Rest | PQ]} + end, + {NA, PA} = lists:foldl( + fun([V1, V2], {N, P}) -> + {[V1 | N], [V2 | P]}; + (Other, {N, P}) -> + {[Other | N], [Other | P]} + end, {[], []}, State#state.args), + {State#state{query = lists:reverse(NQ2), args = lists:reverse(NA), need_array_pass = false}, + State#state{query = lists:reverse(PQ2), args = lists:reverse(PA), need_array_pass = false}}. make_sql_query(State) -> Hash = erlang:phash2(State#state{loc = undefined, use_new_schema = true}), -- 2.40.0