]> granicus.if.org Git - ejabberd/commitdiff
Fix SQL connections leakage
authorEvgeny Khramtsov <ekhramtsov@process-one.net>
Tue, 30 Jul 2019 11:26:11 +0000 (14:26 +0300)
committerEvgeny Khramtsov <ekhramtsov@process-one.net>
Tue, 30 Jul 2019 11:26:11 +0000 (14:26 +0300)
src/ejabberd_sql.erl

index e0f1e9e102a6e14fa430226ee4d4153a67cea154..b225a107a27e8c3f424497e41e982fff294d162b 100644 (file)
@@ -343,31 +343,29 @@ connecting(connect, #state{host = Host} = State) ->
                 end,
     case ConnectRes of
         {ok, Ref} ->
-            erlang:monitor(process, Ref),
-            lists:foreach(
-              fun({{?PREPARE_KEY, _} = Key, _}) ->
-                      erase(Key);
-                 (_) ->
-                      ok
-              end, get()),
-           PendingRequests =
-               p1_queue:dropwhile(
-                 fun(Req) ->
-                         p1_fsm:send_event(self(), Req),
-                         true
-                 end, State#state.pending_requests),
-            State1 = State#state{db_ref = Ref,
-                                 pending_requests = PendingRequests},
-            State2 = get_db_version(State1),
-            {next_state, session_established, State2};
-      {error, Reason} ->
-           StartInterval = ejabberd_option:sql_start_interval(Host),
-           ?WARNING_MSG("~p connection failed:~n** Reason: ~p~n** "
-                        "Retry after: ~B seconds",
-                        [State#state.db_type, Reason,
-                         StartInterval div 1000]),
-           p1_fsm:send_event_after(StartInterval, connect),
-           {next_state, connecting, State}
+           try link(Ref) of
+               _ ->
+                   lists:foreach(
+                     fun({{?PREPARE_KEY, _} = Key, _}) ->
+                             erase(Key);
+                        (_) ->
+                             ok
+                     end, get()),
+                   PendingRequests =
+                       p1_queue:dropwhile(
+                         fun(Req) ->
+                                 p1_fsm:send_event(self(), Req),
+                                 true
+                         end, State#state.pending_requests),
+                   State1 = State#state{db_ref = Ref,
+                                        pending_requests = PendingRequests},
+                   State2 = get_db_version(State1),
+                   {next_state, session_established, State2}
+           catch _:Reason ->
+                   handle_reconnect(Reason, State)
+           end;
+       {error, Reason} ->
+           handle_reconnect(Reason, State)
     end;
 connecting(Event, State) ->
     ?WARNING_MSG("Unexpected event in 'connecting': ~p",
@@ -431,12 +429,8 @@ handle_sync_event(_Event, _From, StateName, State) ->
 code_change(_OldVsn, StateName, State, _Extra) ->
     {ok, StateName, State}.
 
-%% We receive the down signal when we loose the MySQL connection (we are
-%% monitoring the connection)
-handle_info({'DOWN', _MonitorRef, process, _Pid, _Info},
-           _StateName, State) ->
-    p1_fsm:send_event(self(), connect),
-    {next_state, connecting, State};
+handle_info({'EXIT', _Pid, Reason}, _StateName, State) ->
+    handle_reconnect(Reason, State);
 handle_info(Info, StateName, State) ->
     ?WARNING_MSG("Unexpected info in ~p: ~p",
                 [StateName, Info]),
@@ -460,6 +454,15 @@ print_state(State) -> State.
 %%%----------------------------------------------------------------------
 %%% Internal functions
 %%%----------------------------------------------------------------------
+handle_reconnect(Reason, #state{host = Host} = State) ->
+    StartInterval = ejabberd_option:sql_start_interval(Host),
+    ?WARNING_MSG("~p connection failed:~n"
+                "** Reason: ~p~n"
+                "** Retry after: ~B seconds",
+                [State#state.db_type, Reason,
+                 StartInterval div 1000]),
+    p1_fsm:send_event_after(StartInterval, connect),
+    {next_state, connecting, State}.
 
 run_sql_cmd(Command, From, State, Timestamp) ->
     QueryTimeout = query_timeout(State#state.host),