]> granicus.if.org Git - ejabberd/commitdiff
improved SQL reconnect behaviour
authorEvgeniy Khramtsov <xramtsov@gmail.com>
Sun, 31 Jan 2010 11:41:28 +0000 (11:41 +0000)
committerEvgeniy Khramtsov <xramtsov@gmail.com>
Sun, 31 Jan 2010 11:41:28 +0000 (11:41 +0000)
SVN Revision: 2947

src/odbc/ejabberd_odbc.erl

index b2c1c20fc5cbea020cc0a19be0527828688ef8ef..de38260c4fe3164dfb3b6a3ec35fff5b01806e66 100644 (file)
@@ -27,7 +27,9 @@
 -module(ejabberd_odbc).
 -author('alexey@process-one.net').
 
--behaviour(gen_server).
+-define(GEN_FSM, p1_fsm).
+
+-behaviour(?GEN_FSM).
 
 %% External exports
 -export([start/1, start_link/2,
         escape_like/1,
         keep_alive/1]).
 
-%% gen_server callbacks
+%% gen_fsm callbacks
 -export([init/1,
-        handle_call/3,
-        handle_cast/2,
-        code_change/3,
-        handle_info/2,
-        terminate/2]).
+        handle_event/3,
+        handle_sync_event/4,
+        handle_info/3,
+        terminate/3,
+        code_change/4]).
+
+%% gen_fsm states
+-export([connecting/2,
+        connecting/3,
+        session_established/2,
+        session_established/3]).
 
 -include("ejabberd.hrl").
 
--record(state, {db_ref, db_type}).
+-record(state, {db_ref,
+               db_type,
+               start_interval,
+               host,
+               max_pending_requests_len,
+               pending_requests}).
 
 -define(STATE_KEY, ejabberd_odbc_state).
 -define(NESTING_KEY, ejabberd_odbc_nesting_level).
 -define(KEEPALIVE_TIMEOUT, 60000).
 -define(KEEPALIVE_QUERY, "SELECT 1;").
 
+%%-define(DBGFSM, true).
+
+-ifdef(DBGFSM).
+-define(FSMOPTS, [{debug, [trace]}]).
+-else.
+-define(FSMOPTS, []).
+-endif.
+
 %%%----------------------------------------------------------------------
 %%% API
 %%%----------------------------------------------------------------------
 start(Host) ->
-    gen_server:start(ejabberd_odbc, [Host], []).
+    ?GEN_FSM:start(ejabberd_odbc, [Host], fsm_limit_opts() ++ ?FSMOPTS).
 
 start_link(Host, StartInterval) ->
-    gen_server:start_link(ejabberd_odbc, [Host, StartInterval], []).
+    ?GEN_FSM:start_link(ejabberd_odbc, [Host, StartInterval],
+                       fsm_limit_opts() ++ ?FSMOPTS).
 
 sql_query(Host, Query) ->
     sql_call(Host, {sql_query, Query}).
@@ -95,12 +117,16 @@ sql_bloc(Host, F) ->
 sql_call(Host, Msg) ->
     case get(?STATE_KEY) of
         undefined ->
-            gen_server:call(ejabberd_odbc_sup:get_random_pid(Host),
-                           {sql_cmd, Msg}, ?TRANSACTION_TIMEOUT);
+            ?GEN_FSM:sync_send_event(ejabberd_odbc_sup:get_random_pid(Host),
+                                    {sql_cmd, Msg}, ?TRANSACTION_TIMEOUT);
         _State ->
             nested_op(Msg)
     end.
 
+% perform a harmless query on all opened connexions to avoid connexion close.
+keep_alive(PID) ->
+    ?GEN_FSM:sync_send_event(PID, {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}},
+                            ?KEEPALIVE_TIMEOUT).
 
 %% This function is intended to be used from inside an sql_transaction:
 sql_query_t(Query) ->
@@ -134,16 +160,8 @@ escape_like(C)  -> odbc_queries:escape(C).
 
 
 %%%----------------------------------------------------------------------
-%%% Callback functions from gen_server
+%%% Callback functions from gen_fsm
 %%%----------------------------------------------------------------------
-
-%%----------------------------------------------------------------------
-%% Func: init/1
-%% Returns: {ok, State}          |
-%%          {ok, State, Timeout} |
-%%          ignore               |
-%%          {stop, Reason}
-%%----------------------------------------------------------------------
 init([Host, StartInterval]) ->
     case ejabberd_config:get_local_option({odbc_keepalive_interval, Host}) of
        KeepaliveInterval when is_integer(KeepaliveInterval) ->
@@ -155,80 +173,114 @@ init([Host, StartInterval]) ->
            ?ERROR_MSG("Wrong odbc_keepalive_interval definition '~p'"
                       " for host ~p.~n", [_Other, Host])
     end,
-    SQLServer = ejabberd_config:get_local_option({odbc_server, Host}),
-    case SQLServer of
-       %% Default pgsql port
-       {pgsql, Server, DB, Username, Password} ->
-           pgsql_connect(Server, ?PGSQL_PORT, DB, Username, Password,
-                         StartInterval);
-       {pgsql, Server, Port, DB, Username, Password} when is_integer(Port) ->
-           pgsql_connect(Server, Port, DB, Username, Password,
-                         StartInterval);
-       %% Default mysql port
-       {mysql, Server, DB, Username, Password} ->
-           mysql_connect(Server, ?MYSQL_PORT, DB, Username, Password,
-                         StartInterval);
-       {mysql, Server, Port, DB, Username, Password} when is_integer(Port) ->
-           mysql_connect(Server, Port, DB, Username, Password,
-                         StartInterval);
-       _ when is_list(SQLServer) ->
-           odbc_connect(SQLServer, StartInterval)
-    end.
+    [DBType | _] = db_opts(Host),
+    ?GEN_FSM:send_event(self(), connect),
+    {ok, connecting, #state{db_type = DBType,
+                           host = Host,
+                           max_pending_requests_len = max_fsm_queue(),
+                           pending_requests = {0, queue:new()},
+                           start_interval = StartInterval}}.
+
+connecting(connect, #state{host = Host} = State) ->
+    ConnectRes = case db_opts(Host) of
+                    [mysql | Args] ->
+                        apply(fun mysql_connect/5, Args);
+                    [pgsql | Args] ->
+                        apply(fun pgsql_connect/5, Args);
+                    [odbc | Args] ->
+                        apply(fun odbc_connect/1, Args)
+                end,
+    {_, PendingRequests} = State#state.pending_requests,
+    case ConnectRes of
+       {ok, Ref} ->
+           erlang:monitor(process, Ref),
+           queue:filter(
+             fun(Req) ->
+                     ?GEN_FSM:send_event(self(), Req),
+                     false
+             end, PendingRequests),
+           {next_state, session_established,
+            State#state{db_ref = Ref,
+                        pending_requests = {0, queue:new()}}};
+       {error, Reason} ->
+           ?INFO_MSG("~p connection failed:~n"
+                     "** Reason: ~p~n"
+                     "** Retry after: ~p seconds",
+                     [State#state.db_type, Reason,
+                      State#state.start_interval div 1000]),
+           ?GEN_FSM:send_event_after(State#state.start_interval,
+                                     connect),
+           {next_state, connecting, State}
+    end;
+connecting(Event, State) ->
+    ?WARNING_MSG("unexpected event in 'connecting': ~p", [Event]),
+    {next_state, connecting, State}.
+
+connecting({sql_cmd, {sql_query, ?KEEPALIVE_QUERY}}, From, State) ->
+    ?GEN_FSM:reply(From, {error, "SQL connection failed"}),
+    {next_state, connecting, State};
+connecting({sql_cmd, Command} = Req, From, State) ->
+    ?DEBUG("queueing pending request while connecting:~n\t~p", [Req]),
+    {Len, PendingRequests} = State#state.pending_requests,
+    NewPendingRequests =
+       if Len < State#state.max_pending_requests_len ->
+               {Len + 1, queue:in({sql_cmd, Command, From}, PendingRequests)};
+          true ->
+               queue:filter(
+                 fun({sql_cmd, _, To}) ->
+                         ?GEN_FSM:reply(To,
+                                        {error, "SQL connection failed"}),
+                         false
+                 end, PendingRequests),
+               {1, queue:from_list([{sql_cmd, Command, From}])}
+       end,
+    {next_state, connecting,
+     State#state{pending_requests = NewPendingRequests}};
+connecting(Request, {Who, _Ref}, State) ->
+    ?WARNING_MSG("unexpected call ~p from ~p in 'connecting'",
+                [Request, Who]),
+    {reply, {error, badarg}, connecting, State}.
+
+session_established({sql_cmd, Command}, From, State) ->
+    put(?NESTING_KEY, ?TOP_LEVEL_TXN),
+    put(?STATE_KEY, State),
+    abort_on_driver_error(outer_op(Command), From);
+session_established(Request, {Who, _Ref}, State) ->
+    ?WARNING_MSG("unexpected call ~p from ~p in 'session_established'",
+                [Request, Who]),
+    {reply, {error, badarg}, session_established, State}.
 
-%%----------------------------------------------------------------------
-%% Func: handle_call/3
-%% Returns: {reply, Reply, State}          |
-%%          {reply, Reply, State, Timeout} |
-%%          {noreply, State}               |
-%%          {noreply, State, Timeout}      |
-%%          {stop, Reason, Reply, State}   | (terminate/2 is called)
-%%          {stop, Reason, State}            (terminate/2 is called)
-%%----------------------------------------------------------------------
-handle_call({sql_cmd, Command}, _From, State) ->
+session_established({sql_cmd, Command, From}, State) ->
     put(?NESTING_KEY, ?TOP_LEVEL_TXN),
     put(?STATE_KEY, State),
-    abort_on_driver_error(outer_op(Command));
-handle_call(Request, {Who, _Ref}, State) ->
-    ?WARNING_MSG("Unexpected call ~p from ~p.", [Request, Who]),
-    {reply, ok, State}.
-
-%%----------------------------------------------------------------------
-%% Func: handle_cast/2
-%% Returns: {noreply, State}          |
-%%          {noreply, State, Timeout} |
-%%          {stop, Reason, State}            (terminate/2 is called)
-%%----------------------------------------------------------------------
-handle_cast(_Msg, State) ->
-    {noreply, State}.
-
-
-code_change(_OldVsn, State, _Extra) ->
-    {ok, State}.
-
-%%----------------------------------------------------------------------
-%% Func: handle_info/2
-%% Returns: {noreply, State}          |
-%%          {noreply, State, Timeout} |
-%%          {stop, Reason, State}            (terminate/2 is called)
-%%----------------------------------------------------------------------
+    abort_on_driver_error(outer_op(Command), From);
+session_established(Event, State) ->
+    ?WARNING_MSG("unexpected event in 'session_established': ~p", [Event]),
+    {next_state, session_established, State}.
+
+handle_event(_Event, StateName, State) ->
+    {next_state, StateName, State}.
+
+handle_sync_event(_Event, _From, StateName, State) ->
+    {reply, {error, badarg}, StateName, State}.
+
+code_change(_OldVsn, StateName, State, _Extra) ->
+    {ok, StateName, State}.
+
 %% We receive the down signal when we loose the MySQL connection (we are
 %% monitoring the connection)
-%% => We exit and let the supervisor restart the connection.
-handle_info({'DOWN', _MonitorRef, process, _Pid, _Info}, State) ->
-    {stop, connection_dropped, State};
-handle_info(_Info, State) ->
-    {noreply, State}.
-
-%%----------------------------------------------------------------------
-%% Func: terminate/2
-%% Purpose: Shutdown the server
-%% Returns: any (ignored by gen_server)
-%%----------------------------------------------------------------------
-terminate(_Reason, State) ->
+handle_info({'DOWN', _MonitorRef, process, _Pid, _Info}, _StateName, State) ->
+    ?GEN_FSM:send_event(self(), connect),
+    {next_state, connecting, State};
+handle_info(Info, StateName, State) ->
+    ?WARNING_MSG("unexpected info in ~p: ~p", [StateName, Info]),
+    {next_state, StateName, State}.
+
+terminate(_Reason, _StateName, State) ->
     case State#state.db_type of
        mysql ->
-           % old versions of mysql driver don't have the stop function
-           % so the catch
+           %% old versions of mysql driver don't have the stop function
+           %% so the catch
            catch mysql_conn:stop(State#state.db_ref);
        _ ->
            ok
@@ -367,50 +419,34 @@ sql_query_internal(Query) ->
     end.
 
 %% Generate the OTP callback return tuple depending on the driver result.
-abort_on_driver_error({error, "query timed out"} = Reply) ->
+abort_on_driver_error({error, "query timed out"} = Reply, From) ->
     %% mysql driver error
-    {stop, timeout, Reply, get(?STATE_KEY)};
-abort_on_driver_error({error, "Failed sending data on socket"++_} = Reply) ->
+    ?GEN_FSM:reply(From, Reply),
+    {stop, timeout, get(?STATE_KEY)};
+abort_on_driver_error({error, "Failed sending data on socket" ++ _} = Reply,
+                     From) ->
     %% mysql driver error
-    {stop, closed, Reply, get(?STATE_KEY)};
-abort_on_driver_error(Reply) ->
-    {reply, Reply, get(?STATE_KEY)}.
+    ?GEN_FSM:reply(From, Reply),
+    {stop, closed, get(?STATE_KEY)};
+abort_on_driver_error(Reply, From) ->
+    ?GEN_FSM:reply(From, Reply),
+    {next_state, session_established, get(?STATE_KEY)}.
 
 
 %% == pure ODBC code
 
 %% part of init/1
 %% Open an ODBC database connection
-odbc_connect(SQLServer, StartInterval) ->
+odbc_connect(SQLServer) ->
     application:start(odbc),
-    case odbc:connect(SQLServer,[{scrollable_cursors, off}]) of
-       {ok, Ref} ->
-           erlang:monitor(process, Ref),
-           {ok, #state{db_ref = Ref, db_type = odbc}};
-       {error, Reason} ->
-           ?ERROR_MSG("ODBC connection (~s) failed: ~p~n",
-                      [SQLServer, Reason]),
-           %% If we can't connect we wait before retrying
-           timer:sleep(StartInterval),
-           {stop, odbc_connection_failed}
-    end.
-
+    odbc:connect(SQLServer, [{scrollable_cursors, off}]).
 
 %% == Native PostgreSQL code
 
 %% part of init/1
 %% Open a database connection to PostgreSQL
-pgsql_connect(Server, Port, DB, Username, Password, StartInterval) ->
-    case pgsql:connect(Server, DB, Username, Password, Port) of
-       {ok, Ref} ->
-           erlang:monitor(process, Ref),
-           {ok, #state{db_ref = Ref, db_type = pgsql}};
-       {error, Reason} ->
-           ?ERROR_MSG("PostgreSQL connection failed: ~p~n", [Reason]),
-           %% If we can't connect we wait before retrying
-           timer:sleep(StartInterval),
-           {stop, pgsql_connection_failed}
-    end.
+pgsql_connect(Server, Port, DB, Username, Password) ->
+    pgsql:connect(Server, DB, Username, Password, Port).
 
 %% Convert PostgreSQL query result to Erlang ODBC result formalism
 pgsql_to_odbc({ok, PGSQLResult}) ->
@@ -441,19 +477,13 @@ pgsql_item_to_odbc(_) ->
 
 %% part of init/1
 %% Open a database connection to MySQL
-mysql_connect(Server, Port, DB, Username, Password, StartInterval) ->
+mysql_connect(Server, Port, DB, Username, Password) ->
     case mysql_conn:start(Server, Port, Username, Password, DB, fun log/3) of
        {ok, Ref} ->
-           erlang:monitor(process, Ref),
             mysql_conn:fetch(Ref, ["set names 'utf8';"], self()),
-           {ok, #state{db_ref = Ref, db_type = mysql}};
-       {error, Reason} ->
-           ?ERROR_MSG("MySQL connection failed: ~p~n"
-                      "Waiting ~p seconds before retrying...~n",
-                      [Reason, StartInterval div 1000]),
-           %% If we can't connect we wait before retrying
-           timer:sleep(StartInterval),
-           {stop, mysql_connection_failed}
+           {ok, Ref};
+       Err ->
+           Err
     end.
 
 %% Convert MySQL query result to Erlang ODBC result formalism
@@ -475,11 +505,6 @@ mysql_item_to_odbc(Columns, Recs) ->
      [element(2, Column) || Column <- Columns],
      [list_to_tuple(Rec) || Rec <- Recs]}.
 
-% perform a harmless query on all opened connexions to avoid connexion close.
-keep_alive(PID) ->
-    gen_server:call(PID, {sql_cmd, {sql_query, ?KEEPALIVE_QUERY}},
-                   ?KEEPALIVE_TIMEOUT).
-
 % log function used by MySQL driver
 log(Level, Format, Args) ->
     case Level of
@@ -490,3 +515,35 @@ log(Level, Format, Args) ->
        error ->
            ?ERROR_MSG(Format, Args)
     end.
+
+db_opts(Host) ->
+    case ejabberd_config:get_local_option({odbc_server, Host}) of
+       %% Default pgsql port
+       {pgsql, Server, DB, User, Pass} ->
+           [pgsql, Server, ?PGSQL_PORT, DB, User, Pass];
+       {pgsql, Server, Port, DB, User, Pass} when is_integer(Port) ->
+           [pgsql, Server, Port, DB, User, Pass];
+       %% Default mysql port
+       {mysql, Server, DB, User, Pass} ->
+           [mysql, Server, ?MYSQL_PORT, DB, User, Pass];
+       {mysql, Server, Port, DB, User, Pass} when is_integer(Port) ->
+           [mysql, Server, Port, DB, User, Pass];
+       SQLServer when is_list(SQLServer) ->
+           [odbc, SQLServer]
+    end.
+
+max_fsm_queue() ->
+    case ejabberd_config:get_local_option(max_fsm_queue) of
+       N when is_integer(N), N>0 ->
+           N;
+       _ ->
+           undefined
+    end.
+           
+fsm_limit_opts() ->
+    case max_fsm_queue() of
+       N when is_integer(N) ->
+           [{max_queue, N}];
+       _ ->
+           []
+    end.