]> granicus.if.org Git - ejabberd/commitdiff
New parse transform for SQL queries, use prepare/execute calls with Postgres
authorAlexey Shchepin <alexey@process-one.net>
Tue, 9 Feb 2016 16:23:15 +0000 (19:23 +0300)
committerAlexey Shchepin <alexey@process-one.net>
Tue, 1 Mar 2016 19:48:30 +0000 (22:48 +0300)
include/ejabberd_sql_pt.hrl [new file with mode: 0644]
src/ejabberd_odbc.erl
src/ejabberd_sql_pt.erl [new file with mode: 0644]

diff --git a/include/ejabberd_sql_pt.hrl b/include/ejabberd_sql_pt.hrl
new file mode 100644 (file)
index 0000000..ca6df9e
--- /dev/null
@@ -0,0 +1,27 @@
+%%%----------------------------------------------------------------------
+%%%
+%%% ejabberd, Copyright (C) 2002-2016   ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
+%%%----------------------------------------------------------------------
+
+-define(SQL_MARK, sql__mark_).
+-define(SQL(SQL), ?SQL_MARK(SQL)).
+
+-record(sql_query, {hash, format_query, format_res, args, loc}).
+
+-record(sql_escape, {string, integer, boolean}).
+
index a15c66b5d29e86984b6890f8a15bd23f66840259..ef3c61d0a2006111ede438d05434a2ffc143a768 100644 (file)
@@ -63,6 +63,7 @@
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
+-include("ejabberd_sql_pt.hrl").
 
 -record(state,
        {db_ref = self()                     :: pid(),
@@ -92,6 +93,8 @@
 
 -define(KEEPALIVE_QUERY, [<<"SELECT 1;">>]).
 
+-define(PREPARE_KEY, ejabberd_odbc_prepare).
+
 %%-define(DBGFSM, true).
 
 -ifdef(DBGFSM).
@@ -116,11 +119,12 @@ start_link(Host, StartInterval) ->
                          [Host, StartInterval],
                          fsm_limit_opts() ++ (?FSMOPTS)).
 
--type sql_query() :: [sql_query() | binary()].
+-type sql_query() :: [sql_query() | binary()] | #sql_query{}.
 -type sql_query_result() :: {updated, non_neg_integer()} |
                             {error, binary()} |
                             {selected, [binary()],
-                             [[binary()]]}.
+                             [[binary()]]} |
+                            {selected, [any]}.
 
 -spec sql_query(binary(), sql_query()) -> sql_query_result().
 
@@ -469,6 +473,52 @@ execute_bloc(F) ->
       Res -> {atomic, Res}
     end.
 
+sql_query_internal(#sql_query{} = Query) ->
+    State = get(?STATE_KEY),
+    Res =
+        try
+            case State#state.db_type of
+                odbc ->
+                    generic_sql_query(Query);
+                pgsql ->
+                    Key = {?PREPARE_KEY, Query#sql_query.hash},
+                    case get(Key) of
+                        undefined ->
+                            case pgsql_prepare(Query, State) of
+                                {ok, _, _, _} ->
+                                    put(Key, prepared);
+                                {error, Error} ->
+                                    ?ERROR_MSG("PREPARE failed for SQL query "
+                                               "at ~p: ~p",
+                                               [Query#sql_query.loc, Error]),
+                                    put(Key, ignore)
+                            end;
+                        _ ->
+                            ok
+                    end,
+                    case get(Key) of
+                        prepared ->
+                            pgsql_execute_sql_query(Query, State);
+                        _ ->
+                            generic_sql_query(Query)
+                    end;
+                mysql ->
+                    generic_sql_query(Query);
+                sqlite ->
+                    generic_sql_query(Query)
+            end
+        catch
+            Class:Reason ->
+                ST = erlang:get_stacktrace(),
+                ?ERROR_MSG("Internal error while processing SQL query: ~p",
+                           [{Class, Reason, ST}]),
+                {error, <<"internal error">>}
+        end,
+    case Res of
+        {error, <<"No SQL-driver information available.">>} ->
+            {updated, 0};
+        _Else -> Res
+    end;
 sql_query_internal(Query) ->
     State = get(?STATE_KEY),
     ?DEBUG("SQL: \"~s\"", [Query]),
@@ -495,6 +545,66 @@ sql_query_internal(Query) ->
       _Else -> Res
     end.
 
+generic_sql_query(SQLQuery) ->
+    sql_query_format_res(
+      sql_query_internal(generic_sql_query_format(SQLQuery)),
+      SQLQuery).
+
+generic_sql_query_format(SQLQuery) ->
+    Args = (SQLQuery#sql_query.args)(generic_escape()),
+    (SQLQuery#sql_query.format_query)(Args).
+
+generic_escape() ->
+    #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
+                integer = fun(X) -> integer_to_binary(X) end,
+                boolean = fun(true) -> <<"1">>;
+                             (false) -> <<"0">>
+                          end
+               }.
+
+pgsql_prepare(SQLQuery, State) ->
+    Escape = #sql_escape{_ = fun(X) -> X end},
+    N = length((SQLQuery#sql_query.args)(Escape)),
+    Args = [<<$$, (integer_to_binary(I))/binary>> || I <- lists:seq(1, N)],
+    Query = (SQLQuery#sql_query.format_query)(Args),
+    pgsql:prepare(State#state.db_ref, SQLQuery#sql_query.hash, Query).
+
+pgsql_execute_escape() ->
+    #sql_escape{string = fun(X) -> X end,
+                integer = fun(X) -> integer_to_binary(X) end,
+                boolean = fun(true) -> <<"1">>;
+                             (false) -> <<"0">>
+                          end
+               }.
+
+pgsql_execute_sql_query(SQLQuery, State) ->
+    Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()),
+    ExecuteRes =
+        pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args),
+    Res = pgsql_execute_to_odbc(ExecuteRes),
+    sql_query_format_res(Res, SQLQuery).
+
+
+sql_query_format_res({selected, _, Rows}, SQLQuery) ->
+    Res =
+        lists:flatmap(
+          fun(Row) ->
+                  try
+                      [(SQLQuery#sql_query.format_res)(Row)]
+                  catch
+                      Class:Reason ->
+                          ST = erlang:get_stacktrace(),
+                          ?ERROR_MSG("Error while processing "
+                                     "SQL query result: ~p~n"
+                                     "row: ~p",
+                                     [{Class, Reason, ST}, Row]),
+                          []
+                  end
+          end, Rows),
+    {selected, Res};
+sql_query_format_res(Res, _SQLQuery) ->
+    Res.
+
 %% Generate the OTP callback return tuple depending on the driver result.
 abort_on_driver_error({error, <<"query timed out">>} =
                          Reply,
@@ -606,6 +716,18 @@ pgsql_item_to_odbc(<<"UPDATE ", N/binary>>) ->
 pgsql_item_to_odbc({error, Error}) -> {error, Error};
 pgsql_item_to_odbc(_) -> {updated, undefined}.
 
+pgsql_execute_to_odbc({ok, {<<"SELECT", _/binary>>, Rows}}) ->
+    {selected, [], [[Field || {_, Field} <- Row] || Row <- Rows]};
+pgsql_execute_to_odbc({ok, {'INSERT', N}}) ->
+    {updated, N};
+pgsql_execute_to_odbc({ok, {'DELETE', N}}) ->
+    {updated, N};
+pgsql_execute_to_odbc({ok, {'UPDATE', N}}) ->
+    {updated, N};
+pgsql_execute_to_odbc({error, Error}) -> {error, Error};
+pgsql_execute_to_odbc(_) -> {updated, undefined}.
+
+
 %% == Native MySQL code
 
 %% part of init/1
@@ -800,6 +922,10 @@ fsm_limit_opts() ->
       _ -> []
     end.
 
+check_error({error, Why} = Err, #sql_query{} = Query) ->
+    ?ERROR_MSG("SQL query '~s' at ~p failed: ~p",
+               [Query#sql_query.hash, Query#sql_query.loc, Why]),
+    Err;
 check_error({error, Why} = Err, Query) ->
     ?ERROR_MSG("SQL query '~s' failed: ~p", [Query, Why]),
     Err;
diff --git a/src/ejabberd_sql_pt.erl b/src/ejabberd_sql_pt.erl
new file mode 100644 (file)
index 0000000..f9701a0
--- /dev/null
@@ -0,0 +1,255 @@
+%%%-------------------------------------------------------------------
+%%% File    : ejabberd_sql_pt.erl
+%%% Author  : Alexey Shchepin <alexey@process-one.net>
+%%% Description : Parse transform for SQL queries
+%%%
+%%% Created : 20 Jan 2016 by Alexey Shchepin <alexey@process-one.net>
+%%%-------------------------------------------------------------------
+-module(ejabberd_sql_pt).
+
+%% API
+-export([parse_transform/2]).
+
+-export([parse/2]).
+
+-include("ejabberd_sql_pt.hrl").
+
+-record(state, {loc,
+                'query' = [],
+                params = [],
+                param_pos = 0,
+                args = [],
+                res = [],
+                res_vars = [],
+                res_pos = 0}).
+
+-define(QUERY_RECORD, "sql_query").
+
+-define(ESCAPE_RECORD, "sql_escape").
+-define(ESCAPE_VAR, "__SQLEscape").
+
+-define(MOD, sql__module_).
+
+%%====================================================================
+%% API
+%%====================================================================
+%%--------------------------------------------------------------------
+%% Function:
+%% Description:
+%%--------------------------------------------------------------------
+parse_transform(AST, _Options) ->
+    %io:format("PT: ~p~nOpts: ~p~n", [AST, Options]),
+    NewAST = top_transform(AST),
+    %io:format("NewPT: ~p~n", [NewAST]),
+    NewAST.
+
+
+
+%%====================================================================
+%% Internal functions
+%%====================================================================
+
+
+transform(Form) ->
+    case erl_syntax:type(Form) of
+        application ->
+            case erl_syntax_lib:analyze_application(Form) of
+                {?SQL_MARK, 1} ->
+                    case erl_syntax:application_arguments(Form) of
+                        [Arg] ->
+                            case erl_syntax:type(Arg) of
+                                string ->
+                                    S = erl_syntax:string_value(Arg),
+                                    ParseRes =
+                                        parse(S, erl_syntax:get_pos(Arg)),
+                                    make_sql_query(ParseRes);
+                                _ ->
+                                    throw({error, erl_syntax:get_pos(Form),
+                                           "?SQL argument must be "
+                                           "a constant string"})
+                            end;
+                        _ ->
+                            throw({error, erl_syntax:get_pos(Form),
+                                   "wrong number of ?SQL args"})
+                    end;
+                _ ->
+                    Form
+            end;
+        attribute ->
+            case erl_syntax:atom_value(erl_syntax:attribute_name(Form)) of
+                module ->
+                    case erl_syntax:attribute_arguments(Form) of
+                        [M | _] ->
+                            Module = erl_syntax:atom_value(M),
+                            %io:format("module ~p~n", [Module]),
+                            put(?MOD, Module),
+                            Form;
+                        _ ->
+                            Form
+                    end;
+                _ ->
+                    Form
+            end;
+        _ ->
+            Form
+    end.
+
+top_transform(Forms) when is_list(Forms) ->
+    lists:map(
+      fun(Form) ->
+              try
+                  Form2 = erl_syntax_lib:map(
+                            fun(Node) ->
+                                                %io:format("asd ~p~n", [Node]),
+                                    transform(Node)
+                            end, Form),
+                  Form3 = erl_syntax:revert(Form2),
+                  Form3
+             catch
+                 throw:{error, Line, Error} ->
+                     {error, {Line, erl_parse, Error}}
+             end
+      end, Forms).
+
+parse(S, Loc) ->
+    parse1(S, [], #state{loc = Loc}).
+
+parse1([], Acc, State) ->
+    State1 = append_string(lists:reverse(Acc), State),
+    State1#state{'query' = lists:reverse(State1#state.'query'),
+                 params = lists:reverse(State1#state.params),
+                 args = lists:reverse(State1#state.args),
+                 res = lists:reverse(State1#state.res),
+                 res_vars = lists:reverse(State1#state.res_vars)
+                };
+parse1([$@, $( | S], Acc, State) ->
+    State1 = append_string(lists:reverse(Acc), State),
+    {Name, Type, S1, State2} = parse_name(S, State1),
+    Var = "__V" ++ integer_to_list(State2#state.res_pos),
+    EVar = erl_syntax:variable(Var),
+    Convert =
+        case Type of
+            integer ->
+                erl_syntax:application(
+                  erl_syntax:atom(binary_to_integer),
+                  [EVar]);
+            string ->
+                EVar;
+            boolean ->
+                erl_syntax:application(
+                  erl_syntax:atom(ejabberd_odbc),
+                  erl_syntax:atom(to_bool),
+                  [EVar])
+        end,
+    State3 = append_string(Name, State2),
+    State4 = State3#state{res_pos = State3#state.res_pos + 1,
+                          res = [Convert | State3#state.res],
+                          res_vars = [EVar | State3#state.res_vars]},
+    parse1(S1, [], State4);
+parse1([$%, $( | S], Acc, State) ->
+    State1 = append_string(lists:reverse(Acc), State),
+    {Name, Type, S1, State2} = parse_name(S, State1),
+    Var = "__V" ++ integer_to_list(State2#state.param_pos),
+    EVar = erl_syntax:variable(Var),
+    Convert =
+        erl_syntax:application(
+          erl_syntax:record_access(
+            erl_syntax:variable(?ESCAPE_VAR),
+            erl_syntax:atom(?ESCAPE_RECORD),
+            erl_syntax:atom(Type)),
+          [erl_syntax:variable(Name)]),
+    State3 = State2,
+    State4 =
+        State3#state{'query' = [{var, EVar} | State3#state.'query'],
+                     args = [Convert | State3#state.args],
+                     params = [EVar | State3#state.params],
+                     param_pos = State3#state.param_pos + 1},
+    parse1(S1, [], State4);
+parse1([C | S], Acc, State) ->
+    parse1(S, [C | Acc], State).
+
+append_string([], State) ->
+    State;
+append_string(S, State) ->
+    State#state{query = [{str, S} | State#state.query]}.
+
+parse_name(S, State) ->
+    parse_name(S, [], State).
+
+parse_name([], Acc, State) ->
+                                                % todo
+    error;
+parse_name([$), T | S], Acc, State) ->
+    Type =
+        case T of
+            $d -> integer;
+            $s -> string;
+            $b -> boolean;
+            _ ->
+                                                % todo
+                error
+        end,
+    {lists:reverse(Acc), Type, S, State};
+parse_name([$) | _], Acc, State) ->
+                                                % todo
+    error;
+parse_name([C | S], Acc, State) ->
+    parse_name(S, [C | Acc], State).
+
+
+make_sql_query(State) ->
+    Hash = erlang:phash2(State#state{loc = undefined}),
+    SHash = <<"Q", (integer_to_binary(Hash))/binary>>,
+    Query = pack_query(State#state.'query'),
+    EQuery =
+        lists:map(
+          fun({str, S}) ->
+                  erl_syntax:binary(
+                    [erl_syntax:binary_field(
+                       erl_syntax:string(S))]);
+             ({var, V}) -> V
+          end, Query),
+    erl_syntax:record_expr(
+     erl_syntax:atom(?QUERY_RECORD),
+     [erl_syntax:record_field(
+       erl_syntax:atom(hash),
+        %erl_syntax:abstract(SHash)
+        erl_syntax:binary(
+          [erl_syntax:binary_field(
+             erl_syntax:string(binary_to_list(SHash)))])),
+      erl_syntax:record_field(
+       erl_syntax:atom(args),
+        erl_syntax:fun_expr(
+          [erl_syntax:clause(
+             [erl_syntax:variable(?ESCAPE_VAR)],
+             none,
+             [erl_syntax:list(State#state.args)]
+            )])),
+      erl_syntax:record_field(
+       erl_syntax:atom(format_query),
+        erl_syntax:fun_expr(
+          [erl_syntax:clause(
+             [erl_syntax:list(State#state.params)],
+             none,
+             [erl_syntax:list(EQuery)]
+            )])),
+      erl_syntax:record_field(
+       erl_syntax:atom(format_res),
+        erl_syntax:fun_expr(
+          [erl_syntax:clause(
+             [erl_syntax:list(State#state.res_vars)],
+             none,
+             [erl_syntax:tuple(State#state.res)]
+            )])),
+      erl_syntax:record_field(
+       erl_syntax:atom(loc),
+        erl_syntax:abstract({get(?MOD), State#state.loc}))
+     ]).
+
+pack_query([]) ->
+    [];
+pack_query([{str, S1}, {str, S2} | Rest]) ->
+    pack_query([{str, S1 ++ S2} | Rest]);
+pack_query([X | Rest]) ->
+    [X | pack_query(Rest)].
+