]> granicus.if.org Git - ejabberd/commitdiff
New parse transform for ?SQL_UPSERT and ?SQL_UPSERT_T
authorAlexey Shchepin <alexey@process-one.net>
Thu, 18 Feb 2016 14:38:25 +0000 (17:38 +0300)
committerAlexey Shchepin <alexey@process-one.net>
Tue, 1 Mar 2016 21:12:49 +0000 (00:12 +0300)
include/ejabberd_sql_pt.hrl
src/ejabberd_odbc.erl
src/ejabberd_sql_pt.erl

index ca6df9ec9eb38d9fb3b177bac57775c7422fab04..f189fdcf61fc93e3304e746e92755eceaa79535b 100644 (file)
 -define(SQL_MARK, sql__mark_).
 -define(SQL(SQL), ?SQL_MARK(SQL)).
 
+-define(SQL_UPSERT_MARK, sql_upsert__mark_).
+-define(SQL_UPSERT(Host, Table, Fields),
+        ejabberd_odbc:sql_query(Host, ?SQL_UPSERT_MARK(Table, Fields))).
+-define(SQL_UPSERT_T(Table, Fields),
+        ejabberd_odbc:sql_query_t(Host, ?SQL_UPSERT_MARK(Table, Fields))).
+
 -record(sql_query, {hash, format_query, format_res, args, loc}).
 
 -record(sql_escape, {string, integer, boolean}).
index b430d920a4499c4edc6873e4e95e4743607188e0..4f818f51323ed8ab89c15b3fc0cbbbaf1f6513ca 100644 (file)
@@ -475,6 +475,12 @@ execute_bloc(F) ->
       Res -> {atomic, Res}
     end.
 
+execute_fun(F) when is_function(F, 0) ->
+    F();
+execute_fun(F) when is_function(F, 2) ->
+    State = get(?STATE_KEY),
+    F(State#state.db_type, State#state.db_version).
+
 sql_query_internal([{_, _} | _] = Queries) ->
     State = get(?STATE_KEY),
     case select_sql_query(Queries, State) of
@@ -529,6 +535,11 @@ sql_query_internal(#sql_query{} = Query) ->
             {updated, 0};
         _Else -> Res
     end;
+sql_query_internal(F) when is_function(F) ->
+    case catch execute_fun(F) of
+        {'EXIT', Reason} -> {error, Reason};
+        Res -> Res
+    end;
 sql_query_internal(Query) ->
     State = get(?STATE_KEY),
     ?DEBUG("SQL: \"~s\"", [Query]),
@@ -615,6 +626,9 @@ pgsql_execute_sql_query(SQLQuery, State) ->
     Args = (SQLQuery#sql_query.args)(pgsql_execute_escape()),
     ExecuteRes =
         pgsql:execute(State#state.db_ref, SQLQuery#sql_query.hash, Args),
+%    {T, ExecuteRes} =
+%        timer:tc(pgsql, execute, [State#state.db_ref, SQLQuery#sql_query.hash, Args]),
+%    io:format("T ~s ~p~n", [SQLQuery#sql_query.hash, T]),
     Res = pgsql_execute_to_odbc(ExecuteRes),
     sql_query_format_res(Res, SQLQuery).
 
index 6b26cbcd66a738cb091eb30a54db7168ba505b5a..cb7a82e0f1e89f523e14cac8e516fef1830268bd 100644 (file)
@@ -72,6 +72,26 @@ transform(Form) ->
                             throw({error, erl_syntax:get_pos(Form),
                                    "wrong number of ?SQL args"})
                     end;
+                {?SQL_UPSERT_MARK, 2} ->
+                    case erl_syntax:application_arguments(Form) of
+                        [TableArg, FieldsArg] ->
+                            case {erl_syntax:type(TableArg),
+                                  erl_syntax:is_proper_list(FieldsArg)}of
+                                {string, true} ->
+                                    Table = erl_syntax:string_value(TableArg),
+                                    ParseRes =
+                                        parse_upsert(
+                                          erl_syntax:list_elements(FieldsArg)),
+                                    make_sql_upsert(Table, ParseRes);
+                                _ ->
+                                    throw({error, erl_syntax:get_pos(Form),
+                                           "?SQL_UPSERT arguments must be "
+                                           "a constant string and a list"})
+                            end;
+                        _ ->
+                            throw({error, erl_syntax:get_pos(Form),
+                                   "wrong number of ?SQL_UPSERT args"})
+                    end;
                 _ ->
                     Form
             end;
@@ -114,6 +134,9 @@ top_transform(Forms) when is_list(Forms) ->
 parse(S, Loc) ->
     parse1(S, [], #state{loc = Loc}).
 
+parse(S, ParamPos, Loc) ->
+    parse1(S, [], #state{loc = Loc, param_pos = ParamPos}).
+
 parse1([], Acc, State) ->
     State1 = append_string(lists:reverse(Acc), State),
     State1#state{'query' = lists:reverse(State1#state.'query'),
@@ -149,8 +172,7 @@ parse1([$@, $( | S], Acc, State) ->
 parse1([$%, $( | S], Acc, State) ->
     State1 = append_string(lists:reverse(Acc), State),
     {Name, Type, S1, State2} = parse_name(S, State1),
-    Var = "__V" ++ integer_to_list(State2#state.param_pos),
-    EVar = erl_syntax:variable(Var),
+    Var = State2#state.param_pos,
     Convert =
         erl_syntax:application(
           erl_syntax:record_access(
@@ -160,9 +182,9 @@ parse1([$%, $( | S], Acc, State) ->
           [erl_syntax:variable(Name)]),
     State3 = State2,
     State4 =
-        State3#state{'query' = [{var, EVar} | State3#state.'query'],
+        State3#state{'query' = [{var, Var} | State3#state.'query'],
                      args = [Convert | State3#state.args],
-                     params = [EVar | State3#state.params],
+                     params = [Var | State3#state.params],
                      param_pos = State3#state.param_pos + 1},
     parse1(S1, [], State4);
 parse1([C | S], Acc, State) ->
@@ -190,7 +212,7 @@ parse_name([$), T | S], Acc, 0, State) ->
                        ["unknown type specifier '", T, "'"]})
         end,
     {lists:reverse(Acc), Type, S, State};
-parse_name([$)], Acc, 0, State) ->
+parse_name([$)], _Acc, 0, State) ->
     throw({error, State#state.loc,
            "expected type specifier, found end of string"});
 parse_name([$( = C | S], Acc, Depth, State) ->
@@ -201,6 +223,11 @@ parse_name([C | S], Acc, Depth, State) ->
     parse_name(S, [C | Acc], Depth, State).
 
 
+make_var(V) ->
+    Var = "__V" ++ integer_to_list(V),
+    erl_syntax:variable(Var).
+
+
 make_sql_query(State) ->
     Hash = erlang:phash2(State#state{loc = undefined}),
     SHash = <<"Q", (integer_to_binary(Hash))/binary>>,
@@ -211,7 +238,7 @@ make_sql_query(State) ->
                   erl_syntax:binary(
                     [erl_syntax:binary_field(
                        erl_syntax:string(S))]);
-             ({var, V}) -> V
+             ({var, V}) -> make_var(V)
           end, Query),
     erl_syntax:record_expr(
      erl_syntax:atom(?QUERY_RECORD),
@@ -233,7 +260,7 @@ make_sql_query(State) ->
        erl_syntax:atom(format_query),
         erl_syntax:fun_expr(
           [erl_syntax:clause(
-             [erl_syntax:list(State#state.params)],
+             [erl_syntax:list(lists:map(fun make_var/1, State#state.params))],
              none,
              [erl_syntax:list(EQuery)]
             )])),
@@ -257,3 +284,232 @@ pack_query([{str, S1}, {str, S2} | Rest]) ->
 pack_query([X | Rest]) ->
     [X | pack_query(Rest)].
 
+
+parse_upsert(Fields) ->
+    {Fs, _} =
+        lists:foldr(
+          fun(F, {Acc, Param}) ->
+                  case erl_syntax:type(F) of
+                      string ->
+                          V = erl_syntax:string_value(F),
+                          {_, _, State} = Res =
+                              parse_upsert_field(
+                                V, Param, erl_syntax:get_pos(F)),
+                          {[Res | Acc], State#state.param_pos};
+                      _ ->
+                          throw({error, erl_syntax:get_pos(F),
+                                 "?SQL_UPSERT field must be "
+                                 "a constant string"})
+                  end
+          end, {[], 0}, Fields),
+    %io:format("asd ~p~n", [{Fields, Fs}]),
+    Fs.
+
+parse_upsert_field([$! | S], ParamPos, Loc) ->
+    {Name, ParseState} = parse_upsert_field1(S, [], ParamPos, Loc),
+    {Name, true, ParseState};
+parse_upsert_field(S, ParamPos, Loc) ->
+    {Name, ParseState} = parse_upsert_field1(S, [], ParamPos, Loc),
+    {Name, false, ParseState}.
+
+parse_upsert_field1([], _Acc, _ParamPos, Loc) ->
+    throw({error, Loc,
+           "?SQL_UPSERT fields must have the "
+           "following form: \"[!]name=value\""});
+parse_upsert_field1([$= | S], Acc, ParamPos, Loc) ->
+    {lists:reverse(Acc), parse(S, ParamPos, Loc)};
+parse_upsert_field1([C | S], Acc, ParamPos, Loc) ->
+    parse_upsert_field1(S, [C | Acc], ParamPos, Loc).
+
+
+make_sql_upsert(Table, ParseRes) ->
+    erl_syntax:fun_expr(
+      [erl_syntax:clause(
+         [erl_syntax:atom(pgsql), erl_syntax:variable("__Version")],
+         [erl_syntax:infix_expr(
+            erl_syntax:variable("__Version"),
+            erl_syntax:operator('>='),
+            erl_syntax:integer(90100))],
+         [make_sql_upsert_pgsql901(Table, ParseRes),
+          erl_syntax:atom(ok)]),
+       erl_syntax:clause(
+         [erl_syntax:underscore(), erl_syntax:underscore()],
+         none,
+         [make_sql_upsert_generic(Table, ParseRes)])
+      ]).
+
+make_sql_upsert_generic(Table, ParseRes) ->
+    Update = make_sql_query(make_sql_upsert_update(Table, ParseRes)),
+    Insert = make_sql_query(make_sql_upsert_insert(Table, ParseRes)),
+    InsertBranch =
+        erl_syntax:case_expr(
+          erl_syntax:application(
+            erl_syntax:atom(ejabberd_odbc),
+            erl_syntax:atom(sql_query_t),
+            [Insert]),
+          [erl_syntax:clause(
+             [erl_syntax:abstract({updated, 1})],
+             none,
+             [erl_syntax:atom(ok)]),
+           erl_syntax:clause(
+             [erl_syntax:variable("__UpdateRes")],
+             none,
+             [erl_syntax:variable("__UpdateRes")])]),
+    erl_syntax:case_expr(
+      erl_syntax:application(
+        erl_syntax:atom(ejabberd_odbc),
+        erl_syntax:atom(sql_query_t),
+        [Update]),
+      [erl_syntax:clause(
+         [erl_syntax:abstract({updated, 1})],
+         none,
+         [erl_syntax:atom(ok)]),
+       erl_syntax:clause(
+         [erl_syntax:underscore()],
+         none,
+         [InsertBranch])]).
+
+make_sql_upsert_update(Table, ParseRes) ->
+    WPairs =
+        lists:flatmap(
+          fun({_Field, false, _ST}) ->
+                  [];
+             ({Field, true, ST}) ->
+                  [ST#state{
+                     'query' = [{str, Field}, {str, "="}] ++ ST#state.'query'
+                    }]
+          end, ParseRes),
+    Where = join_states(WPairs, " AND "),
+    SPairs =
+        lists:flatmap(
+          fun({_Field, true, _ST}) ->
+                  [];
+             ({Field, false, ST}) ->
+                  [ST#state{
+                     'query' = [{str, Field}, {str, "="}] ++ ST#state.'query'
+                    }]
+          end, ParseRes),
+    Set = join_states(SPairs, ", "),
+    State =
+        concat_states(
+          [#state{'query' = [{str, "UPDATE "}, {str, Table}, {str, " SET "}]},
+           Set,
+           #state{'query' = [{str, " WHERE "}]},
+           Where
+          ]),
+    State.
+
+make_sql_upsert_insert(Table, ParseRes) ->
+    Vals =
+        lists:map(
+          fun({_Field, _, ST}) ->
+                  ST
+          end, ParseRes),
+    Fields =
+        lists:map(
+          fun({Field, _, _ST}) ->
+                  #state{'query' = [{str, Field}]}
+          end, ParseRes),
+    State =
+        concat_states(
+          [#state{'query' = [{str, "INSERT INTO "}, {str, Table}, {str, "("}]},
+           join_states(Fields, ", "),
+           #state{'query' = [{str, ") VALUES ("}]},
+           join_states(Vals, ", "),
+           #state{'query' = [{str, ")"}]}
+          ]),
+    State.
+
+make_sql_upsert_pgsql901(Table, ParseRes) ->
+    Update = make_sql_upsert_update(Table, ParseRes),
+    Vals =
+        lists:map(
+          fun({_Field, _, ST}) ->
+                  ST
+          end, ParseRes),
+    Fields =
+        lists:map(
+          fun({Field, _, _ST}) ->
+                  #state{'query' = [{str, Field}]}
+          end, ParseRes),
+    Insert =
+        concat_states(
+          [#state{'query' = [{str, "INSERT INTO "}, {str, Table}, {str, "("}]},
+           join_states(Fields, ", "),
+           #state{'query' = [{str, ") SELECT "}]},
+           join_states(Vals, ", "),
+           #state{'query' = [{str, " WHERE NOT EXISTS (SELECT * FROM upsert)"}]}
+          ]),
+    State =
+        concat_states(
+          [#state{'query' = [{str, "WITH upsert AS ("}]},
+           Update,
+           #state{'query' = [{str, " RETURNING *) "}]},
+           Insert
+          ]),
+    Upsert = make_sql_query(State),
+    erl_syntax:application(
+      erl_syntax:atom(ejabberd_odbc),
+      erl_syntax:atom(sql_query_t),
+      [Upsert]).
+
+
+concat_states(States) ->
+    lists:foldr(
+      fun(ST11, ST2) ->
+              ST1 = resolve_vars(ST11, ST2),
+              ST1#state{
+                'query' = ST1#state.'query' ++ ST2#state.'query',
+                params = ST1#state.params ++ ST2#state.params,
+                args = ST1#state.args ++ ST2#state.args,
+                res = ST1#state.res ++ ST2#state.res,
+                res_vars = ST1#state.res_vars ++ ST2#state.res_vars,
+                loc = case ST1#state.loc of
+                          undefined -> ST2#state.loc;
+                          _ -> ST1#state.loc
+                      end
+               }
+      end, #state{}, States).
+
+resolve_vars(ST1, ST2) ->
+    Max = lists:max([0 | ST1#state.params ++ ST2#state.params]),
+    {Map, _} =
+        lists:foldl(
+          fun(Var, {Acc, New}) ->
+                  case lists:member(Var, ST2#state.params) of
+                      true ->
+                          {dict:store(Var, New, Acc), New + 1};
+                      false ->
+                          {Acc, New}
+                  end
+          end, {dict:new(), Max + 1}, ST1#state.params),
+    NewParams =
+        lists:map(
+          fun(Var) ->
+                  case dict:find(Var, Map) of
+                      {ok, New} ->
+                          New;
+                      error ->
+                          Var
+                  end
+          end, ST1#state.params),
+    NewQuery =
+        lists:map(
+          fun({var, Var}) ->
+                  case dict:find(Var, Map) of
+                      {ok, New} ->
+                          {var, New};
+                      error ->
+                          {var, Var}
+                  end;
+             (S) -> S
+          end, ST1#state.'query'),
+    ST1#state{params = NewParams, 'query' = NewQuery}.
+
+
+
+join_states([], _Sep) ->
+    #state{};
+join_states([H | T], Sep) ->
+    J = [[H] | [[#state{'query' = [{str, Sep}]}, X] || X <- T]],
+    concat_states(lists:append(J)).