-include("ejabberd.hrl").
-include("logger.hrl").
+-include("ejabberd_sql_pt.hrl").
-record(state,
{db_ref = self() :: pid(),
-define(KEEPALIVE_QUERY, [<<"SELECT 1;">>]).
+-define(PREPARE_KEY, ejabberd_odbc_prepare).
+
%%-define(DBGFSM, true).
-ifdef(DBGFSM).
[Host, StartInterval],
fsm_limit_opts() ++ (?FSMOPTS)).
--type sql_query() :: [sql_query() | binary()].
+-type sql_query() :: [sql_query() | binary()] | #sql_query{}.
-type sql_query_result() :: {updated, non_neg_integer()} |
{error, binary()} |
{selected, [binary()],
- [[binary()]]}.
+ [[binary()]]} |
+ {selected, [any]}.
-spec sql_query(binary(), sql_query()) -> sql_query_result().
Res -> {atomic, Res}
end.
+sql_query_internal(#sql_query{} = Query) ->
+ State = get(?STATE_KEY),
+ Res =
+ try
+ case State#state.db_type of
+ odbc ->
+ generic_sql_query(Query);
+ pgsql ->
+ Key = {?PREPARE_KEY, Query#sql_query.hash},
+ case get(Key) of
+ undefined ->
+ case pgsql_prepare(Query, State) of
+ {ok, _, _, _} ->
+ put(Key, prepared);
+ {error, Error} ->
+ ?ERROR_MSG("PREPARE failed for SQL query "
+ "at ~p: ~p",
+ [Query#sql_query.loc, Error]),
+ put(Key, ignore)
+ end;
+ _ ->
+ ok
+ end,
+ case get(Key) of
+ prepared ->
+ pgsql_execute_sql_query(Query, State);
+ _ ->
+ generic_sql_query(Query)
+ end;
+ mysql ->
+ generic_sql_query(Query);
+ sqlite ->
+ generic_sql_query(Query)
+ end
+ catch
+ Class:Reason ->
+ ST = erlang:get_stacktrace(),
+ ?ERROR_MSG("Internal error while processing SQL query: ~p",
+ [{Class, Reason, ST}]),
+ {error, <<"internal error">>}
+ end,
+ case Res of
+ {error, <<"No SQL-driver information available.">>} ->
+ {updated, 0};
+ _Else -> Res
+ end;
sql_query_internal(Query) ->
State = get(?STATE_KEY),
?DEBUG("SQL: \"~s\"", [Query]),
_Else -> Res
end.
+generic_sql_query(SQLQuery) ->
+ sql_query_format_res(
+ sql_query_internal(generic_sql_query_format(SQLQuery)),
+ SQLQuery).
+
+generic_sql_query_format(SQLQuery) ->
+ Args = (SQLQuery#sql_query.args)(generic_escape()),
+ (SQLQuery#sql_query.format_query)(Args).
+
+generic_escape() ->
+ #sql_escape{string = fun(X) -> <<"'", (escape(X))/binary, "'">> end,
+ integer = fun(X) -> integer_to_binary(X) end,
+ boolean = fun(true) -> <<"1">>;
+ (false) -> <<"0">>
+ end
+ }.
+
+pgsql_prepare(SQLQuery, State) ->
+ Escape = #sql_escape{_ = fun(X) -> X end},
+ N = length((SQLQuery#sql_query.args)(Escape)),
+ Args = [<<$$, (integer_to_binary(I))/binary>> || I <- lists:seq(1, N)],
+ Query = (SQLQuery#sql_query.format_query)(Args),
+ pgsql:prepare(State#state.db_ref, SQLQuery#sql_query.hash, Query).
+
+pgsql_execute_escape() ->
+ #sql_escape{string = fun(X) -> X end,
+ integer = fun(X) -> integer_to_binary(X) end,
+ boolean = fun(true) -> <<"1">>;
+ (false) -> <<"0">>
+ end
+ }.
+
+pgsql_execute_sql_query(SQLQuery, State) ->
+ Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()),
+ ExecuteRes =
+ pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args),
+ Res = pgsql_execute_to_odbc(ExecuteRes),
+ sql_query_format_res(Res, SQLQuery).
+
+
+sql_query_format_res({selected, _, Rows}, SQLQuery) ->
+ Res =
+ lists:flatmap(
+ fun(Row) ->
+ try
+ [(SQLQuery#sql_query.format_res)(Row)]
+ catch
+ Class:Reason ->
+ ST = erlang:get_stacktrace(),
+ ?ERROR_MSG("Error while processing "
+ "SQL query result: ~p~n"
+ "row: ~p",
+ [{Class, Reason, ST}, Row]),
+ []
+ end
+ end, Rows),
+ {selected, Res};
+sql_query_format_res(Res, _SQLQuery) ->
+ Res.
+
%% Generate the OTP callback return tuple depending on the driver result.
abort_on_driver_error({error, <<"query timed out">>} =
Reply,
pgsql_item_to_odbc({error, Error}) -> {error, Error};
pgsql_item_to_odbc(_) -> {updated, undefined}.
+pgsql_execute_to_odbc({ok, {<<"SELECT", _/binary>>, Rows}}) ->
+ {selected, [], [[Field || {_, Field} <- Row] || Row <- Rows]};
+pgsql_execute_to_odbc({ok, {'INSERT', N}}) ->
+ {updated, N};
+pgsql_execute_to_odbc({ok, {'DELETE', N}}) ->
+ {updated, N};
+pgsql_execute_to_odbc({ok, {'UPDATE', N}}) ->
+ {updated, N};
+pgsql_execute_to_odbc({error, Error}) -> {error, Error};
+pgsql_execute_to_odbc(_) -> {updated, undefined}.
+
+
%% == Native MySQL code
%% part of init/1
_ -> []
end.
+check_error({error, Why} = Err, #sql_query{} = Query) ->
+ ?ERROR_MSG("SQL query '~s' at ~p failed: ~p",
+ [Query#sql_query.hash, Query#sql_query.loc, Why]),
+ Err;
check_error({error, Why} = Err, Query) ->
?ERROR_MSG("SQL query '~s' failed: ~p", [Query, Why]),
Err;
--- /dev/null
+%%%-------------------------------------------------------------------
+%%% File : ejabberd_sql_pt.erl
+%%% Author : Alexey Shchepin <alexey@process-one.net>
+%%% Description : Parse transform for SQL queries
+%%%
+%%% Created : 20 Jan 2016 by Alexey Shchepin <alexey@process-one.net>
+%%%-------------------------------------------------------------------
+-module(ejabberd_sql_pt).
+
+%% API
+-export([parse_transform/2]).
+
+-export([parse/2]).
+
+-include("ejabberd_sql_pt.hrl").
+
+-record(state, {loc,
+ 'query' = [],
+ params = [],
+ param_pos = 0,
+ args = [],
+ res = [],
+ res_vars = [],
+ res_pos = 0}).
+
+-define(QUERY_RECORD, "sql_query").
+
+-define(ESCAPE_RECORD, "sql_escape").
+-define(ESCAPE_VAR, "__SQLEscape").
+
+-define(MOD, sql__module_).
+
+%%====================================================================
+%% API
+%%====================================================================
+%%--------------------------------------------------------------------
+%% Function:
+%% Description:
+%%--------------------------------------------------------------------
+parse_transform(AST, _Options) ->
+ %io:format("PT: ~p~nOpts: ~p~n", [AST, Options]),
+ NewAST = top_transform(AST),
+ %io:format("NewPT: ~p~n", [NewAST]),
+ NewAST.
+
+
+
+%%====================================================================
+%% Internal functions
+%%====================================================================
+
+
+transform(Form) ->
+ case erl_syntax:type(Form) of
+ application ->
+ case erl_syntax_lib:analyze_application(Form) of
+ {?SQL_MARK, 1} ->
+ case erl_syntax:application_arguments(Form) of
+ [Arg] ->
+ case erl_syntax:type(Arg) of
+ string ->
+ S = erl_syntax:string_value(Arg),
+ ParseRes =
+ parse(S, erl_syntax:get_pos(Arg)),
+ make_sql_query(ParseRes);
+ _ ->
+ throw({error, erl_syntax:get_pos(Form),
+ "?SQL argument must be "
+ "a constant string"})
+ end;
+ _ ->
+ throw({error, erl_syntax:get_pos(Form),
+ "wrong number of ?SQL args"})
+ end;
+ _ ->
+ Form
+ end;
+ attribute ->
+ case erl_syntax:atom_value(erl_syntax:attribute_name(Form)) of
+ module ->
+ case erl_syntax:attribute_arguments(Form) of
+ [M | _] ->
+ Module = erl_syntax:atom_value(M),
+ %io:format("module ~p~n", [Module]),
+ put(?MOD, Module),
+ Form;
+ _ ->
+ Form
+ end;
+ _ ->
+ Form
+ end;
+ _ ->
+ Form
+ end.
+
+top_transform(Forms) when is_list(Forms) ->
+ lists:map(
+ fun(Form) ->
+ try
+ Form2 = erl_syntax_lib:map(
+ fun(Node) ->
+ %io:format("asd ~p~n", [Node]),
+ transform(Node)
+ end, Form),
+ Form3 = erl_syntax:revert(Form2),
+ Form3
+ catch
+ throw:{error, Line, Error} ->
+ {error, {Line, erl_parse, Error}}
+ end
+ end, Forms).
+
+parse(S, Loc) ->
+ parse1(S, [], #state{loc = Loc}).
+
+parse1([], Acc, State) ->
+ State1 = append_string(lists:reverse(Acc), State),
+ State1#state{'query' = lists:reverse(State1#state.'query'),
+ params = lists:reverse(State1#state.params),
+ args = lists:reverse(State1#state.args),
+ res = lists:reverse(State1#state.res),
+ res_vars = lists:reverse(State1#state.res_vars)
+ };
+parse1([$@, $( | S], Acc, State) ->
+ State1 = append_string(lists:reverse(Acc), State),
+ {Name, Type, S1, State2} = parse_name(S, State1),
+ Var = "__V" ++ integer_to_list(State2#state.res_pos),
+ EVar = erl_syntax:variable(Var),
+ Convert =
+ case Type of
+ integer ->
+ erl_syntax:application(
+ erl_syntax:atom(binary_to_integer),
+ [EVar]);
+ string ->
+ EVar;
+ boolean ->
+ erl_syntax:application(
+ erl_syntax:atom(ejabberd_odbc),
+ erl_syntax:atom(to_bool),
+ [EVar])
+ end,
+ State3 = append_string(Name, State2),
+ State4 = State3#state{res_pos = State3#state.res_pos + 1,
+ res = [Convert | State3#state.res],
+ res_vars = [EVar | State3#state.res_vars]},
+ parse1(S1, [], State4);
+parse1([$%, $( | S], Acc, State) ->
+ State1 = append_string(lists:reverse(Acc), State),
+ {Name, Type, S1, State2} = parse_name(S, State1),
+ Var = "__V" ++ integer_to_list(State2#state.param_pos),
+ EVar = erl_syntax:variable(Var),
+ Convert =
+ erl_syntax:application(
+ erl_syntax:record_access(
+ erl_syntax:variable(?ESCAPE_VAR),
+ erl_syntax:atom(?ESCAPE_RECORD),
+ erl_syntax:atom(Type)),
+ [erl_syntax:variable(Name)]),
+ State3 = State2,
+ State4 =
+ State3#state{'query' = [{var, EVar} | State3#state.'query'],
+ args = [Convert | State3#state.args],
+ params = [EVar | State3#state.params],
+ param_pos = State3#state.param_pos + 1},
+ parse1(S1, [], State4);
+parse1([C | S], Acc, State) ->
+ parse1(S, [C | Acc], State).
+
+append_string([], State) ->
+ State;
+append_string(S, State) ->
+ State#state{query = [{str, S} | State#state.query]}.
+
+parse_name(S, State) ->
+ parse_name(S, [], State).
+
+parse_name([], Acc, State) ->
+ % todo
+ error;
+parse_name([$), T | S], Acc, State) ->
+ Type =
+ case T of
+ $d -> integer;
+ $s -> string;
+ $b -> boolean;
+ _ ->
+ % todo
+ error
+ end,
+ {lists:reverse(Acc), Type, S, State};
+parse_name([$) | _], Acc, State) ->
+ % todo
+ error;
+parse_name([C | S], Acc, State) ->
+ parse_name(S, [C | Acc], State).
+
+
+make_sql_query(State) ->
+ Hash = erlang:phash2(State#state{loc = undefined}),
+ SHash = <<"Q", (integer_to_binary(Hash))/binary>>,
+ Query = pack_query(State#state.'query'),
+ EQuery =
+ lists:map(
+ fun({str, S}) ->
+ erl_syntax:binary(
+ [erl_syntax:binary_field(
+ erl_syntax:string(S))]);
+ ({var, V}) -> V
+ end, Query),
+ erl_syntax:record_expr(
+ erl_syntax:atom(?QUERY_RECORD),
+ [erl_syntax:record_field(
+ erl_syntax:atom(hash),
+ %erl_syntax:abstract(SHash)
+ erl_syntax:binary(
+ [erl_syntax:binary_field(
+ erl_syntax:string(binary_to_list(SHash)))])),
+ erl_syntax:record_field(
+ erl_syntax:atom(args),
+ erl_syntax:fun_expr(
+ [erl_syntax:clause(
+ [erl_syntax:variable(?ESCAPE_VAR)],
+ none,
+ [erl_syntax:list(State#state.args)]
+ )])),
+ erl_syntax:record_field(
+ erl_syntax:atom(format_query),
+ erl_syntax:fun_expr(
+ [erl_syntax:clause(
+ [erl_syntax:list(State#state.params)],
+ none,
+ [erl_syntax:list(EQuery)]
+ )])),
+ erl_syntax:record_field(
+ erl_syntax:atom(format_res),
+ erl_syntax:fun_expr(
+ [erl_syntax:clause(
+ [erl_syntax:list(State#state.res_vars)],
+ none,
+ [erl_syntax:tuple(State#state.res)]
+ )])),
+ erl_syntax:record_field(
+ erl_syntax:atom(loc),
+ erl_syntax:abstract({get(?MOD), State#state.loc}))
+ ]).
+
+pack_query([]) ->
+ [];
+pack_query([{str, S1}, {str, S2} | Rest]) ->
+ pack_query([{str, S1 ++ S2} | Rest]);
+pack_query([X | Rest]) ->
+ [X | pack_query(Rest)].
+