]> granicus.if.org Git - ejabberd/commitdiff
Support for run-time SQL queries selection depending on DBMS version
authorAlexey Shchepin <alexey@process-one.net>
Thu, 11 Feb 2016 17:00:00 +0000 (20:00 +0300)
committerAlexey Shchepin <alexey@process-one.net>
Tue, 1 Mar 2016 19:49:56 +0000 (22:49 +0300)
src/ejabberd_odbc.erl

index ef3c61d0a2006111ede438d05434a2ffc143a768..2a253082b133ab76830fbbbdbbf56f9966efc0d7 100644 (file)
@@ -68,6 +68,7 @@
 -record(state,
        {db_ref = self()                     :: pid(),
         db_type = odbc                      :: pgsql | mysql | sqlite | odbc | mssql,
+        db_version = undefined              :: undefined | non_neg_integer(),
         start_interval = 0                  :: non_neg_integer(),
         host = <<"">>                       :: binary(),
         max_pending_requests_len            :: non_neg_integer(),
@@ -263,15 +264,16 @@ connecting(connect, #state{host = Host} = State) ->
                 end,
     {_, PendingRequests} = State#state.pending_requests,
     case ConnectRes of
-      {ok, Ref} ->
-         erlang:monitor(process, Ref),
-         lists:foreach(fun (Req) ->
-                               (?GEN_FSM):send_event(self(), Req)
-                       end,
-                       queue:to_list(PendingRequests)),
-         {next_state, session_established,
-          State#state{db_ref = Ref,
-                      pending_requests = {0, queue:new()}}};
+        {ok, Ref} ->
+            erlang:monitor(process, Ref),
+            lists:foreach(fun (Req) ->
+                                  (?GEN_FSM):send_event(self(), Req)
+                          end,
+                          queue:to_list(PendingRequests)),
+            State1 = State#state{db_ref = Ref,
+                                 pending_requests = {0, queue:new()}},
+            State2 = get_db_version(State1),
+            {next_state, session_established, State2};
       {error, Reason} ->
          ?INFO_MSG("~p connection failed:~n** Reason: ~p~n** "
                    "Retry after: ~p seconds",
@@ -473,6 +475,14 @@ execute_bloc(F) ->
       Res -> {atomic, Res}
     end.
 
+sql_query_internal([{_, _} | _] = Queries) ->
+    State = get(?STATE_KEY),
+    case select_sql_query(Queries, State) of
+        undefined ->
+            {error, <<"no matching query for the current DBMS found">>};
+        Query ->
+            sql_query_internal(Query)
+    end;
 sql_query_internal(#sql_query{} = Query) ->
     State = get(?STATE_KEY),
     Res =
@@ -545,6 +555,28 @@ sql_query_internal(Query) ->
       _Else -> Res
     end.
 
+select_sql_query(Queries, State) ->
+    select_sql_query(
+      Queries, State#state.db_type, State#state.db_version, undefined).
+
+select_sql_query([], _Type, _Version, undefined) ->
+    undefined;
+select_sql_query([], _Type, _Version, Query) ->
+    Query;
+select_sql_query([{Type, Query} | _], Type, _Version, _) ->
+    Query;
+select_sql_query([{{Type, _Version1}, Query1} | Rest], Type, undefined, _) ->
+    select_sql_query(Rest, Type, undefined, Query1);
+select_sql_query([{{Type, Version1}, Query1} | Rest], Type, Version, Query) ->
+    if
+        Version >= Version1 ->
+            Query1;
+        true ->
+            select_sql_query(Rest, Type, Version, Query)
+    end;
+select_sql_query([{_, _} | Rest], Type, Version, Query) ->
+    select_sql_query(Rest, Type, Version, Query).
+
 generic_sql_query(SQLQuery) ->
     sql_query_format_res(
       sql_query_internal(generic_sql_query_format(SQLQuery)),
@@ -780,6 +812,24 @@ to_odbc({error, Reason}) when is_list(Reason) ->
 to_odbc(Res) ->
     Res.
 
+get_db_version(#state{db_type = pgsql} = State) ->
+    case pgsql:squery(State#state.db_ref,
+                      <<"select current_setting('server_version_num')">>) of
+        {ok, [{_, _, [[SVersion]]}]} ->
+            case catch binary_to_integer(SVersion) of
+                Version when is_integer(Version) ->
+                    State#state{db_version = Version};
+                Error ->
+                    ?WARNING_MSG("error getting pgsql version: ~p", [Error]),
+                    State
+            end;
+        Res ->
+            ?WARNING_MSG("error getting pgsql version: ~p", [Res]),
+            State
+    end;
+get_db_version(State) ->
+    State.
+
 log(Level, Format, Args) ->
     case Level of
       debug -> ?DEBUG(Format, Args);
@@ -927,7 +977,12 @@ check_error({error, Why} = Err, #sql_query{} = Query) ->
                [Query#sql_query.hash, Query#sql_query.loc, Why]),
     Err;
 check_error({error, Why} = Err, Query) ->
-    ?ERROR_MSG("SQL query '~s' failed: ~p", [Query, Why]),
+    case catch iolist_to_binary(Query) of
+        SQuery when is_binary(SQuery) ->
+            ?ERROR_MSG("SQL query '~s' failed: ~p", [SQuery, Why]);
+        _ ->
+            ?ERROR_MSG("SQL query ~p failed: ~p", [Query, Why])
+    end,
     Err;
 check_error(Result, _Query) ->
     Result.