]> granicus.if.org Git - ejabberd/commitdiff
More refactoring on session management
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Thu, 29 Dec 2016 21:00:36 +0000 (00:00 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Thu, 29 Dec 2016 21:00:36 +0000 (00:00 +0300)
14 files changed:
src/ejabberd_auth.erl
src/ejabberd_c2s.erl
src/ejabberd_http.erl
src/ejabberd_piefxis.erl
src/ejabberd_s2s_in.erl
src/ejabberd_s2s_out.erl
src/ejabberd_service.erl
src/ejabberd_sm.erl
src/ejabberd_web_admin.erl
src/mod_legacy_auth.erl
src/mod_s2s_dialback.erl
src/mod_sm.erl
src/xmpp_stream_in.erl
src/xmpp_stream_out.erl

index eba0a403855f7461d7c5ae2ac960238111e502af..17bc729cc92a677e9b0cf71b30a63bab5a31ab94 100644 (file)
@@ -188,7 +188,7 @@ try_register(User, Server, Password) ->
       true -> {atomic, exists};
       false ->
          LServer = jid:nameprep(Server),
-         case lists:member(LServer, ?MYHOSTS) of
+         case ejabberd_router:is_my_host(LServer) of
            true ->
                Res = lists:foldl(fun (_M, {atomic, ok} = Res) -> Res;
                                      (M, _) ->
index 07d04fbc486b37652285b83379e57b9340c267e4..f22960c50019ebfd8d72aa4c1d161dee942c7c0f 100644 (file)
         compress_methods/1, bind/2, get_password_fun/1,
         check_password_fun/1, check_password_digest_fun/1,
         unauthenticated_stream_features/1, authenticated_stream_features/1,
-        handle_stream_start/2, handle_stream_end/2, handle_stream_close/2,
+        handle_stream_start/2, handle_stream_end/2,
         handle_unauthenticated_packet/2, handle_authenticated_packet/2,
         handle_auth_success/4, handle_auth_failure/4, handle_send/3,
         handle_recv/3, handle_cdata/2, handle_unbinded_packet/2]).
 %% Hooks
--export([handle_unexpected_info/2, handle_unexpected_cast/2,
-        reject_unauthenticated_packet/2, process_closed/2]).
+-export([handle_unexpected_cast/2,
+        reject_unauthenticated_packet/2, process_closed/2,
+        process_terminated/2, process_info/2]).
 %% API
 -export([get_presence/1, get_subscription/2, get_subscribed/1,
         open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1,
-        copy_state/2, add_hooks/0]).
+        reply/2, copy_state/2, set_timeout/2, add_hooks/1]).
 
 -include("ejabberd.hrl").
 -include("xmpp.hrl").
@@ -76,6 +77,9 @@ socket_type() ->
 call(Ref, Msg, Timeout) ->
     xmpp_stream_in:call(Ref, Msg, Timeout).
 
+reply(Ref, Reply) ->
+    xmpp_stream_in:reply(Ref, Reply).
+
 -spec get_presence(pid()) -> presence().
 get_presence(Ref) ->
     call(Ref, get_presence, 1000).
@@ -112,37 +116,39 @@ stop(Ref) ->
 send(Pid, Pkt) when is_pid(Pid) ->
     xmpp_stream_in:send(Pid, Pkt);
 send(#{lserver := LServer} = State, Pkt) ->
-    case ejabberd_hooks:run_fold(c2s_filter_send, LServer, Pkt, [State]) of
-       drop -> State;
-       Pkt1 -> xmpp_stream_in:send(State, Pkt1)
+    case ejabberd_hooks:run_fold(c2s_filter_send, LServer, {Pkt, State}, []) of
+       {drop, State1} -> State1;
+       {Pkt1, State1} -> xmpp_stream_in:send(State1, Pkt1)
     end.
 
+-spec set_timeout(state(), timeout()) -> state().
+set_timeout(State, Timeout) ->
+    xmpp_stream_in:set_timeout(State, Timeout).
+
 -spec establish(state()) -> state().
 establish(State) ->
     xmpp_stream_in:establish(State).
 
--spec add_hooks() -> ok.
-add_hooks() ->
-    lists:foreach(
-      fun(Host) ->
-             ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100),
-             ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
-                                reject_unauthenticated_packet, 100),
-             ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
-                                handle_unexpected_info, 100),
-             ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE,
-                                handle_unexpected_cast, 100)
-             
-      end, ?MYHOSTS).
+-spec add_hooks(binary()) -> ok.
+add_hooks(Host) ->
+    ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100),
+    ejabberd_hooks:add(c2s_terminated, Host, ?MODULE,
+                      process_terminated, 100),
+    ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
+                      reject_unauthenticated_packet, 100),
+    ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
+                      process_info, 100),
+    ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE,
+                      handle_unexpected_cast, 100).
 
 %% Copies content of one c2s state to another.
 %% This is needed for session migration from one pid to another.
 -spec copy_state(state(), state()) -> state().
 copy_state(#{owner := Owner} = NewState,
-         #{jid := JID, resource := Resource, sid := {Time, _},
-           auth_module := AuthModule, lserver := LServer,
-           pres_t := PresT, pres_a := PresA,
-           pres_f := PresF} = OldState) ->
+          #{jid := JID, resource := Resource, sid := {Time, _},
+            auth_module := AuthModule, lserver := LServer,
+            pres_t := PresT, pres_a := PresA,
+            pres_f := PresF} = OldState) ->
     State1 = case OldState of
                 #{pres_last := Pres, pres_timestamp := PresTS} ->
                     NewState#{pres_last => Pres, pres_timestamp => PresTS};
@@ -158,10 +164,46 @@ copy_state(#{owner := Owner} = NewState,
                     pres_f => PresF},
     ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]).
 
+-spec open_session(state()) -> {ok, state()} | state().
+open_session(#{user := U, server := S, resource := R,
+              sid := SID, ip := IP, auth_module := AuthModule} = State) ->
+    JID = jid:make(U, S, R),
+    change_shaper(State),
+    Conn = get_conn_type(State),
+    State1 = State#{conn => Conn, resource => R, jid => JID},
+    Prio = try maps:get(pres_last, State) of
+              Pres -> get_priority_from_presence(Pres)
+          catch _:{badkey, _} ->
+                  undefined
+          end,
+    Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}],
+    ejabberd_sm:open_session(SID, U, S, R, Prio, Info),
+    xmpp_stream_in:establish(State1).
+
 %%%===================================================================
 %%% Hooks
 %%%===================================================================
-handle_unexpected_info(State, Info) ->
+process_info(#{lserver := LServer} = State,
+            {route, From, To, Packet0}) ->
+    Packet = xmpp:set_from_to(Packet0, From, To),
+    {Pass, State1} = case Packet of
+                        #presence{} ->
+                            process_presence_in(State, Packet);
+                        #message{} ->
+                            process_message_in(State, Packet);
+                        #iq{} ->
+                            process_iq_in(State, Packet)
+                    end,
+    if Pass ->
+           Packet1 = ejabberd_hooks:run_fold(
+                       user_receive_packet, LServer, Packet, [State1]),
+           ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
+           send(State1, Packet1);
+       true ->
+           ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
+           State1
+    end;
+process_info(State, Info) ->
     ?WARNING_MSG("got unexpected info: ~p", [Info]),
     State.
 
@@ -173,8 +215,22 @@ reject_unauthenticated_packet(State, Pkt) ->
     Err = xmpp:err_not_authorized(),
     xmpp_stream_in:send_error(State, Pkt, Err).
 
-process_closed(State, _Reason) ->
-    stop(State).
+process_closed(State, Reason) ->
+    stop(State#{stop_reason => Reason}).
+
+process_terminated(#{socket := Socket, jid := JID} = State,
+                  Reason) ->
+    Status = format_reason(State, Reason),
+    ?INFO_MSG("(~s) Closing c2s connection for ~s: ~s",
+             [ejabberd_socket:pp(Socket), jid:to_string(JID), Status]),
+    Pres = #presence{type = unavailable,
+                    status = xmpp:mk_text(Status),
+                    from = JID, to = jid:remove_resource(JID)},
+    State1 = broadcast_presence_unavailable(State, Pres),
+    bounce_message_queue(),
+    State1;
+process_terminated(State, _Reason) ->
+    State.
 
 %%%===================================================================
 %%% xmpp_stream_in callbacks
@@ -248,25 +304,9 @@ bind(R, #{user := U, server := S, access := Access, lang := Lang,
            end
     end.
 
--spec open_session(state()) -> {ok, state()} | state().
-open_session(#{user := U, server := S, resource := R,
-              sid := SID, ip := IP, auth_module := AuthModule} = State) ->
-    JID = jid:make(U, S, R),
-    change_shaper(State),
-    Conn = get_conn_type(State),
-    State1 = State#{conn => Conn, resource => R, jid => JID},
-    Prio = try maps:get(pres_last, State) of
-              Pres -> get_priority_from_presence(Pres)
-          catch _:{badkey, _} ->
-                  undefined
-          end,
-    Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}],
-    ejabberd_sm:open_session(SID, U, S, R, Prio, Info),
-    State1.
-
 handle_stream_start(StreamStart,
                    #{lserver := LServer, ip := IP, lang := Lang} = State) ->
-    case lists:member(LServer, ?MYHOSTS) of
+    case ejabberd_router:is_my_host(LServer) of
        false ->
            send(State, xmpp:serr_host_unknown());
        true ->
@@ -284,10 +324,8 @@ handle_stream_start(StreamStart,
     end.
 
 handle_stream_end(Reason, #{lserver := LServer} = State) ->
-    ejabberd_hooks:run_fold(c2s_closed, LServer, State, [Reason]).
-
-handle_stream_close(_Reason, #{lserver := LServer} = State) ->
-    ejabberd_hooks:run_fold(c2s_closed, LServer, State, [normal]).
+    State1 = State#{stop_reason => Reason},
+    ejabberd_hooks:run_fold(c2s_closed, LServer, State1, [Reason]).
 
 handle_auth_success(User, Mech, AuthModule,
                    #{socket := Socket, ip := IP, lserver := LServer} = State) ->
@@ -296,8 +334,7 @@ handle_auth_success(User, Mech, AuthModule,
               ejabberd_auth:backend_type(AuthModule),
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
     State1 = State#{auth_module => AuthModule},
-    ejabberd_hooks:run_fold(c2s_auth_result, LServer,
-                           State1, [true, User]).
+    ejabberd_hooks:run_fold(c2s_auth_result, LServer, State1, [true, User]).
 
 handle_auth_failure(User, Mech, Reason,
                    #{socket := Socket, ip := IP, lserver := LServer} = State) ->
@@ -307,16 +344,13 @@ handle_auth_failure(User, Mech, Reason,
                  true -> ""
               end,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
-    ejabberd_hooks:run_fold(c2s_auth_result, LServer,
-                           State, [false, User]).
+    ejabberd_hooks:run_fold(c2s_auth_result, LServer, State, [false, User]).
 
 handle_unbinded_packet(Pkt, #{lserver := LServer} = State) ->
-    ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer,
-                           State, [Pkt]).
+    ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer, State, [Pkt]).
 
 handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) ->
-    ejabberd_hooks:run_fold(c2s_unauthenticated_packet,
-                           LServer, State, [Pkt]).
+    ejabberd_hooks:run_fold(c2s_unauthenticated_packet, LServer, State, [Pkt]).
 
 handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) ->
     ejabberd_hooks:run_fold(c2s_authenticated_packet,
@@ -366,20 +400,22 @@ init([State, Opts]) ->
                    zlib => Zlib,
                    lang => ?MYLANG,
                    server => ?MYNAME,
+                   lserver => ?MYNAME,
                    access => Access,
                    shaper => Shaper},
     ejabberd_hooks:run_fold(c2s_init, {ok, State1}, [Opts]).
 
-handle_call(get_presence, _From, #{jid := JID} = State) ->
+handle_call(get_presence, From, #{jid := JID} = State) ->
     Pres = try maps:get(pres_last, State)
           catch _:{badkey, _} ->
                   BareJID = jid:remove_resource(JID),
                   #presence{from = JID, to = BareJID, type = unavailable}
           end,
-    {reply, Pres, State};
-handle_call(get_subscribed, _From, #{pres_f := PresF} = State) ->
-    Subscribed = ?SETS:to_list(PresF),
-    {reply, Subscribed, State};
+    reply(From, Pres),
+    State;
+handle_call(get_subscribed, From, #{pres_f := PresF} = State) ->
+    reply(From, ?SETS:to_list(PresF)),
+    State;
 handle_call(Request, From, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(
       c2s_handle_call, LServer, State, [Request, From]).
@@ -387,30 +423,22 @@ handle_call(Request, From, #{lserver := LServer} = State) ->
 handle_cast(Msg, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_handle_cast, LServer, State, [Msg]).
 
-handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) ->
-    Packet = xmpp:set_from_to(Packet0, From, To),
-    {Pass, NewState} = case Packet of
-                          #presence{} ->
-                              process_presence_in(State, Packet);
-                          #message{} ->
-                              process_message_in(State, Packet);
-                          #iq{} ->
-                              process_iq_in(State, Packet)
-                      end,
-    if Pass ->
-           Packet1 = ejabberd_hooks:run_fold(
-                       user_receive_packet, LServer, Packet, [NewState]),
-           ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
-           send(NewState, Packet1);
-       true ->
-           ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
-           NewState
-    end;
 handle_info(Info, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]).
 
-terminate(_Reason, _State) ->
-    ok.
+terminate(Reason, #{sid := SID, jid := _,
+                   user := U, server := S, resource := R,
+                   lserver := LServer} = State) ->
+    Status = format_reason(State, Reason),
+    case maps:is_key(pres_last, State) of
+       true ->
+           ejabberd_sm:close_session_unset_presence(SID, U, S, R, Status);
+       false ->
+           ejabberd_sm:close_session(SID, U, S, R)
+    end,
+    ejabberd_hooks:run_fold(c2s_terminated, LServer, State, [Reason]);
+terminate(Reason, #{lserver := LServer} = State) ->
+    ejabberd_hooks:run_fold(c2s_terminated, LServer, State, [Reason]).
 
 code_change(_OldVsn, State, _Extra) ->
     {ok, State}.
@@ -684,6 +712,15 @@ resource_conflict_action(U, S, R) ->
            {accept_resource, Rnew}
     end.
 
+-spec bounce_message_queue() -> ok.
+bounce_message_queue() ->
+    receive {route, From, To, Pkt} ->
+           ejabberd_router:route(From, To, Pkt),
+           bounce_message_queue()
+    after 0 ->
+           ok
+    end.
+
 -spec new_uniq_id() -> binary().
 new_uniq_id() ->
     iolist_to_binary(
@@ -735,6 +772,14 @@ do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) ->
            end
     end.
 
+-spec format_reason(state(), term()) -> binary().
+format_reason(#{stop_reason := Reason}, _) ->
+    xmpp_stream_in:format_error(Reason);
+format_reason(_, Reason) when Reason /= normal ->
+    <<"internal server error">>;
+format_reason(_, _) ->
+    <<"">>.
+
 transform_listen_option(Opt, Opts) ->
     [Opt|Opts].
 
index c0c7bbbd6863aef744261c8d4ba51e3e043c01cc..1a33580eaf509bd89e69fb698edf36e34cff63fa 100644 (file)
@@ -322,7 +322,7 @@ add_header(Name, Value, State)->
 get_host_really_served(undefined, Provided) ->
     Provided;
 get_host_really_served(Default, Provided) ->
-    case lists:member(Provided, ?MYHOSTS) of
+    case ejabberd_router:is_my_host(Provided) of
       true -> Provided;
       false -> Default
     end.
index b6f90ccf844ffc72cef3eb8f4a55c2af642b6988..36d734004264ca982c4bbaf125a78a32c117fd5d 100644 (file)
@@ -350,7 +350,7 @@ process_el({xmlstreamelement, #xmlel{name = <<"host">>,
     JIDS = fxml:get_attr_s(<<"jid">>, Attrs),
     case jid:from_string(JIDS) of
         #jid{lserver = S} ->
-            case lists:member(S, ?MYHOSTS) of
+            case ejabberd_router:is_my_host(S) of
                 true ->
                     process_users(Els, State#state{server = S});
                 false ->
index 93f75bfcf32ce767bab9d0d7834b47ab40d9f8ac..a31af337e6af3f0f93ca1daacaf1bceb188ea33e 100644 (file)
@@ -34,7 +34,7 @@
 -export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
         compress_methods/1,
         unauthenticated_stream_features/1, authenticated_stream_features/1,
-        handle_stream_start/2, handle_stream_end/2, handle_stream_close/2,
+        handle_stream_start/2, handle_stream_end/2,
         handle_stream_established/1, handle_auth_success/4,
         handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2,
         handle_unauthenticated_packet/2, handle_authenticated_packet/2]).
@@ -160,9 +160,6 @@ handle_stream_start(_StreamStart, #{lserver := LServer} = State) ->
 handle_stream_end(Reason, #{server_host := LServer} = State) ->
     ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [Reason]).
 
-handle_stream_close(_Reason, #{server_host := LServer} = State) ->
-    ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [normal]).
-
 handle_stream_established(State) ->
     set_idle_timeout(State#{established => true}).
 
index 72d9dfea8594c50f5c58f509c5e7172dfb3df830..6069c786ca465c5a71427d64d003d6a5cf08351f 100644 (file)
 %% xmpp_stream_out callbacks
 -export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
         handle_auth_success/2, handle_auth_failure/3, handle_packet/2,
-        handle_stream_end/2, handle_stream_close/2,
+        handle_stream_end/2, handle_stream_downgraded/2,
         handle_recv/3, handle_send/4, handle_cdata/2,
         handle_stream_established/1, handle_timeout/1]).
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
         terminate/2, code_change/3]).
 %% Hooks
 -export([process_auth_result/2, process_closed/2, handle_unexpected_info/2,
-        handle_unexpected_cast/2]).
+        handle_unexpected_cast/2, process_downgraded/2]).
 %% API
 -export([start/3, start_link/3, connect/1, close/1, stop/1, send/2,
         route/2, establish/1, update_state/2, add_hooks/0]).
@@ -83,7 +83,9 @@ add_hooks() ->
              ejabberd_hooks:add(s2s_out_handle_info, Host, ?MODULE,
                                 handle_unexpected_info, 100),
              ejabberd_hooks:add(s2s_out_handle_cast, Host, ?MODULE,
-                                handle_unexpected_cast, 100)
+                                handle_unexpected_cast, 100),
+             ejabberd_hooks:add(s2s_out_downgraded, Host, ?MODULE,
+                                process_downgraded, 100)
       end, ?MYHOSTS).
 
 %%%===================================================================
@@ -95,25 +97,28 @@ process_auth_result(#{server := LServer, remote_server := RServer} = State,
     ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: authentication failed;"
              " bouncing for ~p seconds",
              [LServer, RServer, Delay]),
-    State1 = close(State),
-    State2 = bounce_queue(State1),
-    xmpp_stream_out:set_timeout(State2, timer:seconds(Delay));
+    State1 = State#{on_route => bounce},
+    State2 = close(State1),
+    State3 = bounce_queue(State2),
+    xmpp_stream_out:set_timeout(State3, timer:seconds(Delay));
 process_auth_result(State, true) ->
     State.
 
+process_closed(#{server := LServer, remote_server := RServer,
+                on_route := send} = State,
+              Reason) ->
+    ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s",
+             [LServer, RServer, xmpp_stream_out:format_error(Reason)]),
+    stop(State);
 process_closed(#{server := LServer, remote_server := RServer} = State,
-              _Reason) ->
+              Reason) ->
     Delay = get_delay(),
     ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s; "
              "bouncing for ~p seconds",
-             [LServer, RServer,
-              try maps:get(stop_reason, State) of
-                  {error, Why} -> xmpp_stream_out:format_error(Why)
-              catch _:undef -> <<"unexplained reason">>
-              end,
-              Delay]),
-    State1 = bounce_queue(State),
-    xmpp_stream_out:set_timeout(State1, timer:seconds(Delay)).
+             [LServer, RServer, xmpp_stream_out:format_error(Reason), Delay]),
+    State1 = State#{on_route => bounce},
+    State2 = bounce_queue(State1),
+    xmpp_stream_out:set_timeout(State2, timer:seconds(Delay)).
 
 handle_unexpected_info(State, Info) ->
     ?WARNING_MSG("got unexpected info: ~p", [Info]),
@@ -123,6 +128,9 @@ handle_unexpected_cast(State, Msg) ->
     ?WARNING_MSG("got unexpected cast: ~p", [Msg]),
     State.
 
+process_downgraded(State, _StreamStart) ->
+    send(State, xmpp:serr_unsupported_version()).
+
 %%%===================================================================
 %%% gen_server callbacks
 %%%===================================================================
@@ -153,21 +161,19 @@ handle_auth_failure(Mech, Reason,
     ?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s",
              [ejabberd_socket:pp(Socket), Mech, LServer, RServer,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
-    State1 = State#{on_route => bounce,
-                   stop_reason => {error, {auth, Reason}}},
+    State1 = State#{stop_reason => {auth, Reason}},
     ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State1, [false]).
 
 handle_packet(Pkt, #{server := LServer} = State) ->
     ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]).
 
 handle_stream_end(Reason, #{server := LServer} = State) ->
-    State1 = State#{on_route => bounce, stop_reason => Reason},
-    ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [normal]).
-
-handle_stream_close(Reason, #{server := LServer} = State) ->
-    State1 = State#{on_route => bounce, stop_reason => Reason},
+    State1 = State#{stop_reason => Reason},
     ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [Reason]).
 
+handle_stream_downgraded(StreamStart, #{server := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_out_downgraded, LServer, State, [StreamStart]).
+
 handle_stream_established(State) ->
     State1 = State#{on_route => send},
     State2 = resend_queue(State1),
@@ -183,15 +189,10 @@ handle_send(Pkt, El, Data, #{server := LServer} = State) ->
     ejabberd_hooks:run_fold(s2s_out_handle_send, LServer,
                            State, [Pkt, El, Data]).
 
-handle_timeout(#{server := LServer, remote_server := RServer,
-                on_route := Action} = State) ->
+handle_timeout(#{on_route := Action} = State) ->
     case Action of
        bounce -> stop(State);
-       queue -> send(State, xmpp:serr_connection_timeout());
-       send ->
-           ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: inactive",
-                     [LServer, RServer]),
-           stop(State)
+       _ -> send(State, xmpp:serr_connection_timeout())
     end.
 
 init([#{server := LServer, remote_server := RServer} = State, Opts]) ->
@@ -229,7 +230,7 @@ terminate(Reason, #{server := LServer,
     ejabberd_s2s:remove_connection({LServer, RServer}, self()),
     State1 = case Reason of
                 normal -> State;
-                _ -> State#{stop_reason => {error, internal_failure}}
+                _ -> State#{stop_reason => internal_failure}
             end,
     bounce_queue(State1),
     bounce_message_queue(State1).
@@ -258,8 +259,7 @@ bounce_queue(#{queue := Q} = State) ->
 
 -spec bounce_message_queue(state()) -> state().
 bounce_message_queue(State) ->
-    receive
-       {route, Pkt} ->
+    receive {route, Pkt} ->
            State1 = bounce_packet(Pkt, State),
            bounce_message_queue(State1)
     after 0 ->
@@ -278,21 +278,19 @@ bounce_packet(_, State) ->
     State.
 
 -spec mk_bounce_error(binary(), state()) -> stanza_error().
-mk_bounce_error(Lang, State) ->
-    try maps:get(stop_reason, State) of
-       {error, internal_failure} ->
+mk_bounce_error(Lang, #{stop_reason := Why}) ->
+    Reason = xmpp_stream_out:format_error(Why),
+    case Why of
+       internal_failure ->
            xmpp:err_internal_server_error();
-       {error, Why} ->
-           Reason = xmpp_stream_out:format_error(Why),
-           case Why of
-               {dns, _} ->
-                   xmpp:err_remote_server_timeout(Reason, Lang);
-               _ ->
-                   xmpp:err_remote_server_not_found(Reason, Lang)
-           end
-    catch _:{badkey, _} ->
-           xmpp:err_remote_server_not_found()
-    end.
+       {dns, _} ->
+           xmpp:err_remote_server_not_found(Reason, Lang);
+       _ ->
+           xmpp:err_remote_server_timeout(Reason, Lang)
+    end;
+mk_bounce_error(_Lang, _State) ->
+    %% We should not be here. Probably :)
+    xmpp:err_remote_server_not_found().
 
 -spec get_delay() -> non_neg_integer().
 get_delay() ->
index 13efd15e7917e09bb23460f32f19517d6897f327..6ecd03a4c2008672ac88004c0e6fc88d22038d0c 100644 (file)
@@ -99,7 +99,7 @@ handle_stream_start(_StreamStart,
                    #{remote_server := RemoteServer,
                      lang := Lang,
                      host_opts := HostOpts} = State) ->
-    case lists:member(RemoteServer, ?MYHOSTS) of
+    case ejabberd_router:is_my_host(RemoteServer) of
        true ->
            Txt = <<"Unable to register route on existing local domain">>,
            xmpp_stream_in:send(State, xmpp:serr_conflict(Txt, Lang));
index 46008bec47e1b9f25682d6efce53b6efac5e8452..a15d788d0abdf33aeba77a047ca19521d559e7f3 100644 (file)
@@ -390,7 +390,8 @@ init([]) ->
              ejabberd_hooks:add(offline_message_hook, Host,
                                 ejabberd_sm, bounce_offline_message, 100),
              ejabberd_hooks:add(remove_user, Host,
-                                ejabberd_sm, disconnect_removed_user, 100)
+                                ejabberd_sm, disconnect_removed_user, 100),
+             ejabberd_c2s:add_hooks(Host)
       end, ?MYHOSTS),
     ejabberd_commands:register_commands(get_commands_spec()),
     {ok, #state{}}.
index 3836beda76c63f318076cd90fa62049b26325bdb..e1c0760e984ae88389446f1d89906a1b82746621 100644 (file)
@@ -192,7 +192,7 @@ process([<<"server">>, SHost | RPath] = Path,
                 method = Method} =
            Request) ->
     Host = jid:nameprep(SHost),
-    case lists:member(Host, ?MYHOSTS) of
+    case ejabberd_router:is_my_host(Host) of
       true ->
          case get_auth_admin(Auth, HostHTTP, Path, Method) of
            {ok, {User, Server}} ->
index f93b67e05a71db7969b1cbb2509c32cfb38f44c5..e9057b4327c8ae56b4e6dff756186ad94b1519ea 100644 (file)
@@ -133,8 +133,7 @@ open_session(State, IQ, R) ->
     case ejabberd_c2s:bind(R, State) of
        {ok, State1} ->
            Res = xmpp:make_iq_result(IQ),
-           State2 = ejabberd_c2s:send(State1, Res),
-           ejabberd_c2s:establish(State2);
+           ejabberd_c2s:send(State1, Res);
        {error, Err, State1} ->
            Res = xmpp:make_error(IQ, Err),
            ejabberd_c2s:send(State1, Res)
index ce9d2705b6088460e8a1f516c73a8418d287188a..4bdda2ca7a9dbec2b2ad2a5db69e9511cf06744a 100644 (file)
@@ -28,7 +28,8 @@
 %% gen_mod API
 -export([start/2, stop/1, depends/2, mod_opt_type/1]).
 %% Hooks
--export([s2s_out_auth_result/2, s2s_in_packet/2, s2s_out_packet/2,
+-export([s2s_out_auth_result/2, s2s_out_downgraded/2,
+        s2s_in_packet/2, s2s_out_packet/2,
         s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]).
 
 -include("ejabberd.hrl").
@@ -57,6 +58,8 @@ start(Host, _Opts) ->
                               s2s_in_packet, 50),
            ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE,
                               s2s_out_packet, 50),
+           ejabberd_hooks:add(s2s_out_downgraded, Host, ?MODULE,
+                              s2s_out_downgraded, 50),
            ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE,
                               s2s_out_auth_result, 50)
     end.
@@ -74,6 +77,8 @@ stop(Host) ->
                          s2s_in_packet, 50),
     ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE,
                          s2s_out_packet, 50),
+    ejabberd_hooks:delete(s2s_out_downgraded, Host, ?MODULE,
+                         s2s_out_downgraded, 50),
     ejabberd_hooks:delete(s2s_out_auth_result, Host, ?MODULE,
                          s2s_out_auth_result, 50).
 
@@ -104,47 +109,56 @@ s2s_out_init(Acc, _Opts) ->
 
 s2s_out_closed(#{server := LServer,
                 remote_server := RServer,
-                db_verify := {StreamID, _Key, _Pid}} = State, _Reason) ->
+                db_verify := {StreamID, _Key, _Pid}} = State, Reason) ->
     %% Outbound s2s verificating connection (created at step 1) is
     %% closed suddenly without receiving the response.
     %% Building a response on our own
     Response = #db_verify{from = RServer, to = LServer,
                          id = StreamID, type = error,
-                         sub_els = [mk_error(internal_server_error)]},
+                         sub_els = [mk_error(Reason)]},
     s2s_out_packet(State, Response);
 s2s_out_closed(State, _Reason) ->
     State.
 
-s2s_out_auth_result(#{server := LServer,
-                     remote_server := RServer,
-                     db_verify := {StreamID, Key, _Pid}} = State,
-                   _) ->
+s2s_out_auth_result(#{db_verify := _} = State, _) ->
     %% The temporary outbound s2s connect (intended for verification)
     %% has passed authentication state (either successfully or not, no matter)
     %% and at this point we can send verification request as described
     %% in section 2.1.2, step 2
-    Request = #db_verify{from = LServer, to = RServer,
-                        key = Key, id = StreamID},
-    {stop, ejabberd_s2s_out:send(State, Request)};
+    {stop, send_verify_request(State)};
 s2s_out_auth_result(#{db_enabled := true,
                      socket := Socket, ip := IP,
                      server := LServer,
-                     remote_server := RServer,
-                     stream_remote_id := StreamID} = State, false) ->
+                     remote_server := RServer} = State, false) ->
     %% SASL authentication has failed, retrying with dialback
     %% Sending dialback request, section 2.1.1, step 1
     ?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)",
              [ejabberd_socket:pp(Socket), LServer, RServer,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
-    Key = make_key(LServer, RServer, StreamID),
     State1 = maps:remove(stop_reason, State#{on_route => queue}),
-    State2 = ejabberd_s2s_out:send(State1, #db_result{from = LServer,
-                                                     to = RServer,
-                                                     key = Key}),
-    {stop, State2};
+    {stop, send_db_request(State1)};
 s2s_out_auth_result(State, _) ->
     State.
 
+s2s_out_downgraded(#{db_verify := _} = State, _) ->
+    %% The verifying outbound s2s connection detected non-RFC compliant
+    %% server, send verification request immediately without auth phase,
+    %% section 2.1.2, step 2
+    {stop, send_verify_request(State)};
+s2s_out_downgraded(#{db_enabled := true,
+                    socket := Socket, ip := IP,
+                    server := LServer,
+                    remote_server := RServer} = State, _) ->
+    %% non-RFC compliant server detected, send dialback request instantly,
+    %% section 2.1.1, step 1
+    ?INFO_MSG("(~s) Trying s2s dialback authentication with "
+             "non-RFC compliant server: ~s -> ~s (~s)",
+             [ejabberd_socket:pp(Socket), LServer, RServer,
+              ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+    {stop, send_db_request(State)};
+s2s_out_downgraded(State, _) ->
+    State.
+
 s2s_in_packet(#{stream_id := StreamID} = State,
              #db_result{from = From, to = To, key = Key, type = undefined}) ->
     %% Received dialback request, section 2.2.1, step 1
@@ -220,6 +234,23 @@ make_key(From, To, StreamID) ->
       crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)),
                  [To, " ", From, " ", StreamID])).
 
+-spec send_verify_request(ejabberd_s2s_out:state()) -> ejabberd_s2s_out:state().
+send_verify_request(#{server := LServer,
+                     remote_server := RServer,
+                     db_verify := {StreamID, Key, _Pid}} = State) ->
+    Request = #db_verify{from = LServer, to = RServer,
+                        key = Key, id = StreamID},
+    ejabberd_s2s_out:send(State, Request).
+
+-spec send_db_request(ejabberd_s2s_out:state()) -> ejabberd_s2s_out:state().
+send_db_request(#{server := LServer,
+                 remote_server := RServer,
+                 stream_remote_id := StreamID} = State) ->
+    Key = make_key(LServer, RServer, StreamID),
+    ejabberd_s2s_out:send(State, #db_result{from = LServer,
+                                           to = RServer,
+                                           key = Key}).
+
 -spec send_db_result(ejabberd_s2s_in:state(), db_verify()) -> ejabberd_s2s_in:state().
 send_db_result(State, #db_verify{from = From, to = To,
                                 type = Type, sub_els = Els}) ->
@@ -255,6 +286,9 @@ mk_error(forbidden) ->
     xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG);
 mk_error(host_unknown) ->
     xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG);
+mk_error({_Class, _Reason} = Why) ->
+    Txt = xmpp_stream_out:format_error(Why),
+    xmpp:err_remote_server_not_found(Txt, ?MYLANG);
 mk_error(_) ->
     xmpp:err_internal_server_error().
 
index 82d68702dced979a3870f906336ad6b669982181..7032344196a113d63f16cdcdfb96dbf7e7d82337 100644 (file)
@@ -30,8 +30,9 @@
 %% hooks
 -export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2,
         c2s_authenticated_packet/2, c2s_unauthenticated_packet/2,
-        c2s_unbinded_packet/2, c2s_closed/2,
-        c2s_handle_send/3, c2s_filter_send/2, c2s_handle_info/2]).
+        c2s_unbinded_packet/2, c2s_closed/2, c2s_terminated/2,
+        c2s_handle_send/3, c2s_filter_send/1, c2s_handle_info/2,
+        c2s_handle_call/3, c2s_handle_recv/3]).
 
 -include("xmpp.hrl").
 -include("logger.hrl").
@@ -60,13 +61,13 @@ start(Host, _Opts) ->
                       c2s_unbinded_packet, 50),
     ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE,
                       c2s_authenticated_packet, 50),
-    ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE,
-                      c2s_handle_send, 50),
-    ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE,
-                      c2s_filter_send, 50),
-    ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
-                      c2s_handle_info, 50),
-    ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50).
+    ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50),
+    ejabberd_hooks:add(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50),
+    ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50),
+    ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50),
+    ejabberd_hooks:add(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50),
+    ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50),
+    ejabberd_hooks:add(c2s_terminated, Host, ?MODULE, c2s_terminated, 50).
 
 stop(Host) ->
     %% TODO: do something with global 'c2s_init' hook
@@ -80,13 +81,13 @@ stop(Host) ->
                          c2s_unbinded_packet, 50),
     ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE,
                          c2s_authenticated_packet, 50),
-    ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE,
-                         c2s_handle_send, 50),
-    ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE,
-                         c2s_filter_send, 50),
-    ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
-                         c2s_handle_info, 50),
-    ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50).
+    ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50),
+    ejabberd_hooks:delete(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50),
+    ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50),
+    ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50),
+    ejabberd_hooks:delete(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50),
+    ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50),
+    ejabberd_hooks:delete(c2s_terminated, Host, ?MODULE, c2s_terminated, 50).
 
 depends(_Host, _Opts) ->
     [].
@@ -115,7 +116,10 @@ c2s_stream_started(#{lserver := LServer, mgmt_options := Opts} = State,
            mgmt_timeout => ResumeTimeout,
            mgmt_max_timeout => MaxResumeTimeout,
            mgmt_ack_timeout => get_ack_timeout(LServer, Opts),
-           mgmt_resend => get_resend_on_timeout(LServer, Opts)};
+           mgmt_resend => get_resend_on_timeout(LServer, Opts),
+           mgmt_stanzas_in => 0,
+           mgmt_stanzas_out => 0,
+           mgmt_stanzas_req => 0};
 c2s_stream_started(State, _StreamStart) ->
     State.
 
@@ -143,8 +147,8 @@ c2s_unbinded_packet(State, #sm_resume{} = Pkt) ->
     case handle_resume(State, Pkt) of
        {ok, ResumedState} ->
            {stop, ResumedState};
-       error ->
-           {stop, State}
+       {error, State1} ->
+           {stop, State1}
     end;
 c2s_unbinded_packet(State, Pkt) when ?is_sm_packet(Pkt) ->
     c2s_unauthenticated_packet(State, Pkt);
@@ -161,12 +165,26 @@ c2s_authenticated_packet(#{mgmt_state := MgmtState} = State, Pkt)
 c2s_authenticated_packet(State, Pkt) ->
     update_num_stanzas_in(State, Pkt).
 
+c2s_handle_recv(#{lang := Lang} = State, El, {error, Why}) ->
+    Xmlns = xmpp:get_ns(El),
+    if Xmlns == ?NS_STREAM_MGMT_2; Xmlns == ?NS_STREAM_MGMT_3 ->
+           Txt = xmpp:io_format_error(Why),
+           Err = #sm_failed{reason = 'bad-request',
+                            text = xmpp:mk_text(Txt, Lang),
+                            xmlns = Xmlns},
+           send(State, Err);
+       true ->
+           State
+    end;
+c2s_handle_recv(State, _, _) ->
+    State.
+
 c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
   when MgmtState == pending; MgmtState == active ->
     State1 = mgmt_queue_add(State, Pkt),
     case Result of
        ok when ?is_stanza(Pkt) ->
-           send_ack(State1);
+           send_rack(State1);
        ok ->
            State1;
        {error, _} ->
@@ -175,21 +193,57 @@ c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
 c2s_handle_send(State, _Pkt, _Result) ->
     State.
 
-c2s_filter_send(Pkt, _State) ->
-    Pkt.
+c2s_filter_send({Pkt, State}) ->
+    {Pkt, State}.
+
+c2s_handle_call(#{sid := {Time, _}} = State,
+               {resume_session, Time}, From) ->
+    ejabberd_c2s:reply(From, {resume, State}),
+    {stop, State#{mgmt_state => resumed}};
+c2s_handle_call(State, {resume_session, _}, From) ->
+    ejabberd_c2s:reply(From, {error, <<"Previous session not found">>}),
+    {stop, State};
+c2s_handle_call(State, _Call, _From) ->
+    State.
 
-c2s_handle_info(#{mgmt_ack_timer := T, jid := JID} = State,
-               {timeout, T, ack_timeout}) ->
-    ?DEBUG("Timeout waiting for stream management acknowledgement of ~s",
+c2s_handle_info(#{mgmt_ack_timer := TRef, jid := JID} = State,
+               {timeout, TRef, ack_timeout}) ->
+    ?DEBUG("Timeout waiting for stream management acknowledgement of ~s",
           [jid:to_string(JID)]),
     State1 = ejabberd_c2s:close(State, _SendTrailer = false),
-    c2s_closed(State1, ack_timeout);
+    {stop, transition_to_pending(State1)};
+c2s_handle_info(#{mgmt_state := pending, jid := JID} = State,
+               {timeout, _, pending_timeout}) ->
+    ?DEBUG("Timed out waiting for resumption of stream for ~s",
+          [jid:to_string(JID)]),
+    ejabberd_c2s:stop(State#{mgmt_state => timeout});
 c2s_handle_info(State, _) ->
     State.
 
-c2s_closed(#{mgmt_state := active} = State, Reason) when Reason /= normal ->
-    {stop, transition_to_pending(State)};
-c2s_closed(State, _) ->
+c2s_closed(State, {stream, _}) ->
+    State;
+c2s_closed(#{mgmt_state := active} = State, Reason) ->
+    {stop, transition_to_pending(State#{stop_reason => Reason})};
+c2s_closed(State, _Reason) ->
+    State.
+
+c2s_terminated(#{mgmt_state := resumed, jid := JID} = State, _Reason) ->
+    ?INFO_MSG("Closing former stream of resumed session for ~s",
+             [jid:to_string(JID)]),
+    bounce_message_queue(),
+    {stop, State};
+c2s_terminated(#{mgmt_state := MgmtState, mgmt_stanzas_in := In, sid := SID,
+                user := U, server := S, resource := R} = State, _Reason) ->
+    case MgmtState of
+       timeout ->
+           Info = [{num_stanzas_in, In}],
+           ejabberd_sm:set_offline_info(SID, U, S, R, Info);
+       _ ->
+           ok
+    end,
+    route_unacked_stanzas(State),
+    State;
+c2s_terminated(State, _Reason) ->
     State.
 
 %%%===================================================================
@@ -201,17 +255,14 @@ negotiate_stream_mgmt(Pkt, State) ->
     case Pkt of
        #sm_enable{} ->
            handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt);
+       _ when is_record(Pkt, sm_a);
+              is_record(Pkt, sm_r);
+              is_record(Pkt, sm_resume) ->
+           Err = #sm_failed{reason = 'unexpected-request', xmlns = Xmlns},
+           send(State, Err);
        _ ->
-           Res = if is_record(Pkt, sm_a);
-                    is_record(Pkt, sm_r);
-                    is_record(Pkt, sm_resume) ->
-                         #sm_failed{reason = 'unexpected-request',
-                                    xmlns = Xmlns};
-                    true ->
-                         #sm_failed{reason = 'bad-request',
-                                    xmlns = Xmlns}
-                 end,
-           send(State, Res)
+           Err = #sm_failed{reason = 'bad-request', xmlns = Xmlns},
+           send(State, Err)
     end.
 
 -spec perform_stream_mgmt(xmpp_element(), state()) -> state().
@@ -223,16 +274,13 @@ perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) ->
                    handle_r(State);
                #sm_a{} ->
                    handle_a(State, Pkt);
+               _ when is_record(Pkt, sm_enable);
+                      is_record(Pkt, sm_resume) ->
+                   send(State, #sm_failed{reason = 'unexpected-request',
+                                          xmlns = Xmlns});
                _ ->
-                   Res = if is_record(Pkt, sm_enable);
-                            is_record(Pkt, sm_resume) ->
-                                 #sm_failed{reason = 'unexpected-request',
-                                            xmlns = Xmlns};
-                            true ->
-                                 #sm_failed{reason = 'bad-request',
-                                            xmlns = Xmlns}
-                         end,
-                   send(State, Res)
+                   send(State, #sm_failed{reason = 'bad-request',
+                                          xmlns = Xmlns})
            end;
        _ ->
            send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns})
@@ -241,7 +289,7 @@ perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) ->
 -spec handle_enable(state(), sm_enable()) -> state().
 handle_enable(#{mgmt_timeout := DefaultTimeout,
                mgmt_max_timeout := MaxTimeout,
-               xmlns := Xmlns, jid := JID} = State,
+               mgmt_xmlns := Xmlns, jid := JID} = State,
              #sm_enable{resume = Resume, max = Max}) ->
     Timeout = if Resume == false ->
                      0;
@@ -264,7 +312,7 @@ handle_enable(#{mgmt_timeout := DefaultTimeout,
          end,
     State1 = State#{mgmt_state => active,
                    mgmt_queue => queue_new(),
-                   mgmt_timeout => Timeout * 1000},
+                   mgmt_timeout => Timeout},
     send(State1, Res).
 
 -spec handle_r(state()) -> state().
@@ -275,23 +323,26 @@ handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) ->
 -spec handle_a(state(), sm_a()) -> state().
 handle_a(State, #sm_a{h = H}) ->
     State1 = check_h_attribute(State, H),
-    resend_ack(State1).
+    resend_rack(State1).
 
 -spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}.
-handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State,
+handle_resume(#{user := User, lserver := LServer,
+               lang := Lang, socket := Socket} = State,
              #sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) ->
     R = case inherit_session_state(State, PrevID) of
            {ok, InheritedState} ->
                {ok, InheritedState, H};
            {error, Err, InH} ->
                {error, #sm_failed{reason = 'item-not-found',
+                                  text = xmpp:mk_text(Err, Lang),
                                   h = InH, xmlns = Xmlns}, Err};
            {error, Err} ->
                {error, #sm_failed{reason = 'item-not-found',
+                                  text = xmpp:mk_text(Err, Lang),
                                   xmlns = Xmlns}, Err}
        end,
     case R of
-       {ok, ResumedState, NumHandled} ->
+       {ok, #{jid := JID} = ResumedState, NumHandled} ->
            State1 = check_h_attribute(ResumedState, NumHandled),
            #{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1,
            AttrId = make_resume_id(State1),
@@ -307,14 +358,20 @@ handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State,
                      [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
            {ok, State5};
        {error, El, Msg} ->
-           ?INFO_MSG("Cannot resume session for ~s: ~s", [jid:to_string(JID), Msg]),
+           ?INFO_MSG("Cannot resume session for ~s@~s: ~s",
+                     [User, LServer, Msg]),
            {error, send(State, El)}
     end.
 
 -spec transition_to_pending(state()) -> state().
-transition_to_pending(#{mgmt_state := active} = State) ->
-    %% TODO
-    State;
+transition_to_pending(#{mgmt_state := active, jid := JID,
+                       lserver := LServer, mgmt_timeout := Timeout} = State) ->
+    State1 = cancel_ack_timer(State),
+    ?INFO_MSG("Waiting for resumption of stream for ~s", [jid:to_string(JID)]),
+    State2 = ejabberd_hooks:run_fold(c2s_session_pending, LServer, State1, []),
+    State3 = ejabberd_c2s:close(State2, _SendTrailer = false),
+    erlang:start_timer(timer:seconds(Timeout), self(), pending_timeout),
+    State3#{mgmt_state => pending};
 transition_to_pending(State) ->
     State.
 
@@ -345,25 +402,25 @@ update_num_stanzas_in(#{mgmt_state := MgmtState,
 update_num_stanzas_in(State, _El) ->
     State.
 
-send_ack(#{mgmt_ack_timer := _} = State) ->
+send_rack(#{mgmt_ack_timer := _} = State) ->
     State;
-send_ack(#{mgmt_xmlns := Xmlns,
+send_rack(#{mgmt_xmlns := Xmlns,
           mgmt_stanzas_out := NumStanzasOut,
           mgmt_ack_timeout := AckTimeout} = State) ->
     State1 = send(State, #sm_r{xmlns = Xmlns}),
     TRef = erlang:start_timer(AckTimeout, self(), ack_timeout),
     State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}.
 
-resend_ack(#{mgmt_ack_timer := _,
-            mgmt_queue := Queue,
-            mgmt_stanzas_out := NumStanzasOut,
-            mgmt_stanzas_req := NumStanzasReq} = State) ->
+resend_rack(#{mgmt_ack_timer := _,
+             mgmt_queue := Queue,
+             mgmt_stanzas_out := NumStanzasOut,
+             mgmt_stanzas_req := NumStanzasReq} = State) ->
     State1 = cancel_ack_timer(State),
     case NumStanzasReq < NumStanzasOut andalso not queue_is_empty(Queue) of
-       true -> send_ack(State1);
+       true -> send_rack(State1);
        false -> State1
     end;
-resend_ack(State) ->
+resend_rack(State) ->
     State.
 
 -spec mgmt_queue_add(state(), xmpp_element()) -> state().
@@ -492,10 +549,22 @@ inherit_session_state(#{user := U, server := S} = State, ResumeID) ->
                OldPID ->
                    OldSID = {Time, OldPID},
                    try resume_session(OldSID, State) of
-                       {resume, OldState} ->
+                       {resume, #{mgmt_xmlns := Xmlns,
+                                  mgmt_queue := Queue,
+                                  mgmt_timeout := Timeout,
+                                  mgmt_stanzas_in := NumStanzasIn,
+                                  mgmt_stanzas_out := NumStanzasOut} = OldState} ->
                            State1 = ejabberd_c2s:copy_state(State, OldState),
-                           State2 = ejabberd_c2s:open_session(State1),
-                           {ok, State2};
+                           State2 = State1#{mgmt_xmlns => Xmlns,
+                                            mgmt_queue => Queue,
+                                            mgmt_timeout => Timeout,
+                                            mgmt_stanzas_in => NumStanzasIn,
+                                            mgmt_stanzas_out => NumStanzasOut,
+                                            mgmt_state => active},
+                           ejabberd_sm:close_session(OldSID, U, S, R),
+                           State3 = ejabberd_c2s:open_session(State2),
+                           ejabberd_c2s:stop(OldPID),
+                           {ok, State3};
                        {error, Msg} ->
                            {error, Msg}
                    catch exit:{noproc, _} ->
@@ -591,6 +660,15 @@ cancel_ack_timer(#{mgmt_ack_timer := TRef} = State) ->
 cancel_ack_timer(State) ->
     State.
 
+-spec bounce_message_queue() -> ok.
+bounce_message_queue() ->
+    receive {route, From, To, Pkt} ->
+           ejabberd_router:route(From, To, Pkt),
+           bounce_message_queue()
+    after 0 ->
+           ok
+    end.
+
 %%%===================================================================
 %%% Configuration processing
 %%%===================================================================
index e9c1b333903b8b4cdafa88baf0daff7763bdfe7e..a0387064376521f22d2cb2584559f8ec0e0b3569 100644 (file)
@@ -44,7 +44,8 @@
 -type state() :: map().
 -type stop_reason() :: {stream, reset | stream_error()} |
                       {tls, term()} |
-                      {socket, inet:posix() | closed | timeout}.
+                      {socket, inet:posix() | closed | timeout} |
+                      internal_failure.
 
 -callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
 -callback handle_cast(term(), state()) -> state().
@@ -54,7 +55,6 @@
 -callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
 -callback handle_stream_start(state()) -> state().
 -callback handle_stream_end(stop_reason(), state()) -> state().
--callback handle_stream_close(stop_reason(), state()) -> state().
 -callback handle_cdata(binary(), state()) -> state().
 -callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
 -callback handle_authenticated_packet(xmpp_element(), state()) -> state().
@@ -83,7 +83,6 @@
                     code_change/3,
                     handle_stream_start/1,
                     handle_stream_end/2,
-                    handle_stream_close/2,
                     handle_cdata/2,
                     handle_authenticated_packet/2,
                     handle_unauthenticated_packet/2,
@@ -193,6 +192,8 @@ format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
     format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
 format_error({tls, Reason}) ->
     format("TLS failed: ~w", [Reason]);
+format_error(internal_failure) ->
+    <<"Internal server error">>;
 format_error(Err) ->
     format("Unrecognized error: ~w", [Err]).
 
@@ -263,75 +264,78 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
            #{stream_state := wait_for_stream,
              xmlns := XMLNS, lang := MyLang} = State) ->
     El = #xmlel{name = Name, attrs = Attrs},
-    try xmpp:decode(El, XMLNS, []) of
-       #stream_start{} = Pkt ->
-           State1 = send_header(State, Pkt),
-           case is_disconnected(State1) of
-               true -> State1;
-               false -> noreply(process_stream(Pkt, State1))
-           end;
-       _ ->
-           State1 = send_header(State),
-           case is_disconnected(State1) of
-               true -> State1;
-               false -> noreply(send_element(State1, xmpp:serr_invalid_xml()))
-           end
-    catch _:{xmpp_codec, Why} ->
-           State1 = send_header(State),
-           case is_disconnected(State1) of
-               true -> State1;
-               false ->
-                   Txt = xmpp:io_format_error(Why),
-                   Lang = select_lang(MyLang, xmpp:get_lang(El)),
-                   Err = xmpp:serr_invalid_xml(Txt, Lang),
-                   noreply(send_element(State1, Err))
-           end
-    end;
+    noreply(
+      try xmpp:decode(El, XMLNS, []) of
+         #stream_start{} = Pkt ->
+             State1 = send_header(State, Pkt),
+             case is_disconnected(State1) of
+                 true -> State1;
+                 false -> process_stream(Pkt, State1)
+             end;
+         _ ->
+             State1 = send_header(State),
+             case is_disconnected(State1) of
+                 true -> State1;
+                 false -> send_element(State1, xmpp:serr_invalid_xml())
+             end
+      catch _:{xmpp_codec, Why} ->
+             State1 = send_header(State),
+             case is_disconnected(State1) of
+                 true -> State1;
+                 false ->
+                     Txt = xmpp:io_format_error(Why),
+                     Lang = select_lang(MyLang, xmpp:get_lang(El)),
+                     Err = xmpp:serr_invalid_xml(Txt, Lang),
+                     send_element(State1, Err)
+             end
+      end);
 handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
     State1 = send_header(State),
-    case is_disconnected(State1) of
-       true -> State1;
-       false ->
-           Err = case Reason of
-                     <<"XML stanza is too big">> ->
-                         xmpp:serr_policy_violation(Reason, Lang);
-                     _ ->
-                         xmpp:serr_not_well_formed()
-                 end,
-           noreply(send_element(State1, Err))
-    end;
+    noreply(
+      case is_disconnected(State1) of
+         true -> State1;
+         false ->
+             Err = case Reason of
+                       <<"XML stanza is too big">> ->
+                           xmpp:serr_policy_violation(Reason, Lang);
+                       _ ->
+                           xmpp:serr_not_well_formed()
+                   end,
+             send_element(State1, Err)
+      end);
 handle_info({'$gen_event', {xmlstreamelement, El}},
            #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
-    try xmpp:decode(El, NS, [ignore_els]) of
-       Pkt ->
-           State1 = try Mod:handle_recv(El, Pkt, State)
-                    catch _:undef -> State
-                    end,
-           case is_disconnected(State1) of
-               true -> State1;
-               false -> noreply(process_element(Pkt, State1))
-           end
-    catch _:{xmpp_codec, Why} ->
-           State1 = try Mod:handle_recv(El, {error, Why}, State)
-                    catch _:undef -> State
-                    end,
-           case is_disconnected(State1) of
-               true -> State1;
-               false ->
-                   Txt = xmpp:io_format_error(Why),
-                   Lang = select_lang(MyLang, xmpp:get_lang(El)),
-                   noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
-           end
-    end;
+    noreply(
+      try xmpp:decode(El, NS, [ignore_els]) of
+         Pkt ->
+             State1 = try Mod:handle_recv(El, Pkt, State)
+                      catch _:undef -> State
+                      end,
+             case is_disconnected(State1) of
+                 true -> State1;
+                 false -> process_element(Pkt, State1)
+             end
+      catch _:{xmpp_codec, Why} ->
+             State1 = try Mod:handle_recv(El, {error, Why}, State)
+                      catch _:undef -> State
+                      end,
+             case is_disconnected(State1) of
+                 true -> State1;
+                 false ->
+                     Txt = xmpp:io_format_error(Why),
+                     Lang = select_lang(MyLang, xmpp:get_lang(El)),
+                     send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
+             end
+      end);
 handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
            #{mod := Mod} = State) ->
     noreply(try Mod:handle_cdata(Data, State)
            catch _:undef -> State
            end);
 handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
-    noreply(process_stream_end({error, {stream, reset}}, State));
+    noreply(process_stream_end({stream, reset}, State));
 handle_info({'$gen_event', closed}, State) ->
-    noreply(process_stream_close({error, {socket, closed}}, State));
+    noreply(process_stream_end({socket, closed}, State));
 handle_info(timeout, #{mod := Mod} = State) ->
     Disconnected = is_disconnected(State),
     noreply(try Mod:handle_timeout(State)
@@ -342,7 +346,7 @@ handle_info(timeout, #{mod := Mod} = State) ->
            end);
 handle_info({'DOWN', MRef, _Type, _Object, _Info},
            #{socket_monitor := MRef} = State) ->
-    noreply(process_stream_close({error, {socket, closed}}, State));
+    noreply(process_stream_end({socket, closed}, State));
 handle_info(Info, #{mod := Mod} = State) ->
     noreply(try Mod:handle_info(Info, State)
            catch _:undef -> State
@@ -390,15 +394,6 @@ peername(SockMod, Socket) ->
        _ -> SockMod:peername(Socket)
     end.
 
--spec process_stream_close(stop_reason(), state()) -> state().
-process_stream_close(_, #{stream_state := disconnected} = State) ->
-    State;
-process_stream_close(Reason, #{mod := Mod} = State) ->
-    State1 = send_trailer(State),
-    try Mod:handle_stream_close(Reason, State1)
-    catch _:undef -> stop(State1)
-    end.
-
 -spec process_stream_end(stop_reason(), state()) -> state().
 process_stream_end(_, #{stream_state := disconnected} = State) ->
     State;
@@ -414,6 +409,8 @@ process_stream(#stream_start{xmlns = XML_NS,
               #{xmlns := NS} = State)
   when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
     send_element(State, xmpp:serr_invalid_namespace());
+process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
+    send_element(State, xmpp:serr_unsupported_version());
 process_stream(#stream_start{lang = Lang},
               #{xmlns := ?NS_CLIENT, lang := DefaultLang} = State)
   when size(Lang) > 35 ->
@@ -520,7 +517,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
        #handshake{} ->
            State;
        #stream_error{} ->
-           process_stream_end({error, {stream, Pkt}}, State);
+           process_stream_end({stream, Pkt}, State);
        _ when StateName == wait_for_sasl_request;
               StateName == wait_for_handshake;
               StateName == wait_for_sasl_response ->
@@ -704,7 +701,7 @@ process_starttls_failure(Why, State) ->
     State1 = send_element(State, #starttls_failure{}),
     case is_disconnected(State1) of
        true -> State1;
-       false -> process_stream_end({error, {tls, Why}}, State1)
+       false -> process_stream_end({tls, Why}, State1)
     end.
 
 -spec process_sasl_request(sasl_auth(), state()) -> state().
@@ -939,8 +936,8 @@ set_from_to(Pkt, #{lang := Lang}) ->
     end.
 
 -spec send_header(state()) -> state().
-send_header(State) ->
-    send_header(State, #stream_start{}).
+send_header(#{stream_version := Version} = State) ->
+    send_header(State, #stream_start{version = Version}).
 
 -spec send_header(state(), stream_start()) -> state().
 send_header(#{stream_id := StreamID,
@@ -959,8 +956,9 @@ send_header(#{stream_id := StreamID,
               undefined -> jid:make(DefaultServer)
           end,
     Version = case HisVersion of
-                 undefined -> MyVersion;
-                 _ -> HisVersion
+                 undefined -> undefined;
+                 {0,_} -> HisVersion;
+                 _ -> MyVersion
              end,
     Header = xmpp:encode(#stream_start{version = Version,
                                       lang = Lang,
@@ -969,10 +967,12 @@ send_header(#{stream_id := StreamID,
                                       db_xmlns = NS_DB,
                                       id = StreamID,
                                       from = From}),
-    State1 = State#{lang => Lang, stream_header_sent => true},
+    State1 = State#{lang => Lang,
+                   stream_version => Version,
+                   stream_header_sent => true},
     case send_text(State1, fxml:element_to_header(Header)) of
        ok -> State1;
-       {error, Why} -> process_stream_close({error, {socket, Why}}, State1)
+       {error, Why} -> process_stream_end({socket, Why}, State1)
     end;
 send_header(State, _) ->
     State.
@@ -987,11 +987,11 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
             end,
     case Result of
        _ when is_record(Pkt, stream_error) ->
-           process_stream_end({error, {stream, Pkt}}, State1);
+           process_stream_end({stream, Pkt}, State1);
        ok ->
            State1;
        {error, Why} ->
-           process_stream_close({error, {socket, Why}}, State1)
+           process_stream_end({socket, Why}, State1)
     end.
 
 -spec send_error(state(), xmpp_element(), stanza_error()) -> state().
@@ -1022,7 +1022,7 @@ send_text(#{socket := Sock, sockmod := SockMod,
            stream_header_sent := true}, Data) when StateName /= disconnected ->
     SockMod:send(Sock, Data);
 send_text(_, _) ->
-    {error, einval}.
+    {error, closed}.
 
 -spec close_socket(state()) -> state().
 close_socket(#{sockmod := SockMod, socket := Socket} = State) ->
index fc373fff84ad1874d87fb48c5cd8756a7c0facf3..08804e43282927e29cb06bf6dea763b871ead57b 100644 (file)
@@ -33,6 +33,7 @@
 -include_lib("kernel/include/inet.hrl").
 
 -type state() :: map().
+-type noreply() :: {noreply, state(), timeout()}.
 -type host_port() :: {inet:hostname(), inet:port_number()}.
 -type ip_port() :: {inet:ip_address(), inet:port_number()}.
 -type network_error() :: {error, inet:posix() | inet_res:res_error()}.
@@ -42,7 +43,8 @@
                       {tls, term()} |
                       {pkix, binary()} |
                       {auth, atom() | binary() | string()} |
-                      {socket, inet:posix() | closed | timeout}.
+                      {socket, inet:posix() | closed | timeout} |
+                      internal_failure.
 
 -callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
 
@@ -107,7 +109,7 @@ close(_, _) ->
 establish(State) ->
     process_stream_established(State).
 
--spec set_timeout(state(), non_neg_integer() | infinity) -> state().
+-spec set_timeout(state(), timeout()) -> state().
 set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
     case Timeout of
        infinity -> State#{stream_timeout => infinity};
@@ -148,12 +150,15 @@ format_error({tls, Reason}) ->
     format("TLS failed: ~w", [Reason]);
 format_error({auth, Reason}) ->
     format("Authentication failed: ~s", [Reason]);
+format_error(internal_failure) ->
+    <<"Internal server error">>;
 format_error(Err) ->
     format("Unrecognized error: ~w", [Err]).
 
 %%%===================================================================
 %%% gen_server callbacks
 %%%===================================================================
+-spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore.
 init([Mod, SockMod, From, To, Opts]) ->
     Time = p1_time_compat:monotonic_time(milli_seconds),
     State = #{owner => self(),
@@ -183,36 +188,38 @@ init([Mod, SockMod, From, To, Opts]) ->
            Err
     end.
 
+-spec handle_call(term(), term(), state()) -> noreply().
 handle_call(Call, From, #{mod := Mod} = State) ->
     noreply(try Mod:handle_call(Call, From, State)
            catch _:undef -> State
            end).
 
+-spec handle_cast(term(), state()) -> noreply().
 handle_cast(connect, #{remote_server := RemoteServer,
                       sockmod := SockMod,
                       stream_state := connecting} = State) ->
-    case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of
-       false ->
-           noreply(process_stream_close({error, {idna, bad_string}}, State));
-       ASCIIName ->
-           case resolve(binary_to_list(ASCIIName), State) of
-               {ok, AddrPorts} ->
-                   case connect(AddrPorts, State) of
-                       {ok, Socket, AddrPort} ->
-                           SocketMonitor = SockMod:monitor(Socket),
-                           State1 = State#{ip => AddrPort,
-                                           socket => Socket,
-                                           socket_monitor => SocketMonitor},
-                           State2 = State1#{stream_state => wait_for_stream},
-                           noreply(send_header(State2));
-                       {error, Why} ->
-                           Err = {error, {socket, Why}},
-                           noreply(process_stream_close(Err, State))
-                   end;
-               {error, Why} ->
-                   noreply(process_stream_close({error, {dns, Why}}, State))
-           end
-    end;
+    noreply(
+      case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of
+         false ->
+             process_stream_end({idna, bad_string}, State);
+         ASCIIName ->
+             case resolve(binary_to_list(ASCIIName), State) of
+                 {ok, AddrPorts} ->
+                     case connect(AddrPorts, State) of
+                         {ok, Socket, AddrPort} ->
+                             SocketMonitor = SockMod:monitor(Socket),
+                             State1 = State#{ip => AddrPort,
+                                             socket => Socket,
+                                             socket_monitor => SocketMonitor},
+                             State2 = State1#{stream_state => wait_for_stream},
+                             send_header(State2);
+                         {error, Why} ->
+                             process_stream_end({socket, Why}, State)
+                     end;
+                 {error, Why} ->
+                     process_stream_end({dns, Why}, State)
+             end
+      end);
 handle_cast(connect, State) ->
     %% Ignoring connection attempts in other states
     noreply(State);
@@ -225,66 +232,70 @@ handle_cast(Cast, #{mod := Mod} = State) ->
            catch _:undef -> State
            end).
 
+-spec handle_info(term(), state()) -> noreply().
 handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
            #{stream_state := wait_for_stream,
              xmlns := XMLNS, lang := MyLang} = State) ->
     El = #xmlel{name = Name, attrs = Attrs},
-    try xmpp:decode(El, XMLNS, []) of
-       #stream_start{} = Pkt ->
-           noreply(process_stream(Pkt, State));
-       _ ->
-           noreply(send_element(State, xmpp:serr_invalid_xml()))
-    catch _:{xmpp_codec, Why} ->
-           Txt = xmpp:io_format_error(Why),
-           Lang = select_lang(MyLang, xmpp:get_lang(El)),
-           Err = xmpp:serr_invalid_xml(Txt, Lang),
-           noreply(send_element(State, Err))
-    end;
+    noreply(
+      try xmpp:decode(El, XMLNS, []) of
+         #stream_start{} = Pkt ->
+             process_stream(Pkt, State);
+         _ ->
+             send_element(State, xmpp:serr_invalid_xml())
+      catch _:{xmpp_codec, Why} ->
+             Txt = xmpp:io_format_error(Why),
+             Lang = select_lang(MyLang, xmpp:get_lang(El)),
+             Err = xmpp:serr_invalid_xml(Txt, Lang),
+             send_element(State, Err)
+      end);
 handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
     State1 = send_header(State),
-    case is_disconnected(State1) of
-       true -> State1;
-       false ->
-           Err = case Reason of
-                     <<"XML stanza is too big">> ->
-                         xmpp:serr_policy_violation(Reason, Lang);
-                     _ ->
-                         xmpp:serr_not_well_formed()
-                 end,
-           noreply(send_element(State1, Err))
-    end;
+    noreply(
+      case is_disconnected(State1) of
+         true -> State1;
+         false ->
+             Err = case Reason of
+                       <<"XML stanza is too big">> ->
+                           xmpp:serr_policy_violation(Reason, Lang);
+                       _ ->
+                           xmpp:serr_not_well_formed()
+                   end,
+             send_element(State1, Err)
+      end);
 handle_info({'$gen_event', {xmlstreamelement, El}},
            #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
-    try xmpp:decode(El, NS, [ignore_els]) of
-       Pkt ->
-           State1 = try Mod:handle_recv(El, Pkt, State)
-                    catch _:undef -> State
-                    end,
-           case is_disconnected(State1) of
-               true -> State1;
-               false -> noreply(process_element(Pkt, State1))
-           end
-    catch _:{xmpp_codec, Why} ->
-           State1 = try Mod:handle_recv(El, undefined, State)
-                    catch _:undef -> State
-                    end,
-           case is_disconnected(State1) of
-               true -> State1;
-               false ->
-                   Txt = xmpp:io_format_error(Why),
-                   Lang = select_lang(MyLang, xmpp:get_lang(El)),
-                   noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
-           end
-    end;
+    noreply(
+      try xmpp:decode(El, NS, [ignore_els]) of
+         Pkt ->
+             State1 = try Mod:handle_recv(El, Pkt, State)
+                      catch _:undef -> State
+                      end,
+             case is_disconnected(State1) of
+                 true -> State1;
+                 false -> process_element(Pkt, State1)
+             end
+      catch _:{xmpp_codec, Why} ->
+             State1 = try Mod:handle_recv(El, undefined, State)
+                      catch _:undef -> State
+                      end,
+             case is_disconnected(State1) of
+                 true -> State1;
+                 false ->
+                     Txt = xmpp:io_format_error(Why),
+                     Lang = select_lang(MyLang, xmpp:get_lang(El)),
+                     send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
+             end
+      end);
 handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
            #{mod := Mod} = State) ->
     noreply(try Mod:handle_cdata(Data, State)
            catch _:undef -> State
            end);
 handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
-    noreply(process_stream_end({error, {stream, reset}}, State));
+    noreply(process_stream_end({stream, reset}, State));
 handle_info({'$gen_event', closed}, State) ->
-    noreply(process_stream_close({error, {socket, closed}}, State));
+    noreply(process_stream_end({socket, closed}, State));
 handle_info(timeout, #{mod := Mod} = State) ->
     Disconnected = is_disconnected(State),
     noreply(try Mod:handle_timeout(State)
@@ -295,12 +306,13 @@ handle_info(timeout, #{mod := Mod} = State) ->
            end);
 handle_info({'DOWN', MRef, _Type, _Object, _Info},
            #{socket_monitor := MRef} = State) ->
-    noreply(process_stream_close({error, {socket, closed}}, State));
+    noreply(process_stream_end({socket, closed}, State));
 handle_info(Info, #{mod := Mod} = State) ->
     noreply(try Mod:handle_info(Info, State)
            catch _:undef -> State
            end).
 
+-spec terminate(term(), state()) -> any().
 terminate(Reason, #{mod := Mod} = State) ->
     case get(already_terminated) of
        true ->
@@ -319,7 +331,7 @@ code_change(OldVsn, #{mod := Mod} = State, Extra) ->
 %%%===================================================================
 %%% Internal functions
 %%%===================================================================
--spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
+-spec noreply(state()) -> noreply().
 noreply(#{stream_timeout := infinity} = State) ->
     {noreply, State, infinity};
 noreply(#{stream_timeout := {MSecs, OldTime}} = State) ->
@@ -335,15 +347,6 @@ new_id() ->
 is_disconnected(#{stream_state := StreamState}) ->
     StreamState == disconnected.
 
--spec process_stream_close(stop_reason(), state()) -> state().
-process_stream_close(_, #{stream_state := disconnected} = State) ->
-    State;
-process_stream_close(Reason, #{mod := Mod} = State) ->
-    State1 = send_trailer(State),
-    try Mod:handle_stream_close(Reason, State1)
-    catch _:undef -> stop(State1)
-    end.
-
 -spec process_stream_end(stop_reason(), state()) -> state().
 process_stream_end(_, #{stream_state := disconnected} = State) ->
     State;
@@ -359,6 +362,8 @@ process_stream(#stream_start{xmlns = XML_NS,
               #{xmlns := NS} = State)
   when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
     send_element(State, xmpp:serr_invalid_namespace());
+process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
+    send_element(State, xmpp:serr_unsupported_version());
 process_stream(#stream_start{lang = Lang, id = ID,
                             version = Version} = StreamStart,
               #{mod := Mod} = State) ->
@@ -370,8 +375,10 @@ process_stream(#stream_start{lang = Lang, id = ID,
        true -> State2;
        false ->
            case Version of
-               {1,0} -> State2#{stream_state => wait_for_features};
-               _ -> process_stream_downgrade(StreamStart, State)
+               {1, _} ->
+                   State2#{stream_state => wait_for_features};
+               _ ->
+                   process_stream_downgrade(StreamStart, State2)
            end
     end.
 
@@ -387,7 +394,7 @@ process_element(Pkt, #{stream_state := StateName} = State) ->
        #sasl_failure{} when StateName == wait_for_sasl_response ->
            process_sasl_failure(Pkt, State);
        #stream_error{} ->
-           process_stream_end({error, {stream, Pkt}}, State);
+           process_stream_end({stream, Pkt}, State);
        _ when is_record(Pkt, stream_features);
               is_record(Pkt, starttls_proceed);
               is_record(Pkt, starttls);
@@ -487,14 +494,23 @@ process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) ->
                            stream_encrypted => true},
            send_header(State1);
        {error, Why} ->
-           process_stream_close({error, {tls, Why}}, State)
+           process_stream_end({tls, Why}, State)
     end.
 
 -spec process_stream_downgrade(stream_start(), state()) -> state().
-process_stream_downgrade(StreamStart, #{mod := Mod} = State) ->
-    try Mod:downgrade_stream(StreamStart, State)
-    catch _:undef ->
-           send_element(State, xmpp:serr_unsupported_version())
+process_stream_downgrade(StreamStart,
+                        #{mod := Mod, lang := Lang,
+                          stream_encrypted := Encrypted} = State) ->
+    TLSRequired = is_starttls_required(State),
+    if not Encrypted and TLSRequired ->
+           Txt = <<"Use of STARTTLS required">>,
+           send_element(State, xmpp:err_policy_violation(Txt, Lang));
+       true ->
+           State1 = State#{stream_state => downgraded},
+           try Mod:handle_stream_downgraded(StreamStart, State1)
+           catch _:undef ->
+                   send_element(State1, xmpp:serr_unsupported_version())
+           end
     end.
 
 -spec process_cert_verification(state()) -> state().
@@ -509,7 +525,7 @@ process_cert_verification(#{stream_encrypted := true,
                {ok, _} ->
                    State#{stream_verified => true};
                {error, Why, _Peer} ->
-                   process_stream_close({error, {pkix, Why}}, State)
+                   process_stream_end({pkix, Why}, State)
            end;
        false ->
            State#{stream_verified => true}
@@ -538,7 +554,7 @@ process_sasl_success(#{mod := Mod,
 -spec process_sasl_failure(sasl_failure(), state()) -> state().
 process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) ->
     try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State)
-    catch _:undef -> process_stream_close({error, {auth, Reason}}, State)
+    catch _:undef -> process_stream_end({auth, Reason}, State)
     end.
 
 -spec process_packet(xmpp_element(), state()) -> state().
@@ -581,7 +597,7 @@ send_header(#{remote_server := RemoteServer,
                             version = {1,0}}),
     case send_text(State, fxml:element_to_header(Header)) of
        ok -> State;
-       {error, Why} -> process_stream_close({error, {socket, Why}}, State)
+       {error, Why} -> process_stream_end({socket, Why}, State)
     end.
 
 -spec send_element(state(), xmpp_element()) -> state().
@@ -596,11 +612,11 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
        false ->
            case send_text(State1, Data) of
                _ when is_record(Pkt, stream_error) ->
-                   process_stream_end({error, {stream, Pkt}}, State1);
+                   process_stream_end({stream, Pkt}, State1);
                ok ->
                    State1;
                {error, Why} ->
-                   process_stream_close({error, {socket, Why}}, State1)
+                   process_stream_end({socket, Why}, State1)
            end
     end.
 
@@ -626,7 +642,7 @@ send_text(#{sockmod := SockMod, socket := Socket,
            stream_state := StateName}, Data) when StateName /= disconnected ->
     SockMod:send(Socket, Data);
 send_text(_, _) ->
-    {error, einval}.
+    {error, closed}.
 
 -spec send_trailer(state()) -> state().
 send_trailer(State) ->