]> granicus.if.org Git - ejabberd/commitdiff
* src/ejabberd_s2s_out.erl: Implements s2s negociation timeouts and s2s connection...
authorMickaël Rémond <mickael.remond@process-one.net>
Fri, 14 Sep 2007 14:18:32 +0000 (14:18 +0000)
committerMickaël Rémond <mickael.remond@process-one.net>
Fri, 14 Sep 2007 14:18:32 +0000 (14:18 +0000)
SVN Revision: 936

src/ejabberd_s2s_out.erl

index 7dce6b159b7b9547a5baaed441f7ea0c331939a3..136a10893ba5bef359ccc0cdc0d3048dba47d8d5 100644 (file)
@@ -25,6 +25,7 @@
         wait_for_auth_result/2,
         wait_for_starttls_proceed/2,
         reopen_socket/2,
+        wait_before_retry/2,
         stream_established/2,
         handle_event/3,
         handle_sync_event/4,
@@ -61,7 +62,7 @@
 %% Module start with or without supervisor:
 -ifdef(NO_TRANSIENT_SUPERVISORS).
 -define(SUPERVISOR_START, p1_fsm:start(ejabberd_s2s_out, [From, Host, Type],
-                           ?FSMOPTS)).
+                           ?FSMLIMITS ++ ?FSMOPTS)).
 -else.
 -define(SUPERVISOR_START, supervisor:start_child(ejabberd_s2s_out_sup,
                                      [From, Host, Type])).
@@ -70,6 +71,7 @@
 %% Only change this value if you now what your are doing:
 -define(FSMLIMITS,[]).
 %% -define(FSMLIMITS, [{max_queue, 2000}]).
+-define(FSMTIMEOUT, 5000).
 
 -define(STREAM_HEADER,
        "<?xml version='1.0'?>"
@@ -99,7 +101,7 @@ start(From, Host, Type) ->
 
 start_link(From, Host, Type) ->
     p1_fsm:start_link(ejabberd_s2s_out, [From, Host, Type],
-                      ?FSMLIMITS ++ ?FSMOPTS).
+                       ?FSMLIMITS ++ ?FSMOPTS).
 
 start_connection(Pid) ->
     p1_fsm:send_event(Pid, init).
@@ -150,7 +152,7 @@ init([From, Server, Type]) ->
                             server = Server,
                             new = New,
                             verify = Verify,
-                            timer = Timer}}.
+                            timer = Timer}, ?FSMTIMEOUT}.
 
 %%----------------------------------------------------------------------
 %% Func: StateName/2
@@ -192,16 +194,23 @@ open_socket(init, StateData) ->
            send_text(NewStateData, io_lib:format(?STREAM_HEADER,
                                            [StateData#state.server,
                                             Version])),
-           {next_state, wait_for_stream, NewStateData};
+           {next_state, wait_for_stream, NewStateData, ?FSMTIMEOUT};
        {error, _Reason} ->
            ?INFO_MSG("s2s connection: ~s -> ~s (remote server not found)",
                      [StateData#state.myname, StateData#state.server]),
-           {stop, normal, StateData}
+           wait_before_reconnect(StateData, 300000)
+           %%{stop, normal, StateData}
     end;
 open_socket(stop, StateData) ->
+    ?INFO_MSG("s2s connection: ~s -> ~s (stopped in open socket)",
+             [StateData#state.myname, StateData#state.server]),
+    {stop, normal, StateData};
+open_socket(timeout, StateData) ->
+    ?INFO_MSG("s2s connection: ~s -> ~s (timeout in open socket)",
+             [StateData#state.myname, StateData#state.server]),
     {stop, normal, StateData};
 open_socket(_, StateData) ->
-    {next_state, open_socket, StateData}.
+    {next_state, open_socket, StateData, ?FSMTIMEOUT}.
 
 %%----------------------------------------------------------------------
 open_socket1(Addr, Port) ->
@@ -246,9 +255,9 @@ wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
            send_db_request(StateData);
        {"jabber:server", "jabber:server:dialback", true} when
              StateData#state.use_v10 ->
-           {next_state, wait_for_features, StateData};
+           {next_state, wait_for_features, StateData, ?FSMTIMEOUT};
        {"jabber:server", "", true} when StateData#state.use_v10 ->
-           {next_state, wait_for_features, StateData#state{db_enabled = false}};
+           {next_state, wait_for_features, StateData#state{db_enabled = false}, ?FSMTIMEOUT};
        _ ->
            send_text(StateData, ?INVALID_NAMESPACE_ERR),
            ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid namespace)",
@@ -296,7 +305,8 @@ wait_for_validation({xmlstreamelement, El}, StateData) ->
            ?DEBUG("recv verify: ~p", [{From, To, Id, Type}]),
            case StateData#state.verify of
                false ->
-                   {next_state, wait_for_validation, StateData};
+                   %% TODO: Should'nt we close the connection here ?
+                   {next_state, wait_for_validation, StateData, ?FSMTIMEOUT};
                {Pid, _Key, _SID} ->
                    case Type of
                        "valid" ->
@@ -314,25 +324,34 @@ wait_for_validation({xmlstreamelement, El}, StateData) ->
                        StateData#state.verify == false ->
                            {stop, normal, StateData};
                        true ->
-                           {next_state, wait_for_validation, StateData}
+                           {next_state, wait_for_validation, StateData,
+                            ?FSMTIMEOUT*3}
                    end
            end;
        _ ->
-           {next_state, wait_for_validation, StateData}
+           {next_state, wait_for_validation, StateData, ?FSMTIMEOUT*3}
     end;
 
 wait_for_validation({xmlstreamend, Name}, StateData) ->
+    ?INFO_MSG("wait for validation: ~s -> ~s (xmlstreamend)",
+             [StateData#state.myname, StateData#state.server]),
     {stop, normal, StateData};
 
 wait_for_validation({xmlstreamerror, _}, StateData) ->
+    ?INFO_MSG("wait for validation: ~s -> ~s (xmlstreamerror)",
+             [StateData#state.myname, StateData#state.server]),
     send_text(StateData,
              ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
     {stop, normal, StateData};
 
 wait_for_validation(timeout, StateData) ->
+    ?INFO_MSG("wait_for_validation: ~s -> ~s (connect timeout)",
+             [StateData#state.myname, StateData#state.server]),
     {stop, normal, StateData};
 
 wait_for_validation(closed, StateData) ->
+    ?INFO_MSG("wait for validation: ~s -> ~s (closed)",
+             [StateData#state.myname, StateData#state.server]),
     {stop, normal, StateData}.
 
 
@@ -391,20 +410,21 @@ wait_for_features({xmlstreamelement, El}, StateData) ->
                                    jlib:encode_base64(
                                      StateData#state.myname)}]}),
                    {next_state, wait_for_auth_result,
-                    StateData#state{try_auth = false}};
+                    StateData#state{try_auth = false}, ?FSMTIMEOUT};
                StartTLS and StateData#state.tls and
                (not StateData#state.tls_enabled) ->
                    send_element(StateData,
                                 {xmlelement, "starttls",
                                  [{"xmlns", ?NS_TLS}], []}),
-                   {next_state, wait_for_starttls_proceed, StateData};
+                   {next_state, wait_for_starttls_proceed, StateData,
+                    ?FSMTIMEOUT};
                StartTLSRequired and (not StateData#state.tls) ->
                    ?DEBUG("restarted: ~p", [{StateData#state.myname,
                                                 StateData#state.server}]),
                    ejabberd_socket:close(StateData#state.socket),
                    {next_state, reopen_socket,
                     StateData#state{socket = undefined,
-                                    use_v10 = false}};
+                                    use_v10 = false}, ?FSMTIMEOUT};
                StateData#state.db_enabled ->
                    send_db_request(StateData);
                true ->
@@ -413,7 +433,7 @@ wait_for_features({xmlstreamelement, El}, StateData) ->
                    % TODO: clear message queue
                    ejabberd_socket:close(StateData#state.socket),
                    {next_state, reopen_socket, StateData#state{socket = undefined,
-                                                               use_v10 = false}}
+                                                               use_v10 = false}, ?FSMTIMEOUT}
            end;
        _ ->
            send_text(StateData,
@@ -458,7 +478,7 @@ wait_for_auth_result({xmlstreamelement, El}, StateData) ->
                    {next_state, wait_for_stream,
                     StateData#state{streamid = new_id(),
                                     authenticated = true
-                                   }};
+                                   }, ?FSMTIMEOUT};
                _ ->
                    send_text(StateData,
                              xml:element_to_string(?SERR_BAD_FORMAT) ++
@@ -474,7 +494,7 @@ wait_for_auth_result({xmlstreamelement, El}, StateData) ->
                                                 StateData#state.server}]),
                    ejabberd_socket:close(StateData#state.socket),
                    {next_state, reopen_socket,
-                    StateData#state{socket = undefined}};
+                    StateData#state{socket = undefined}, ?FSMTIMEOUT};
                _ ->
                    send_text(StateData,
                              xml:element_to_string(?SERR_BAD_FORMAT) ++
@@ -539,7 +559,7 @@ wait_for_starttls_proceed({xmlstreamelement, El}, StateData) ->
                              io_lib:format(?STREAM_HEADER,
                                            [StateData#state.server,
                                             " version='1.0'"])),
-                   {next_state, wait_for_stream, NewStateData};
+                   {next_state, wait_for_stream, NewStateData, ?FSMTIMEOUT};
                _ ->
                    send_text(StateData,
                              xml:element_to_string(?SERR_BAD_FORMAT) ++
@@ -574,17 +594,21 @@ wait_for_starttls_proceed(closed, StateData) ->
 
 
 reopen_socket({xmlstreamelement, El}, StateData) ->
-    {next_state, reopen_socket, StateData};
+    {next_state, reopen_socket, StateData, ?FSMTIMEOUT};
 reopen_socket({xmlstreamend, Name}, StateData) ->
-    {next_state, reopen_socket, StateData};
+    {next_state, reopen_socket, StateData, ?FSMTIMEOUT};
 reopen_socket({xmlstreamerror, _}, StateData) ->
-    {next_state, reopen_socket, StateData};
+    {next_state, reopen_socket, StateData, ?FSMTIMEOUT};
 reopen_socket(timeout, StateData) ->
+    ?INFO_MSG("reopen socket: timeout", []),
     {stop, normal, StateData};
 reopen_socket(closed, StateData) ->
     p1_fsm:send_event(self(), init),
-    {next_state, open_socket, StateData}.
+    {next_state, open_socket, StateData, ?FSMTIMEOUT}.
 
+%% This state is use to avoid reconnecting to often to bad sockets
+wait_before_retry(Event, StateData) ->
+    {next_state, wait_before_retry, StateData, ?FSMTIMEOUT}.
 
 stream_established({xmlstreamelement, El}, StateData) ->
     ?DEBUG("s2S stream established", []),
@@ -657,7 +681,7 @@ stream_established(closed, StateData) ->
 %%          {stop, Reason, NewStateData}
 %%----------------------------------------------------------------------
 handle_event(Event, StateName, StateData) ->
-    {next_state, StateName, StateData}.
+    {next_state, StateName, StateData, get_timeout_interval(StateName)}.
 
 %%----------------------------------------------------------------------
 %% Func: handle_sync_event/4
@@ -670,7 +694,7 @@ handle_event(Event, StateName, StateData) ->
 %%----------------------------------------------------------------------
 handle_sync_event(Event, From, StateName, StateData) ->
     Reply = ok,
-    {reply, Reply, StateName, StateData}.
+    {reply, Reply, StateName, StateData, get_timeout_interval(StateName)}.
 
 code_change(OldVsn, StateName, StateData, Extra) ->
     {ok, StateName, StateData}.
@@ -685,27 +709,39 @@ handle_info({send_text, Text}, StateName, StateData) ->
     send_text(StateData, Text),
     cancel_timer(StateData#state.timer),
     Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
-    {next_state, StateName, StateData#state{timer = Timer}};
+    {next_state, StateName, StateData#state{timer = Timer},
+     get_timeout_interval(StateName)};
 
 handle_info({send_element, El}, StateName, StateData) ->
-    cancel_timer(StateData#state.timer),
-    Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
     case StateName of
        stream_established ->
+           cancel_timer(StateData#state.timer),
+           Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
            send_element(StateData, El),
            {next_state, StateName, StateData#state{timer = Timer}};
+       %% In this state we bounce all message: We are waiting before
+       %% trying to reconnect
+       wait_before_retry ->
+           bounce_element(El, ?ERR_REMOTE_SERVER_NOT_FOUND),
+           {next_state, StateName, StateData};
        _ ->
            Q = queue:in(El, StateData#state.queue),
-           {next_state, StateName, StateData#state{queue = Q,
-                                                   timer = Timer}}
+           {next_state, StateName, StateData#state{queue = Q},
+            get_timeout_interval(StateName)}
     end;
 
+handle_info({timeout, Timer, _}, wait_before_retry,
+           #state{timer = Timer} = StateData) ->
+    ?INFO_MSG("Reconnect delay expired: Will now retry to connect to ~s when needed.", [StateData#state.server]),
+    {stop, normal, StateData};
+
 handle_info({timeout, Timer, _}, StateName,
            #state{timer = Timer} = StateData) ->
+    ?INFO_MSG("Closing connection with ~s: timeout", [StateData#state.server]),
     {stop, normal, StateData};
 
 handle_info(_, StateName, StateData) ->
-    {next_state, StateName, StateData}.
+    {next_state, StateName, StateData, get_timeout_interval(StateName)}.
 
 %%----------------------------------------------------------------------
 %% Func: terminate/3
@@ -751,15 +787,19 @@ send_queue(StateData, Q) ->
            ok
     end.
 
+%% Bounce a single message (xmlelement)
+bounce_element(El, Error) ->
+    Err = jlib:make_error_reply(El, Error),
+    From = jlib:string_to_jid(xml:get_tag_attr_s("from", El)),
+    To = jlib:string_to_jid(xml:get_tag_attr_s("to", El)),
+    ejabberd_router:route(To, From, Err).
+
 bounce_queue(Q, Error) ->
     case queue:out(Q) of
        {{value, El}, Q1} ->
-           Err = jlib:make_error_reply(El, Error),
-           From = jlib:string_to_jid(xml:get_tag_attr_s("from", El)),
-           To = jlib:string_to_jid(xml:get_tag_attr_s("to", El)),
-           ejabberd_router:route(To, From, Err),
+           bounce_element(El, Error),
            bounce_queue(Q1, Error);
-       {empty, Q1} ->
+       {empty, _} ->
            ok
     end.
 
@@ -783,10 +823,7 @@ bounce_messages(Error) ->
                "error" ->
                    ok;
                _ ->
-                   Err = jlib:make_error_reply(El, Error),
-                   From = jlib:string_to_jid(xml:get_attr_s("from", Attrs)),
-                   To = jlib:string_to_jid(xml:get_attr_s("to", Attrs)),
-                   ejabberd_router:route(To, From, Err)
+                   bounce_element(El, Error)
            end,
            bounce_messages(Error)
     after 0 ->
@@ -831,7 +868,7 @@ send_db_request(StateData) ->
                           {"id", SID}],
                          [{xmlcdata, Key2}]})
     end,
-    {next_state, wait_for_validation, StateData#state{new = New}}.
+    {next_state, wait_for_validation, StateData#state{new = New}, ?FSMTIMEOUT*6}.
 
 
 is_verify_res({xmlelement, Name, Attrs, Els}) when Name == "db:result" ->
@@ -918,3 +955,28 @@ log_s2s_out(false, _, _) -> ok;
 %% Log new outgoing connections:
 log_s2s_out(_, Myname, Server) ->
     ?INFO_MSG("Trying to open s2s connection: ~s -> ~s",[Myname, Server]).
+
+%% Calcultate timeout depending on which state we are in:
+%% Can return integer > 0 | infinity
+get_timeout_interval(StateName) ->
+    case StateName of
+       %% Validation implies dialback: Networking can take longer:
+       wait_for_validation ->
+           ?FSMTIMEOUT*6;
+       %% When stream is established, we only rely on S2S Timeout timer:
+       stream_established ->
+           infinity;
+       _ ->
+           ?FSMTIMEOUT
+    end.
+
+%% This function is intended to be called at the end of a state
+%% function that want to wait for a reconnect delay before stopping.
+wait_before_reconnect(StateData, Delay) ->
+    %% bounce queue manage by process and Erlang message queue
+    bounce_queue(StateData#state.queue, ?ERR_REMOTE_SERVER_NOT_FOUND),
+    bounce_messages(?ERR_REMOTE_SERVER_NOT_FOUND),
+    cancel_timer(StateData#state.timer),
+    Timer = erlang:start_timer(Delay, self(), []),
+    {next_state, wait_before_retry, StateData#state{timer=Timer,
+                                                   queue = queue:new()}}.