]> granicus.if.org Git - ejabberd/commitdiff
Fix handling of list arguments on pgsql
authorPaweł Chmielowski <pchmielowski@process-one.net>
Tue, 23 Apr 2019 15:46:14 +0000 (17:46 +0200)
committerPaweł Chmielowski <pchmielowski@process-one.net>
Tue, 23 Apr 2019 15:46:42 +0000 (17:46 +0200)
include/ejabberd_sql_pt.hrl
src/ejabberd_sql.erl
src/ejabberd_sql_pt.erl

index 50f902fc6a50265665e5c78ab355ee3a51cbdbb0..5906a2efb6dbce715693731362e170ff039ac5d8 100644 (file)
@@ -32,5 +32,4 @@
 
 -record(sql_query, {hash, format_query, format_res, args, loc}).
 
--record(sql_escape, {string, integer, boolean}).
-
+-record(sql_escape, {string, integer, boolean, in_array_string}).
index a02aece4d4739b14e5631d951b4c4c0367d33d55..f807b0bca8672f05564f02c5ea8e2ff057312a50 100644 (file)
@@ -56,7 +56,8 @@
         odbcinst_config/0,
         init_mssql/1,
         keep_alive/2,
-        to_list/2]).
+        to_list/2,
+        to_array/2]).
 
 %% gen_fsm callbacks
 -export([init/1, handle_event/3, handle_sync_event/4,
@@ -264,6 +265,10 @@ to_list(EscapeFun, Val) ->
     Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
     [<<"(">>, Escaped, <<")">>].
 
+to_array(EscapeFun, Val) ->
+    Escaped = lists:join(<<",">>, lists:map(EscapeFun, Val)),
+    [<<"{">>, Escaped, <<"}">>].
+
 encode_term(Term) ->
     escape(list_to_binary(
              erl_prettypr:format(erl_syntax:abstract(Term),
@@ -676,10 +681,11 @@ generic_sql_query_format(SQLQuery) ->
 
 generic_escape() ->
     #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
-                integer = fun(X) -> misc:i2l(X) end,
-                boolean = fun(true) -> <<"1">>;
+               integer = fun(X) -> misc:i2l(X) end,
+               boolean = fun(true) -> <<"1">>;
                              (false) -> <<"0">>
-                          end
+                          end,
+               in_array_string = fun(X) -> <<"'", (escape(X))/binary, "'">> end
                }.
 
 sqlite_sql_query(SQLQuery) ->
@@ -693,10 +699,11 @@ sqlite_sql_query_format(SQLQuery) ->
 
 sqlite_escape() ->
     #sql_escape{string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end,
-                integer = fun(X) -> misc:i2l(X) end,
-                boolean = fun(true) -> <<"1">>;
+               integer = fun(X) -> misc:i2l(X) end,
+               boolean = fun(true) -> <<"1">>;
                              (false) -> <<"0">>
-                          end
+                          end,
+               in_array_string = fun(X) -> <<"'", (standard_escape(X))/binary, "'">> end
                }.
 
 standard_escape(S) ->
@@ -717,10 +724,11 @@ pgsql_prepare(SQLQuery, State) ->
 
 pgsql_execute_escape() ->
     #sql_escape{string = fun(X) -> X end,
-                integer = fun(X) -> [misc:i2l(X)] end,
-                boolean = fun(true) -> "1";
+               integer = fun(X) -> [misc:i2l(X)] end,
+               boolean = fun(true) -> "1";
                              (false) -> "0"
-                          end
+                          end,
+               in_array_string = fun(X) -> <<"\"", (escape(X))/binary, "\"">> end
                }.
 
 pgsql_execute_sql_query(SQLQuery, State) ->
index 0ae04c64d468e5dc40fba1dc1e113b41f0782075..2497c2a74db2d1fe3a06d9bff6ba2482ae7077fc 100644 (file)
@@ -42,7 +42,8 @@
                 res_pos = 0,
                 server_host_used = false,
                 used_vars = [],
-                use_new_schema}).
+                use_new_schema,
+                need_array_pass = false}).
 
 -define(QUERY_RECORD, "sql_query").
 
@@ -183,12 +184,24 @@ transform_sql(Arg) ->
               Pos, no_server_host),
             []
     end,
-    set_pos(
-      make_schema_check(
-        make_sql_query(ParseRes),
-        make_sql_query(ParseResOld)
-       ),
-      Pos).
+    case ParseRes#state.need_array_pass of
+        true ->
+            {PR1, PR2} = perform_array_pass(ParseRes),
+            {PRO1, PRO2} = perform_array_pass(ParseResOld),
+            set_pos(make_schema_check(
+                    erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PR2)]),
+                                     erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PR1)])]),
+                    erl_syntax:list([erl_syntax:tuple([erl_syntax:atom(pgsql), make_sql_query(PRO2)]),
+                                     erl_syntax:tuple([erl_syntax:atom(any), make_sql_query(PRO1)])])),
+                Pos);
+        false ->
+            set_pos(
+                make_schema_check(
+                    make_sql_query(ParseRes),
+                    make_sql_query(ParseResOld)
+                ),
+                Pos)
+    end.
 
 transform_upsert(Form, TableArg, FieldsArg) ->
     Table = erl_syntax:string_value(TableArg),
@@ -315,8 +328,23 @@ parse1([$%, $( | S], Acc, State) ->
                         erl_syntax:atom(?ESCAPE_RECORD),
                         erl_syntax:atom(InternalType)),
                      erl_syntax:variable(Name)]),
-                State2#state{'query' = [{var, Var} | State2#state.'query'],
-                             args = [Convert | State2#state.args],
+                IT2 = case InternalType of
+                          string ->
+                              in_array_string;
+                          _ ->
+                              InternalType
+                      end,
+                ConvertArr = erl_syntax:application(
+                    erl_syntax:atom(ejabberd_sql),
+                    erl_syntax:atom(to_array),
+                    [erl_syntax:record_access(
+                        erl_syntax:variable(?ESCAPE_VAR),
+                        erl_syntax:atom(?ESCAPE_RECORD),
+                        erl_syntax:atom(IT2)),
+                     erl_syntax:variable(Name)]),
+                State2#state{'query' = [[{var, Var}] | State2#state.'query'],
+                             need_array_pass = true,
+                             args = [[Convert, ConvertArr] | State2#state.args],
                              params = [Var | State2#state.params],
                              param_pos = State2#state.param_pos + 1,
                              used_vars = [Name | State2#state.used_vars]};
@@ -389,6 +417,31 @@ make_var(V) ->
     Var = "__V" ++ integer_to_list(V),
     erl_syntax:variable(Var).
 
+perform_array_pass(State) ->
+    {NQ, PQ, Rest} = lists:foldl(
+        fun([{var, _} = Var], {N, P, {str, Str} = Prev}) ->
+            Str2 = re:replace(Str, "(^|\s+)in\s*$", " = any(", [{return, list}]),
+            {[Var, Prev | N], [{str, ")"}, Var, {str, Str2} | P], none};
+           ([{var, _}], _) ->
+               throw({error, State#state.loc, ["List variable not following 'in' operator"]});
+           (Other, {N, P, none}) ->
+               {N, P, Other};
+           (Other, {N, P, Prev}) ->
+               {[Prev | N], [Prev | P], Other}
+        end, {[], [], none}, State#state.query),
+    {NQ2, PQ2} = case Rest of
+                     none ->
+                         {NQ, PQ};
+                     _ -> {[Rest | NQ], [Rest | PQ]}
+                 end,
+    {NA, PA} = lists:foldl(
+        fun([V1, V2], {N, P}) ->
+            {[V1 | N], [V2 | P]};
+           (Other, {N, P}) ->
+               {[Other | N], [Other | P]}
+        end, {[], []}, State#state.args),
+    {State#state{query = lists:reverse(NQ2), args = lists:reverse(NA), need_array_pass = false},
+     State#state{query = lists:reverse(PQ2), args = lists:reverse(PA), need_array_pass = false}}.
 
 make_sql_query(State) ->
     Hash = erlang:phash2(State#state{loc = undefined, use_new_schema = true}),