]> granicus.if.org Git - ejabberd/commitdiff
Add xmpp_stream_out behaviour and rewrite s2s/SM code
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Wed, 28 Dec 2016 06:47:11 +0000 (09:47 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Wed, 28 Dec 2016 06:47:11 +0000 (09:47 +0300)
28 files changed:
include/ejabberd.hrl
src/cyrsasl.erl
src/ejabberd_app.erl
src/ejabberd_auth.erl
src/ejabberd_c2s.erl
src/ejabberd_config.erl
src/ejabberd_hooks.erl
src/ejabberd_listener.erl
src/ejabberd_router.erl
src/ejabberd_s2s.erl
src/ejabberd_s2s_in.erl
src/ejabberd_s2s_out.erl
src/ejabberd_service.erl
src/ejabberd_sm.erl
src/ejabberd_socket.erl
src/jlib.erl
src/mod_blocking.erl
src/mod_legacy_auth.erl
src/mod_offline.erl
src/mod_privacy.erl
src/mod_pubsub.erl
src/mod_register.erl
src/mod_roster.erl
src/mod_s2s_dialback.erl [new file with mode: 0644]
src/mod_sm.erl [new file with mode: 0644]
src/xmpp_stream_in.erl
src/xmpp_stream_out.erl [new file with mode: 0644]
src/xmpp_stream_pkix.erl [new file with mode: 0644]

index 391089a0e7d92a08b2d39459a4140b51e32b6771..ddf41f0946caa092015ad25baddb170e58268404 100644 (file)
@@ -41,8 +41,6 @@
 
 -define(COPYRIGHT, "Copyright (c) 2002-2016 ProcessOne").
 
--define(S2STIMEOUT, timer:minutes(10)).
-
 %%-define(DBGFSM, true).
 
 -record(scram,
index c49f8a3cbc51066137bce34d0b06f913f1b5f370..874a417a17a98edee1eb6abe119e9841e98803e9 100644 (file)
 -include("ejabberd.hrl").
 -include("logger.hrl").
 
-%%
--export_type([
-    mechanism/0,
-    mechanisms/0,
-    sasl_mechanism/0
-]).
-
 -record(sasl_mechanism,
         {mechanism = <<"">>    :: mechanism() | '$1',
          module                :: atom(),
 -type(mechanism() :: binary()).
 -type(mechanisms() :: [mechanism(),...]).
 -type(password_type() :: plain | digest | scram).
--type(props() :: [{username, binary()} |
-                  {authzid, binary()} |
-                 {mechanism, binary()} |
-                  {auth_module, atom()}]).
+-type sasl_property() :: {username, binary()} |
+                        {authzid, binary()} |
+                        {mechanism, binary()} |
+                        {auth_module, atom()}.
+-type sasl_return() :: {ok, [sasl_property()]} |
+                      {ok, [sasl_property()], binary()} |
+                      {continue, binary(), any()} |
+                      {error, atom()} |
+                      {error, atom(), binary()}.
 
 -type(sasl_mechanism() :: #sasl_mechanism{}).
 
     mech_state
 }).
 -type sasl_state() :: #sasl_state{}.
--export_type([sasl_state/0]).
+-export_type([mechanism/0, mechanisms/0, sasl_mechanism/0,
+             sasl_state/0, sasl_return/0, sasl_property/0]).
 
 -callback mech_new(binary(), fun(), fun(), fun()) -> any().
--callback mech_step(any(), binary()) -> {ok, props()} |
-                                        {ok, props(), binary()} |
-                                        {continue, binary(), any()} |
-                                        {error, atom()} |
-                                        {error, atom(), binary()}.
+-callback mech_step(any(), binary()) -> sasl_return().
 
 start() ->
     ets:new(sasl_mechanism,
index e4333c8168e01f293aa1da91b041c5b5cfe813a9..eb25fe65689ba32523ea4aecdb16e5ac5811d01b 100644 (file)
@@ -169,7 +169,7 @@ broadcast_c2s_shutdown() ->
     Children = ejabberd_sm:get_all_pids(),
     lists:foreach(
       fun(C2SPid) when node(C2SPid) == node() ->
-             C2SPid ! system_shutdown;
+             ejabberd_c2s:send(C2SPid, xmpp:serr_system_shutdown());
         (_) ->
              ok
       end, Children).
index 74c8009c2bfdde98834eb1c21a7e6e04a2d518f2..eba0a403855f7461d7c5ae2ac960238111e502af 100644 (file)
@@ -42,7 +42,7 @@
         get_password_s/2, get_password_with_authmodule/2,
         is_user_exists/2, is_user_exists_in_other_modules/3,
         remove_user/2, remove_user/3, plain_password_required/1,
-        store_type/1, entropy/1]).
+        store_type/1, entropy/1, backend_type/1]).
 
 -export([auth_modules/1, opt_type/1]).
 
@@ -412,6 +412,13 @@ entropy(B) ->
          length(S) * math:log(lists:sum(Set)) / math:log(2)
     end.
 
+-spec backend_type(atom()) -> atom().
+backend_type(Mod) ->
+    case atom_to_list(Mod) of
+       "ejabberd_auth_" ++ T -> list_to_atom(T);
+       _ -> Mod
+    end.
+
 %%%----------------------------------------------------------------------
 %%% Internal functions
 %%%----------------------------------------------------------------------
index b5113c34bf943659b1e7d6ebcd04b08920bedba2..07d04fbc486b37652285b83379e57b9340c267e4 100644 (file)
 -module(ejabberd_c2s).
 -behaviour(xmpp_stream_in).
 -behaviour(ejabberd_config).
+-behaviour(ejabberd_socket).
 
 -protocol({rfc, 6121}).
 
 %% ejabberd_socket callbacks
--export([start/2, socket_type/0]).
+-export([start/2, start_link/2, socket_type/0]).
 %% ejabberd_config callbacks
 -export([opt_type/1, transform_listen_option/2]).
 %% xmpp_stream_in callbacks
 -export([init/1, handle_call/3, handle_cast/2,
         handle_info/2, terminate/2, code_change/3]).
--export([tls_options/1, tls_required/1, compress_methods/1,
-        sasl_mechanisms/1, init_sasl/1, bind/2, handshake/2,
+-export([tls_options/1, tls_required/1, tls_verify/1,
+        compress_methods/1, bind/2, get_password_fun/1,
+        check_password_fun/1, check_password_digest_fun/1,
         unauthenticated_stream_features/1, authenticated_stream_features/1,
-        handle_stream_start/1, handle_stream_end/1, handle_stream_close/1,
+        handle_stream_start/2, handle_stream_end/2, handle_stream_close/2,
         handle_unauthenticated_packet/2, handle_authenticated_packet/2,
-        handle_auth_success/4, handle_auth_failure/4, handle_send/5,
-        handle_unbinded_packet/2, handle_cdata/2]).
+        handle_auth_success/4, handle_auth_failure/4, handle_send/3,
+        handle_recv/3, handle_cdata/2, handle_unbinded_packet/2]).
+%% Hooks
+-export([handle_unexpected_info/2, handle_unexpected_cast/2,
+        reject_unauthenticated_packet/2, process_closed/2]).
 %% API
 -export([get_presence/1, get_subscription/2, get_subscribed/1,
-        send/2, close/1]).
+        open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1,
+        copy_state/2, add_hooks/0]).
 
 -include("ejabberd.hrl").
 -include("xmpp.hrl").
 
 -define(SETS, gb_sets).
 
-%%-define(DBGFSM, true).
--ifdef(DBGFSM).
--define(FSMOPTS, [{debug, [trace]}]).
--else.
--define(FSMOPTS, []).
--endif.
-
 -type state() :: map().
--type next_state() :: {noreply, state()} | {stop, term(), state()}.
--export_type([state/0, next_state/0]).
+-export_type([state/0]).
 
 %%%===================================================================
 %%% ejabberd_socket API
 %%%===================================================================
 start(SockData, Opts) ->
     xmpp_stream_in:start(?MODULE, [SockData, Opts],
-                        fsm_limit_opts(Opts) ++ ?FSMOPTS).
+                        ejabberd_config:fsm_limit_opts(Opts)).
+
+start_link(SockData, Opts) ->
+    xmpp_stream_in:start_link(?MODULE, [SockData, Opts],
+                             ejabberd_config:fsm_limit_opts(Opts)).
 
 socket_type() ->
     xml_stream.
 
+-spec call(pid(), term(), non_neg_integer() | infinity) -> term().
+call(Ref, Msg, Timeout) ->
+    xmpp_stream_in:call(Ref, Msg, Timeout).
+
 -spec get_presence(pid()) -> presence().
 get_presence(Ref) ->
-    xmpp_stream_in:call(Ref, get_presence, 1000).
+    call(Ref, get_presence, 1000).
 
 -spec get_subscription(jid() | ljid(), state()) -> both | from | to | none.
 get_subscription(#jid{} = From, State) ->
@@ -90,15 +96,85 @@ get_subscription(LFrom, #{pres_f := PresF, pres_t := PresT}) ->
 -spec get_subscribed(pid()) -> [ljid()].
 %% Return list of all available resources of contacts
 get_subscribed(Ref) ->
-    xmpp_stream_in:call(Ref, get_subscribed, 1000).
+    call(Ref, get_subscribed, 1000).
 
--spec close(pid()) -> ok.
 close(Ref) ->
-    xmpp_stream_in:cast(Ref, closed).
+    xmpp_stream_in:close(Ref).
+
+close(Ref, SendTrailer) ->
+    xmpp_stream_in:close(Ref, SendTrailer).
+
+stop(Ref) ->
+    xmpp_stream_in:stop(Ref).
+
+-spec send(pid(), xmpp_element()) -> ok;
+         (state(), xmpp_element()) -> state().
+send(Pid, Pkt) when is_pid(Pid) ->
+    xmpp_stream_in:send(Pid, Pkt);
+send(#{lserver := LServer} = State, Pkt) ->
+    case ejabberd_hooks:run_fold(c2s_filter_send, LServer, Pkt, [State]) of
+       drop -> State;
+       Pkt1 -> xmpp_stream_in:send(State, Pkt1)
+    end.
+
+-spec establish(state()) -> state().
+establish(State) ->
+    xmpp_stream_in:establish(State).
+
+-spec add_hooks() -> ok.
+add_hooks() ->
+    lists:foreach(
+      fun(Host) ->
+             ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100),
+             ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
+                                reject_unauthenticated_packet, 100),
+             ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
+                                handle_unexpected_info, 100),
+             ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE,
+                                handle_unexpected_cast, 100)
+             
+      end, ?MYHOSTS).
+
+%% Copies content of one c2s state to another.
+%% This is needed for session migration from one pid to another.
+-spec copy_state(state(), state()) -> state().
+copy_state(#{owner := Owner} = NewState,
+         #{jid := JID, resource := Resource, sid := {Time, _},
+           auth_module := AuthModule, lserver := LServer,
+           pres_t := PresT, pres_a := PresA,
+           pres_f := PresF} = OldState) ->
+    State1 = case OldState of
+                #{pres_last := Pres, pres_timestamp := PresTS} ->
+                    NewState#{pres_last => Pres, pres_timestamp => PresTS};
+                _ ->
+                    NewState
+            end,
+    Conn = get_conn_type(State1),
+    State2 = State1#{jid => JID, resource => Resource,
+                    conn => Conn,
+                    sid => {Time, Owner},
+                    auth_module => AuthModule,
+                    pres_t => PresT, pres_a => PresA,
+                    pres_f => PresF},
+    ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]).
+
+%%%===================================================================
+%%% Hooks
+%%%===================================================================
+handle_unexpected_info(State, Info) ->
+    ?WARNING_MSG("got unexpected info: ~p", [Info]),
+    State.
 
--spec send(state(), xmpp_element()) -> next_state().
-send(State, Pkt) ->
-    xmpp_stream_in:send(State, Pkt).
+handle_unexpected_cast(State, Msg) ->
+    ?WARNING_MSG("got unexpected cast: ~p", [Msg]),
+    State.
+
+reject_unauthenticated_packet(State, Pkt) ->
+    Err = xmpp:err_not_authorized(),
+    xmpp_stream_in:send_error(State, Pkt, Err).
+
+process_closed(State, _Reason) ->
+    stop(State).
 
 %%%===================================================================
 %%% xmpp_stream_in callbacks
@@ -115,128 +191,158 @@ tls_options(#{lserver := LServer, tls_options := TLSOpts}) ->
 tls_required(#{tls_required := TLSRequired}) ->
     TLSRequired.
 
+tls_verify(#{tls_verify := TLSVerify}) ->
+    TLSVerify.
+
 compress_methods(#{zlib := true}) ->
     [<<"zlib">>];
 compress_methods(_) ->
     [].
 
-sasl_mechanisms(#{lserver := LServer}) ->
-    cyrsasl:listmech(LServer).
-
 unauthenticated_stream_features(#{lserver := LServer}) ->
     ejabberd_hooks:run_fold(c2s_pre_auth_features, LServer, [], [LServer]).
 
 authenticated_stream_features(#{lserver := LServer}) ->
     ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]).
 
-init_sasl(#{lserver := LServer}) ->
-    cyrsasl:server_new(
-      <<"jabber">>, LServer, <<"">>, [],
-      fun(U) ->
-             ejabberd_auth:get_password_with_authmodule(U, LServer)
-      end,
-      fun(U, AuthzId, P) ->
-             ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P)
-      end,
-      fun(U, AuthzId, P, D, DG) ->
-             ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG)
-      end).
+get_password_fun(#{lserver := LServer}) ->
+    fun(U) ->
+           ejabberd_auth:get_password_with_authmodule(U, LServer)
+    end.
+
+check_password_fun(#{lserver := LServer}) ->
+    fun(U, AuthzId, P) ->
+           ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P)
+    end.
+
+check_password_digest_fun(#{lserver := LServer}) ->
+    fun(U, AuthzId, P, D, DG) ->
+           ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG)
+    end.
 
 bind(<<"">>, State) ->
     bind(new_uniq_id(), State);
-bind(R, #{user := U, server := S} = State) ->
+bind(R, #{user := U, server := S, access := Access, lang := Lang,
+         lserver := LServer, socket := Socket, ip := IP} = State) ->
     case resource_conflict_action(U, S, R) of
        closenew ->
            {error, xmpp:err_conflict(), State};
        {accept_resource, Resource} ->
-           open_session(State, Resource)
+           JID = jid:make(U, S, Resource),
+           case acl:access_matches(Access,
+                                   #{usr => jid:split(JID), ip => IP},
+                                   LServer) of
+               allow ->
+                   State1 = open_session(State#{resource => Resource}),
+                   State2 = ejabberd_hooks:run_fold(
+                              c2s_session_opened, LServer, State1, []),
+                   ?INFO_MSG("(~s) Opened session for ~s",
+                             [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
+                   {ok, State2};
+               deny ->
+                   ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]),
+                   ?INFO_MSG("(~s) Forbidden session for ~s",
+                             [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
+                   Txt = <<"Denied by ACL">>,
+                   {error, xmpp:err_not_allowed(Txt, Lang), State}
+           end
     end.
 
-handshake(_Data, State) ->
-    %% This is only for jabber component
-    {ok, State}.
+-spec open_session(state()) -> {ok, state()} | state().
+open_session(#{user := U, server := S, resource := R,
+              sid := SID, ip := IP, auth_module := AuthModule} = State) ->
+    JID = jid:make(U, S, R),
+    change_shaper(State),
+    Conn = get_conn_type(State),
+    State1 = State#{conn => Conn, resource => R, jid => JID},
+    Prio = try maps:get(pres_last, State) of
+              Pres -> get_priority_from_presence(Pres)
+          catch _:{badkey, _} ->
+                  undefined
+          end,
+    Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}],
+    ejabberd_sm:open_session(SID, U, S, R, Prio, Info),
+    State1.
 
-handle_stream_start(#{lserver := LServer, ip := IP, lang := Lang} = State) ->
+handle_stream_start(StreamStart,
+                   #{lserver := LServer, ip := IP, lang := Lang} = State) ->
     case lists:member(LServer, ?MYHOSTS) of
        false ->
-           xmpp_stream_in:send(State, xmpp:serr_host_unknown());
+           send(State, xmpp:serr_host_unknown());
        true ->
            case check_bl_c2s(IP, Lang) of
                false ->
                    change_shaper(State),
-                   {noreply, State};
+                   ejabberd_hooks:run_fold(
+                     c2s_stream_started, LServer, State, [StreamStart]);
                {true, LogReason, ReasonT} ->
                    ?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s",
                              [jlib:ip_to_list(IP), LogReason]),
                    Err = xmpp:serr_policy_violation(ReasonT, Lang),
-                   xmpp_stream_in:send(State, Err)
+                   send(State, Err)
            end
     end.
 
-handle_stream_end(State) ->
-    {stop, normal, State}.
+handle_stream_end(Reason, #{lserver := LServer} = State) ->
+    ejabberd_hooks:run_fold(c2s_closed, LServer, State, [Reason]).
 
-handle_stream_close(State) ->
-    {stop, normal, State}.
+handle_stream_close(_Reason, #{lserver := LServer} = State) ->
+    ejabberd_hooks:run_fold(c2s_closed, LServer, State, [normal]).
 
 handle_auth_success(User, Mech, AuthModule,
                    #{socket := Socket, ip := IP, lserver := LServer} = State) ->
-    ?INFO_MSG("(~w) Accepted ~s authentication for ~s@~s by ~p from ~s",
-             [Socket, Mech, User, LServer, AuthModule,
+    ?INFO_MSG("(~s) Accepted c2s ~s authentication for ~s@~s by ~s backend from ~s",
+             [ejabberd_socket:pp(Socket), Mech, User, LServer,
+              ejabberd_auth:backend_type(AuthModule),
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
     State1 = State#{auth_module => AuthModule},
     ejabberd_hooks:run_fold(c2s_auth_result, LServer,
-                           {noreply, State1}, [true, User]).
+                           State1, [true, User]).
 
 handle_auth_failure(User, Mech, Reason,
                    #{socket := Socket, ip := IP, lserver := LServer} = State) ->
-    ?INFO_MSG("(~w) Failed ~s authentication ~sfrom ~s: ~s",
-             [Socket, Mech,
+    ?INFO_MSG("(~s) Failed c2s ~s authentication ~sfrom ~s: ~s",
+             [ejabberd_socket:pp(Socket), Mech,
               if User /= <<"">> -> ["for ", User, "@", LServer, " "];
                  true -> ""
               end,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
     ejabberd_hooks:run_fold(c2s_auth_result, LServer,
-                           {noreply, State}, [false, User]).
+                           State, [false, User]).
 
 handle_unbinded_packet(Pkt, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer,
-                           {noreply, State}, [Pkt]).
+                           State, [Pkt]).
 
 handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_unauthenticated_packet,
-                           LServer, {noreply, State}, [Pkt]).
+                           LServer, State, [Pkt]).
 
 handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) ->
     ejabberd_hooks:run_fold(c2s_authenticated_packet,
-                           LServer, {noreply, State}, [Pkt]);
+                           LServer, State, [Pkt]);
 handle_authenticated_packet(Pkt, #{lserver := LServer} = State) ->
-    case ejabberd_hooks:run_fold(c2s_authenticated_packet,
-                                LServer, {noreply, State}, [Pkt]) of
-       {noreply, State1} ->
-           Pkt1 = ejabberd_hooks:run_fold(user_send_packet, LServer, Pkt, [State1]),
-           Res = case Pkt1 of
-                     #presence{to = #jid{lresource = <<"">>}} ->
-                         process_self_presence(State1, Pkt1);
-                     #presence{} ->
-                         process_presence_out(State1, Pkt1);
-                     _ ->
-                         check_privacy_then_route(State1, Pkt1)
-                 end,
-           ejabberd_hooks:run(c2s_loop_debug, [{xmlstreamelement, Pkt}]),
-           Res;
-       Err ->
-           ejabberd_hooks:run(c2s_loop_debug, [{xmlstreamelement, Pkt}]),
-           Err
+    State1 = ejabberd_hooks:run_fold(c2s_authenticated_packet,
+                                    LServer, State, [Pkt]),
+    Pkt1 = ejabberd_hooks:run_fold(user_send_packet, LServer, Pkt, [State1]),
+    case Pkt1 of
+       #presence{to = #jid{lresource = <<"">>}} ->
+           process_self_presence(State1, Pkt1);
+       #presence{} ->
+           process_presence_out(State1, Pkt1);
+       _ ->
+           check_privacy_then_route(State1, Pkt1)
     end.
 
 handle_cdata(Data, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_handle_cdata, LServer,
-                           {noreply, State}, [Data]).
+                           State, [Data]).
+
+handle_recv(El, Pkt, #{lserver := LServer} = State) ->
+    ejabberd_hooks:run_fold(c2s_handle_recv, LServer, State, [El, Pkt]).
 
-handle_send(Reason, Pkt, El, Data, #{lserver := LServer} = State) ->
-    ejabberd_hooks:run_fold(c2s_handle_send, LServer,
-                           {noreply, State}, [Reason, Pkt, El, Data]).
+handle_send(Pkt, Result, #{lserver := LServer} = State) ->
+    ejabberd_hooks:run_fold(c2s_handle_send, LServer, State, [Pkt, Result]).
 
 init([State, Opts]) ->
     Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all),
@@ -262,15 +368,13 @@ init([State, Opts]) ->
                    server => ?MYNAME,
                    access => Access,
                    shaper => Shaper},
-    ejabberd_hooks:run_fold(c2s_init, {ok, State1}, []).
+    ejabberd_hooks:run_fold(c2s_init, {ok, State1}, [Opts]).
 
 handle_call(get_presence, _From, #{jid := JID} = State) ->
-    Pres = case maps:get(pres_last, State, undefined) of
-              undefined ->
+    Pres = try maps:get(pres_last, State)
+          catch _:{badkey, _} ->
                   BareJID = jid:remove_resource(JID),
-                  #presence{from = JID, to = BareJID, type = unavailable};
-              P ->
-                  P
+                  #presence{from = JID, to = BareJID, type = unavailable}
           end,
     {reply, Pres, State};
 handle_call(get_subscribed, _From, #{pres_f := PresF} = State) ->
@@ -278,12 +382,10 @@ handle_call(get_subscribed, _From, #{pres_f := PresF} = State) ->
     {reply, Subscribed, State};
 handle_call(Request, From, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(
-      c2s_handle_call, LServer, {noreply, State}, [Request, From]).
+      c2s_handle_call, LServer, State, [Request, From]).
 
-handle_cast(closed, State) ->
-    handle_stream_close(State);
 handle_cast(Msg, #{lserver := LServer} = State) ->
-    ejabberd_hooks:run_fold(c2s_handle_cast, LServer, {noreply, State}, [Msg]).
+    ejabberd_hooks:run_fold(c2s_handle_cast, LServer, State, [Msg]).
 
 handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) ->
     Packet = xmpp:set_from_to(Packet0, From, To),
@@ -299,15 +401,13 @@ handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) ->
            Packet1 = ejabberd_hooks:run_fold(
                        user_receive_packet, LServer, Packet, [NewState]),
            ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
-           xmpp_stream_in:send(NewState, Packet1);
+           send(NewState, Packet1);
        true ->
            ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
-           {noreply, NewState}
+           NewState
     end;
-handle_info(system_shutdown, State) ->
-    xmpp_stream_in:send(State, xmpp:serr_system_shutdown());
 handle_info(Info, #{lserver := LServer} = State) ->
-    ejabberd_hooks:run_fold(c2s_handle_info, LServer, {noreply, State}, [Info]).
+    ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]).
 
 terminate(_Reason, _State) ->
     ok.
@@ -323,33 +423,6 @@ code_change(_OldVsn, State, _Extra) ->
 check_bl_c2s({IP, _Port}, Lang) ->
     ejabberd_hooks:run_fold(check_bl_c2s, false, [IP, Lang]).
 
--spec open_session(state(), binary()) -> {ok, state()} | {error, stanza_error(), state()}.
-open_session(#{user := U, server := S, lserver := LServer, sid := SID,
-              socket := Socket, ip := IP, auth_module := AuthMod,
-              access := Access, lang := Lang} = State, R) ->
-    JID = jid:make(U, S, R),
-    case acl:access_matches(Access,
-                           #{usr => jid:split(JID), ip => IP},
-                           LServer) of
-       allow ->
-           ?INFO_MSG("(~w) Opened session for ~s",
-                     [Socket, jid:to_string(JID)]),
-           change_shaper(State),
-           Conn = get_conn_type(State),
-            Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}],
-           ejabberd_sm:open_session(SID, U, LServer, R, Info),
-           State1 = State#{conn => Conn, resource => R, jid => JID},
-           State2 = ejabberd_hooks:run_fold(
-                      c2s_session_opened, LServer, State1, []),
-           {ok, State2};
-       deny ->
-           ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]),
-            ?INFO_MSG("(~w) Forbidden session for ~s",
-                      [Socket, jid:to_string(JID)]),
-           Txt = <<"Denied by ACL">>,
-           {error, xmpp:err_not_allowed(Txt, Lang), State}
-    end.
-
 -spec process_iq_in(state(), iq()) -> {boolean(), state()}.
 process_iq_in(State, #iq{} = IQ) ->
     case privacy_check_packet(State, IQ, in) of
@@ -433,7 +506,7 @@ route_probe_reply(From, To, #{lserver := LServer, pres_f := PresF,
 route_probe_reply(_, _, _) ->
     ok.
 
--spec process_presence_out(state(), presence()) -> next_state().
+-spec process_presence_out(state(), presence()) -> state().
 process_presence_out(#{user := User, server := Server, lserver := LServer,
                       jid := JID, lang := Lang, pres_a := PresA} = State,
                     #presence{from = From, to = To, type = Type} = Pres) ->
@@ -461,21 +534,21 @@ process_presence_out(#{user := User, server := Server, lserver := LServer,
                                       [User, Server, To, Type]),
                    BareFrom = jid:remove_resource(From),
                    route(xmpp:set_from_to(Pres, BareFrom, To)),
-                   {noreply, State}
+                   State
            end;
        allow when Type == error; Type == probe ->
            route(Pres),
-           {noreply, State};
+           State;
        allow ->
            route(Pres),
            A = case Type of
                    available -> ?SETS:add_element(LTo, PresA);
                    unavailable -> ?SETS:del_element(LTo, PresA)
                end,
-           {noreply, State#{pres_a => A}}
+           State#{pres_a => A}
     end.
 
--spec process_self_presence(state(), presence()) -> {noreply, state()}.
+-spec process_self_presence(state(), presence()) -> state().
 process_self_presence(#{ip := IP, conn := Conn,
                        auth_module := AuthMod, sid := SID,
                        user := U, server := S, resource := R} = State,
@@ -484,8 +557,7 @@ process_self_presence(#{ip := IP, conn := Conn,
     Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}],
     ejabberd_sm:unset_presence(SID, U, S, R, Status, Info),
     State1 = broadcast_presence_unavailable(State, Pres),
-    State2 = maps:remove(pres_last, maps:remove(pres_timestamp, State1)),
-    {noreply, State2};
+    maps:remove(pres_last, maps:remove(pres_timestamp, State1));
 process_self_presence(#{lserver := LServer} = State,
                      #presence{type = available} = Pres) ->
     PreviousPres = maps:get(pres_last, State, undefined),
@@ -494,10 +566,9 @@ process_self_presence(#{lserver := LServer} = State,
     State2 = State1#{pres_last => Pres,
                     pres_timestamp => p1_time_compat:timestamp()},
     FromUnavailable = PreviousPres == undefined,
-    State3 = broadcast_presence_available(State2, Pres, FromUnavailable),
-    {noreply, State3};
+    broadcast_presence_available(State2, Pres, FromUnavailable);
 process_self_presence(State, _Pres) ->
-    {noreply, State}.
+    State.
 
 -spec update_priority(state(), presence()) -> ok.
 update_priority(#{ip := IP, conn := Conn, auth_module := AuthMod,
@@ -529,7 +600,7 @@ broadcast_presence_available(#{pres_a := PresA, pres_f := PresF} = State,
     route_multiple(State, JIDs, Pres),
     State.
 
--spec check_privacy_then_route(state(), stanza()) -> next_state().
+-spec check_privacy_then_route(state(), stanza()) -> state().
 check_privacy_then_route(#{lang := Lang} = State, Pkt) ->
     case privacy_check_packet(State, Pkt, out) of
         deny ->
@@ -539,7 +610,7 @@ check_privacy_then_route(#{lang := Lang} = State, Pkt) ->
            xmpp_stream_in:send_error(State, Pkt, Err);
         allow ->
            route(Pkt),
-           {noreply, State}
+           State
     end.
 
 -spec privacy_check_packet(state(), stanza(), in | out) -> allow | deny.
@@ -664,25 +735,10 @@ do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) ->
            end
     end.
 
--spec fsm_limit_opts([proplists:property()]) -> [proplists:property()].
-fsm_limit_opts(Opts) ->
-    case lists:keysearch(max_fsm_queue, 1, Opts) of
-       {value, {_, N}} when is_integer(N) -> [{max_queue, N}];
-       _ ->
-           case ejabberd_config:get_option(
-                  max_fsm_queue,
-                  fun(I) when is_integer(I), I > 0 -> I end) of
-               undefined -> [];
-               N -> [{max_queue, N}]
-           end
-    end.
-
 transform_listen_option(Opt, Opts) ->
     [Opt|Opts].
 
 opt_type(domain_certfile) -> fun iolist_to_binary/1;
-opt_type(max_fsm_queue) ->
-    fun (I) when is_integer(I), I > 0 -> I end;
 opt_type(resource_conflict) ->
     fun (setresource) -> setresource;
        (closeold) -> closeold;
@@ -690,4 +746,4 @@ opt_type(resource_conflict) ->
        (acceptnew) -> acceptnew
     end;
 opt_type(_) ->
-    [domain_certfile, max_fsm_queue, resource_conflict].
+    [domain_certfile, resource_conflict].
index e930e36b165541909eb06078df25b81b1eafef0b..9014bfabdab8c9d99517271807b3ef16a25b6ad1 100644 (file)
@@ -38,7 +38,8 @@
         transform_options/1, collect_options/1, default_db/2,
         convert_to_yaml/1, convert_to_yaml/2, v_db/2,
         env_binary_to_list/2, opt_type/1, may_hide_data/1,
-        is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1]).
+        is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1,
+        fsm_limit_opts/1]).
 
 -export([start/2]).
 
@@ -1403,6 +1404,8 @@ opt_type(hosts) ->
     end;
 opt_type(language) ->
     fun iolist_to_binary/1;
+opt_type(max_fsm_queue) ->
+    fun (I) when is_integer(I), I > 0 -> I end;
 opt_type(_) ->
     [hide_sensitive_log_data, hosts, language].
 
@@ -1421,3 +1424,17 @@ may_hide_data(Data) ->
        true ->
            "hidden_by_ejabberd"
     end.
+
+-spec fsm_limit_opts([proplists:property()]) -> [{max_queue, pos_integer()}].
+fsm_limit_opts(Opts) ->
+    case lists:keyfind(max_fsm_queue, 1, Opts) of
+       {_, I} when is_integer(I), I>0 ->
+           [{max_queue, I}];
+       false ->
+           case get_option(
+                  max_fsm_queue,
+                  fun(I) when is_integer(I), I>0 -> I end) of
+               undefined -> [];
+               N -> [{max_queue, N}]
+           end
+    end.
index c1daa4c0e4130aee1fa61418622422f5e4a8d839..612d5afe5e7c1a88d930a36f1a4bf90a49b2b38f 100644 (file)
@@ -376,8 +376,11 @@ run_fold1([{_Seq, Module, Function} | Ls], Hook, Val, Args) ->
     end.
 
 safe_apply(Module, Function, Args) ->
-    if is_function(Function) ->
-            catch apply(Function, Args);
-       true ->
-            catch apply(Module, Function, Args)
+    try if is_function(Function) ->
+               apply(Function, Args);
+          true ->
+               apply(Module, Function, Args)
+       end
+    catch E:R when E /= exit, R /= normal ->
+           {'EXIT', {E, {R, erlang:get_stacktrace()}}}
     end.
index a9cc441e9891ad1fbb2b8fef3b43312e8dd592b6..f720fc5850c0198215d6dd241a9c12b6e3c7d570 100644 (file)
@@ -330,9 +330,9 @@ accept(ListenSocket, Module, Opts, Interval) ->
        {ok, Socket} ->
            case {inet:sockname(Socket), inet:peername(Socket)} of
                {{ok, {Addr, Port}}, {ok, {PAddr, PPort}}} ->
-                   ?INFO_MSG("(~w) Accepted connection ~s:~p -> ~s:~p",
-                             [Socket, ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)), PPort,
-                              inet_parse:ntoa(Addr), Port]);
+                   ?INFO_MSG("Accepted connection ~s:~p -> ~s:~p",
+                             [ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)),
+                              PPort, inet_parse:ntoa(Addr), Port]);
                _ ->
                    ok
            end,
index 33093abb091e5d060196c8336e6b533234aa5ffc..5ce8a8afb89dacd103359ef8d7f91e072f010762 100644 (file)
@@ -43,7 +43,9 @@
         unregister_route/1,
         unregister_routes/1,
         dirty_get_all_routes/0,
-        dirty_get_all_domains/0
+        dirty_get_all_domains/0,
+        is_my_route/1,
+        is_my_host/1
        ]).
 
 -export([start_link/0]).
@@ -110,12 +112,12 @@ register_route(Domain) ->
                 [?MODULE, ?MODULE]),
     register_route(Domain, ?MYNAME).
 
--spec register_route(binary(), binary()) -> term().
+-spec register_route(binary(), binary()) -> ok.
 
 register_route(Domain, ServerHost) ->
     register_route(Domain, ServerHost, undefined).
 
--spec register_route(binary(), binary(), local_hint()) -> term().
+-spec register_route(binary(), binary(), local_hint()) -> ok.
 
 register_route(Domain, ServerHost, LocalHint) ->
     case {jid:nameprep(Domain), jid:nameprep(ServerHost)} of
@@ -165,6 +167,11 @@ register_route(Domain, ServerHost, LocalHint) ->
                            end
                    end,
                mnesia:transaction(F)
+         end,
+         if LocalHint == undefined ->
+                 ?INFO_MSG("Route registered: ~s", [LDomain]);
+            true ->
+                 ok
          end
     end.
 
@@ -175,7 +182,7 @@ register_routes(Domains) ->
                  end,
                  Domains).
 
--spec unregister_route(binary()) -> term().
+-spec unregister_route(binary()) -> ok.
 
 unregister_route(Domain) ->
     case jid:nameprep(Domain) of
@@ -210,7 +217,8 @@ unregister_route(Domain) ->
                            end
                    end,
                mnesia:transaction(F)
-         end
+         end,
+         ?INFO_MSG("Route unregistered: ~s", [LDomain])
     end.
 
 -spec unregister_routes([binary()]) -> ok.
@@ -245,6 +253,29 @@ host_of_route(Domain) ->
            end
     end.
 
+-spec is_my_route(binary()) -> boolean().
+is_my_route(Domain) ->
+    case jid:nameprep(Domain) of
+       error ->
+           erlang:error({invalid_domain, Domain});
+       LDomain ->
+           mnesia:dirty_read(route, LDomain) /= []
+    end.
+
+-spec is_my_host(binary()) -> boolean().
+is_my_host(Domain) ->
+    case jid:nameprep(Domain) of
+       error ->
+           erlang:error({invalid_domain, Domain});
+       LDomain ->
+           case mnesia:dirty_read(route, LDomain) of
+               [#route{server_host = Host}|_] ->
+                   Host == LDomain;
+               [] ->
+                   false
+           end
+    end.
+
 -spec process_iq(jid(), jid(), iq() | xmlel()) -> any().
 process_iq(From, To, #iq{} = IQ) ->
     if To#jid.luser == <<"">> ->
index 4df1761cb537677b5c6f7f46de8b78ead232a66d..af4d6a66218058e80b232cdf54ba5211901d7757 100644 (file)
 
 %% API
 -export([start_link/0, route/3, have_connection/1,
-        make_key/2, get_connections_pids/1, try_register/1,
-        remove_connection/2, find_connection/2,
+        get_connections_pids/1, try_register/1,
+        remove_connection/2, start_connection/2, start_connection/3,
         dirty_get_connections/0, allow_host/2,
         incoming_s2s_number/0, outgoing_s2s_number/0,
         stop_all_connections/0,
         clean_temporarily_blocked_table/0,
         list_temporarily_blocked_hosts/0,
         external_host_overloaded/1, is_temporarly_blocked/1,
-        check_peer_certificate/3,
-        get_commands_spec/0]).
+        get_commands_spec/0, zlib_enabled/1, get_idle_timeout/1,
+        tls_required/1, tls_verify/1, tls_enabled/1, tls_options/2]).
 
 %% gen_server callbacks
 -export([init/1, handle_call/3, handle_cast/2,
@@ -196,39 +196,94 @@ try_register(FromTo) ->
 dirty_get_connections() ->
     mnesia:dirty_all_keys(s2s).
 
-check_peer_certificate(SockMod, Sock, Peer) ->
-    case SockMod:get_peer_certificate(Sock) of
-      {ok, Cert} ->
-         case SockMod:get_verify_result(Sock) of
-           0 ->
-               case ejabberd_idna:domain_utf8_to_ascii(Peer) of
-                 false ->
-                     {error, <<"Cannot decode remote server name">>};
-                 AsciiPeer ->
-                     case
-                       lists:any(fun(D) -> match_domain(AsciiPeer, D) end,
-                                 get_cert_domains(Cert)) of
-                       true ->
-                           {ok, <<"Verification successful">>};
-                       false ->
-                           {error, <<"Certificate host name mismatch">>}
-                     end
-               end;
-           VerifyRes ->
-               {error, fast_tls:get_cert_verify_string(VerifyRes, Cert)}
-         end;
-      {error, _Reason} ->
-           {error, <<"Cannot get peer certificate">>};
-      error ->
-           {error, <<"Cannot get peer certificate">>}
+-spec tls_options(binary(), [proplists:property()]) -> [proplists:property()].
+tls_options(LServer, DefaultOpts) ->
+    TLSOpts1 = case ejabberd_config:get_option(
+                     {s2s_certfile, LServer},
+                     fun iolist_to_binary/1,
+                     ejabberd_config:get_option(
+                       {domain_certfile, LServer},
+                       fun iolist_to_binary/1)) of
+                  undefined -> [];
+                  CertFile -> lists:keystore(certfile, 1, DefaultOpts,
+                                             {certfile, CertFile})
+              end,
+    TLSOpts2 = case ejabberd_config:get_option(
+                      {s2s_ciphers, LServer},
+                     fun iolist_to_binary/1) of
+                   undefined -> TLSOpts1;
+                   Ciphers -> lists:keystore(ciphers, 1, TLSOpts1,
+                                            {ciphers, Ciphers})
+               end,
+    TLSOpts3 = case ejabberd_config:get_option(
+                      {s2s_protocol_options, LServer},
+                      fun (Options) -> str:join(Options, <<$|>>) end) of
+                   undefined -> TLSOpts2;
+                   ProtoOpts -> lists:keystore(protocol_options, 1, TLSOpts2,
+                                              {protocol_options, ProtoOpts})
+               end,
+    TLSOpts4 = case ejabberd_config:get_option(
+                      {s2s_dhfile, LServer},
+                     fun iolist_to_binary/1) of
+                   undefined -> TLSOpts3;
+                   DHFile -> lists:keystore(dhfile, 1, TLSOpts3,
+                                           {dhfile, DHFile})
+               end,
+    TLSOpts5 = case ejabberd_config:get_option(
+                     {s2s_cafile, LServer},
+                     fun iolist_to_binary/1) of
+                  undefined -> TLSOpts4;
+                  CAFile -> lists:keystore(cafile, 1, TLSOpts4,
+                                           {cafile, CAFile})
+              end,
+    case ejabberd_config:get_option(
+          {s2s_tls_compression, LServer},
+          fun(B) when is_boolean(B) -> B end) of
+       undefined -> TLSOpts5;
+       false -> [compression_none | TLSOpts5];
+       true -> lists:delete(compression_none, TLSOpts5)
     end.
 
--spec make_key({binary(), binary()}, binary()) -> binary().
-make_key({From, To}, StreamID) ->
-    Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end),
-    p1_sha:to_hexlist(
-      crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)),
-                 [To, " ", From, " ", StreamID])).
+-spec tls_required(binary()) -> boolean().
+tls_required(LServer) ->
+    TLS = use_starttls(LServer),
+    TLS == required orelse TLS == required_trusted.
+
+-spec tls_verify(binary()) -> boolean().
+tls_verify(LServer) ->
+    TLS = use_starttls(LServer),
+    TLS == required_trusted.
+
+-spec tls_enabled(binary()) -> boolean().
+tls_enabled(LServer) ->
+    TLS = use_starttls(LServer),
+    TLS == true orelse TLS == optional.
+
+-spec zlib_enabled(binary()) -> boolean().
+zlib_enabled(LServer) ->
+    ejabberd_config:get_option(
+      {s2s_zlib, LServer},
+      fun(B) when is_boolean(B) -> B end,
+      false).
+
+-spec use_starttls(binary()) -> boolean() | optional | required | required_trusted.
+use_starttls(LServer) ->
+    ejabberd_config:get_option(
+      {s2s_use_starttls, LServer},
+      fun(true) -> true;
+        (false) -> false;
+        (optional) -> optional;
+        (required) -> required;
+        (required_trusted) -> required_trusted
+      end, false).
+
+-spec get_idle_timeout(binary()) -> non_neg_integer() | infinity.
+get_idle_timeout(LServer) ->
+    ejabberd_config:get_option(
+      {s2s_timeout, LServer},
+      fun(I) when is_integer(I), I >= 0 -> timer:seconds(I);
+        (infinity) -> infinity
+      end, timer:minutes(10)).
 
 %%====================================================================
 %% gen_server callbacks
@@ -246,6 +301,8 @@ init([]) ->
     ejabberd_mnesia:create(?MODULE, temporarily_blocked,
                        [{ram_copies, [node()]},
                         {attributes, record_info(fields, temporarily_blocked)}]),
+    ejabberd_s2s_in:add_hooks(),
+    ejabberd_s2s_out:add_hooks(),
     {ok, #state{}}.
 
 handle_call(_Request, _From, State) ->
@@ -291,30 +348,36 @@ clean_table_from_bad_node(Node) ->
        end,
     mnesia:async_dirty(F).
 
--spec do_route(jid(), jid(), stanza()) -> ok | false.
+-spec do_route(jid(), jid(), stanza()) -> ok.
 do_route(From, To, Packet) ->
     ?DEBUG("s2s manager~n\tfrom ~p~n\tto ~p~n\tpacket "
           "~P~n",
           [From, To, Packet, 8]),
-    case find_connection(From, To) of
-      {atomic, Pid} when is_pid(Pid) ->
-         ?DEBUG("sending to process ~p~n", [Pid]),
-         #jid{lserver = MyServer} = From,
-         ejabberd_hooks:run(s2s_send_packet, MyServer,
-                            [From, To, Packet]),
-         send_element(Pid, xmpp:set_from_to(Packet, From, To)),
-         ok;
-      {aborted, _Reason} ->
-         Lang = xmpp:get_lang(Packet),
-         Txt = <<"No s2s connection found">>,
-         Err = xmpp:err_service_unavailable(Txt, Lang),
-         ejabberd_router:route_error(To, From, Packet, Err),
-         false
+    case start_connection(From, To) of
+       {ok, Pid} when is_pid(Pid) ->
+           ?DEBUG("sending to process ~p~n", [Pid]),
+           #jid{lserver = MyServer} = From,
+           ejabberd_hooks:run(s2s_send_packet, MyServer, [From, To, Packet]),
+           ejabberd_s2s_out:route(Pid, xmpp:set_from_to(Packet, From, To));
+       {error, Reason} ->
+           Err = case Reason of
+                     forbidden ->
+                         Lang = xmpp:get_lang(Packet),
+                         xmpp:err_forbidden(<<"Denied by ACL">>, Lang);
+                     internal_server_error ->
+                         xmpp:err_internal_server_error()
+                 end,
+           ejabberd_router:route_error(To, From, Packet, Err)
     end.
 
--spec find_connection(jid(), jid()) -> {aborted, any()} | {atomic, pid()}.
+-spec start_connection(jid(), jid()) -> {ok, pid()} |
+                                       {error, forbidden | internal_server_error}.
+start_connection(From, To) ->
+    start_connection(From, To, []).
 
-find_connection(From, To) ->
+-spec start_connection(jid(), jid(), [proplists:property()])
+      -> {ok, pid()} | {error, forbidden | internal_server_error}.
+start_connection(From, To, Opts) ->
     #jid{lserver = MyServer} = From,
     #jid{lserver = Server} = To,
     FromTo = {MyServer, Server},
@@ -323,15 +386,13 @@ find_connection(From, To) ->
     MaxS2SConnectionsNumberPerNode =
        max_s2s_connections_number_per_node(FromTo),
     ?DEBUG("Finding connection for ~p~n", [FromTo]),
-    case catch mnesia:dirty_read(s2s, FromTo) of
-      {'EXIT', Reason} -> {aborted, Reason};
+    case mnesia:dirty_read(s2s, FromTo) of
       [] ->
          %% We try to establish all the connections if the host is not a
          %% service and if the s2s host is not blacklisted or
          %% is in whitelist:
-         case not is_service(From, To) andalso
-                allow_host(MyServer, Server)
-             of
+         LServer = ejabberd_router:host_of_route(MyServer),
+         case not is_service(From, To) andalso allow_host(LServer, Server) of
            true ->
                NeededConnections = needed_connections_number([],
                                                              MaxS2SConnectionsNumber,
@@ -339,8 +400,8 @@ find_connection(From, To) ->
                open_several_connections(NeededConnections, MyServer,
                                         Server, From, FromTo,
                                         MaxS2SConnectionsNumber,
-                                        MaxS2SConnectionsNumberPerNode);
-           false -> {aborted, error}
+                                        MaxS2SConnectionsNumberPerNode, Opts);
+           false -> {error, forbidden}
          end;
       L when is_list(L) ->
          NeededConnections = needed_connections_number(L,
@@ -351,10 +412,10 @@ find_connection(From, To) ->
                 open_several_connections(NeededConnections, MyServer,
                                          Server, From, FromTo,
                                          MaxS2SConnectionsNumber,
-                                         MaxS2SConnectionsNumberPerNode);
+                                         MaxS2SConnectionsNumberPerNode, Opts);
             true ->
                 %% We choose a connexion from the pool of opened ones.
-                {atomic, choose_connection(From, L)}
+                {ok, choose_connection(From, L)}
          end
     end.
 
@@ -377,20 +438,22 @@ choose_pid(From, Pids) ->
 
 open_several_connections(N, MyServer, Server, From,
                         FromTo, MaxS2SConnectionsNumber,
-                        MaxS2SConnectionsNumberPerNode) ->
-    ConnectionsResult = [new_connection(MyServer, Server,
-                                       From, FromTo, MaxS2SConnectionsNumber,
-                                       MaxS2SConnectionsNumberPerNode)
-                        || _N <- lists:seq(1, N)],
-    case [PID || {atomic, PID} <- ConnectionsResult] of
-      [] -> hd(ConnectionsResult);
-      PIDs -> {atomic, choose_pid(From, PIDs)}
+                        MaxS2SConnectionsNumberPerNode, Opts) ->
+    case lists:flatmap(
+          fun(_) ->
+                  new_connection(MyServer, Server,
+                                 From, FromTo, MaxS2SConnectionsNumber,
+                                 MaxS2SConnectionsNumberPerNode, Opts)
+          end, lists:seq(1, N)) of
+       [] ->
+           {error, internal_server_error};
+       PIDs ->
+           {ok, choose_pid(From, PIDs)}
     end.
 
 new_connection(MyServer, Server, From, FromTo,
-              MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode) ->
-    {ok, Pid} = ejabberd_s2s_out:start(
-                 MyServer, Server, new),
+              MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) ->
+    {ok, Pid} = ejabberd_s2s_out:start(MyServer, Server, Opts),
     F = fun() ->
                L = mnesia:read({s2s, FromTo}),
                NeededConnections = needed_connections_number(L,
@@ -398,17 +461,21 @@ new_connection(MyServer, Server, From, FromTo,
                                                              MaxS2SConnectionsNumberPerNode),
                if NeededConnections > 0 ->
                       mnesia:write(#s2s{fromto = FromTo, pid = Pid}),
-                      ?INFO_MSG("New s2s connection started ~p", [Pid]),
                       Pid;
                   true -> choose_connection(From, L)
                end
        end,
     TRes = mnesia:transaction(F),
     case TRes of
-      {atomic, Pid} -> ejabberd_s2s_out:start_connection(Pid);
-      _ -> ejabberd_s2s_out:stop_connection(Pid)
-    end,
-    TRes.
+      {atomic, Pid} ->
+           ejabberd_s2s_out:connect(Pid),
+           [Pid];
+      {aborted, Reason} ->
+           ?ERROR_MSG("failed to register connection ~s -> ~s: ~p",
+                      [MyServer, Server, Reason]),
+           ejabberd_s2s_out:stop(Pid),
+           []
+    end.
 
 -spec max_s2s_connections_number({binary(), binary()}) -> integer().
 max_s2s_connections_number({From, To}) ->
@@ -459,9 +526,6 @@ parent_domains(Domain) ->
                end,
                [], lists:reverse(str:tokens(Domain, <<".">>))).
 
-send_element(Pid, El) ->
-    Pid ! {send_element, El}.
-
 %%%----------------------------------------------------------------------
 %%% ejabberd commands
 
@@ -536,24 +600,13 @@ update_tables() ->
 
 %% Check if host is in blacklist or white list
 allow_host(MyServer, S2SHost) ->
-    allow_host2(MyServer, S2SHost) andalso
+    allow_host1(MyServer, S2SHost) andalso
       not is_temporarly_blocked(S2SHost).
 
-allow_host2(MyServer, S2SHost) ->
-    Hosts = (?MYHOSTS),
-    case lists:dropwhile(fun (ParentDomain) ->
-                                not lists:member(ParentDomain, Hosts)
-                        end,
-                        parent_domains(MyServer))
-       of
-      [MyHost | _] -> allow_host1(MyHost, S2SHost);
-      [] -> allow_host1(MyServer, S2SHost)
-    end.
-
 allow_host1(MyHost, S2SHost) ->
     Rule = ejabberd_config:get_option(
-             s2s_access,
-             fun(A) -> A end,
+             {s2s_access, MyHost},
+             fun acl:access_rules_validator/1,
              all),
     JID = jid:make(S2SHost),
     case acl:match_rule(MyHost, Rule, JID) of
@@ -624,133 +677,34 @@ get_s2s_state(S2sPid) ->
            end,
     [{s2s_pid, S2sPid} | Infos].
 
-get_cert_domains(Cert) ->
-    TBSCert = Cert#'Certificate'.tbsCertificate,
-    Subject = case TBSCert#'TBSCertificate'.subject of
-                 {rdnSequence, Subj} -> lists:flatten(Subj);
-                 _ -> []
-             end,
-    Extensions = case TBSCert#'TBSCertificate'.extensions of
-                    Exts when is_list(Exts) -> Exts;
-                    _ -> []
-                end,
-    lists:flatmap(fun (#'AttributeTypeAndValue'{type =
-                                                   ?'id-at-commonName',
-                                               value = Val}) ->
-                         case 'OTP-PUB-KEY':decode('X520CommonName', Val) of
-                           {ok, {_, D1}} ->
-                               D = if is_binary(D1) -> D1;
-                                      is_list(D1) -> list_to_binary(D1);
-                                      true -> error
-                                   end,
-                               if D /= error ->
-                                      case jid:from_string(D) of
-                                        #jid{luser = <<"">>, lserver = LD,
-                                             lresource = <<"">>} ->
-                                            [LD];
-                                        _ -> []
-                                      end;
-                                  true -> []
-                               end;
-                           _ -> []
-                         end;
-                     (_) -> []
-                 end,
-                 Subject)
-      ++
-      lists:flatmap(fun (#'Extension'{extnID =
-                                         ?'id-ce-subjectAltName',
-                                     extnValue = Val}) ->
-                           BVal = if is_list(Val) -> list_to_binary(Val);
-                                     true -> Val
-                                  end,
-                           case 'OTP-PUB-KEY':decode('SubjectAltName', BVal)
-                               of
-                             {ok, SANs} ->
-                                 lists:flatmap(fun ({otherName,
-                                                     #'AnotherName'{'type-id' =
-                                                                        ?'id-on-xmppAddr',
-                                                                    value =
-                                                                        XmppAddr}}) ->
-                                                       case
-                                                         'XmppAddr':decode('XmppAddr',
-                                                                           XmppAddr)
-                                                           of
-                                                         {ok, D}
-                                                             when
-                                                               is_binary(D) ->
-                                                             case
-                                                               jid:from_string((D))
-                                                                 of
-                                                               #jid{luser =
-                                                                        <<"">>,
-                                                                    lserver =
-                                                                        LD,
-                                                                    lresource =
-                                                                        <<"">>} ->
-                                                                   case
-                                                                     ejabberd_idna:domain_utf8_to_ascii(LD)
-                                                                       of
-                                                                     false ->
-                                                                         [];
-                                                                     PCLD ->
-                                                                         [PCLD]
-                                                                   end;
-                                                               _ -> []
-                                                             end;
-                                                         _ -> []
-                                                       end;
-                                                   ({dNSName, D})
-                                                       when is_list(D) ->
-                                                       case
-                                                         jid:from_string(list_to_binary(D))
-                                                           of
-                                                         #jid{luser = <<"">>,
-                                                              lserver = LD,
-                                                              lresource =
-                                                                  <<"">>} ->
-                                                             [LD];
-                                                         _ -> []
-                                                       end;
-                                                   (_) -> []
-                                               end,
-                                               SANs);
-                             _ -> []
-                           end;
-                       (_) -> []
-                   end,
-                   Extensions).
-
-match_domain(Domain, Domain) -> true;
-match_domain(Domain, Pattern) ->
-    DLabels = str:tokens(Domain, <<".">>),
-    PLabels = str:tokens(Pattern, <<".">>),
-    match_labels(DLabels, PLabels).
-
-match_labels([], []) -> true;
-match_labels([], [_ | _]) -> false;
-match_labels([_ | _], []) -> false;
-match_labels([DL | DLabels], [PL | PLabels]) ->
-    case lists:all(fun (C) ->
-                          $a =< C andalso C =< $z orelse
-                            $0 =< C andalso C =< $9 orelse
-                              C == $- orelse C == $*
-                  end,
-                  binary_to_list(PL))
-       of
-      true ->
-         Regexp = ejabberd_regexp:sh_to_awk(PL),
-         case ejabberd_regexp:run(DL, Regexp) of
-           match -> match_labels(DLabels, PLabels);
-           nomatch -> false
-         end;
-      false -> false
-    end.
-
 opt_type(route_subdomains) ->
     fun (s2s) -> s2s;
        (local) -> local
     end;
 opt_type(s2s_access) ->
     fun acl:access_rules_validator/1;
-opt_type(_) -> [route_subdomains, s2s_access].
+opt_type(domain_certfile) -> fun iolist_to_binary/1;
+opt_type(s2s_certfile) -> fun iolist_to_binary/1;
+opt_type(s2s_ciphers) -> fun iolist_to_binary/1;
+opt_type(s2s_dhfile) -> fun iolist_to_binary/1;
+opt_type(s2s_protocol_options) ->
+    fun (Options) -> str:join(Options, <<"|">>) end;
+opt_type(s2s_tls_compression) ->
+    fun (true) -> true;
+       (false) -> false
+    end;
+opt_type(s2s_use_starttls) ->
+    fun (true) -> true;
+       (false) -> false;
+       (optional) -> optional;
+       (required) -> required;
+       (required_trusted) -> required_trusted
+    end;
+opt_type(s2s_timeout) ->
+    fun(I) when is_integer(I), I>=0 -> I;
+       (infinity) -> infinity
+    end;
+opt_type(_) ->
+    [route_subdomains, s2s_access,  s2s_certfile,
+     s2s_ciphers, s2s_dhfile, s2s_protocol_options,
+     s2s_tls_compression, s2s_use_starttls, s2s_timeout].
index 395a0fce7eb9a8e67e2dcf2a90391cba1f600e3e..93f75bfcf32ce767bab9d0d7834b47ab40d9f8ac 100644 (file)
@@ -1,8 +1,5 @@
-%%%----------------------------------------------------------------------
-%%% File    : ejabberd_s2s_in.erl
-%%% Author  : Alexey Shchepin <alexey@process-one.net>
-%%% Purpose : Serve incoming s2s connection
-%%% Created :  6 Dec 2002 by Alexey Shchepin <alexey@process-one.net>
+%%%-------------------------------------------------------------------
+%%% Created : 12 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
 %%%
 %%%
 %%% ejabberd, Copyright (C) 2002-2016   ProcessOne
 %%% with this program; if not, write to the Free Software Foundation, Inc.,
 %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 %%%
-%%%----------------------------------------------------------------------
-
+%%%-------------------------------------------------------------------
 -module(ejabberd_s2s_in).
-
+-behaviour(xmpp_stream_in).
 -behaviour(ejabberd_config).
+-behaviour(ejabberd_socket).
 
--author('alexey@process-one.net').
-
--behaviour(p1_fsm).
-
-%% External exports
+%% ejabberd_socket callbacks
 -export([start/2, start_link/2, socket_type/0]).
-
--export([init/1, wait_for_stream/2,
-        wait_for_feature_request/2, stream_established/2,
-        handle_event/3, handle_sync_event/4, code_change/4,
-        handle_info/3, print_state/1, terminate/3, opt_type/1]).
+%% ejabberd_config callbacks
+-export([opt_type/1]).
+%% xmpp_stream_in callbacks
+-export([init/1, handle_call/3, handle_cast/2,
+        handle_info/2, terminate/2, code_change/3]).
+-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
+        compress_methods/1,
+        unauthenticated_stream_features/1, authenticated_stream_features/1,
+        handle_stream_start/2, handle_stream_end/2, handle_stream_close/2,
+        handle_stream_established/1, handle_auth_success/4,
+        handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2,
+        handle_unauthenticated_packet/2, handle_authenticated_packet/2]).
+%% Hooks
+-export([handle_unexpected_info/2, handle_unexpected_cast/2,
+        reject_unauthenticated_packet/2, process_closed/2]).
+%% API
+-export([stop/1, close/1, send/2, update_state/2, establish/1, add_hooks/0]).
 
 -include("ejabberd.hrl").
--include("logger.hrl").
-
 -include("xmpp.hrl").
+-include("logger.hrl").
 
--define(DICT, dict).
-
--record(state,
-       {socket                      :: ejabberd_socket:socket_state(),
-         sockmod = ejabberd_socket   :: ejabberd_socket | ejabberd_frontend_socket,
-         streamid = <<"">>           :: binary(),
-         shaper = none               :: shaper:shaper(),
-         tls = false                 :: boolean(),
-        tls_enabled = false         :: boolean(),
-         tls_required = false        :: boolean(),
-        tls_certverify = false      :: boolean(),
-         tls_options = []            :: list(),
-         server = <<"">>             :: binary(),
-        authenticated = false       :: boolean(),
-         auth_domain = <<"">>        :: binary(),
-        connections = (?DICT):new() :: ?TDICT,
-         timer = make_ref()          :: reference()}).
-
--type state_name() :: wait_for_stream | wait_for_feature_request | stream_established.
--type state() :: #state{}.
--type fsm_next() :: {next_state, state_name(), state()}.
--type fsm_stop() :: {stop, normal, state()}.
--type fsm_transition() :: fsm_stop() | fsm_next().
-
-%%-define(DBGFSM, true).
--ifdef(DBGFSM).
--define(FSMOPTS, [{debug, [trace]}]).
--else.
--define(FSMOPTS, []).
--endif.
+-type state() :: map().
+-export_type([state/0]).
 
+%%%===================================================================
+%%% API
+%%%===================================================================
 start(SockData, Opts) ->
-    supervisor:start_child(ejabberd_s2s_in_sup,
-                            [SockData, Opts]).
+    xmpp_stream_in:start(?MODULE, [SockData, Opts],
+                        ejabberd_config:fsm_limit_opts(Opts)).
 
 start_link(SockData, Opts) ->
-    p1_fsm:start_link(ejabberd_s2s_in, [SockData, Opts],
-                      ?FSMOPTS ++ fsm_limit_opts(Opts)).
+    xmpp_stream_in:start_link(?MODULE, [SockData, Opts],
+                             ejabberd_config:fsm_limit_opts(Opts)).
+
+close(Ref) ->
+    xmpp_stream_in:close(Ref).
+
+stop(Ref) ->
+    xmpp_stream_in:stop(Ref).
+
+socket_type() ->
+    xml_stream.
+
+-spec send(pid(), xmpp_element()) -> ok;
+         (state(), xmpp_element()) -> state().
+send(Stream, Pkt) ->
+    xmpp_stream_in:send(Stream, Pkt).
+
+-spec establish(state()) -> state().
+establish(State) ->
+    xmpp_stream_in:establish(State).
+
+-spec update_state(pid(), fun((state()) -> state()) |
+                  {module(), atom(), list()}) -> ok.
+update_state(Ref, Callback) ->
+    xmpp_stream_in:cast(Ref, {update_state, Callback}).
+
+-spec add_hooks() -> ok.
+add_hooks() ->
+    lists:foreach(
+      fun(Host) ->
+             ejabberd_hooks:add(s2s_in_closed, Host, ?MODULE,
+                                process_closed, 100),
+             ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE,
+                                reject_unauthenticated_packet, 100),
+             ejabberd_hooks:add(s2s_in_handle_info, Host, ?MODULE,
+                                handle_unexpected_info, 100),
+             ejabberd_hooks:add(s2s_in_handle_cast, Host, ?MODULE,
+                                handle_unexpected_cast, 100)
+      end, ?MYHOSTS).
+
+%%%===================================================================
+%%% Hooks
+%%%===================================================================
+handle_unexpected_info(State, Info) ->
+    ?WARNING_MSG("got unexpected info: ~p", [Info]),
+    State.
+
+handle_unexpected_cast(State, Msg) ->
+    ?WARNING_MSG("got unexpected cast: ~p", [Msg]),
+    State.
+
+reject_unauthenticated_packet(State, Pkt) ->
+    Err = xmpp:err_not_authorized(),
+    xmpp_stream_in:send_error(State, Pkt, Err).
+
+process_closed(State, _Reason) ->
+    stop(State).
+
+%%%===================================================================
+%%% xmpp_stream_in callbacks
+%%%===================================================================
+tls_options(#{tls_compression := Compression, server_host := LServer}) ->
+    Opts = case Compression of
+              false -> [compression_none];
+              true -> []
+          end,
+    ejabberd_s2s:tls_options(LServer, Opts).
+
+tls_required(#{server_host := LServer}) ->
+    ejabberd_s2s:tls_required(LServer).
 
-socket_type() -> xml_stream.
+tls_verify(#{server_host := LServer}) ->
+    ejabberd_s2s:tls_verify(LServer).
 
-%%%----------------------------------------------------------------------
-%%% Callback functions from gen_fsm
-%%%----------------------------------------------------------------------
+tls_enabled(#{server_host := LServer}) ->
+    ejabberd_s2s:tls_enabled(LServer).
 
-init([{SockMod, Socket}, Opts]) ->
-    ?DEBUG("started: ~p", [{SockMod, Socket}]),
-    Shaper = case lists:keysearch(shaper, 1, Opts) of
-              {value, {_, S}} -> S;
-              _ -> none
+compress_methods(#{server_host := LServer}) ->
+    case ejabberd_s2s:zlib_enabled(LServer) of
+       true -> [<<"zlib">>];
+       false -> []
+    end.
+
+unauthenticated_stream_features(#{server_host := LServer}) ->
+    ejabberd_hooks:run_fold(s2s_in_pre_auth_features, LServer, [], [LServer]).
+
+authenticated_stream_features(#{server_host := LServer}) ->
+    ejabberd_hooks:run_fold(s2s_in_post_auth_features, LServer, [], [LServer]).
+
+handle_stream_start(_StreamStart, #{lserver := LServer} = State) ->
+    case check_to(jid:make(LServer), State) of
+       false ->
+           send(State, xmpp:serr_host_unknown());
+       true ->
+           ServerHost = ejabberd_router:host_of_route(LServer),
+           State#{server_host => ServerHost}
+    end.
+
+handle_stream_end(Reason, #{server_host := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [Reason]).
+
+handle_stream_close(_Reason, #{server_host := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [normal]).
+
+handle_stream_established(State) ->
+    set_idle_timeout(State#{established => true}).
+
+handle_auth_success(RServer, Mech, _AuthModule,
+                   #{socket := Socket, ip := IP,
+                     auth_domains := AuthDomains,
+                     server_host := ServerHost,
+                     lserver := LServer} = State) ->
+    ?INFO_MSG("(~s) Accepted inbound s2s ~s authentication ~s -> ~s (~s)",
+             [ejabberd_socket:pp(Socket), Mech, RServer, LServer,
+              ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+    State1 = case ejabberd_s2s:allow_host(ServerHost, RServer) of
+                true ->
+                    AuthDomains1 = sets:add_element(RServer, AuthDomains),
+                    State#{auth_domains => AuthDomains1};
+                false ->
+                    State
             end,
-    {StartTLS, TLSRequired, TLSCertverify} =
-        case ejabberd_config:get_option(
-               s2s_use_starttls,
-               fun(false) -> false;
-                  (true) -> true;
-                  (optional) -> optional;
-                  (required) -> required;
-                  (required_trusted) -> required_trusted
-               end,
-               false) of
-            UseTls
-              when (UseTls == undefined) or
-                   (UseTls == false) ->
-                {false, false, false};
-            UseTls
-              when (UseTls == true) or
-                   (UseTls ==
-                        optional) ->
-                {true, false, false};
-            required -> {true, true, false};
-            required_trusted ->
-                {true, true, true}
-        end,
-    TLSOpts1 = case ejabberd_config:get_option(
-                     s2s_certfile,
-                     fun iolist_to_binary/1) of
-                  undefined -> [];
-                  CertFile -> [{certfile, CertFile}]
-             end,
-    TLSOpts2 = case ejabberd_config:get_option(
-                      s2s_ciphers, fun iolist_to_binary/1) of
-                   undefined -> TLSOpts1;
-                   Ciphers -> [{ciphers, Ciphers} | TLSOpts1]
-               end,
-    TLSOpts3 = case ejabberd_config:get_option(
-                      s2s_protocol_options,
-                      fun (Options) ->
-                              [_|O] = lists:foldl(
-                                           fun(X, Acc) -> X ++ Acc end, [],
-                                           [["|" | binary_to_list(Opt)] || Opt <- Options, is_binary(Opt)]
-                                          ),
-                              iolist_to_binary(O)
-                      end) of
-                   undefined -> TLSOpts2;
-                   ProtocolOpts -> [{protocol_options, ProtocolOpts} | TLSOpts2]
-               end,
-    TLSOpts4 = case ejabberd_config:get_option(
-                      s2s_dhfile, fun iolist_to_binary/1) of
-                   undefined -> TLSOpts3;
-                   DHFile -> [{dhfile, DHFile} | TLSOpts3]
-               end,
-    TLSOpts = case proplists:get_bool(tls_compression, Opts) of
-                  false -> [compression_none | TLSOpts4];
-                  true -> TLSOpts4
-              end,
-    Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
-    {ok, wait_for_stream,
-     #state{socket = Socket, sockmod = SockMod,
-           streamid = new_id(), shaper = Shaper, tls = StartTLS,
-           tls_enabled = false, tls_required = TLSRequired,
-           tls_certverify = TLSCertverify, tls_options = TLSOpts,
-           timer = Timer}}.
-
-%%----------------------------------------------------------------------
-%% Func: StateName/2
-%% Returns: {next_state, NextStateName, NextStateData}          |
-%%          {next_state, NextStateName, NextStateData, Timeout} |
-%%          {stop, Reason, NewStateData}
-%%----------------------------------------------------------------------
-wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
-    try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
-       #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM}
-         when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM ->
-           send_header(StateData, {1,0}),
-           send_element(StateData, xmpp:serr_invalid_namespace()),
-           {stop, normal, StateData};
-       #stream_start{to = #jid{lserver = Server},
-                     from = From, version = {1,0}}
-         when StateData#state.tls and not StateData#state.authenticated ->
-           send_header(StateData, {1,0}),
-           Auth = if StateData#state.tls_enabled ->
-                          case From of
-                              #jid{} ->
-                                  {Result, Message} =
-                                      ejabberd_s2s:check_peer_certificate(
-                                        StateData#state.sockmod,
-                                        StateData#state.socket,
-                                        From#jid.lserver),
-                                  {Result, From#jid.lserver, Message};
-                              undefined ->
-                                  {error, <<"(unknown)">>,
-                                   <<"Got no valid 'from' attribute">>}
-                          end;
-                     true ->
-                          {no_verify, <<"(unknown)">>, <<"TLS not (yet) enabled">>}
-                  end,
-           StartTLS = if StateData#state.tls_enabled -> [];
-                         not StateData#state.tls_enabled and
-                         not StateData#state.tls_required ->
-                              [#starttls{required = false}];
-                         not StateData#state.tls_enabled and
-                         StateData#state.tls_required ->
-                              [#starttls{required = true}]
-                      end,
-           case Auth of
-               {error, RemoteServer, CertError}
-                 when StateData#state.tls_certverify ->
-                   ?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)",
-                             [StateData#state.server, RemoteServer, CertError]),
-                   send_element(StateData,
-                                xmpp:serr_policy_violation(CertError, ?MYLANG)),
-                   {stop, normal, StateData};
-               {VerifyResult, RemoteServer, Msg} ->
-                   {SASL, NewStateData} =
-                       case VerifyResult of
-                           ok ->
-                               {[#sasl_mechanisms{list = [<<"EXTERNAL">>]}],
-                                StateData#state{auth_domain = RemoteServer}};
-                           error ->
-                               ?DEBUG("Won't accept certificate of ~s: ~s",
-                                      [RemoteServer, Msg]),
-                               {[], StateData};
-                           no_verify ->
-                               {[], StateData}
-                       end,
-                   send_element(NewStateData,
-                                #stream_features{
-                                   sub_els = SASL ++ StartTLS ++
-                                       ejabberd_hooks:run_fold(
-                                         s2s_stream_features, Server, [],
-                                         [Server])}),
-                   {next_state, wait_for_feature_request,
-                    NewStateData#state{server = Server}}
-           end;
-       #stream_start{to = #jid{lserver = Server},
-                     version = {1,0}} when StateData#state.authenticated ->
-           send_header(StateData, {1,0}),
-           send_element(StateData,
-                        #stream_features{
-                           sub_els = ejabberd_hooks:run_fold(
-                                       s2s_stream_features, Server, [],
-                                       [Server])}),
-           {next_state, stream_established, StateData};
-       #stream_start{db_xmlns = ?NS_SERVER_DIALBACK}
-         when (StateData#state.tls_required and StateData#state.tls_enabled)
-              or (not StateData#state.tls_required) ->
-           send_header(StateData, undefined),
-           {next_state, stream_established, StateData};
-       #stream_start{} ->
-           send_header(StateData, {1,0}),
-           send_element(StateData, xmpp:serr_undefined_condition()),
-           {stop, normal, StateData};
-       _ ->
-           send_header(StateData, {1,0}),
-            send_element(StateData, xmpp:serr_invalid_xml()),
-            {stop, normal, StateData}
-    catch _:{xmpp_codec, Why} ->
-           Txt = xmpp:format_error(Why),
-           send_header(StateData, {1,0}),
-           send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)),
-           {stop, normal, StateData}
-    end;
-wait_for_stream({xmlstreamerror, _}, StateData) ->
-    send_header(StateData, {1,0}),
-    send_element(StateData, xmpp:serr_not_well_formed()),
-    {stop, normal, StateData};
-wait_for_stream(timeout, StateData) ->
-    send_header(StateData, {1,0}),
-    send_element(StateData, xmpp:serr_connection_timeout()),
-    {stop, normal, StateData};
-wait_for_stream(closed, StateData) ->
-    {stop, normal, StateData}.
-
-wait_for_feature_request({xmlstreamelement, El}, StateData) ->
-    decode_element(El, wait_for_feature_request, StateData);
-wait_for_feature_request(#starttls{},
-                        #state{tls = true, tls_enabled = false} = StateData) ->
-    case (StateData#state.sockmod):get_sockmod(StateData#state.socket) of
-       gen_tcp ->
-           ?DEBUG("starttls", []),
-           Socket = StateData#state.socket,
-           TLSOpts1 = case
-                          ejabberd_config:get_option(
-                            {domain_certfile, StateData#state.server},
-                            fun iolist_to_binary/1) of
-                          undefined -> StateData#state.tls_options;
-                          CertFile ->
-                              lists:keystore(certfile, 1,
-                                             StateData#state.tls_options,
-                                             {certfile, CertFile})
-                      end,
-           TLSOpts2 = case ejabberd_config:get_option(
-                             {s2s_cafile, StateData#state.server},
-                             fun iolist_to_binary/1) of
-                          undefined -> TLSOpts1;
-                          CAFile ->
-                              lists:keystore(cafile, 1, TLSOpts1,
-                                             {cafile, CAFile})
-                      end,
-           TLSOpts = case ejabberd_config:get_option(
-                            {s2s_tls_compression, StateData#state.server},
-                            fun(true) -> true;
-                               (false) -> false
-                            end, false) of
-                         true -> lists:delete(compression_none, TLSOpts2);
-                         false -> [compression_none | TLSOpts2]
-                     end,
-           TLSSocket = (StateData#state.sockmod):starttls(
-                         Socket, TLSOpts,
-                         fxml:element_to_binary(
-                           xmpp:encode(#starttls_proceed{}))),
-           {next_state, wait_for_stream,
-            StateData#state{socket = TLSSocket, streamid = new_id(),
-                            tls_enabled = true, tls_options = TLSOpts}};
-       _ ->
-            send_element(StateData, #starttls_failure{}),
-            {stop, normal, StateData}
-    end;
-wait_for_feature_request(#sasl_auth{mechanism = Mech},
-                        #state{tls_enabled = true} = StateData) ->
-    case Mech of
-       <<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> ->
-           AuthDomain = StateData#state.auth_domain,
-           AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>, AuthDomain),
-           if AllowRemoteHost ->
-                   (StateData#state.sockmod):reset_stream(StateData#state.socket),
-                   send_element(StateData, #sasl_success{}),
-                   ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)",
-                             [AuthDomain, StateData#state.tls_enabled]),
-                   change_shaper(StateData, <<"">>, jid:make(AuthDomain)),
-                   {next_state, wait_for_stream,
-                    StateData#state{streamid = new_id(),
-                                    authenticated = true}};
-              true ->
-                   Txt = xmpp:mk_text(<<"Denied by ACL">>, ?MYLANG),
-                   send_element(StateData,
-                                #sasl_failure{reason = 'not-authorized',
-                                              text = Txt}),
-                   {stop, normal, StateData}
-           end;
-       _ ->
-           send_element(StateData, #sasl_failure{reason = 'invalid-mechanism'}),
-           {stop, normal, StateData}
-    end;
-wait_for_feature_request({xmlstreamend, _Name}, StateData) ->
-    {stop, normal, StateData};
-wait_for_feature_request({xmlstreamerror, _}, StateData) ->
-    send_element(StateData, xmpp:serr_not_well_formed()),
-    {stop, normal, StateData};
-wait_for_feature_request(closed, StateData) ->
-    {stop, normal, StateData};
-wait_for_feature_request(_Pkt, #state{tls_required = TLSRequired,
-                                     tls_enabled = TLSEnabled} = StateData)
-  when TLSRequired and not TLSEnabled ->
-    Txt = <<"Use of STARTTLS required">>,
-    send_element(StateData, xmpp:serr_policy_violation(Txt, ?MYLANG)),
-    {stop, normal, StateData};
-wait_for_feature_request(El, StateData) ->
-    stream_established({xmlstreamelement, El}, StateData).
-
-stream_established({xmlstreamelement, El}, StateData) ->
-    cancel_timer(StateData#state.timer),
-    Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
-    decode_element(El, stream_established, StateData#state{timer = Timer});
-stream_established(#db_result{to = To, from = From, key = Key},
-                  StateData) ->
-    ?DEBUG("GET KEY: ~p", [{To, From, Key}]),
-    case {ejabberd_s2s:allow_host(To, From),
-         lists:member(To, ejabberd_router:dirty_get_all_domains())} of
-       {true, true} ->
-           ejabberd_s2s_out:terminate_if_waiting_delay(To, From),
-           ejabberd_s2s_out:start(To, From,
-                                  {verify, self(), Key,
-                                   StateData#state.streamid}),
-           Conns = (?DICT):store({From, To},
-                                 wait_for_verification,
-                                 StateData#state.connections),
-           change_shaper(StateData, To, jid:make(From)),
-           {next_state, stream_established,
-            StateData#state{connections = Conns}};
-       {_, false} ->
-           send_element(StateData, xmpp:serr_host_unknown()),
-           {stop, normal, StateData};
-       {false, _} ->
-           send_element(StateData, xmpp:serr_invalid_from()),
-           {stop, normal, StateData}
-    end;
-stream_established(#db_verify{to = To, from = From, id = Id, key = Key},
-                  StateData) ->
-    ?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]),
-    Type = case ejabberd_s2s:make_key({To, From}, Id) of
-              Key -> valid;
-              _ -> invalid
-          end,
-    send_element(StateData,
-                #db_verify{from = To, to = From, id = Id, type = Type}),
-    {next_state, stream_established, StateData};
-stream_established(Pkt, StateData) when ?is_stanza(Pkt) ->
+    ejabberd_hooks:run_fold(s2s_in_auth_result, ServerHost, State1, [true, RServer]).
+
+handle_auth_failure(RServer, Mech, Reason,
+                   #{socket := Socket, ip := IP,
+                     server_host := ServerHost,
+                     lserver := LServer} = State) ->
+    ?INFO_MSG("(~s) Failed inbound s2s ~s authentication ~s -> ~s (~s): ~s",
+             [ejabberd_socket:pp(Socket), Mech, RServer, LServer,
+              ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
+    ejabberd_hooks:run_fold(s2s_in_auth_result,
+                           ServerHost, State, [false, RServer]).
+
+handle_unauthenticated_packet(Pkt, #{server_host := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_in_unauthenticated_packet,
+                           LServer, State, [Pkt]).
+
+handle_authenticated_packet(Pkt, #{server_host := LServer} = State) when not ?is_stanza(Pkt) ->
+    ejabberd_hooks:run_fold(s2s_in_authenticated_packet, LServer, State, [Pkt]);
+handle_authenticated_packet(Pkt, State) ->
     From = xmpp:get_from(Pkt),
     To = xmpp:get_to(Pkt),
-    if To /= undefined, From /= undefined ->
-           LFrom = From#jid.lserver,
-           LTo = To#jid.lserver,
-           if StateData#state.authenticated ->
-                   case LFrom == StateData#state.auth_domain andalso
-                       lists:member(LTo, ejabberd_router:dirty_get_all_domains()) of
-                       true ->
-                           ejabberd_hooks:run(s2s_receive_packet, LTo,
-                                              [From, To, Pkt]),
-                           ejabberd_router:route(From, To, Pkt);
-                       false ->
-                           send_error(StateData, Pkt, xmpp:err_not_authorized())
-                   end;
-              true ->
-                   case (?DICT):find({LFrom, LTo}, StateData#state.connections) of
-                       {ok, established} ->
-                           ejabberd_hooks:run(s2s_receive_packet, LTo,
-                                              [From, To, Pkt]),
-                           ejabberd_router:route(From, To, Pkt);
-                       _ ->
-                           send_error(StateData, Pkt, xmpp:err_not_authorized())
-                   end
-           end;
-       true ->
-           send_error(StateData, Pkt, xmpp:err_jid_malformed())
-    end,
-    ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]),
-    {next_state, stream_established, StateData};
-stream_established({valid, From, To}, StateData) ->
-    send_element(StateData,
-                #db_result{from = To, to = From, type = valid}),
-    ?INFO_MSG("Accepted s2s dialback authentication for ~s (TLS=~p)",
-             [From, StateData#state.tls_enabled]),
-    NSD = StateData#state{connections =
-                             (?DICT):store({From, To}, established,
-                                           StateData#state.connections)},
-    {next_state, stream_established, NSD};
-stream_established({invalid, From, To}, StateData) ->
-    send_element(StateData,
-                #db_result{from = To, to = From, type = invalid}),
-    NSD = StateData#state{connections =
-                             (?DICT):erase({From, To},
-                                           StateData#state.connections)},
-    {next_state, stream_established, NSD};
-stream_established({xmlstreamend, _Name}, StateData) ->
-    {stop, normal, StateData};
-stream_established({xmlstreamerror, _}, StateData) ->
-    send_element(StateData, xmpp:serr_not_well_formed()),
-    {stop, normal, StateData};
-stream_established(timeout, StateData) ->
-    send_element(StateData, xmpp:serr_connection_timeout()),
-    {stop, normal, StateData};
-stream_established(closed, StateData) ->
-    {stop, normal, StateData};
-stream_established(Pkt, StateData) ->
-    ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]),
-    {next_state, stream_established, StateData}.
-
-%%----------------------------------------------------------------------
-%% Func: StateName/3
-%% Returns: {next_state, NextStateName, NextStateData}            |
-%%          {next_state, NextStateName, NextStateData, Timeout}   |
-%%          {reply, Reply, NextStateName, NextStateData}          |
-%%          {reply, Reply, NextStateName, NextStateData, Timeout} |
-%%          {stop, Reason, NewStateData}                          |
-%%          {stop, Reason, Reply, NewStateData}
-%%----------------------------------------------------------------------
-%state_name(Event, From, StateData) ->
-%    Reply = ok,
-%    {reply, Reply, state_name, StateData}.
-
-handle_event(_Event, StateName, StateData) ->
-    {next_state, StateName, StateData}.
-
-handle_sync_event(get_state_infos, _From, StateName,
-                 StateData) ->
-    SockMod = StateData#state.sockmod,
-    {Addr, Port} = try
-                    SockMod:peername(StateData#state.socket)
-                  of
-                    {ok, {A, P}} -> {A, P};
-                    {error, _} -> {unknown, unknown}
-                  catch
-                    _:_ -> {unknown, unknown}
-                  end,
-    Domains = get_external_hosts(StateData),
-    Infos = [{direction, in}, {statename, StateName},
-            {addr, Addr}, {port, Port},
-            {streamid, StateData#state.streamid},
-            {tls, StateData#state.tls},
-            {tls_enabled, StateData#state.tls_enabled},
-            {tls_options, StateData#state.tls_options},
-            {authenticated, StateData#state.authenticated},
-            {shaper, StateData#state.shaper}, {sockmod, SockMod},
-            {domains, Domains}],
-    Reply = {state_infos, Infos},
-    {reply, Reply, StateName, StateData};
-%%----------------------------------------------------------------------
-%% Func: handle_sync_event/4
-%% Returns: {next_state, NextStateName, NextStateData}            |
-%%          {next_state, NextStateName, NextStateData, Timeout}   |
-%%          {reply, Reply, NextStateName, NextStateData}          |
-%%          {reply, Reply, NextStateName, NextStateData, Timeout} |
-%%          {stop, Reason, NewStateData}                          |
-%%          {stop, Reason, Reply, NewStateData}
-%%----------------------------------------------------------------------
-handle_sync_event(_Event, _From, StateName,
-                 StateData) ->
-    Reply = ok, {reply, Reply, StateName, StateData}.
-
-code_change(_OldVsn, StateName, StateData, _Extra) ->
-    {ok, StateName, StateData}.
-
-handle_info({send_text, Text}, StateName, StateData) ->
-    send_text(StateData, Text),
-    {next_state, StateName, StateData};
-handle_info({timeout, Timer, _}, StateName,
-           #state{timer = Timer} = StateData) ->
-    if StateName == wait_for_stream ->
-           send_header(StateData, undefined);
-       true ->
-           ok
-    end,
-    send_element(StateData, xmpp:serr_connection_timeout()),
-    {stop, normal, StateData};
-handle_info(_, StateName, StateData) ->
-    {next_state, StateName, StateData}.
-
-terminate(Reason, _StateName, StateData) ->
-    ?DEBUG("terminated: ~p", [Reason]),
-    case Reason of
-      {process_limit, _} ->
-         [ejabberd_s2s:external_host_overloaded(Host)
-          || Host <- get_external_hosts(StateData)];
-      _ -> ok
-    end,
-    catch send_trailer(StateData),
-    (StateData#state.sockmod):close(StateData#state.socket),
-    ok.
-
-get_external_hosts(StateData) ->
-    case StateData#state.authenticated of
-      true -> [StateData#state.auth_domain];
-      false ->
-         Connections = StateData#state.connections,
-         [D
-          || {{D, _}, established} <- dict:to_list(Connections)]
+    case check_from_to(From, To, State) of
+       ok ->
+           LServer = ejabberd_router:host_of_route(To#jid.lserver),
+           State1 = ejabberd_hooks:run_fold(s2s_in_authenticated_packet,
+                                            LServer, State, [Pkt]),
+           Pkt1 = ejabberd_hooks:run_fold(s2s_receive_packet, LServer,
+                                          Pkt, [State1]),
+           ejabberd_router:route(From, To, Pkt1),
+           State1;
+       {error, Err} ->
+           send(State, Err)
     end.
 
-print_state(State) -> State.
+handle_cdata(Data, #{server_host := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_in_handle_cdata, LServer, State, [Data]).
+
+handle_recv(El, Pkt, #{server_host := LServer} = State) ->
+    State1 = set_idle_timeout(State),
+    ejabberd_hooks:run_fold(s2s_in_handle_recv, LServer, State1, [El, Pkt]).
+
+handle_send(Pkt, Result, #{server_host := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_in_handle_send, LServer,
+                           State, [Pkt, Result]).
+
+init([State, Opts]) ->
+    Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none),
+    TLSCompression = proplists:get_bool(tls_compression, Opts),
+    State1 = State#{tls_compression => TLSCompression,
+                   auth_domains => sets:new(),
+                   xmlns => ?NS_SERVER,
+                   lang => ?MYLANG,
+                   server => ?MYNAME,
+                   lserver => ?MYNAME,
+                   server_host => ?MYNAME,
+                   established => false,
+                   shaper => Shaper},
+    ejabberd_hooks:run_fold(s2s_in_init, {ok, State1}, [Opts]).
+
+handle_call(Request, From, #{server_host := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_in_handle_call, LServer, State, [Request, From]).
+
+handle_cast({update_state, Fun}, State) ->
+    case Fun of
+       {M, F, A} -> erlang:apply(M, F, [State|A]);
+       _ when is_function(Fun) -> Fun(State)
+    end;
+handle_cast(Msg, #{server_host := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_in_handle_cast, LServer, State, [Msg]).
+
+handle_info(Info, #{server_host := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_in_handle_info, LServer, State, [Info]).
+
+terminate(_Reason, _State) ->
+    ok.
+
+code_change(_OldVsn, State, _Extra) ->
+    {ok, State}.
 
-%%%----------------------------------------------------------------------
+%%%===================================================================
 %%% Internal functions
-%%%----------------------------------------------------------------------
-
--spec send_text(state(), iodata()) -> ok.
-send_text(StateData, Text) ->
-    (StateData#state.sockmod):send(StateData#state.socket,
-                                  Text).
-
--spec send_element(state(), xmpp_element()) -> ok.
-send_element(StateData, El) ->
-    El1 = xmpp:encode(El, ?NS_SERVER),
-    send_text(StateData, fxml:element_to_binary(El1)).
-
--spec send_error(state(), xmlel() | stanza(), stanza_error()) -> ok.
-send_error(StateData, Stanza, Error) ->
-    Type = xmpp:get_type(Stanza),
-    if Type == error; Type == result;
-       Type == <<"error">>; Type == <<"result">> ->
-           ok;
-       true ->
-           send_element(StateData, xmpp:make_error(Stanza, Error))
+%%%===================================================================
+-spec check_from_to(jid(), jid(), state()) -> ok | {error, stream_error()}.
+check_from_to(From, To, State) ->
+    case check_from(From, State) of
+       true ->
+           case check_to(To, State) of
+               true ->
+                   ok;
+               false ->
+                   {error, xmpp:serr_improper_addressing()}
+           end;
+       false ->
+           {error, xmpp:serr_invalid_from()}
     end.
 
--spec send_trailer(state()) -> ok.
-send_trailer(StateData) ->
-    send_text(StateData, <<"</stream:stream>">>).
-
--spec send_header(state(), undefined | {integer(), integer()}) -> ok.
-send_header(StateData, Version) ->
-    Header = xmpp:encode(
-              #stream_start{xmlns = ?NS_SERVER,
-                            stream_xmlns = ?NS_STREAM,
-                            db_xmlns = ?NS_SERVER_DIALBACK,
-                            id = StateData#state.streamid,
-                            version = Version}),
-    send_text(StateData, fxml:element_to_header(Header)).
-
--spec change_shaper(state(), binary(), jid()) -> ok.
-change_shaper(StateData, Host, JID) ->
-    Shaper = acl:match_rule(Host, StateData#state.shaper,
-                           JID),
-    (StateData#state.sockmod):change_shaper(StateData#state.socket,
-                                           Shaper).
-
--spec new_id() -> binary().
-new_id() -> randoms:get_string().
-
--spec cancel_timer(reference()) -> ok.
-cancel_timer(Timer) ->
-    erlang:cancel_timer(Timer),
-    receive {timeout, Timer, _} -> ok after 0 -> ok end.
-
-fsm_limit_opts(Opts) ->
-    case lists:keysearch(max_fsm_queue, 1, Opts) of
-      {value, {_, N}} when is_integer(N) -> [{max_queue, N}];
-      _ ->
-         case ejabberd_config:get_option(
-                 max_fsm_queue,
-                 fun(I) when is_integer(I), I > 0 -> I end) of
-            undefined -> [];
-           N -> [{max_queue, N}]
-         end
-    end.
+-spec check_from(jid(), state()) -> boolean().
+check_from(#jid{lserver = S1}, #{auth_domains := AuthDomains}) ->
+    sets:is_element(S1, AuthDomains).
+
+-spec check_to(jid(), state()) -> boolean().
+check_to(#jid{lserver = LServer}, _State) ->
+    ejabberd_router:is_my_route(LServer).
+
+-spec set_idle_timeout(state()) -> state().
+set_idle_timeout(#{server_host := LServer,
+                  established := true} = State) ->
+    Timeout = ejabberd_s2s:get_idle_timeout(LServer),
+    xmpp_stream_in:set_timeout(State, Timeout);
+set_idle_timeout(State) ->
+    State.
 
--spec decode_element(xmlel() | xmpp_element(), state_name(), state()) -> fsm_transition().
-decode_element(#xmlel{} = El, StateName, StateData) ->
-    Opts = if StateName == stream_established ->
-                  [ignore_els];
-             true ->
-                  []
-          end,
-    try xmpp:decode(El, ?NS_SERVER, Opts) of
-       Pkt -> ?MODULE:StateName(Pkt, StateData)
-    catch error:{xmpp_codec, Why} ->
-            case xmpp:is_stanza(El) of
-                true ->
-                   Lang = xmpp:get_lang(El),
-                    Txt = xmpp:format_error(Why),
-                   send_error(StateData, El, xmpp:err_bad_request(Txt, Lang));
-                false ->
-                    ok
-            end,
-            {next_state, StateName, StateData}
-    end;
-decode_element(Pkt, StateName, StateData) ->
-    ?MODULE:StateName(Pkt, StateData).
-
-opt_type(domain_certfile) -> fun iolist_to_binary/1;
-opt_type(max_fsm_queue) ->
-    fun (I) when is_integer(I), I > 0 -> I end;
-opt_type(s2s_certfile) -> fun iolist_to_binary/1;
-opt_type(s2s_cafile) -> fun iolist_to_binary/1;
-opt_type(s2s_ciphers) -> fun iolist_to_binary/1;
-opt_type(s2s_dhfile) -> fun iolist_to_binary/1;
-opt_type(s2s_protocol_options) ->
-    fun (Options) ->
-           [_ | O] = lists:foldl(fun (X, Acc) -> X ++ Acc end, [],
-                                 [["|" | binary_to_list(Opt)]
-                                  || Opt <- Options, is_binary(Opt)]),
-           iolist_to_binary(O)
-    end;
-opt_type(s2s_tls_compression) ->
-    fun (true) -> true;
-       (false) -> false
-    end;
-opt_type(s2s_use_starttls) ->
-    fun (false) -> false;
-       (true) -> true;
-       (optional) -> optional;
-       (required) -> required;
-       (required_trusted) -> required_trusted
-    end;
 opt_type(_) ->
-    [domain_certfile, max_fsm_queue, s2s_certfile, s2s_cafile,
-     s2s_ciphers, s2s_dhfile, s2s_protocol_options,
-     s2s_tls_compression, s2s_use_starttls].
+    [].
index b9ce47830309f7394a02aafe82aa641daf1cad92..72d9dfea8594c50f5c58f509c5e7172dfb3df830 100644 (file)
-%%%----------------------------------------------------------------------
-%%% File    : ejabberd_s2s_out.erl
-%%% Author  : Alexey Shchepin <alexey@process-one.net>
-%%% Purpose : Manage outgoing server-to-server connections
-%%% Created :  6 Dec 2002 by Alexey Shchepin <alexey@process-one.net>
+%%%-------------------------------------------------------------------
+%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%% @copyright (C) 2016, Evgeny Khramtsov
+%%% @doc
 %%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2016   ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
+%%% @end
+%%% Created : 16 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%-------------------------------------------------------------------
 -module(ejabberd_s2s_out).
-
+-behaviour(xmpp_stream_out).
 -behaviour(ejabberd_config).
 
--author('alexey@process-one.net').
-
--behaviour(p1_fsm).
-
-%% External exports
--export([start/3,
-        start_link/3,
-        start_connection/1,
-        terminate_if_waiting_delay/2,
-        stop_connection/1,
-        transform_options/1]).
-
--export([init/1, open_socket/2, wait_for_stream/2,
-        wait_for_validation/2, wait_for_features/2,
-        wait_for_auth_result/2, wait_for_starttls_proceed/2,
-        relay_to_bridge/2, reopen_socket/2, wait_before_retry/2,
-        stream_established/2, handle_event/3,
-        handle_sync_event/4, handle_info/3, terminate/3,
-        print_state/1, code_change/4, test_get_addr_port/1,
-        get_addr_port/1, opt_type/1]).
+%% ejabberd_config callbacks
+-export([opt_type/1, transform_options/1]).
+%% xmpp_stream_out callbacks
+-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
+        handle_auth_success/2, handle_auth_failure/3, handle_packet/2,
+        handle_stream_end/2, handle_stream_close/2,
+        handle_recv/3, handle_send/4, handle_cdata/2,
+        handle_stream_established/1, handle_timeout/1]).
+-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
+        terminate/2, code_change/3]).
+%% Hooks
+-export([process_auth_result/2, process_closed/2, handle_unexpected_info/2,
+        handle_unexpected_cast/2]).
+%% API
+-export([start/3, start_link/3, connect/1, close/1, stop/1, send/2,
+        route/2, establish/1, update_state/2, add_hooks/0]).
 
 -include("ejabberd.hrl").
--include("logger.hrl").
 -include("xmpp.hrl").
+-include("logger.hrl").
 
--record(state,
-       {socket                           :: ejabberd_socket:socket_state(),
-         streamid = <<"">>                :: binary(),
-        remote_streamid = <<"">>         :: binary(),
-         use_v10 = true                   :: boolean(),
-         tls = false                      :: boolean(),
-        tls_required = false             :: boolean(),
-        tls_certverify = false           :: boolean(),
-         tls_enabled = false              :: boolean(),
-        tls_options = [connect]          :: list(),
-         authenticated = false            :: boolean(),
-        db_enabled = true                :: boolean(),
-         try_auth = true                  :: boolean(),
-         myname = <<"">>                  :: binary(),
-         server = <<"">>                  :: binary(),
-        queue = queue:new()              :: ?TQUEUE,
-         delay_to_retry = undefined_delay :: undefined_delay | non_neg_integer(),
-         new = false                      :: boolean(),
-        verify = false                   :: false | {pid(), binary(), binary()},
-         bridge                           :: {atom(), atom()},
-         timer = make_ref()               :: reference()}).
-
--type state_name() :: open_socket | wait_for_stream |
-                     wait_for_validation | wait_for_features |
-                     wait_for_auth_result | wait_for_starttls_proceed |
-                     relay_to_bridge | reopen_socket | wait_before_retry |
-                     stream_established.
--type state() :: #state{}.
--type fsm_stop() :: {stop, normal, state()}.
--type fsm_next() :: {next_state, state_name(), state(), non_neg_integer()} |
-                   {next_state, state_name(), state()}.
--type fsm_transition() :: fsm_stop() | fsm_next().
-
-%%-define(DBGFSM, true).
-
--ifdef(DBGFSM).
-
--define(FSMOPTS, [{debug, [trace]}]).
-
--else.
-
--define(FSMOPTS, []).
-
--endif.
-
--define(FSMTIMEOUT, 30000).
-
-%% We do not block on send anymore.
--define(TCP_SEND_TIMEOUT, 15000).
-
-%% Maximum delay to wait before retrying to connect after a failed attempt.
-%% Specified in miliseconds. Default value is 5 minutes.
--define(MAX_RETRY_DELAY, 300000).
-
--define(SOCKET_DEFAULT_RESULT, {error, badarg}).
+-type state() :: map().
+-export_type([state/0]).
 
-%%%----------------------------------------------------------------------
+%%%===================================================================
 %%% API
-%%%----------------------------------------------------------------------
-start(From, Host, Type) ->
-    supervisor:start_child(ejabberd_s2s_out_sup,
-                          [From, Host, Type]).
-
-start_link(From, Host, Type) ->
-    p1_fsm:start_link(ejabberd_s2s_out, [From, Host, Type],
-                     fsm_limit_opts() ++ (?FSMOPTS)).
-
-start_connection(Pid) -> p1_fsm:send_event(Pid, init).
-
-stop_connection(Pid) -> p1_fsm:send_event(Pid, closed).
-
-%%%----------------------------------------------------------------------
-%%% Callback functions from p1_fsm
-%%%----------------------------------------------------------------------
-
-init([From, Server, Type]) ->
-    process_flag(trap_exit, true),
-    ?DEBUG("started: ~p", [{From, Server, Type}]),
-    {TLS, TLSRequired, TLSCertverify} =
-       case ejabberd_config:get_option(
-              s2s_use_starttls,
-              fun(true) -> true;
-                 (false) -> false;
-                 (optional) -> optional;
-                 (required) -> required;
-                 (required_trusted) -> required_trusted
-              end)
-           of
-         UseTls
-             when (UseTls == undefined) or (UseTls == false) ->
-             {false, false, false};
-         UseTls
-             when (UseTls == true) or (UseTls == optional) ->
-             {true, false, false};
-         required ->
-             {true, true, false};
-         required_trusted ->
-             {true, true, true}
-       end,
-    UseV10 = TLS,
-    TLSOpts1 = case
-               ejabberd_config:get_option(
-                  s2s_certfile, fun iolist_to_binary/1)
-                 of
-               undefined -> [connect];
-               CertFile -> [{certfile, CertFile}, connect]
-             end,
-    TLSOpts2 = case ejabberd_config:get_option(
-                      s2s_ciphers, fun iolist_to_binary/1) of
-                   undefined -> TLSOpts1;
-                   Ciphers -> [{ciphers, Ciphers} | TLSOpts1]
-               end,
-    TLSOpts3 = case ejabberd_config:get_option(
-                      s2s_protocol_options,
-                      fun (Options) ->
-                              [_|O] = lists:foldl(
-                                           fun(X, Acc) -> X ++ Acc end, [],
-                                           [["|" | binary_to_list(Opt)] || Opt <- Options, is_binary(Opt)]
-                                          ),
-                              iolist_to_binary(O)
-                      end) of
-                   undefined -> TLSOpts2;
-                   ProtocolOpts -> [{protocol_options, ProtocolOpts} | TLSOpts2]
-               end,
-    TLSOpts4 = case ejabberd_config:get_option(
-                      s2s_dhfile, fun iolist_to_binary/1) of
-                   undefined -> TLSOpts3;
-                   DHFile -> [{dhfile, DHFile} | TLSOpts3]
-               end,
-    TLSOpts = case ejabberd_config:get_option(
-                     {s2s_tls_compression, From},
-                     fun(true) -> true;
-                        (false) -> false
-                     end, false) of
-                  false -> [compression_none | TLSOpts4];
-                  true -> TLSOpts4
-              end,
-    {New, Verify} = case Type of
-                     new -> {true, false};
-                     {verify, Pid, Key, SID} ->
-                         start_connection(self()), {false, {Pid, Key, SID}}
-                   end,
-    Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
-    {ok, open_socket,
-     #state{use_v10 = UseV10, tls = TLS,
-           tls_required = TLSRequired, tls_certverify = TLSCertverify,
-           tls_options = TLSOpts, queue = queue:new(), myname = From,
-           server = Server, new = New, verify = Verify, timer = Timer}}.
-
-open_socket(init, StateData) ->
-    log_s2s_out(StateData#state.new, StateData#state.myname,
-               StateData#state.server, StateData#state.tls),
-    ?DEBUG("open_socket: ~p",
-          [{StateData#state.myname, StateData#state.server,
-            StateData#state.new, StateData#state.verify}]),
-    AddrList = case
-                ejabberd_idna:domain_utf8_to_ascii(StateData#state.server)
-                  of
-                false -> [];
-                ASCIIAddr -> get_addr_port(ASCIIAddr)
+%%%===================================================================
+start(From, To, Opts) ->
+    xmpp_stream_out:start(?MODULE, [ejabberd_socket, From, To, Opts],
+                         ejabberd_config:fsm_limit_opts([])).
+
+start_link(From, To, Opts) ->
+    xmpp_stream_out:start_link(?MODULE, [ejabberd_socket, From, To, Opts],
+                              ejabberd_config:fsm_limit_opts([])).
+
+connect(Ref) ->
+    xmpp_stream_out:connect(Ref).
+
+close(Ref) ->
+    xmpp_stream_out:close(Ref).
+
+stop(Ref) ->
+    xmpp_stream_out:stop(Ref).
+
+-spec send(pid(), xmpp_element()) -> ok;
+         (state(), xmpp_element()) -> state().
+send(Stream, Pkt) ->
+    xmpp_stream_out:send(Stream, Pkt).
+
+-spec route(pid(), xmpp_element()) -> ok.
+route(Ref, Pkt) ->
+    Ref ! {route, Pkt}.
+
+-spec establish(state()) -> state().
+establish(State) ->
+    xmpp_stream_out:establish(State).
+
+-spec update_state(pid(), fun((state()) -> state()) |
+                  {module(), atom(), list()}) -> ok.
+update_state(Ref, Callback) ->
+    xmpp_stream_out:cast(Ref, {update_state, Callback}).
+
+-spec add_hooks() -> ok.
+add_hooks() ->
+    lists:foreach(
+      fun(Host) ->
+             ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE,
+                                process_auth_result, 100),
+             ejabberd_hooks:add(s2s_out_closed, Host, ?MODULE,
+                                process_closed, 100),
+             ejabberd_hooks:add(s2s_out_handle_info, Host, ?MODULE,
+                                handle_unexpected_info, 100),
+             ejabberd_hooks:add(s2s_out_handle_cast, Host, ?MODULE,
+                                handle_unexpected_cast, 100)
+      end, ?MYHOSTS).
+
+%%%===================================================================
+%%% Hooks
+%%%===================================================================
+process_auth_result(#{server := LServer, remote_server := RServer} = State,
+                   false) ->
+    Delay = get_delay(),
+    ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: authentication failed;"
+             " bouncing for ~p seconds",
+             [LServer, RServer, Delay]),
+    State1 = close(State),
+    State2 = bounce_queue(State1),
+    xmpp_stream_out:set_timeout(State2, timer:seconds(Delay));
+process_auth_result(State, true) ->
+    State.
+
+process_closed(#{server := LServer, remote_server := RServer} = State,
+              _Reason) ->
+    Delay = get_delay(),
+    ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s; "
+             "bouncing for ~p seconds",
+             [LServer, RServer,
+              try maps:get(stop_reason, State) of
+                  {error, Why} -> xmpp_stream_out:format_error(Why)
+              catch _:undef -> <<"unexplained reason">>
               end,
-    case lists:foldl(fun ({Addr, Port}, Acc) ->
-                            case Acc of
-                              {ok, Socket} -> {ok, Socket};
-                              _ -> open_socket1(Addr, Port)
-                            end
-                    end,
-                    ?SOCKET_DEFAULT_RESULT, AddrList)
-       of
-      {ok, Socket} ->
-         Version = if StateData#state.use_v10 -> {1,0};
-                      true -> undefined
-                   end,
-         NewStateData = StateData#state{socket = Socket,
-                                        tls_enabled = false,
-                                        streamid = new_id()},
-         send_header(NewStateData, Version),
-         {next_state, wait_for_stream, NewStateData,
-          ?FSMTIMEOUT};
-      {error, Reason} ->
-         ?INFO_MSG("s2s connection: ~s -> ~s (remote server "
-                   "not found: ~p)",
-                   [StateData#state.myname, StateData#state.server, Reason]),
-         case ejabberd_hooks:run_fold(find_s2s_bridge, undefined,
-                                      [StateData#state.myname,
-                                       StateData#state.server])
-             of
-           {Mod, Fun, Type} ->
-               ?INFO_MSG("found a bridge to ~s for: ~s -> ~s",
-                         [Type, StateData#state.myname,
-                          StateData#state.server]),
-               NewStateData = StateData#state{bridge = {Mod, Fun}},
-               {next_state, relay_to_bridge, NewStateData};
-           _ -> wait_before_reconnect(StateData)
-         end
-    end;
-open_socket(Event, StateData) ->
-    handle_unexpected_event(Event, open_socket, StateData).
-
-open_socket1({_, _, _, _} = Addr, Port) ->
-    open_socket2(inet, Addr, Port);
-%% IPv6
-open_socket1({_, _, _, _, _, _, _, _} = Addr, Port) ->
-    open_socket2(inet6, Addr, Port);
-%% Hostname
-open_socket1(Host, Port) ->
-    lists:foldl(fun (_Family, {ok, _Socket} = R) -> R;
-                   (Family, _) ->
-                       Addrs = get_addrs(Host, Family),
-                       lists:foldl(fun (_Addr, {ok, _Socket} = R) -> R;
-                                       (Addr, _) -> open_socket1(Addr, Port)
-                                   end,
-                                   ?SOCKET_DEFAULT_RESULT, Addrs)
-               end,
-               ?SOCKET_DEFAULT_RESULT, outgoing_s2s_families()).
-
-open_socket2(Type, Addr, Port) ->
-    ?DEBUG("s2s_out: connecting to ~p:~p~n", [Addr, Port]),
-    Timeout = outgoing_s2s_timeout(),
-    case catch ejabberd_socket:connect(Addr, Port,
-                                      [binary, {packet, 0},
-                                       {send_timeout, ?TCP_SEND_TIMEOUT},
-                                       {send_timeout_close, true},
-                                       {active, false}, Type],
-                                      Timeout)
-       of
-      {ok, _Socket} = R -> R;
-      {error, Reason} = R ->
-         ?DEBUG("s2s_out: connect return ~p~n", [Reason]), R;
-      {'EXIT', Reason} ->
-         ?DEBUG("s2s_out: connect crashed ~p~n", [Reason]),
-         {error, Reason}
+              Delay]),
+    State1 = bounce_queue(State),
+    xmpp_stream_out:set_timeout(State1, timer:seconds(Delay)).
+
+handle_unexpected_info(State, Info) ->
+    ?WARNING_MSG("got unexpected info: ~p", [Info]),
+    State.
+
+handle_unexpected_cast(State, Msg) ->
+    ?WARNING_MSG("got unexpected cast: ~p", [Msg]),
+    State.
+
+%%%===================================================================
+%%% gen_server callbacks
+%%%===================================================================
+tls_options(#{server := LServer}) ->
+    ejabberd_s2s:tls_options(LServer, []).
+
+tls_required(#{server := LServer}) ->
+    ejabberd_s2s:tls_required(LServer).
+
+tls_verify(#{server := LServer}) ->
+    ejabberd_s2s:tls_verify(LServer).
+
+tls_enabled(#{server := LServer}) ->
+    ejabberd_s2s:tls_enabled(LServer).
+
+handle_auth_success(Mech, #{socket := Socket, ip := IP,
+                           remote_server := RServer,
+                           server := LServer} = State) ->
+    ?INFO_MSG("(~s) Accepted outbound s2s ~s authentication ~s -> ~s (~s)",
+             [ejabberd_socket:pp(Socket), Mech, LServer, RServer,
+              ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+    ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State, [true]).
+
+handle_auth_failure(Mech, Reason,
+                   #{socket := Socket, ip := IP,
+                     remote_server := RServer,
+                     server := LServer} = State) ->
+    ?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s",
+             [ejabberd_socket:pp(Socket), Mech, LServer, RServer,
+              ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
+    State1 = State#{on_route => bounce,
+                   stop_reason => {error, {auth, Reason}}},
+    ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State1, [false]).
+
+handle_packet(Pkt, #{server := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]).
+
+handle_stream_end(Reason, #{server := LServer} = State) ->
+    State1 = State#{on_route => bounce, stop_reason => Reason},
+    ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [normal]).
+
+handle_stream_close(Reason, #{server := LServer} = State) ->
+    State1 = State#{on_route => bounce, stop_reason => Reason},
+    ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [Reason]).
+
+handle_stream_established(State) ->
+    State1 = State#{on_route => send},
+    State2 = resend_queue(State1),
+    set_idle_timeout(State2).
+
+handle_cdata(Data, #{server := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_out_handle_cdata, LServer, State, [Data]).
+
+handle_recv(El, Pkt, #{server := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_out_handle_recv, LServer, State, [El, Pkt]).
+
+handle_send(Pkt, El, Data, #{server := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_out_handle_send, LServer,
+                           State, [Pkt, El, Data]).
+
+handle_timeout(#{server := LServer, remote_server := RServer,
+                on_route := Action} = State) ->
+    case Action of
+       bounce -> stop(State);
+       queue -> send(State, xmpp:serr_connection_timeout());
+       send ->
+           ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: inactive",
+                     [LServer, RServer]),
+           stop(State)
     end.
 
-%%----------------------------------------------------------------------
-
-wait_for_stream({xmlstreamstart, Name, Attrs}, StateData0) ->
-    {CertCheckRes, CertCheckMsg, StateData} =
-       if StateData0#state.tls_certverify, StateData0#state.tls_enabled ->
-               {Res, Msg} =
-                   ejabberd_s2s:check_peer_certificate(ejabberd_socket,
-                                                       StateData0#state.socket,
-                                                       StateData0#state.server),
-               ?DEBUG("Certificate verification result for ~s: ~s",
-                      [StateData0#state.server, Msg]),
-               {Res, Msg, StateData0#state{tls_certverify = false}};
-          true ->
-               {no_verify, <<"Not verified">>, StateData0}
-       end,
-    try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
-       _ when CertCheckRes == error ->
-           send_element(StateData,
-                        xmpp:serr_policy_violation(CertCheckMsg, ?MYLANG)),
-           ?INFO_MSG("Closing s2s connection: ~s -> ~s (~s)",
-                     [StateData#state.myname, StateData#state.server,
-                      CertCheckMsg]),
-           {stop, normal, StateData};
-       #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM}
-         when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM ->
-           send_element(StateData, xmpp:serr_invalid_namespace()),
-           {stop, normal, StateData};
-       #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID,
-                     version = V} when V /= {1,0} ->
-           send_db_request(StateData#state{remote_streamid = ID});
-       #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID}
-         when StateData#state.use_v10 ->
-           {next_state, wait_for_features,
-            StateData#state{remote_streamid = ID}, ?FSMTIMEOUT};
-       #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID}
-         when not StateData#state.use_v10 ->
-           %% Handle Tigase's workaround for an old ejabberd bug:
-           send_db_request(StateData#state{remote_streamid = ID});
-       #stream_start{id = ID} when StateData#state.use_v10 ->
-           {next_state, wait_for_features,
-            StateData#state{db_enabled = false, remote_streamid = ID},
-            ?FSMTIMEOUT};
-       #stream_start{} ->
-           send_element(StateData, xmpp:serr_invalid_namespace()),
-           {stop, normal, StateData};
-       _ ->
-           send_element(StateData, xmpp:serr_invalid_xml()),
-            {stop, normal, StateData}
-    catch _:{xmpp_codec, Why} ->
-           Txt = xmpp:format_error(Why),
-           send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)),
-           {stop, normal, StateData}
-    end;
-wait_for_stream(Event, StateData) ->
-    handle_unexpected_event(Event, wait_for_stream, StateData).
-
-wait_for_validation({xmlstreamelement, El}, StateData) ->
-    decode_element(El, wait_for_validation, StateData);
-wait_for_validation(#db_result{to = To, from = From, type = Type}, StateData) ->
-    ?DEBUG("recv result: ~p", [{From, To, Type}]),
-    case {Type, StateData#state.tls_enabled, StateData#state.tls_required} of
-       {valid, Enabled, Required} when (Enabled == true) or (Required == false) ->
-           send_queue(StateData, StateData#state.queue),
-           ?INFO_MSG("Connection established: ~s -> ~s with "
-                     "TLS=~p",
-                     [StateData#state.myname, StateData#state.server,
-                      StateData#state.tls_enabled]),
-           ejabberd_hooks:run(s2s_connect_hook,
-                              [StateData#state.myname,
-                               StateData#state.server]),
-           {next_state, stream_established, StateData#state{queue = queue:new()}};
-       {valid, Enabled, Required} when (Enabled == false) and (Required == true) ->
-           ?INFO_MSG("Closing s2s connection: ~s -> ~s (TLS "
-                     "is required but unavailable)",
-                     [StateData#state.myname, StateData#state.server]),
-           {stop, normal, StateData};
-       _ ->
-           ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid "
-                     "dialback key result)",
-                     [StateData#state.myname, StateData#state.server]),
-           {stop, normal, StateData}
-    end;
-wait_for_validation(#db_verify{to = To, from = From, id = Id, type = Type},
-                   StateData) ->
-    ?DEBUG("recv verify: ~p", [{From, To, Id, Type}]),
-    case StateData#state.verify of
-       false ->
-           NextState = wait_for_validation,
-           {next_state, NextState, StateData, get_timeout_interval(NextState)};
-       {Pid, _Key, _SID} ->
-           case Type of
-               valid ->
-                   p1_fsm:send_event(Pid,
-                                     {valid, StateData#state.server,
-                                      StateData#state.myname});
-               _ ->
-                   p1_fsm:send_event(Pid,
-                                     {invalid, StateData#state.server,
-                                      StateData#state.myname})
-           end,
-           if StateData#state.verify == false ->
-                   {stop, normal, StateData};
-              true ->
-                   NextState = wait_for_validation,
-                   {next_state, NextState, StateData, get_timeout_interval(NextState)}
-           end
-    end;
-wait_for_validation(timeout,
-                   #state{verify = {VPid, VKey, SID}} = StateData)
-  when is_pid(VPid) and is_binary(VKey) and is_binary(SID) ->
-    ?DEBUG("wait_for_validation: ~s -> ~s (timeout in verify connection)",
-          [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-wait_for_validation(Event, StateData) ->
-    handle_unexpected_event(Event, wait_for_validation, StateData).
-
-wait_for_features({xmlstreamelement, El}, StateData) ->
-    decode_element(El, wait_for_features, StateData);
-wait_for_features(#stream_features{sub_els = Els}, StateData) ->
-    {SASLEXT, StartTLS, StartTLSRequired} =
-       lists:foldl(
-         fun(#sasl_mechanisms{list = Mechs}, {_, STLS, STLSReq}) ->
-                 {lists:member(<<"EXTERNAL">>, Mechs), STLS, STLSReq};
-            (#starttls{required = Required}, {SEXT, _, _}) ->
-                 {SEXT, true, Required};
-            (_, Acc) ->
-                 Acc
-         end, {false, false, false}, Els),
-    if not SASLEXT and not StartTLS and StateData#state.authenticated ->
-           send_queue(StateData, StateData#state.queue),
-           ?INFO_MSG("Connection established: ~s -> ~s with "
-                     "SASL EXTERNAL and TLS=~p",
-                     [StateData#state.myname, StateData#state.server,
-                      StateData#state.tls_enabled]),
-           ejabberd_hooks:run(s2s_connect_hook,
-                              [StateData#state.myname,
-                               StateData#state.server]),
-           {next_state, stream_established,
-            StateData#state{queue = queue:new()}};
-       SASLEXT and StateData#state.try_auth and
-       (StateData#state.new /= false) and
-       (StateData#state.tls_enabled or
-       not StateData#state.tls_required) ->
-           send_element(StateData,
-                        #sasl_auth{mechanism = <<"EXTERNAL">>,
-                                   text = StateData#state.myname}),
-           {next_state, wait_for_auth_result,
-            StateData#state{try_auth = false}, ?FSMTIMEOUT};
-       StartTLS and StateData#state.tls and
-       not StateData#state.tls_enabled ->
-           send_element(StateData, #starttls{}),
-           {next_state, wait_for_starttls_proceed, StateData, ?FSMTIMEOUT};
-       StartTLSRequired and not StateData#state.tls ->
-           ?DEBUG("restarted: ~p",
-                  [{StateData#state.myname, StateData#state.server}]),
-           ejabberd_socket:close(StateData#state.socket),
-           {next_state, reopen_socket,
-            StateData#state{socket = undefined, use_v10 = false},
-            ?FSMTIMEOUT};
-       StateData#state.db_enabled ->
-           send_db_request(StateData);
-       true ->
-           ?DEBUG("restarted: ~p",
-                  [{StateData#state.myname, StateData#state.server}]),
-           ejabberd_socket:close(StateData#state.socket),
-           {next_state, reopen_socket,
-            StateData#state{socket = undefined, use_v10 = false}, ?FSMTIMEOUT}
+init([#{server := LServer, remote_server := RServer} = State, Opts]) ->
+    State1 = State#{on_route => queue,
+                   queue => queue:new(),
+                   xmlns => ?NS_SERVER,
+                   lang => ?MYLANG,
+                   shaper => none},
+    ?INFO_MSG("Outbound s2s connection started: ~s -> ~s",
+             [LServer, RServer]),
+    ejabberd_hooks:run_fold(s2s_out_init, LServer, {ok, State1}, [Opts]).
+
+handle_call(Request, From, #{server := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_out_handle_call, LServer, State, [Request, From]).
+
+handle_cast({update_state, Fun}, State) ->
+    case Fun of
+       {M, F, A} -> erlang:apply(M, F, [State|A]);
+       _ when is_function(Fun) -> Fun(State)
     end;
-wait_for_features(Event, StateData) ->
-    handle_unexpected_event(Event, wait_for_features, StateData).
-
-wait_for_auth_result({xmlstreamelement, El}, StateData) ->
-    decode_element(El, wait_for_auth_result, StateData);
-wait_for_auth_result(#sasl_success{}, StateData) ->
-    ?DEBUG("auth: ~p", [{StateData#state.myname, StateData#state.server}]),
-    ejabberd_socket:reset_stream(StateData#state.socket),
-    send_header(StateData, {1,0}),
-    {next_state, wait_for_stream,
-     StateData#state{streamid = new_id(), authenticated = true},
-     ?FSMTIMEOUT};
-wait_for_auth_result(#sasl_failure{}, StateData) ->
-    ?DEBUG("restarted: ~p", [{StateData#state.myname, StateData#state.server}]),
-    ejabberd_socket:close(StateData#state.socket),
-    {next_state, reopen_socket,
-     StateData#state{socket = undefined}, ?FSMTIMEOUT};
-wait_for_auth_result(Event, StateData) ->
-    handle_unexpected_event(Event, wait_for_auth_result, StateData).
-
-wait_for_starttls_proceed({xmlstreamelement, El}, StateData) ->
-    decode_element(El, wait_for_starttls_proceed, StateData);
-wait_for_starttls_proceed(#starttls_proceed{}, StateData) ->
-    ?DEBUG("starttls: ~p", [{StateData#state.myname, StateData#state.server}]),
-    Socket = StateData#state.socket,
-    TLSOpts = case ejabberd_config:get_option(
-                    {domain_certfile, StateData#state.myname},
-                    fun iolist_to_binary/1) of
-                 undefined -> StateData#state.tls_options;
-                 CertFile ->
-                     [{certfile, CertFile}
-                      | lists:keydelete(certfile, 1,
-                                        StateData#state.tls_options)]
-             end,
-    TLSSocket = ejabberd_socket:starttls(Socket, TLSOpts),
-    NewStateData = StateData#state{socket = TLSSocket,
-                                  streamid = new_id(),
-                                  tls_enabled = true,
-                                  tls_options = TLSOpts},
-    send_header(NewStateData, {1,0}),
-    {next_state, wait_for_stream, NewStateData, ?FSMTIMEOUT};
-wait_for_starttls_proceed(Event, StateData) ->
-    handle_unexpected_event(Event, wait_for_starttls_proceed, StateData).
-
-reopen_socket({xmlstreamelement, _El}, StateData) ->
-    {next_state, reopen_socket, StateData, ?FSMTIMEOUT};
-reopen_socket({xmlstreamend, _Name}, StateData) ->
-    {next_state, reopen_socket, StateData, ?FSMTIMEOUT};
-reopen_socket({xmlstreamerror, _}, StateData) ->
-    {next_state, reopen_socket, StateData, ?FSMTIMEOUT};
-reopen_socket(timeout, StateData) ->
-    ?INFO_MSG("reopen socket: timeout", []),
-    {stop, normal, StateData};
-reopen_socket(closed, StateData) ->
-    p1_fsm:send_event(self(), init),
-    {next_state, open_socket, StateData, ?FSMTIMEOUT}.
-
-%% This state is use to avoid reconnecting to often to bad sockets
-wait_before_retry(_Event, StateData) ->
-    {next_state, wait_before_retry, StateData, ?FSMTIMEOUT}.
-
-relay_to_bridge(stop, StateData) ->
-    wait_before_reconnect(StateData);
-relay_to_bridge(closed, StateData) ->
-    ?INFO_MSG("relay to bridge: ~s -> ~s (closed)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-relay_to_bridge(_Event, StateData) ->
-    {next_state, relay_to_bridge, StateData}.
-
-stream_established({xmlstreamelement, El}, StateData) ->
-    decode_element(El, stream_established, StateData);
-stream_established(#db_verify{to = VTo, from = VFrom, id = VId, type = VType},
-                  StateData) ->
-    ?DEBUG("recv verify: ~p", [{VFrom, VTo, VId, VType}]),
-    case StateData#state.verify of
-       {VPid, _VKey, _SID} ->
-           case VType of
-               valid ->
-                   p1_fsm:send_event(VPid,
-                                     {valid, StateData#state.server,
-                                      StateData#state.myname});
-               _ ->
-                   p1_fsm:send_event(VPid,
-                                     {invalid, StateData#state.server,
-                                      StateData#state.myname})
-           end;
-       _ -> ok
-    end,
-    {next_state, stream_established, StateData};
-stream_established(Event, StateData) ->
-    handle_unexpected_event(Event, stream_established, StateData).
-
--spec handle_unexpected_event(term(), state_name(), state()) -> fsm_transition().
-handle_unexpected_event(Event, StateName, StateData) ->
-    case Event of
-       {xmlstreamerror, _} ->
-           send_element(StateData, xmpp:serr_not_well_formed()),
-           ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
-                     "got invalid XML from peer",
-                     [StateData#state.myname, StateData#state.server,
-                      StateName]),
-           {stop, normal, StateData};
-       {xmlstreamend, _} ->
-           ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
-                     "XML stream closed by peer",
-                     [StateData#state.myname, StateData#state.server,
-                      StateName]),
-           {stop, normal, StateData};
-       timeout ->
-           send_element(StateData, xmpp:serr_connection_timeout()),
-           ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
-                     "timed out during establishing an XML stream",
-                     [StateData#state.myname, StateData#state.server,
-                      StateName]),
-           {stop, normal, StateData};
-       closed ->
-           ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
-                     "connection socket closed",
-                     [StateData#state.myname, StateData#state.server,
-                      StateName]),
-           {stop, normal, StateData};
-       Pkt when StateName == wait_for_stream;
-                StateName == wait_for_features;
-                StateName == wait_for_auth_result;
-                StateName == wait_for_starttls_proceed ->
-           send_element(StateData, xmpp:serr_bad_format()),
-           ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
-                     "got unexpected event ~p",
-                     [StateData#state.myname, StateData#state.server,
-                      StateName, Pkt]),
-           {stop, normal, StateData};
-       _ ->
-           {next_state, StateName, StateData, get_timeout_interval(StateName)}
-    end.
-
-%%----------------------------------------------------------------------
-%% Func: StateName/3
-%% Returns: {next_state, NextStateName, NextStateData}            |
-%%          {next_state, NextStateName, NextStateData, Timeout}   |
-%%          {reply, Reply, NextStateName, NextStateData}          |
-%%          {reply, Reply, NextStateName, NextStateData, Timeout} |
-%%          {stop, Reason, NewStateData}                          |
-%%          {stop, Reason, Reply, NewStateData}
-%%----------------------------------------------------------------------
-%%state_name(Event, From, StateData) ->
-%%    Reply = ok,
-%%    {reply, Reply, state_name, StateData}.
-
-handle_event(_Event, StateName, StateData) ->
-    {next_state, StateName, StateData,
-     get_timeout_interval(StateName)}.
-
-handle_sync_event(get_state_infos, _From, StateName,
-                 StateData) ->
-    {Addr, Port} = try
-                    ejabberd_socket:peername(StateData#state.socket)
-                  of
-                    {ok, {A, P}} -> {A, P};
-                    {error, _} -> {unknown, unknown}
-                  catch
-                    _:_ -> {unknown, unknown}
-                  end,
-    Infos = [{direction, out}, {statename, StateName},
-            {addr, Addr}, {port, Port},
-            {streamid, StateData#state.streamid},
-            {use_v10, StateData#state.use_v10},
-            {tls, StateData#state.tls},
-            {tls_required, StateData#state.tls_required},
-            {tls_enabled, StateData#state.tls_enabled},
-            {tls_options, StateData#state.tls_options},
-            {authenticated, StateData#state.authenticated},
-            {db_enabled, StateData#state.db_enabled},
-            {try_auth, StateData#state.try_auth},
-            {myname, StateData#state.myname},
-            {server, StateData#state.server},
-            {delay_to_retry, StateData#state.delay_to_retry},
-            {verify, StateData#state.verify}],
-    Reply = {state_infos, Infos},
-    {reply, Reply, StateName, StateData};
-%%----------------------------------------------------------------------
-%% Func: handle_sync_event/4
-%% Returns: {next_state, NextStateName, NextStateData}            |
-%%          {next_state, NextStateName, NextStateData, Timeout}   |
-%%          {reply, Reply, NextStateName, NextStateData}          |
-%%          {reply, Reply, NextStateName, NextStateData, Timeout} |
-%%          {stop, Reason, NewStateData}                          |
-%%          {stop, Reason, Reply, NewStateData}
-%%----------------------------------------------------------------------
-handle_sync_event(_Event, _From, StateName,
-                 StateData) ->
-    Reply = ok,
-    {reply, Reply, StateName, StateData,
-     get_timeout_interval(StateName)}.
-
-code_change(_OldVsn, StateName, StateData, _Extra) ->
-    {ok, StateName, StateData}.
-
-handle_info({send_text, Text}, StateName, StateData) ->
-    send_text(StateData, Text),
-    cancel_timer(StateData#state.timer),
-    Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
-    {next_state, StateName, StateData#state{timer = Timer},
-     get_timeout_interval(StateName)};
-handle_info({send_element, El}, StateName, StateData) ->
-    case StateName of
-      stream_established ->
-         cancel_timer(StateData#state.timer),
-         Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
-         send_element(StateData, El),
-         {next_state, StateName, StateData#state{timer = Timer}};
-      %% In this state we bounce all message: We are waiting before
-      %% trying to reconnect
-      wait_before_retry ->
-         bounce_element(El, xmpp:err_remote_server_not_found()),
-         {next_state, StateName, StateData};
-      relay_to_bridge ->
-         {Mod, Fun} = StateData#state.bridge,
-         ?DEBUG("relaying stanza via ~p:~p/1", [Mod, Fun]),
-         case catch Mod:Fun(El) of
-           {'EXIT', Reason} ->
-               ?ERROR_MSG("Error while relaying to bridge: ~p",
-                          [Reason]),
-               bounce_element(El, xmpp:err_internal_server_error()),
-               wait_before_reconnect(StateData);
-           _ -> {next_state, StateName, StateData}
-         end;
-      _ ->
-         Q = queue:in(El, StateData#state.queue),
-         {next_state, StateName, StateData#state{queue = Q},
-          get_timeout_interval(StateName)}
+handle_cast(Msg, #{server := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_out_handle_cast, LServer, State, [Msg]).
+
+handle_info({route, Pkt}, #{queue := Q, on_route := Action} = State) ->
+    case Action of
+       queue -> State#{queue => queue:in(Pkt, Q)};
+       bounce -> bounce_packet(Pkt, State);
+       send -> set_idle_timeout(send(State, Pkt))
     end;
-handle_info({timeout, Timer, _}, wait_before_retry,
-           #state{timer = Timer} = StateData) ->
-    ?INFO_MSG("Reconnect delay expired: Will now retry "
-             "to connect to ~s when needed.",
-             [StateData#state.server]),
-    {stop, normal, StateData};
-handle_info({timeout, Timer, _}, _StateName,
-           #state{timer = Timer} = StateData) ->
-    ?INFO_MSG("Closing connection with ~s: timeout",
-             [StateData#state.server]),
-    {stop, normal, StateData};
-handle_info(terminate_if_waiting_before_retry,
-           wait_before_retry, StateData) ->
-    {stop, normal, StateData};
-handle_info(terminate_if_waiting_before_retry,
-           StateName, StateData) ->
-    {next_state, StateName, StateData,
-     get_timeout_interval(StateName)};
-handle_info(_, StateName, StateData) ->
-    {next_state, StateName, StateData,
-     get_timeout_interval(StateName)}.
-
-terminate(Reason, StateName, StateData) ->
-    ?DEBUG("terminated: ~p", [{Reason, StateName}]),
-    case StateData#state.new of
-      false -> ok;
-      true ->
-         ejabberd_s2s:remove_connection({StateData#state.myname,
-                                         StateData#state.server},
-                                        self())
-    end,
-    bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()),
-    bounce_messages(xmpp:err_remote_server_not_found()),
-    case StateData#state.socket of
-      undefined -> ok;
-      _Socket ->
-           catch send_trailer(StateData),
-           ejabberd_socket:close(StateData#state.socket)
-    end,
-    ok.
-
-print_state(State) -> State.
-
-%%%----------------------------------------------------------------------
+handle_info(Info, #{server := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_out_handle_info, LServer, State, [Info]).
+
+terminate(Reason, #{server := LServer,
+                   remote_server := RServer} = State) ->
+    ejabberd_s2s:remove_connection({LServer, RServer}, self()),
+    State1 = case Reason of
+                normal -> State;
+                _ -> State#{stop_reason => {error, internal_failure}}
+            end,
+    bounce_queue(State1),
+    bounce_message_queue(State1).
+
+code_change(_OldVsn, State, _Extra) ->
+    {ok, State}.
+
+%%%===================================================================
 %%% Internal functions
-%%%----------------------------------------------------------------------
-
--spec send_text(state(), iodata()) -> ok.
-send_text(StateData, Text) ->
-    ?DEBUG("Send Text on stream = ~s", [Text]),
-    ejabberd_socket:send(StateData#state.socket, Text).
-
--spec send_element(state(), xmpp_element()) -> ok.
-send_element(StateData, El) ->
-    El1 = xmpp:encode(El, ?NS_SERVER),
-    send_text(StateData, fxml:element_to_binary(El1)).
-
--spec send_header(state(), undefined | {integer(), integer()}) -> ok.
-send_header(StateData, Version) ->
-    Header = xmpp:encode(
-              #stream_start{xmlns = ?NS_SERVER,
-                            stream_xmlns = ?NS_STREAM,
-                            db_xmlns = ?NS_SERVER_DIALBACK,
-                            from = jid:make(StateData#state.myname),
-                            to = jid:make(StateData#state.server),
-                            version = Version}),
-    send_text(StateData, fxml:element_to_header(Header)).
-
--spec send_trailer(state()) -> ok.
-send_trailer(StateData) ->
-    send_text(StateData, <<"</stream:stream>">>).
-
--spec send_queue(state(), queue:queue()) -> ok.
-send_queue(StateData, Q) ->
-    case queue:out(Q) of
-      {{value, El}, Q1} ->
-         send_element(StateData, El), send_queue(StateData, Q1);
-      {empty, _Q1} -> ok
-    end.
-
-%% Bounce a single message (xmlelement)
--spec bounce_element(stanza(), stanza_error()) -> ok.
-bounce_element(El, Error) ->
-    From = xmpp:get_from(El),
-    To = xmpp:get_to(El),
-    ejabberd_router:route_error(To, From, El, Error).
-
--spec bounce_queue(queue:queue(), stanza_error()) -> ok.
-bounce_queue(Q, Error) ->
-    case queue:out(Q) of
-      {{value, El}, Q1} ->
-         bounce_element(El, Error), bounce_queue(Q1, Error);
-      {empty, _} -> ok
-    end.
-
--spec new_id() -> binary().
-new_id() -> randoms:get_string().
-
--spec cancel_timer(reference()) -> ok.
-cancel_timer(Timer) ->
-    erlang:cancel_timer(Timer),
-    receive {timeout, Timer, _} -> ok after 0 -> ok end.
-
--spec bounce_messages(stanza_error()) -> ok.
-bounce_messages(Error) ->
+%%%===================================================================
+-spec resend_queue(state()) -> state().
+resend_queue(#{queue := Q} = State) ->
+    State1 = State#{queue => queue:new()},
+    jlib:queue_foldl(
+      fun(Pkt, AccState) ->
+             send(AccState, Pkt)
+      end, State1, Q).
+
+-spec bounce_queue(state()) -> state().
+bounce_queue(#{queue := Q} = State) ->
+    State1 = State#{queue => queue:new()},
+    jlib:queue_foldl(
+      fun(Pkt, AccState) ->
+             bounce_packet(Pkt, AccState)
+      end, State1, Q).
+
+-spec bounce_message_queue(state()) -> state().
+bounce_message_queue(State) ->
     receive
-      {send_element, El} ->
-         bounce_element(El, Error), bounce_messages(Error)
-      after 0 -> ok
-    end.
-
--spec send_db_request(state()) -> fsm_transition().
-send_db_request(StateData) ->
-    Server = StateData#state.server,
-    New = case StateData#state.new of
-             false ->
-                 ejabberd_s2s:try_register({StateData#state.myname, Server});
-             true ->
-                 true
-         end,
-    NewStateData = StateData#state{new = New},
-    try case New of
-           false -> ok;
-           true ->
-             Key1 = ejabberd_s2s:make_key(
-                      {StateData#state.myname, Server},
-                      StateData#state.remote_streamid),
-             send_element(StateData,
-                          #db_result{from = StateData#state.myname,
-                                     to = Server,
-                                     key = Key1})
-       end,
-       case StateData#state.verify of
-         false -> ok;
-         {_Pid, Key2, SID} ->
-             send_element(StateData,
-                          #db_verify{from = StateData#state.myname,
-                                     to = StateData#state.server,
-                                     id = SID,
-                                     key = Key2})
-       end,
-       {next_state, wait_for_validation, NewStateData,
-        (?FSMTIMEOUT) * 6}
-    catch
-      _:_ -> {stop, normal, NewStateData}
+       {route, Pkt} ->
+           State1 = bounce_packet(Pkt, State),
+           bounce_message_queue(State1)
+    after 0 ->
+           State
     end.
 
-%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-%% SRV support
-
--include_lib("kernel/include/inet.hrl").
-
--spec get_addr_port(binary()) -> [{binary(), inet:port_number()}].
-get_addr_port(Server) ->
-    Res = srv_lookup(Server),
-    case Res of
-      {error, Reason} ->
-         ?DEBUG("srv lookup of '~s' failed: ~p~n",
-                [Server, Reason]),
-         [{Server, outgoing_s2s_port()}];
-      {ok, HEnt} ->
-         ?DEBUG("srv lookup of '~s': ~p~n",
-                [Server, HEnt#hostent.h_addr_list]),
-         AddrList = HEnt#hostent.h_addr_list,
-         case catch lists:map(fun ({Priority, Weight, Port,
-                                    Host}) ->
-                                      N = case Weight of
-                                            0 -> 0;
-                                            _ ->
-                                                (Weight + 1) * randoms:uniform()
-                                          end,
-                                      {Priority * 65536 - N, Host, Port}
-                              end,
-                              AddrList)
-             of
-           SortedList = [_ | _] ->
-               List = lists:map(fun ({_, Host, Port}) ->
-                                         {list_to_binary(Host), Port}
-                                end,
-                                lists:keysort(1, SortedList)),
-               ?DEBUG("srv lookup of '~s': ~p~n", [Server, List]),
-               List;
-           _ -> [{Server, outgoing_s2s_port()}]
-         end
-    end.
-
-srv_lookup(Server) ->
-    TimeoutMs = timer:seconds(
-                  ejabberd_config:get_option(
-                    s2s_dns_timeout,
-                    fun(I) when is_integer(I), I>=0 -> I end,
-                    10)),
-    Retries = ejabberd_config:get_option(
-                s2s_dns_retries,
-                fun(I) when is_integer(I), I>=0 -> I end,
-                2),
-    srv_lookup(binary_to_list(Server), TimeoutMs, Retries).
-
-%% XXX - this behaviour is suboptimal in the case that the domain
-%% has a "_xmpp-server._tcp." but not a "_jabber._tcp." record and
-%% we don't get a DNS reply for the "_xmpp-server._tcp." lookup. In this
-%% case we'll give up when we get the "_jabber._tcp." nxdomain reply.
-srv_lookup(_Server, _Timeout, Retries)
-    when Retries < 1 ->
-    {error, timeout};
-srv_lookup(Server, Timeout, Retries) ->
-    case inet_res:getbyname("_xmpp-server._tcp." ++ Server,
-                           srv, Timeout)
-       of
-      {error, _Reason} ->
-         case inet_res:getbyname("_jabber._tcp." ++ Server, srv,
-                                 Timeout)
-             of
-           {error, timeout} ->
-               ?ERROR_MSG("The DNS servers~n  ~p~ntimed out on "
-                          "request for ~p IN SRV. You should check "
-                          "your DNS configuration.",
-                          [inet_db:res_option(nameserver), Server]),
-               srv_lookup(Server, Timeout, Retries - 1);
-           R -> R
-         end;
-      {ok, _HEnt} = R -> R
-    end.
-
-test_get_addr_port(Server) ->
-    lists:foldl(fun (_, Acc) ->
-                       [HostPort | _] = get_addr_port(Server),
-                       case lists:keysearch(HostPort, 1, Acc) of
-                         false -> [{HostPort, 1} | Acc];
-                         {value, {_, Num}} ->
-                             lists:keyreplace(HostPort, 1, Acc,
-                                              {HostPort, Num + 1})
-                       end
-               end,
-               [], lists:seq(1, 100000)).
-
-get_addrs(Host, Family) ->
-    Type = case Family of
-            inet4 -> inet;
-            ipv4 -> inet;
-            inet6 -> inet6;
-            ipv6 -> inet6
-          end,
-    case inet:gethostbyname(binary_to_list(Host), Type) of
-      {ok, #hostent{h_addr_list = Addrs}} ->
-         ?DEBUG("~s of ~s resolved to: ~p~n",
-                [Type, Host, Addrs]),
-         Addrs;
-      {error, Reason} ->
-         ?DEBUG("~s lookup of '~s' failed: ~p~n",
-                [Type, Host, Reason]),
-         []
+-spec bounce_packet(xmpp_element(), state()) -> state().
+bounce_packet(Pkt, State) when ?is_stanza(Pkt) ->
+    From = xmpp:get_from(Pkt),
+    To = xmpp:get_to(Pkt),
+    Lang = xmpp:get_lang(Pkt),
+    Err = mk_bounce_error(Lang, State),
+    ejabberd_router:route_error(To, From, Pkt, Err),
+    State;
+bounce_packet(_, State) ->
+    State.
+
+-spec mk_bounce_error(binary(), state()) -> stanza_error().
+mk_bounce_error(Lang, State) ->
+    try maps:get(stop_reason, State) of
+       {error, internal_failure} ->
+           xmpp:err_internal_server_error();
+       {error, Why} ->
+           Reason = xmpp_stream_out:format_error(Why),
+           case Why of
+               {dns, _} ->
+                   xmpp:err_remote_server_timeout(Reason, Lang);
+               _ ->
+                   xmpp:err_remote_server_not_found(Reason, Lang)
+           end
+    catch _:{badkey, _} ->
+           xmpp:err_remote_server_not_found()
     end.
 
--spec outgoing_s2s_port() -> pos_integer().
-outgoing_s2s_port() ->
-    ejabberd_config:get_option(
-      outgoing_s2s_port,
-      fun(I) when is_integer(I), I > 0, I =< 65536 -> I end,
-      5269).
-
--spec outgoing_s2s_families() -> [ipv4 | ipv6].
-outgoing_s2s_families() ->
-    ejabberd_config:get_option(
-      outgoing_s2s_families,
-      fun(Families) ->
-              true = lists:all(
-                       fun(ipv4) -> true;
-                          (ipv6) -> true
-                       end, Families),
-              Families
-      end, [ipv4, ipv6]).
-
--spec outgoing_s2s_timeout() -> pos_integer().
-outgoing_s2s_timeout() ->
-    ejabberd_config:get_option(
-      outgoing_s2s_timeout,
-      fun(TimeOut) when is_integer(TimeOut), TimeOut > 0 ->
-              TimeOut;
-         (infinity) ->
-              infinity
-      end, 10000).
+-spec get_delay() -> non_neg_integer().
+get_delay() ->
+    MaxDelay = ejabberd_config:get_option(
+                s2s_max_retry_delay,
+                fun(I) when is_integer(I), I > 0 -> I end,
+                300),
+    crypto:rand_uniform(0, MaxDelay).
+
+-spec set_idle_timeout(state()) -> state().
+set_idle_timeout(#{on_route := send, server := LServer} = State) ->
+    Timeout = ejabberd_s2s:get_idle_timeout(LServer),
+    xmpp_stream_out:set_timeout(State, Timeout);
+set_idle_timeout(State) ->
+    State.
 
 transform_options(Opts) ->
     lists:foldl(fun transform_options/2, [], Opts).
@@ -998,100 +338,6 @@ transform_options({s2s_dns_options, S2SDNSOpts}, AllOpts) ->
 transform_options(Opt, Opts) ->
     [Opt|Opts].
 
-%% Human readable S2S logging: Log only new outgoing connections as INFO
-%% Do not log dialback
-log_s2s_out(false, _, _, _) -> ok;
-%% Log new outgoing connections:
-log_s2s_out(_, Myname, Server, Tls) ->
-    ?INFO_MSG("Trying to open s2s connection: ~s -> "
-             "~s with TLS=~p",
-             [Myname, Server, Tls]).
-
-%% Calculate timeout depending on which state we are in:
-%% Can return integer > 0 | infinity
--spec get_timeout_interval(state_name()) -> pos_integer() | infinity.
-get_timeout_interval(StateName) ->
-    case StateName of
-      %% Validation implies dialback: Networking can take longer:
-      wait_for_validation -> (?FSMTIMEOUT) * 6;
-      %% When stream is established, we only rely on S2S Timeout timer:
-      stream_established -> infinity;
-      relay_to_bridge -> infinity;
-      open_socket -> infinity;
-      _ -> ?FSMTIMEOUT
-    end.
-
-%% This function is intended to be called at the end of a state
-%% function that want to wait for a reconnect delay before stopping.
--spec wait_before_reconnect(state()) -> fsm_next().
-wait_before_reconnect(StateData) ->
-    bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()),
-    bounce_messages(xmpp:err_remote_server_not_found()),
-    cancel_timer(StateData#state.timer),
-    Delay = case StateData#state.delay_to_retry of
-             undefined_delay ->
-                 {_, _, MicroSecs} = p1_time_compat:timestamp(), MicroSecs rem 14000 + 1000;
-             D1 -> lists:min([D1 * 2, get_max_retry_delay()])
-           end,
-    Timer = erlang:start_timer(Delay, self(), []),
-    {next_state, wait_before_retry,
-     StateData#state{timer = Timer, delay_to_retry = Delay,
-                    queue = queue:new()}}.
-
--spec get_max_retry_delay() -> pos_integer().
-get_max_retry_delay() ->
-    case ejabberd_config:get_option(
-           s2s_max_retry_delay,
-           fun(I) when is_integer(I), I > 0 -> I end) of
-        undefined -> ?MAX_RETRY_DELAY;
-        Seconds -> Seconds * 1000
-    end.
-
-%% Terminate s2s_out connections that are in state wait_before_retry
--spec terminate_if_waiting_delay(binary(), binary()) -> ok.
-terminate_if_waiting_delay(From, To) ->
-    FromTo = {From, To},
-    Pids = ejabberd_s2s:get_connections_pids(FromTo),
-    lists:foreach(fun (Pid) ->
-                         Pid ! terminate_if_waiting_before_retry
-                 end,
-                 Pids).
-
--spec fsm_limit_opts() -> [{max_queue, pos_integer()}].
-fsm_limit_opts() ->
-    case ejabberd_config:get_option(
-           max_fsm_queue,
-           fun(I) when is_integer(I), I > 0 -> I end) of
-        undefined -> [];
-        N -> [{max_queue, N}]
-    end.
-
--spec decode_element(xmlel(), state_name(), state()) -> fsm_next().
-decode_element(#xmlel{} = El, StateName, StateData) ->
-    Opts = if StateName == stream_established ->
-                  [ignore_els];
-             true ->
-                  []
-          end,
-    try xmpp:decode(El, ?NS_SERVER, Opts) of
-       Pkt -> ?MODULE:StateName(Pkt, StateData)
-    catch error:{xmpp_codec, Why} ->
-           Type = xmpp:get_type(El),
-           case xmpp:is_stanza(El) of
-               true when Type /= <<"result">>, Type /= <<"error">> ->
-                   Lang = xmpp:get_lang(El),
-                   Txt = xmpp:format_error(Why),
-                   Err = xmpp:make_error(El, xmpp:err_bad_request(Txt, Lang)),
-                   send_element(StateData, Err);
-               false ->
-                   ok
-           end,
-           {next_state, StateName, StateData, get_timeout_interval(StateName)}
-    end.
-
-opt_type(domain_certfile) -> fun iolist_to_binary/1;
-opt_type(max_fsm_queue) ->
-    fun (I) when is_integer(I), I > 0 -> I end;
 opt_type(outgoing_s2s_families) ->
     fun (Families) ->
            true = lists:all(fun (ipv4) -> true;
@@ -1107,36 +353,12 @@ opt_type(outgoing_s2s_timeout) ->
            TimeOut;
        (infinity) -> infinity
     end;
-opt_type(s2s_certfile) -> fun iolist_to_binary/1;
-opt_type(s2s_ciphers) -> fun iolist_to_binary/1;
-opt_type(s2s_dhfile) -> fun iolist_to_binary/1;
 opt_type(s2s_dns_retries) ->
     fun (I) when is_integer(I), I >= 0 -> I end;
 opt_type(s2s_dns_timeout) ->
     fun (I) when is_integer(I), I >= 0 -> I end;
 opt_type(s2s_max_retry_delay) ->
     fun (I) when is_integer(I), I > 0 -> I end;
-opt_type(s2s_protocol_options) ->
-    fun (Options) ->
-           [_ | O] = lists:foldl(fun (X, Acc) -> X ++ Acc end, [],
-                                 [["|" | binary_to_list(Opt)]
-                                  || Opt <- Options, is_binary(Opt)]),
-           iolist_to_binary(O)
-    end;
-opt_type(s2s_tls_compression) ->
-    fun (true) -> true;
-       (false) -> false
-    end;
-opt_type(s2s_use_starttls) ->
-    fun (true) -> true;
-       (false) -> false;
-       (optional) -> optional;
-       (required) -> required;
-       (required_trusted) -> required_trusted
-    end;
 opt_type(_) ->
-    [domain_certfile, max_fsm_queue, outgoing_s2s_families,
-     outgoing_s2s_port, outgoing_s2s_timeout, s2s_certfile,
-     s2s_ciphers, s2s_dhfile, s2s_dns_retries, s2s_dns_timeout,
-     s2s_max_retry_delay, s2s_protocol_options,
-     s2s_tls_compression, s2s_use_starttls].
+    [outgoing_s2s_families, outgoing_s2s_port, outgoing_s2s_timeout,
+     s2s_dns_retries, s2s_dns_timeout, s2s_max_retry_delay].
index c48cd536c6a43a15e4c3124163972ec9a4b1093e..13efd15e7917e09bb23460f32f19517d6897f327 100644 (file)
 -module(ejabberd_service).
 -behaviour(xmpp_stream_in).
 -behaviour(ejabberd_config).
+-behaviour(ejabberd_socket).
 
 -protocol({xep, 114, '1.6'}).
 
 %% ejabberd_socket callbacks
--export([start/2, socket_type/0]).
+-export([start/2, start_link/2, socket_type/0]).
 %% ejabberd_config callbacks
 -export([opt_type/1, transform_listen_option/2]).
 %% xmpp_stream_in callbacks
--export([init/1, handle_call/3, handle_cast/2, handle_info/2,
-        terminate/2, code_change/3]).
--export([handshake/2, handle_stream_start/1, handle_authenticated_packet/2]).
+-export([init/1, handle_info/2, terminate/2, code_change/3]).
+-export([handle_stream_start/2, handle_auth_success/4, handle_auth_failure/4,
+        handle_authenticated_packet/2, get_password_fun/1]).
 %% API
 -export([send/2]).
 
 -include("xmpp.hrl").
 -include("logger.hrl").
 
-%%-define(DBGFSM, true).
--ifdef(DBGFSM).
--define(FSMOPTS, [{debug, [trace]}]).
--else.
--define(FSMOPTS, []).
--endif.
-
 -type state() :: map().
--type next_state() :: {noreply, state()} | {stop, term(), state()}.
--export_type([state/0, next_state/0]).
+-export_type([state/0]).
 
 %%%===================================================================
 %%% API
 %%%===================================================================
 start(SockData, Opts) ->
     xmpp_stream_in:start(?MODULE, [SockData, Opts],
-                        fsm_limit_opts(Opts) ++ ?FSMOPTS).
+                        ejabberd_config:fsm_limit_opts(Opts)).
+
+start_link(SockData, Opts) ->
+    xmpp_stream_in:start_link(?MODULE, [SockData, Opts],
+                             ejabberd_config:fsm_limit_opts(Opts)).
 
 socket_type() ->
     xml_stream.
 
--spec send(state(), xmpp_element()) -> next_state().
-send(State, Pkt) ->
-    xmpp_stream_in:send(State, Pkt).
+-spec send(pid(), xmpp_element()) -> ok;
+         (state(), xmpp_element()) -> state().
+send(Stream, Pkt) ->
+    xmpp_stream_in:send(Stream, Pkt).
 
 %%%===================================================================
 %%% xmpp_stream_in callbacks
 %%%===================================================================
-init([#{socket := Socket} = State, Opts]) ->
-    ?INFO_MSG("(~w) External service connected", [Socket]),
+init([State, Opts]) ->
     Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all),
     Shaper = gen_mod:get_opt(shaper_rule, Opts, fun acl:shaper_rules_validator/1, none),
     HostOpts = case lists:keyfind(hosts, 1, Opts) of
@@ -96,66 +93,85 @@ init([#{socket := Socket} = State, Opts]) ->
                    server => ?MYNAME,
                    host_opts => HostOpts,
                    check_from => CheckFrom},
-    ejabberd_hooks:run_fold(component_init, {ok, State1}, []).
+    ejabberd_hooks:run_fold(component_init, {ok, State1}, [Opts]).
 
-handle_stream_start(#{remote_server := RemoteServer,
+handle_stream_start(_StreamStart,
+                   #{remote_server := RemoteServer,
+                     lang := Lang,
                      host_opts := HostOpts} = State) ->
-    NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of
-                     true ->
-                         HostOpts;
-                     false ->
-                         case dict:find(global, HostOpts) of
-                             {ok, GlobalPass} ->
-                                 dict:from_list([{RemoteServer, GlobalPass}]);
-                             error ->
-                                 HostOpts
-                         end
-                 end,
-    {noreply, State#{host_opts => NewHostOpts}}.
-
-handshake(Digest, #{remote_server := RemoteServer,
-                   stream_id := StreamID,
-                   host_opts := HostOpts} = State) ->
-    case dict:find(RemoteServer, HostOpts) of
-       {ok, Password} ->
-           case p1_sha:sha(<<StreamID/binary, Password/binary>>) of
-               Digest ->
-                   lists:foreach(
-                     fun (H) ->
-                             ejabberd_router:register_route(H, ?MYNAME),
-                             ?INFO_MSG("Route registered for service ~p~n", [H]),
-                             ejabberd_hooks:run(component_connected, [H])
-                     end, dict:fetch_keys(HostOpts)),
-                   {ok, State};
-               _ ->
-                   ?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]),
-                   {error, xmpp:serr_not_authorized(), State}
-           end;
-       _ ->
-           ?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]),
-           {error, xmpp:serr_not_authorized(), State}
+    case lists:member(RemoteServer, ?MYHOSTS) of
+       true ->
+           Txt = <<"Unable to register route on existing local domain">>,
+           xmpp_stream_in:send(State, xmpp:serr_conflict(Txt, Lang));
+       false ->
+           NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of
+                             true ->
+                                 HostOpts;
+                             false ->
+                                 case dict:find(global, HostOpts) of
+                                     {ok, GlobalPass} ->
+                                         dict:from_list([{RemoteServer, GlobalPass}]);
+                                     error ->
+                                         HostOpts
+                                 end
+                         end,
+           State#{host_opts => NewHostOpts}
+    end.
+
+get_password_fun(#{remote_server := RemoteServer,
+                  socket := Socket,
+                  ip := IP,
+                  host_opts := HostOpts}) ->
+    fun(_) ->
+           case dict:find(RemoteServer, HostOpts) of
+               {ok, Password} ->
+                   {Password, undefined};
+               error ->
+                   ?ERROR_MSG("(~s) Domain ~s is unconfigured for "
+                              "external component from ~s",
+                              [ejabberd_socket:pp(Socket), RemoteServer,
+                               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+                   {false, undefined}
+           end
     end.
 
+handle_auth_success(_, Mech, _,
+                   #{remote_server := RemoteServer, host_opts := HostOpts,
+                     socket := Socket, ip := IP} = State) ->
+    ?INFO_MSG("(~s) Accepted external component ~s authentication "
+             "for ~s from ~s",
+             [ejabberd_socket:pp(Socket), Mech, RemoteServer,
+              ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+    lists:foreach(
+      fun (H) ->
+             ejabberd_router:register_route(H, ?MYNAME),
+             ejabberd_hooks:run(component_connected, [H])
+      end, dict:fetch_keys(HostOpts)),
+    State.
+
+handle_auth_failure(_, Mech, Reason,
+                   #{remote_server := RemoteServer,
+                     socket := Socket, ip := IP} = State) ->
+    ?ERROR_MSG("(~s) Failed external component ~s authentication "
+              "for ~s from ~s: ~s",
+              [ejabberd_socket:pp(Socket), Mech, RemoteServer,
+               ejabberd_config:may_hide_data(jlib:ip_to_list(IP)),
+               Reason]),
+    State.
+
 handle_authenticated_packet(Pkt, #{lang := Lang} = State) ->
     From = xmpp:get_from(Pkt),
     case check_from(From, State) of
        true ->
            To = xmpp:get_to(Pkt),
            ejabberd_router:route(From, To, Pkt),
-           {noreply, State};
+           State;
        false ->
            Txt = <<"Improper domain part of 'from' attribute">>,
            Err = xmpp:serr_invalid_from(Txt, Lang),
            xmpp_stream_in:send(State, Err)
     end.
 
-handle_call(_Request, _From, State) ->
-    Reply = ok,
-    {reply, Reply, State}.
-
-handle_cast(_Msg, State) ->
-    {noreply, State}.
-
 handle_info({route, From, To, Packet}, #{access := Access} = State) ->
     case acl:match_rule(global, Access, From) of
        allow ->
@@ -165,16 +181,15 @@ handle_info({route, From, To, Packet}, #{access := Access} = State) ->
            Lang = xmpp:get_lang(Packet),
            Err = xmpp:err_not_allowed(<<"Denied by ACL">>, Lang),
            ejabberd_router:route_error(To, From, Packet, Err),
-           {noreply, State}
+           State
     end;
 handle_info(Info, State) ->
     ?ERROR_MSG("Unexpected info: ~p", [Info]),
-    {noreply, State}.
+    State.
 
 terminate(Reason, #{stream_state := StreamState, host_opts := HostOpts}) ->
-    ?INFO_MSG("External service disconnected: ~p", [Reason]),
     case StreamState of
-       session_established ->
+       established ->
            lists:foreach(
              fun(H) ->
                      ejabberd_router:unregister_route(H),
@@ -220,19 +235,4 @@ transform_listen_option({host, Host, Os}, Opts) ->
 transform_listen_option(Opt, Opts) ->
     [Opt|Opts].
 
-fsm_limit_opts(Opts) ->
-    case lists:keysearch(max_fsm_queue, 1, Opts) of
-        {value, {_, N}} when is_integer(N) ->
-            [{max_queue, N}];
-        _ ->
-            case ejabberd_config:get_option(
-                   max_fsm_queue,
-                   fun(I) when is_integer(I), I > 0 -> I end) of
-                undefined -> [];
-                N -> [{max_queue, N}]
-            end
-    end.
-
-opt_type(max_fsm_queue) ->
-    fun (I) when is_integer(I), I > 0 -> I end;
-opt_type(_) -> [max_fsm_queue].
+opt_type(_) -> [].
index 11b829a94bb2a56c9a94a998f71be0ec401a020f..46008bec47e1b9f25682d6efce53b6efac5e8452 100644 (file)
@@ -359,20 +359,20 @@ unregister_iq_handler(Host, XMLNS) ->
     ejabberd_sm ! {unregister_iq_handler, Host, XMLNS}.
 
 %% Why the hell do we have so many similar kicks?
-c2s_handle_info({noreply, #{lang := Lang} = State}, replaced) ->
+c2s_handle_info(#{lang := Lang} = State, replaced) ->
     State1 = State#{replaced => true},
     Err = xmpp:serr_conflict(<<"Replaced by new connection">>, Lang),
-    ejabberd_c2s:send(State1, Err);
-c2s_handle_info({noreply, #{lang := Lang} = State}, kick) ->
+    {stop, ejabberd_c2s:send(State1, Err)};
+c2s_handle_info(#{lang := Lang} = State, kick) ->
     Err = xmpp:serr_policy_violation(<<"has been kicked">>, Lang),
-    c2s_handle_info({noreply, State}, {kick, kicked_by_admin, Err});
-c2s_handle_info({noreply, State}, {kick, _Reason, Err}) ->
-    ejabberd_c2s:send(State, Err);
-c2s_handle_info({noreply, #{lang := Lang} = State}, {exit, Reason}) ->
+    c2s_handle_info(State, {kick, kicked_by_admin, Err});
+c2s_handle_info(State, {kick, _Reason, Err}) ->
+    {stop, ejabberd_c2s:send(State, Err)};
+c2s_handle_info(#{lang := Lang} = State, {exit, Reason}) ->
     Err = xmpp:serr_conflict(Reason, Lang),
-    ejabberd_c2s:send(State, Err);
-c2s_handle_info(Acc, _) ->
-    Acc.
+    {stop, ejabberd_c2s:send(State, Err)};
+c2s_handle_info(State, _) ->
+    State.
 
 %%====================================================================
 %% gen_server callbacks
index 3f01dae85e08ff8dd87c5e5b02ddd70b6d87e3e0..4e523a7e525745531a1bd52b2b20be2cc1370cd8 100644 (file)
@@ -46,6 +46,7 @@
         get_peer_certificate/1,
         get_verify_result/1,
         close/1,
+        pp/1,
         sockname/1, peername/1]).
 
 -include("ejabberd.hrl").
 
 -export_type([socket/0, socket_state/0, sockmod/0]).
 
+-callback start({module(), socket_state()},
+               [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore.
+-callback start_link({module(), socket_state()},
+                    [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore.
+-callback socket_type() -> xml_stream | independent | raw.
 
 %%====================================================================
 %% API
@@ -109,7 +115,7 @@ start(Module, SockMod, Socket, Opts) ->
                  {error, _Reason} -> SockMod:close(Socket)
                end,
                ReceiverMod:become_controller(Receiver, Pid);
-           {error, _Reason} ->
+           _ ->
                SockMod:close(Socket),
                case ReceiverMod of
                  ejabberd_receiver -> ReceiverMod:close(Receiver);
@@ -190,6 +196,7 @@ reset_stream(SocketData)
 -spec send(socket_state(), iodata()) -> ok.
 
 send(SocketData, Data) ->
+    ?DEBUG("Send XML on stream = ~p", [Data]),
     case catch (SocketData#socket_state.sockmod):send(
             SocketData#socket_state.socket, Data) of
         ok -> ok;
@@ -238,8 +245,8 @@ get_transport(#socket_state{sockmod = SockMod,
        fast_tls -> tls;
        ezlib ->
            case ezlib:get_sockmod(Socket) of
-               tcp -> tcp_zlib;
-               tls -> tls_zlib
+               gen_tcp -> tcp_zlib;
+               fast_tls -> tls_zlib
            end;
        ejabberd_bosh -> http_bind;
        ejabberd_http_bind -> http_bind;
@@ -268,3 +275,7 @@ peername(#socket_state{sockmod = SockMod,
       gen_tcp -> inet:peername(Socket);
       _ -> SockMod:peername(Socket)
     end.
+
+pp(#socket_state{receiver = Receiver} = State) ->
+    Transport = get_transport(State),
+    io_lib:format("~s|~w", [Transport, Receiver]).
index 096ef40127644d291fdc814085845896d0292dd3..939baae841dc84f1d2db773053ba9ae81dbc1f6a 100644 (file)
@@ -38,8 +38,8 @@
 -export([tolower/1, term_to_base64/1, base64_to_term/1,
         decode_base64/1, encode_base64/1, ip_to_list/1,
         atom_to_binary/1, binary_to_atom/1, tuple_to_binary/1,
-        l2i/1, i2l/1, i2l/2, queue_drop_while/2,
-        expr_to_term/1, term_to_expr/1]).
+        l2i/1, i2l/1, i2l/2, expr_to_term/1, term_to_expr/1,
+        queue_drop_while/2, queue_foldl/3, queue_foldr/3, queue_foreach/2]).
 
 %% The following functions are used by gen_iq_handler.erl for providing backward
 %% compatibility and must not be used in other parts of the code
@@ -974,3 +974,33 @@ queue_drop_while(F, Q) ->
       empty ->
          Q
     end.
+
+-spec queue_foldl(fun((term(), T) -> T), T, ?TQUEUE) -> T.
+queue_foldl(F, Acc, Q) ->
+    case queue:out(Q) of
+       {{value, Item}, Q1} ->
+           Acc1 = F(Item, Acc),
+           queue_foldl(F, Acc1, Q1);
+       {empty, _} ->
+           Acc
+    end.
+
+-spec queue_foldr(fun((term(), T) -> T), T, ?TQUEUE) -> T.
+queue_foldr(F, Acc, Q) ->
+    case queue:out_r(Q) of
+       {{value, Item}, Q1} ->
+           Acc1 = F(Item, Acc),
+           queue_foldr(F, Acc1, Q1);
+       {empty, _} ->
+           Acc
+    end.
+
+-spec queue_foreach(fun((_) -> _), ?TQUEUE) -> ok.
+queue_foreach(F, Q) ->
+    case queue:out(Q) of
+       {{value, Item}, Q1} ->
+           F(Item),
+           queue_foreach(F, Q1);
+       {empty, _} ->
+           ok
+    end.
index 826a7bba356e5ca507ad65491b7a59e6837b2ccc..45564daf4ced02aa47b675b08b86ab14ad055996 100644 (file)
@@ -54,8 +54,6 @@ start(Host, Opts) ->
                       process_iq_set, 40),
     ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
                       c2s_handle_info, 40),
-    ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
-                      c2s_handle_info, 40),
     mod_disco:register_feature(Host, ?NS_BLOCKING),
     gen_iq_handler:add_iq_handler(ejabberd_sm, Host,
                                  ?NS_BLOCKING, ?MODULE, process_iq, IQDisc).
@@ -65,6 +63,8 @@ stop(Host) ->
                          process_iq_get, 40),
     ejabberd_hooks:delete(privacy_iq_set, Host, ?MODULE,
                          process_iq_set, 40),
+    ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
+                         c2s_handle_info, 40),
     mod_disco:unregister_feature(Host, ?NS_BLOCKING),
     gen_iq_handler:remove_iq_handler(ejabberd_sm, Host,
                                     ?NS_BLOCKING).
@@ -253,8 +253,8 @@ process_blocklist_get(LUser, LServer, Lang) ->
          {result, #block_list{items = Items}}
     end.
 
--spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state().
-c2s_handle_info({noreply, #{user := U, server := S, resource := R} = State},
+-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
+c2s_handle_info(#{user := U, server := S, resource := R} = State,
                {blocking, Action}) ->
     SubEl = case Action of
                {block, JIDs} ->
@@ -272,7 +272,9 @@ c2s_handle_info({noreply, #{user := U, server := S, resource := R} = State},
     %% No need to replace active privacy list here,
     %% blocking pushes are always accompanied by
     %% Privacy List pushes
-    ejabberd_c2s:send(State, PushIQ).
+    {stop, ejabberd_c2s:send(State, PushIQ)};
+c2s_handle_info(State, _) ->
+    State.
 
 -spec db_mod(binary()) -> module().
 db_mod(LServer) ->
index 3e83680e5d5ff905e457688c6fa4c6681aba8a01..f93b67e05a71db7969b1cbb2509c32cfb38f44c5 100644 (file)
@@ -52,16 +52,16 @@ depends(_Host, _Opts) ->
 mod_opt_type(_) ->
     [].
 
-c2s_unauthenticated_packet({noreply, State}, #iq{type = T, sub_els = [_]} = IQ)
+c2s_unauthenticated_packet(State, #iq{type = T, sub_els = [_]} = IQ)
   when T == get; T == set ->
     case xmpp:get_subtag(IQ, #legacy_auth{}) of
        #legacy_auth{} = Auth ->
            {stop, authenticate(State, xmpp:set_els(IQ, [Auth]))};
        false ->
-           {noreply, State}
+           State
     end;
-c2s_unauthenticated_packet(Acc, _) ->
-    Acc.
+c2s_unauthenticated_packet(State, _) ->
+    State.
 
 c2s_stream_features(Acc, LServer) ->
     case gen_mod:is_loaded(LServer, ?MODULE) of
@@ -112,14 +112,10 @@ authenticate(#{stream_id := StreamID, server := Server,
            case ejabberd_auth:check_password_with_authmodule(
                   U, U, JID#jid.lserver, P, D, DGen) of
                {true, AuthModule} ->
-                   case ejabberd_c2s:handle_auth_success(
-                          U, <<"legacy">>, AuthModule, State) of
-                       {noreply, State1} ->
-                           State2 = State1#{user := U},
-                           open_session(State2, IQ, R);
-                       Err ->
-                           Err
-                   end;
+                   State1 = ejabberd_c2s:handle_auth_success(
+                              U, <<"legacy">>, AuthModule, State),
+                   State2 = State1#{user := U},
+                   open_session(State2, IQ, R);
                _ ->
                    Err = xmpp:make_error(IQ, xmpp:err_not_authorized()),
                    process_auth_failure(State, U, Err, 'not-authorized')
@@ -137,23 +133,13 @@ open_session(State, IQ, R) ->
     case ejabberd_c2s:bind(R, State) of
        {ok, State1} ->
            Res = xmpp:make_iq_result(IQ),
-           case ejabberd_c2s:send(State1, Res) of
-               {noreply, State2} ->
-                   {noreply, State2#{stream_authenticated := true,
-                                     stream_state := session_established}};
-               Err ->
-                   Err
-           end;
+           State2 = ejabberd_c2s:send(State1, Res),
+           ejabberd_c2s:establish(State2);
        {error, Err, State1} ->
            Res = xmpp:make_error(IQ, Err),
            ejabberd_c2s:send(State1, Res)
     end.
 
 process_auth_failure(State, User, StanzaErr, Reason) ->
-    case ejabberd_c2s:send(State, StanzaErr) of
-       {noreply, State1} ->
-           ejabberd_c2s:handle_auth_failure(
-             User, <<"legacy">>, Reason, State1);
-       Err ->
-           Err
-    end.
+    State1 = ejabberd_c2s:send(State, StanzaErr),
+    ejabberd_c2s:handle_auth_failure(User, <<"legacy">>, Reason, State1).
index e0e36a1da9c3fbbc779b5b487dbe16a3decd76db..8d58b14c9606028e6c4dbb26f2712c7e76651ced 100644 (file)
@@ -309,11 +309,11 @@ get_info(_Acc, #jid{luser = U, lserver = S} = JID,
 get_info(Acc, _From, _To, _Node, _Lang) ->
     Acc.
 
--spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state().
-c2s_handle_info({noreply, State}, {resend_offline, Flag}) ->
-    {noreply, State#{resend_offline => Flag}};
-c2s_handle_info(Acc, _) ->
-    Acc.
+-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
+c2s_handle_info(State, {resend_offline, Flag}) ->
+    {stop, State#{resend_offline => Flag}};
+c2s_handle_info(State, _) ->
+    State.
 
 -spec handle_offline_query(iq()) -> iq().
 handle_offline_query(#iq{from = #jid{luser = U1, lserver = S1},
index d76edc91d17753b4a079f8cff70de52f82fc4e64..b28bbcea2fefcfb15d0bb0deb98061b4e24e68fb 100644 (file)
@@ -535,8 +535,8 @@ remove_user(User, Server) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     Mod:remove_user(LUser, LServer).
 
-c2s_handle_info({noreply, #{privacy_list := Old,
-                           user := U, server := S, resource := R} = State},
+c2s_handle_info(#{privacy_list := Old,
+                 user := U, server := S, resource := R} = State,
                {privacy_list, New, Name}) ->
     List = if Old#userlist.name == New#userlist.name -> New;
              true -> Old
@@ -548,9 +548,9 @@ c2s_handle_info({noreply, #{privacy_list := Old,
                 sub_els = [#privacy_query{
                               lists = [#privacy_list{name = Name}]}]},
     State1 = State#{privacy_list => List},
-    ejabberd_c2s:send(State1, PushIQ);
-c2s_handle_info(Acc, _) ->
-    Acc.
+    {stop, ejabberd_c2s:send(State1, PushIQ)};
+c2s_handle_info(State, _) ->
+    State.
 
 -spec updated_list(userlist(), userlist(), userlist()) -> userlist().
 updated_list(_, #userlist{name = OldName} = Old,
index 98d50660c80e9df74e22451fb27f97611778bd4e..8819e3a994d5134daa4b6edc79677ac1207ce1e8 100644 (file)
@@ -3026,8 +3026,8 @@ broadcast_stanza({LUser, LServer, LResource}, Publisher, Node, Nidx, Type, NodeO
 broadcast_stanza(Host, _Publisher, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM) ->
     broadcast_stanza(Host, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM).
 
--spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state().
-c2s_handle_info({noreply, #{server := Server} = C2SState},
+-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
+c2s_handle_info(#{server := Server} = C2SState,
                {pep_message, Feature, From, Packet}) ->
     LServer = jid:nameprep(Server),
     lists:foreach(
@@ -3042,8 +3042,8 @@ c2s_handle_info({noreply, #{server := Server} = C2SState},
                      ok
              end
       end, mod_caps:list_features(C2SState)),
-    {noreply, C2SState};
-c2s_handle_info({noreply, #{server := Server} = C2SState},
+    {stop, C2SState};
+c2s_handle_info(#{server := Server} = C2SState,
                {send_filtered, {pep_message, Feature}, From, To, Packet}) ->
     LServer = jid:nameprep(Server),
     case mod_caps:get_user_caps(To, C2SState) of
@@ -3059,9 +3059,9 @@ c2s_handle_info({noreply, #{server := Server} = C2SState},
        error ->
            ok
     end,
-    {noreply, C2SState};
-c2s_handle_info(Acc, _) ->
-    Acc.
+    {stop, C2SState};
+c2s_handle_info(C2SState, _) ->
+    C2SState.
 
 subscribed_nodes_by_jid(NotifyType, SubsByDepth) ->
     NodesToDeliver = fun (Depth, Node, Subs, Acc) ->
index 515cb10661b73f8902647451f2228f6bf1a506c7..8917d4c5c27cf6dd15f9c9fa525681daca997f88 100644 (file)
@@ -86,7 +86,7 @@ stream_feature_register(Acc, Host) ->
            Acc
     end.
 
-c2s_unauthenticated_packet({noreply, #{ip := IP, server := Server} = State},
+c2s_unauthenticated_packet(#{ip := IP, server := Server} = State,
                           #iq{type = T, sub_els = [_]} = IQ)
   when T == set; T == get ->
     case xmpp:get_subtag(IQ, #register{}) of
@@ -97,10 +97,10 @@ c2s_unauthenticated_packet({noreply, #{ip := IP, server := Server} = State},
            ResIQ1 = xmpp:set_from_to(ResIQ, jid:make(Server), undefined),
            {stop, ejabberd_c2s:send(State, ResIQ1)};
        false ->
-           {noreply, State}
+           State
     end;
-c2s_unauthenticated_packet(Acc, _) ->
-    Acc.
+c2s_unauthenticated_packet(State, _) ->
+    State.
 
 process_iq(#iq{from = From} = IQ) ->
     process_iq(IQ, jid:tolower(From)).
index a896ef05596f49c57fdc3d89745b3842afe08511..5c207f3a484e5dd94f7753246b49a84c5d4d4e53 100644 (file)
@@ -464,10 +464,10 @@ push_item_version(Server, User, From, Item,
                  end,
                  ejabberd_sm:get_user_resources(User, Server)).
 
-c2s_handle_info({noreply, State}, {item, JID, Sub}) ->
-    {noreply, roster_change(State, JID, Sub)};
-c2s_handle_info(Acc, _) ->
-    Acc.
+c2s_handle_info(State, {item, JID, Sub}) ->
+    {stop, roster_change(State, JID, Sub)};
+c2s_handle_info(State, _) ->
+    State.
 
 -spec roster_change(ejabberd_c2s:state(), jid(), subscription()) -> ejabberd_c2s:state().
 roster_change(#{user := U, server := S, resource := R} = State,
diff --git a/src/mod_s2s_dialback.erl b/src/mod_s2s_dialback.erl
new file mode 100644 (file)
index 0000000..ce9d270
--- /dev/null
@@ -0,0 +1,273 @@
+%%%-------------------------------------------------------------------
+%%% Created : 16 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2016   ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
+%%%-------------------------------------------------------------------
+-module(mod_s2s_dialback).
+-behaviour(gen_mod).
+
+-protocol({xep, 220, '1.1.1'}).
+-protocol({xep, 185, '1.0'}).
+
+%% gen_mod API
+-export([start/2, stop/1, depends/2, mod_opt_type/1]).
+%% Hooks
+-export([s2s_out_auth_result/2, s2s_in_packet/2, s2s_out_packet/2,
+        s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]).
+
+-include("ejabberd.hrl").
+-include("xmpp.hrl").
+-include("logger.hrl").
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+start(Host, _Opts) ->
+    case ejabberd_s2s:tls_verify(Host) of
+       true ->
+           ?ERROR_MSG("disabling ~s for host ~s because option "
+                      "'s2s_use_starttls' is set to 'required_trusted'",
+                      [?MODULE, Host]);
+       false ->
+           ejabberd_hooks:add(s2s_out_init, Host, ?MODULE, s2s_out_init, 50),
+           ejabberd_hooks:add(s2s_out_closed, Host, ?MODULE, s2s_out_closed, 50),
+           ejabberd_hooks:add(s2s_in_pre_auth_features, Host, ?MODULE,
+                              s2s_in_features, 50),
+           ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE,
+                              s2s_in_features, 50),
+           ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE,
+                              s2s_in_packet, 50),
+           ejabberd_hooks:add(s2s_in_authenticated_packet, Host, ?MODULE,
+                              s2s_in_packet, 50),
+           ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE,
+                              s2s_out_packet, 50),
+           ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE,
+                              s2s_out_auth_result, 50)
+    end.
+
+stop(Host) ->
+    ejabberd_hooks:delete(s2s_out_init, Host, ?MODULE, s2s_out_init, 50),
+    ejabberd_hooks:delete(s2s_out_closed, Host, ?MODULE, s2s_out_closed, 50),
+    ejabberd_hooks:delete(s2s_in_pre_auth_features, Host, ?MODULE,
+                         s2s_in_features, 50),
+    ejabberd_hooks:delete(s2s_in_post_auth_features, Host, ?MODULE,
+                         s2s_in_features, 50),
+    ejabberd_hooks:delete(s2s_in_unauthenticated_packet, Host, ?MODULE,
+                         s2s_in_packet, 50),
+    ejabberd_hooks:delete(s2s_in_authenticated_packet, Host, ?MODULE,
+                         s2s_in_packet, 50),
+    ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE,
+                         s2s_out_packet, 50),
+    ejabberd_hooks:delete(s2s_out_auth_result, Host, ?MODULE,
+                         s2s_out_auth_result, 50).
+
+depends(_Host, _Opts) ->
+    [].
+
+mod_opt_type(_) ->
+    [].
+
+s2s_in_features(Acc, _) ->
+    [#db_feature{errors = true}|Acc].
+
+s2s_out_init({ok, State}, Opts) ->
+    case proplists:get_value(db_verify, Opts) of
+       {StreamID, Key, Pid} ->
+           %% This is an outbound s2s connection created at step 1.
+           %% The purpose of this connection is to verify dialback key ONLY.
+           %% The connection is not registered in s2s table and thus is not
+           %% seen by anyone.
+           %% The connection will be closed immediately after receiving the
+           %% verification response (at step 3)
+           {ok, State#{db_verify => {StreamID, Key, Pid}}};
+       undefined ->
+           {ok, State#{db_enabled => true}}
+    end;
+s2s_out_init(Acc, _Opts) ->
+    Acc.
+
+s2s_out_closed(#{server := LServer,
+                remote_server := RServer,
+                db_verify := {StreamID, _Key, _Pid}} = State, _Reason) ->
+    %% Outbound s2s verificating connection (created at step 1) is
+    %% closed suddenly without receiving the response.
+    %% Building a response on our own
+    Response = #db_verify{from = RServer, to = LServer,
+                         id = StreamID, type = error,
+                         sub_els = [mk_error(internal_server_error)]},
+    s2s_out_packet(State, Response);
+s2s_out_closed(State, _Reason) ->
+    State.
+
+s2s_out_auth_result(#{server := LServer,
+                     remote_server := RServer,
+                     db_verify := {StreamID, Key, _Pid}} = State,
+                   _) ->
+    %% The temporary outbound s2s connect (intended for verification)
+    %% has passed authentication state (either successfully or not, no matter)
+    %% and at this point we can send verification request as described
+    %% in section 2.1.2, step 2
+    Request = #db_verify{from = LServer, to = RServer,
+                        key = Key, id = StreamID},
+    {stop, ejabberd_s2s_out:send(State, Request)};
+s2s_out_auth_result(#{db_enabled := true,
+                     socket := Socket, ip := IP,
+                     server := LServer,
+                     remote_server := RServer,
+                     stream_remote_id := StreamID} = State, false) ->
+    %% SASL authentication has failed, retrying with dialback
+    %% Sending dialback request, section 2.1.1, step 1
+    ?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)",
+             [ejabberd_socket:pp(Socket), LServer, RServer,
+              ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+    Key = make_key(LServer, RServer, StreamID),
+    State1 = maps:remove(stop_reason, State#{on_route => queue}),
+    State2 = ejabberd_s2s_out:send(State1, #db_result{from = LServer,
+                                                     to = RServer,
+                                                     key = Key}),
+    {stop, State2};
+s2s_out_auth_result(State, _) ->
+    State.
+
+s2s_in_packet(#{stream_id := StreamID} = State,
+             #db_result{from = From, to = To, key = Key, type = undefined}) ->
+    %% Received dialback request, section 2.2.1, step 1
+    try
+       ok = check_from_to(From, To),
+       %% We're creating a temporary outbound s2s connection to
+       %% send verification request and to receive verification response
+       {ok, Pid} = ejabberd_s2s_out:start(
+                     To, From, [{db_verify, {StreamID, Key, self()}}]),
+       ejabberd_s2s_out:connect(Pid),
+       State
+    catch _:{badmatch, {error, Reason}} ->
+           send_db_result(State,
+                          #db_verify{from = From, to = To, type = error,
+                                     sub_els = [mk_error(Reason)]})
+    end;
+s2s_in_packet(State, #db_verify{to = To, from = From, key = Key,
+                               id = StreamID, type = undefined}) ->
+    %% Received verification request, section 2.2.2, step 2
+    Type = case make_key(To, From, StreamID) of
+              Key -> valid;
+              _ -> invalid
+          end,
+    Response = #db_verify{from = To, to = From, id = StreamID, type = Type},
+    ejabberd_s2s_in:send(State, Response);
+s2s_in_packet(State, Pkt) when is_record(Pkt, db_result);
+                              is_record(Pkt, db_verify) ->
+    ?WARNING_MSG("Got stray dialback packet:~n~s", [xmpp:pp(Pkt)]),
+    State;
+s2s_in_packet(State, _) ->
+    State.
+
+s2s_out_packet(#{server := LServer,
+                remote_server := RServer,
+                db_verify := {StreamID, _Key, Pid}} = State,
+              #db_verify{from = RServer, to = LServer,
+                         id = StreamID, type = Type} = Response)
+  when Type /= undefined ->
+    %% Received verification response, section 2.1.2, step 3
+    %% This is a response for the request sent at step 2
+    ejabberd_s2s_in:update_state(
+      Pid, fun(S) -> send_db_result(S, Response) end),
+    %% At this point the connection is no longer needed and we can terminate it
+    ejabberd_s2s_out:stop(State);
+s2s_out_packet(#{server := LServer, remote_server := RServer} = State,
+              #db_result{to = LServer, from = RServer,
+                         type = Type} = Result) when Type /= undefined ->
+    %% Received dialback response, section 2.1.1, step 4
+    %% This is a response to the request sent at step 1
+    State1 = maps:remove(db_enabled, State),
+    case Type of
+       valid ->
+           State2 = ejabberd_s2s_out:handle_auth_success(<<"dialback">>, State1),
+           ejabberd_s2s_out:establish(State2);
+       _ ->
+           Reason = format_error(Result),
+           ejabberd_s2s_out:handle_auth_failure(<<"dialback">>, Reason, State1)
+    end;
+s2s_out_packet(State, Pkt) when is_record(Pkt, db_result);
+                               is_record(Pkt, db_verify) ->
+    ?WARNING_MSG("Got stray dialback packet:~n~s", [xmpp:pp(Pkt)]),
+    State;
+s2s_out_packet(State, _) ->
+    State.
+
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
+-spec make_key(binary(), binary(), binary()) -> binary().
+make_key(From, To, StreamID) ->
+    Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end),
+    p1_sha:to_hexlist(
+      crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)),
+                 [To, " ", From, " ", StreamID])).
+
+-spec send_db_result(ejabberd_s2s_in:state(), db_verify()) -> ejabberd_s2s_in:state().
+send_db_result(State, #db_verify{from = From, to = To,
+                                type = Type, sub_els = Els}) ->
+    %% Sending dialback response, section 2.2.1, step 4
+    %% This is a response to the request received at step 1
+    Response = #db_result{from = To, to = From, type = Type, sub_els = Els},
+    State1 = ejabberd_s2s_in:send(State, Response),
+    case Type of
+       valid ->
+           State2 = ejabberd_s2s_in:handle_auth_success(
+                      From, <<"dialback">>, undefined, State1),
+           ejabberd_s2s_in:establish(State2);
+       _ ->
+           Reason = format_error(Response),
+           ejabberd_s2s_in:handle_auth_failure(
+             From, <<"dialback">>, Reason, State1)
+    end.
+
+-spec check_from_to(binary(), binary()) -> ok | {error, forbidden | host_unknown}.
+check_from_to(From, To) ->
+    case ejabberd_router:is_my_route(To) of
+       false -> {error, host_unknown};
+       true ->
+           LServer = ejabberd_router:host_of_route(To),
+           case ejabberd_s2s:allow_host(LServer, From) of
+               true -> ok;
+               false -> {error, forbidden}
+           end
+    end.
+
+-spec mk_error(term()) -> stanza_error().
+mk_error(forbidden) ->
+    xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG);
+mk_error(host_unknown) ->
+    xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG);
+mk_error(_) ->
+    xmpp:err_internal_server_error().
+
+-spec format_error(db_result()) -> binary().
+format_error(#db_result{type = invalid}) ->
+    <<"invalid dialback key">>;
+format_error(#db_result{type = error, sub_els = Els}) ->
+    %% TODO: improve xmpp.erl
+    case xmpp:get_error(#message{sub_els = Els}) of
+       #stanza_error{reason = Reason} ->
+           erlang:atom_to_binary(Reason, latin1);
+       undefined ->
+           <<"unrecognized error">>
+    end;
+format_error(_) ->
+    <<"unexpected dialback result">>.
diff --git a/src/mod_sm.erl b/src/mod_sm.erl
new file mode 100644 (file)
index 0000000..82d6870
--- /dev/null
@@ -0,0 +1,660 @@
+%%%-------------------------------------------------------------------
+%%% Author  : Holger Weiss <holger@zedat.fu-berlin.de>
+%%% Created : 25 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2016   ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
+%%%-------------------------------------------------------------------
+-module(mod_sm).
+-behaviour(gen_mod).
+-author('holger@zedat.fu-berlin.de').
+-protocol({xep, 198, '1.5.2'}).
+
+%% gen_mod API
+-export([start/2, stop/1, depends/2, mod_opt_type/1]).
+%% hooks
+-export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2,
+        c2s_authenticated_packet/2, c2s_unauthenticated_packet/2,
+        c2s_unbinded_packet/2, c2s_closed/2,
+        c2s_handle_send/3, c2s_filter_send/2, c2s_handle_info/2]).
+
+-include("xmpp.hrl").
+-include("logger.hrl").
+
+-define(is_sm_packet(Pkt),
+       is_record(Pkt, sm_enable) or
+       is_record(Pkt, sm_resume) or
+       is_record(Pkt, sm_a) or
+       is_record(Pkt, sm_r)).
+
+-type state() :: ejabberd_c2s:state().
+-type lqueue() :: {non_neg_integer(), queue:queue()}.
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+start(Host, _Opts) ->
+    ejabberd_hooks:add(c2s_init, ?MODULE, c2s_stream_init, 50),
+    ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE,
+                      c2s_stream_started, 50),
+    ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE,
+                      c2s_stream_features, 50),
+    ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
+                      c2s_unauthenticated_packet, 50),
+    ejabberd_hooks:add(c2s_unbinded_packet, Host, ?MODULE,
+                      c2s_unbinded_packet, 50),
+    ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE,
+                      c2s_authenticated_packet, 50),
+    ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE,
+                      c2s_handle_send, 50),
+    ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE,
+                      c2s_filter_send, 50),
+    ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
+                      c2s_handle_info, 50),
+    ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50).
+
+stop(Host) ->
+    %% TODO: do something with global 'c2s_init' hook
+    ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE,
+                         c2s_stream_started, 50),
+    ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE,
+                         c2s_stream_features, 50),
+    ejabberd_hooks:delete(c2s_unauthenticated_packet, Host, ?MODULE,
+                         c2s_unauthenticated_packet, 50),
+    ejabberd_hooks:delete(c2s_unbinded_packet, Host, ?MODULE,
+                         c2s_unbinded_packet, 50),
+    ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE,
+                         c2s_authenticated_packet, 50),
+    ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE,
+                         c2s_handle_send, 50),
+    ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE,
+                         c2s_filter_send, 50),
+    ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
+                         c2s_handle_info, 50),
+    ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50).
+
+depends(_Host, _Opts) ->
+    [].
+
+c2s_stream_init({ok, State}, Opts) ->
+    MgmtOpts = lists:filter(
+                fun({stream_management, _}) -> true;
+                   ({max_ack_queue, _}) -> true;
+                   ({resume_timeout, _}) -> true;
+                   ({max_resume_timeout, _}) -> true;
+                   ({ack_timeout, _}) -> true;
+                   ({resend_on_timeout, _}) -> true;
+                   (_) -> false
+                end, Opts),
+    {ok, State#{mgmt_options => MgmtOpts}};
+c2s_stream_init(Acc, _Opts) ->
+    Acc.
+
+c2s_stream_started(#{lserver := LServer, mgmt_options := Opts} = State,
+                  _StreamStart) ->
+    State1 = maps:remove(mgmt_options, State),
+    ResumeTimeout = get_resume_timeout(LServer, Opts),
+    MaxResumeTimeout = get_max_resume_timeout(LServer, Opts, ResumeTimeout),
+    State1#{mgmt_state => inactive,
+           mgmt_max_queue => get_max_ack_queue(LServer, Opts),
+           mgmt_timeout => ResumeTimeout,
+           mgmt_max_timeout => MaxResumeTimeout,
+           mgmt_ack_timeout => get_ack_timeout(LServer, Opts),
+           mgmt_resend => get_resend_on_timeout(LServer, Opts)};
+c2s_stream_started(State, _StreamStart) ->
+    State.
+
+c2s_stream_features(Acc, Host) ->
+    case gen_mod:is_loaded(Host, ?MODULE) of
+       true ->
+           [#feature_sm{xmlns = ?NS_STREAM_MGMT_2},
+            #feature_sm{xmlns = ?NS_STREAM_MGMT_3}|Acc];
+       false ->
+           Acc
+    end.
+
+c2s_unauthenticated_packet(State, Pkt) when ?is_sm_packet(Pkt) ->
+    %% XEP-0198 says: "For client-to-server connections, the client MUST NOT
+    %% attempt to enable stream management until after it has completed Resource
+    %% Binding unless it is resuming a previous session".  However, it also
+    %% says: "Stream management errors SHOULD be considered recoverable", so we
+    %% won't bail out.
+    Err = #sm_failed{reason = 'unexpected-request', xmlns = ?NS_STREAM_MGMT_3},
+    {stop, send(State, Err)};
+c2s_unauthenticated_packet(State, _Pkt) ->
+    State.
+
+c2s_unbinded_packet(State, #sm_resume{} = Pkt) ->
+    case handle_resume(State, Pkt) of
+       {ok, ResumedState} ->
+           {stop, ResumedState};
+       error ->
+           {stop, State}
+    end;
+c2s_unbinded_packet(State, Pkt) when ?is_sm_packet(Pkt) ->
+    c2s_unauthenticated_packet(State, Pkt);
+c2s_unbinded_packet(State, _Pkt) ->
+    State.
+
+c2s_authenticated_packet(#{mgmt_state := MgmtState} = State, Pkt)
+  when ?is_sm_packet(Pkt) ->
+    if MgmtState == pending; MgmtState == active ->
+           {stop, perform_stream_mgmt(Pkt, State)};
+       true ->
+           {stop, negotiate_stream_mgmt(Pkt, State)}
+    end;
+c2s_authenticated_packet(State, Pkt) ->
+    update_num_stanzas_in(State, Pkt).
+
+c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
+  when MgmtState == pending; MgmtState == active ->
+    State1 = mgmt_queue_add(State, Pkt),
+    case Result of
+       ok when ?is_stanza(Pkt) ->
+           send_ack(State1);
+       ok ->
+           State1;
+       {error, _} ->
+           transition_to_pending(State1)
+    end;
+c2s_handle_send(State, _Pkt, _Result) ->
+    State.
+
+c2s_filter_send(Pkt, _State) ->
+    Pkt.
+
+c2s_handle_info(#{mgmt_ack_timer := T, jid := JID} = State,
+               {timeout, T, ack_timeout}) ->
+    ?DEBUG("Timeout waiting for stream management acknowledgement of ~s",
+          [jid:to_string(JID)]),
+    State1 = ejabberd_c2s:close(State, _SendTrailer = false),
+    c2s_closed(State1, ack_timeout);
+c2s_handle_info(State, _) ->
+    State.
+
+c2s_closed(#{mgmt_state := active} = State, Reason) when Reason /= normal ->
+    {stop, transition_to_pending(State)};
+c2s_closed(State, _) ->
+    State.
+
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
+-spec negotiate_stream_mgmt(xmpp_element(), state()) -> state().
+negotiate_stream_mgmt(Pkt, State) ->
+    Xmlns = xmpp:get_ns(Pkt),
+    case Pkt of
+       #sm_enable{} ->
+           handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt);
+       _ ->
+           Res = if is_record(Pkt, sm_a);
+                    is_record(Pkt, sm_r);
+                    is_record(Pkt, sm_resume) ->
+                         #sm_failed{reason = 'unexpected-request',
+                                    xmlns = Xmlns};
+                    true ->
+                         #sm_failed{reason = 'bad-request',
+                                    xmlns = Xmlns}
+                 end,
+           send(State, Res)
+    end.
+
+-spec perform_stream_mgmt(xmpp_element(), state()) -> state().
+perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) ->
+    case xmpp:get_ns(Pkt) of
+       Xmlns ->
+           case Pkt of
+               #sm_r{} ->
+                   handle_r(State);
+               #sm_a{} ->
+                   handle_a(State, Pkt);
+               _ ->
+                   Res = if is_record(Pkt, sm_enable);
+                            is_record(Pkt, sm_resume) ->
+                                 #sm_failed{reason = 'unexpected-request',
+                                            xmlns = Xmlns};
+                            true ->
+                                 #sm_failed{reason = 'bad-request',
+                                            xmlns = Xmlns}
+                         end,
+                   send(State, Res)
+           end;
+       _ ->
+           send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns})
+    end.
+
+-spec handle_enable(state(), sm_enable()) -> state().
+handle_enable(#{mgmt_timeout := DefaultTimeout,
+               mgmt_max_timeout := MaxTimeout,
+               xmlns := Xmlns, jid := JID} = State,
+             #sm_enable{resume = Resume, max = Max}) ->
+    Timeout = if Resume == false ->
+                     0;
+                Max /= undefined, Max > 0, Max =< MaxTimeout ->
+                     Max;
+                true ->
+                     DefaultTimeout
+             end,
+    Res = if Timeout > 0 ->
+                 ?INFO_MSG("Stream management with resumption enabled for ~s",
+                           [jid:to_string(JID)]),
+                 #sm_enabled{xmlns = Xmlns,
+                             id = make_resume_id(State),
+                             resume = true,
+                             max = Timeout};
+            true ->
+                 ?INFO_MSG("Stream management without resumption enabled for ~s",
+                           [jid:to_string(JID)]),
+                 #sm_enabled{xmlns = Xmlns}
+         end,
+    State1 = State#{mgmt_state => active,
+                   mgmt_queue => queue_new(),
+                   mgmt_timeout => Timeout * 1000},
+    send(State1, Res).
+
+-spec handle_r(state()) -> state().
+handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) ->
+    Res = #sm_a{xmlns = Xmlns, h = H},
+    send(State, Res).
+
+-spec handle_a(state(), sm_a()) -> state().
+handle_a(State, #sm_a{h = H}) ->
+    State1 = check_h_attribute(State, H),
+    resend_ack(State1).
+
+-spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}.
+handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State,
+             #sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) ->
+    R = case inherit_session_state(State, PrevID) of
+           {ok, InheritedState} ->
+               {ok, InheritedState, H};
+           {error, Err, InH} ->
+               {error, #sm_failed{reason = 'item-not-found',
+                                  h = InH, xmlns = Xmlns}, Err};
+           {error, Err} ->
+               {error, #sm_failed{reason = 'item-not-found',
+                                  xmlns = Xmlns}, Err}
+       end,
+    case R of
+       {ok, ResumedState, NumHandled} ->
+           State1 = check_h_attribute(ResumedState, NumHandled),
+           #{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1,
+           AttrId = make_resume_id(State1),
+           State2 = send(State1, #sm_resumed{xmlns = AttrXmlns,
+                                             h = AttrH,
+                                             previd = AttrId}),
+           State3 = resend_unacked_stanzas(State2),
+           State4 = send(State3, #sm_r{xmlns = AttrXmlns}),
+           %% TODO: move this to mod_client_state
+           %% csi_flush_queue(State4),
+           State5 = ejabberd_hooks:run_fold(c2s_session_resumed, LServer, State4, []),
+           ?INFO_MSG("(~s) Resumed session for ~s",
+                     [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
+           {ok, State5};
+       {error, El, Msg} ->
+           ?INFO_MSG("Cannot resume session for ~s: ~s", [jid:to_string(JID), Msg]),
+           {error, send(State, El)}
+    end.
+
+-spec transition_to_pending(state()) -> state().
+transition_to_pending(#{mgmt_state := active} = State) ->
+    %% TODO
+    State;
+transition_to_pending(State) ->
+    State.
+
+-spec check_h_attribute(state(), non_neg_integer()) -> state().
+check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, jid := JID} = State, H)
+  when H > NumStanzasOut ->
+    ?DEBUG("~s acknowledged ~B stanzas, but only ~B were sent",
+          [jid:to_string(JID), H, NumStanzasOut]),
+    mgmt_queue_drop(State#{mgmt_stanzas_out => H}, NumStanzasOut);
+check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, jid := JID} = State, H) ->
+    ?DEBUG("~s acknowledged ~B of ~B stanzas",
+          [jid:to_string(JID), H, NumStanzasOut]),
+    mgmt_queue_drop(State, H).
+
+-spec update_num_stanzas_in(state(), xmpp_element()) -> state().
+update_num_stanzas_in(#{mgmt_state := MgmtState,
+                       mgmt_stanzas_in := NumStanzasIn} = State, El)
+  when MgmtState == active; MgmtState == pending ->
+    NewNum = case {xmpp:is_stanza(El), NumStanzasIn} of
+                {true, 4294967295} ->
+                    0;
+                {true, Num} ->
+                    Num + 1;
+                {false, Num} ->
+                    Num
+            end,
+    State#{mgmt_stanzas_in => NewNum};
+update_num_stanzas_in(State, _El) ->
+    State.
+
+send_ack(#{mgmt_ack_timer := _} = State) ->
+    State;
+send_ack(#{mgmt_xmlns := Xmlns,
+          mgmt_stanzas_out := NumStanzasOut,
+          mgmt_ack_timeout := AckTimeout} = State) ->
+    State1 = send(State, #sm_r{xmlns = Xmlns}),
+    TRef = erlang:start_timer(AckTimeout, self(), ack_timeout),
+    State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}.
+
+resend_ack(#{mgmt_ack_timer := _,
+            mgmt_queue := Queue,
+            mgmt_stanzas_out := NumStanzasOut,
+            mgmt_stanzas_req := NumStanzasReq} = State) ->
+    State1 = cancel_ack_timer(State),
+    case NumStanzasReq < NumStanzasOut andalso not queue_is_empty(Queue) of
+       true -> send_ack(State1);
+       false -> State1
+    end;
+resend_ack(State) ->
+    State.
+
+-spec mgmt_queue_add(state(), xmpp_element()) -> state().
+mgmt_queue_add(#{mgmt_stanzas_out := NumStanzasOut,
+                mgmt_queue := Queue} = State, Stanza) when ?is_stanza(Stanza) ->
+    NewNum = case NumStanzasOut of
+              4294967295 -> 0;
+              Num -> Num + 1
+            end,
+    Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Stanza}, Queue),
+    State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum},
+    check_queue_length(State1);
+mgmt_queue_add(State, _Nonza) ->
+    State.
+
+-spec mgmt_queue_drop(state(), non_neg_integer()) -> state().
+mgmt_queue_drop(#{mgmt_queue := Queue} = State, NumHandled) ->
+    NewQueue = queue_dropwhile(
+                fun({N, _T, _E}) -> N =< NumHandled end, Queue),
+    State#{mgmt_queue => NewQueue}.
+
+-spec check_queue_length(state()) -> state().
+check_queue_length(#{mgmt_max_queue := Limit} = State)
+  when Limit == infinity; Limit == exceeded ->
+    State;
+check_queue_length(#{mgmt_queue := Queue, mgmt_max_queue := Limit} = State) ->
+    case queue_len(Queue) > Limit of
+       true ->
+           State#{mgmt_max_queue => exceeded};
+       false ->
+           State
+    end.
+
+-spec resend_unacked_stanzas(state()) -> state().
+resend_unacked_stanzas(#{mgmt_state := MgmtState,
+                        mgmt_queue := {QueueLen, _} = Queue,
+                        jid := JID} = State)
+  when (MgmtState == active orelse
+       MgmtState == pending orelse
+       MgmtState == timeout) andalso QueueLen > 0 ->
+    ?DEBUG("Resending ~B unacknowledged stanza(s) to ~s",
+          [QueueLen, jid:to_string(JID)]),
+    queue_foldl(
+      fun({_, Time, Pkt}, AccState) ->
+             NewPkt = add_resent_delay_info(AccState, Pkt, Time),
+             send(AccState, NewPkt)
+      end, State, Queue);
+resend_unacked_stanzas(State) ->
+    State.
+
+-spec route_unacked_stanzas(state()) -> ok.
+route_unacked_stanzas(#{mgmt_state := MgmtState,
+                       mgmt_resend := MgmtResend,
+                       lang := Lang, user := User,
+                       jid := JID, lserver := LServer,
+                       mgmt_queue := {QueueLen, _} = Queue,
+                       resource := Resource} = State)
+  when (MgmtState == active orelse
+       MgmtState == pending orelse
+       MgmtState == timeout) andalso QueueLen > 0 ->
+    ResendOnTimeout = case MgmtResend of
+                         Resend when is_boolean(Resend) ->
+                             Resend;
+                         if_offline ->
+                             case ejabberd_sm:get_user_resources(User, Resource) of
+                                 [Resource] ->
+                                     %% Same resource opened new session
+                                     true;
+                                 [] -> true;
+                                 _ -> false
+                             end
+                     end,
+    ?DEBUG("Re-routing ~B unacknowledged stanza(s) to ~s",
+          [QueueLen, jid:to_string(JID)]),
+    queue_foreach(
+      fun({_, _Time, #presence{from = From}}) ->
+             ?DEBUG("Dropping presence stanza from ~s", [jid:to_string(From)]);
+        ({_, _Time, #iq{} = El}) ->
+             Txt = <<"User session terminated">>,
+             route_error(El, xmpp:err_service_unavailable(Txt, Lang));
+        ({_, _Time, #message{from = From, meta = #{carbon_copy := true}}}) ->
+             %% XEP-0280 says: "When a receiving server attempts to deliver a
+             %% forked message, and that message bounces with an error for
+             %% any reason, the receiving server MUST NOT forward that error
+             %% back to the original sender."  Resending such a stanza could
+             %% easily lead to unexpected results as well.
+             ?DEBUG("Dropping forwarded message stanza from ~s",
+                    [jid:to_string(From)]);
+        ({_, Time, El}) ->
+             case ejabberd_hooks:run_fold(message_is_archived,
+                                          LServer, false,
+                                          [State, El]) of
+                 true ->
+                     ?DEBUG("Dropping archived message stanza from ~s",
+                            [jid:to_string(xmpp:get_from(El))]);
+                 false when ResendOnTimeout ->
+                     NewEl = add_resent_delay_info(State, El, Time),
+                     route(NewEl);
+                 false ->
+                     Txt = <<"User session terminated">>,
+                     route_error(El, xmpp:err_service_unavailable(Txt, Lang))
+             end
+      end, Queue);
+route_unacked_stanzas(_State) ->
+    ok.
+
+-spec inherit_session_state(state(), binary()) -> {ok, state()} |
+                                                 {error, binary()} |
+                                                 {error, binary(), non_neg_integer()}.
+inherit_session_state(#{user := U, server := S} = State, ResumeID) ->
+    case jlib:base64_to_term(ResumeID) of
+       {term, {R, Time}} ->
+           case ejabberd_sm:get_session_pid(U, S, R) of
+               none ->
+                   case ejabberd_sm:get_offline_info(Time, U, S, R) of
+                       none ->
+                           {error, <<"Previous session PID not found">>};
+                       Info ->
+                           case proplists:get_value(num_stanzas_in, Info) of
+                               undefined ->
+                                   {error, <<"Previous session timed out">>};
+                               H ->
+                                   {error, <<"Previous session timed out">>, H}
+                           end
+                   end;
+               OldPID ->
+                   OldSID = {Time, OldPID},
+                   try resume_session(OldSID, State) of
+                       {resume, OldState} ->
+                           State1 = ejabberd_c2s:copy_state(State, OldState),
+                           State2 = ejabberd_c2s:open_session(State1),
+                           {ok, State2};
+                       {error, Msg} ->
+                           {error, Msg}
+                   catch exit:{noproc, _} ->
+                           {error, <<"Previous session PID is dead">>};
+                         exit:{timeout, _} ->
+                           {error, <<"Session state copying timed out">>}
+                   end
+           end;
+       _ ->
+           {error, <<"Invalid 'previd' value">>}
+    end.
+
+-spec resume_session({integer(), pid()}, state()) -> {resume, state()} |
+                                                    {error, binary()}.
+resume_session({Time, Pid}, _State) ->
+    ejabberd_c2s:call(Pid, {resume_session, Time}, timer:seconds(15)).
+
+-spec make_resume_id(state()) -> binary().
+make_resume_id(#{sid := {Time, _}, resource := Resource}) ->
+    jlib:term_to_base64({Resource, Time}).
+
+-spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza().
+add_resent_delay_info(_State, #iq{} = El, _Time) ->
+    El;
+add_resent_delay_info(#{lserver := LServer}, El, Time) ->
+    xmpp_util:add_delay_info(El, jid:make(LServer), Time, <<"Resent">>).
+
+-spec route(stanza()) -> ok.
+route(Pkt) ->
+    From = xmpp:get_from(Pkt),
+    To = xmpp:get_to(Pkt),
+    ejabberd_router:route(From, To, Pkt).
+
+-spec route_error(stanza(), stanza_error()) -> ok.
+route_error(Pkt, Err) ->
+    From = xmpp:get_from(Pkt),
+    To = xmpp:get_to(Pkt),
+    ejabberd_router:route_error(To, From, Pkt, Err).
+
+-spec send(state(), xmpp_element()) -> state().
+send(#{mod := Mod} = State, Pkt) ->
+    Mod:send(State, Pkt).
+
+-spec queue_new() -> lqueue().
+queue_new() ->
+    {0, queue:new()}.
+
+-spec queue_in(term(), lqueue()) -> lqueue().
+queue_in(Elem, {N, Q}) ->
+    {N+1, queue:in(Elem, Q)}.
+
+-spec queue_len(lqueue()) -> non_neg_integer().
+queue_len({N, _}) ->
+    N.
+
+-spec queue_foldl(fun((term(), T) -> T), T, lqueue()) -> T.
+queue_foldl(F, Acc, {_N, Q}) ->
+    jlib:queue_foldl(F, Acc, Q).
+
+-spec queue_foreach(fun((_) -> _), lqueue()) -> ok.
+queue_foreach(F, {_N, Q}) ->
+    jlib:queue_foreach(F, Q).
+
+-spec queue_dropwhile(fun((term()) -> boolean()), lqueue()) -> lqueue().
+queue_dropwhile(F, {N, Q}) ->
+    case queue:peek(Q) of
+       {value, Item} ->
+           case F(Item) of
+               true ->
+                   queue_dropwhile(F, {N-1, queue:drop(Q)});
+               false ->
+                   {N, Q}
+           end;
+       empty ->
+           {N, Q}
+    end.
+
+-spec queue_is_empty(lqueue()) -> boolean().
+queue_is_empty({N, _Q}) ->
+    N == 0.
+
+-spec cancel_ack_timer(state()) -> state().
+cancel_ack_timer(#{mgmt_ack_timer := TRef} = State) ->
+    case erlang:cancel_timer(TRef) of
+        false -> 
+            receive {timeout, TRef, _} -> ok
+            after 0 -> ok
+            end;
+        _ ->
+            ok
+    end,
+    maps:remove(mgmt_ack_timer, State);
+cancel_ack_timer(State) ->
+    State.
+
+%%%===================================================================
+%%% Configuration processing
+%%%===================================================================
+get_max_ack_queue(Host, Opts) ->
+    VFun = mod_opt_type(max_ack_queue),
+    case gen_mod:get_module_opt(Host, ?MODULE, max_ack_queue, VFun) of
+       undefined -> gen_mod:get_opt(max_ack_queue, Opts, VFun, 1000);
+       Limit -> Limit
+    end.
+
+get_resume_timeout(Host, Opts) ->
+    VFun = mod_opt_type(resume_timeout),
+    case gen_mod:get_module_opt(Host, ?MODULE, resume_timeout, VFun) of
+       undefined -> gen_mod:get_opt(resume_timeout, Opts, VFun, 300);
+       Timeout -> Timeout
+    end.
+
+get_max_resume_timeout(Host, Opts, ResumeTimeout) ->
+    VFun = mod_opt_type(max_resume_timeout),
+    case gen_mod:get_module_opt(Host, ?MODULE, max_resume_timeout, VFun) of
+       undefined ->
+           case gen_mod:get_opt(max_resume_timeout, Opts, VFun) of
+               undefined -> ResumeTimeout;
+               Max when Max >= ResumeTimeout -> Max;
+               _ -> ResumeTimeout
+           end;
+       Max when Max >= ResumeTimeout -> Max;
+       _ -> ResumeTimeout
+    end.
+
+get_ack_timeout(Host, Opts) ->
+    VFun = mod_opt_type(ack_timeout),
+    T = case gen_mod:get_module_opt(Host, ?MODULE, ack_timeout, VFun) of
+           undefined -> gen_mod:get_opt(ack_timeout, Opts, VFun, 60);
+           AckTimeout -> AckTimeout
+       end,
+    case T of
+       infinity -> infinity;
+       _ -> timer:seconds(T)
+    end.
+
+get_resend_on_timeout(Host, Opts) ->
+    VFun = mod_opt_type(resend_on_timeout),
+    case gen_mod:get_module_opt(Host, ?MODULE, resend_on_timeout, VFun) of
+       undefined -> gen_mod:get_opt(resend_on_timeout, Opts, VFun, false);
+       Resend -> Resend
+    end.
+
+mod_opt_type(max_ack_queue) ->
+    fun(I) when is_integer(I), I > 0 -> I;
+       (infinity) -> infinity
+    end;
+mod_opt_type(resume_timeout) ->
+    fun(I) when is_integer(I), I >= 0 -> I end;
+mod_opt_type(max_resume_timeout) ->
+    fun(I) when is_integer(I), I >= 0 -> I end;
+mod_opt_type(ack_timeout) ->
+    fun(I) when is_integer(I), I > 0 -> I;
+       (infinity) -> infinity
+    end;
+mod_opt_type(resend_on_timeout) ->
+    fun(B) when is_boolean(B) -> B;
+       (if_offline) -> if_offline
+    end;
+mod_opt_type(_) ->
+    [max_ack_queue, resume_timeout, max_resume_timeout, ack_timeout,
+     resend_on_timeout].
index 1307f9da42ac780ebd7686ff9bed03d880b6f6e5..e9c1b333903b8b4cdafa88baf0daff7763bdfe7e 100644 (file)
 -protocol({rfc, 6120}).
 
 %% API
--export([start/3, call/3, cast/2, reply/2, send/2, send_error/3,
-        get_transport/1, change_shaper/2]).
+-export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1,
+        send/2, close/1, close/2, send_error/3, establish/1,
+        get_transport/1, change_shaper/2, set_timeout/2, format_error/1]).
 
 %% gen_server callbacks
 -export([init/1, handle_cast/2, handle_call/3, handle_info/2,
         terminate/2, code_change/3]).
 
+%%-define(DBGFSM, true).
+-ifdef(DBGFSM).
+-define(FSMOPTS, [{debug, [trace]}]).
+-else.
+-define(FSMOPTS, []).
+-endif.
+
 -include("xmpp.hrl").
 -type state() :: map().
--type next_state() :: {noreply, state()} | {stop, term(), state()}.
+-type stop_reason() :: {stream, reset | stream_error()} |
+                      {tls, term()} |
+                      {socket, inet:posix() | closed | timeout}.
 
 -callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
--callback handle_stream_start(state()) -> next_state().
--callback handle_stream_end(state()) -> next_state().
--callback handle_stream_close(state()) -> next_state().
--callback handle_cdata(binary(), state()) -> next_state().
--callback handle_unauthenticated_packet(xmpp_element(), state()) -> next_state().
--callback handle_authenticated_packet(xmpp_element(), state()) -> next_state().
--callback handle_unbinded_packet(xmpp_element(), state()) -> next_state().
--callback handle_auth_success(binary(), binary(), module(), state()) -> next_state().
--callback handle_auth_failure(binary(), binary(), atom(), state()) -> next_state().
--callback handle_send(ok | {error, atom()},
-                     xmpp_element(), fxml:xmlel(), binary(), state()) -> next_state().
--callback init_sasl(state()) -> cyrsasl:sasl_state().
+-callback handle_cast(term(), state()) -> state().
+-callback handle_call(term(), term(), state()) -> state().
+-callback handle_info(term(), state()) -> state().
+-callback terminate(term(), state()) -> any().
+-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
+-callback handle_stream_start(state()) -> state().
+-callback handle_stream_end(stop_reason(), state()) -> state().
+-callback handle_stream_close(stop_reason(), state()) -> state().
+-callback handle_cdata(binary(), state()) -> state().
+-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
+-callback handle_authenticated_packet(xmpp_element(), state()) -> state().
+-callback handle_unbinded_packet(xmpp_element(), state()) -> state().
+-callback handle_auth_success(binary(), binary(), module(), state()) -> state().
+-callback handle_auth_failure(binary(), binary(), atom(), state()) -> state().
+-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
+-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
+-callback get_password_fun(state()) -> fun().
+-callback check_password_fun(state()) -> fun().
+-callback check_password_digest_fun(state()) -> fun().
 -callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}.
--callback handshake(binary(), state()) -> {ok, state()} | {error, stream_error(), state()}.
 -callback compress_methods(state()) -> [binary()].
 -callback tls_options(state()) -> [proplists:property()].
 -callback tls_required(state()) -> boolean().
--callback sasl_mechanisms(state()) -> [binary()].
+-callback tls_verify(state()) -> boolean().
 -callback unauthenticated_stream_features(state()) -> [xmpp_element()].
 -callback authenticated_stream_features(state()) -> [xmpp_element()].
 
 %% All callbacks are optional
 -optional_callbacks([init/1,
+                    handle_cast/2,
+                    handle_call/3,
+                    handle_info/2,
+                    terminate/2,
+                    code_change/3,
                     handle_stream_start/1,
-                    handle_stream_end/1,
-                    handle_stream_close/1,
+                    handle_stream_end/2,
+                    handle_stream_close/2,
                     handle_cdata/2,
                     handle_authenticated_packet/2,
                     handle_unauthenticated_packet/2,
                     handle_unbinded_packet/2,
                     handle_auth_success/4,
                     handle_auth_failure/4,
-                    handle_send/5,
-                    init_sasl/1,
+                    handle_send/3,
+                    handle_recv/3,
+                    get_password_fun/1,
+                    check_password_fun/1,
+                    check_password_digest_fun/1,
                     bind/2,
-                    handshake/2,
                     compress_methods/1,
                     tls_options/1,
                     tls_required/1,
-                    sasl_mechanisms/1,
+                    tls_verify/1,
                     unauthenticated_stream_features/1,
                     authenticated_stream_features/1]).
 
 %%% API
 %%%===================================================================
 start(Mod, Args, Opts) ->
-    gen_server:start(?MODULE, [Mod|Args], Opts).
+    gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+
+start_link(Mod, Args, Opts) ->
+    gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
 
 call(Ref, Msg, Timeout) ->
     gen_server:call(Ref, Msg, Timeout).
@@ -95,16 +121,80 @@ cast(Ref, Msg) ->
 reply(Ref, Reply) ->
     gen_server:reply(Ref, Reply).
 
--spec send(state(), xmpp_element()) -> next_state().
-send(State, Pkt) ->
-    send_element(State, Pkt).
+-spec stop(pid()) -> ok;
+         (state()) -> no_return().
+stop(Pid) when is_pid(Pid) ->
+    cast(Pid, stop);
+stop(#{owner := Owner} = State) when Owner == self() ->
+    terminate(normal, State),
+    exit(normal);
+stop(_) ->
+    erlang:error(badarg).
 
-get_transport(#{sockmod := SockMod, socket := Socket}) ->
-    SockMod:get_transport(Socket).
+-spec send(pid(), xmpp_element()) -> ok;
+         (state(), xmpp_element()) -> state().
+send(Pid, Pkt) when is_pid(Pid) ->
+    cast(Pid, {send, Pkt});
+send(#{owner := Owner} = State, Pkt) when Owner == self() ->
+    send_element(State, Pkt);
+send(_, _) ->
+    erlang:error(badarg).
+
+-spec close(pid()) -> ok;
+          (state()) -> state().
+close(Ref) ->
+    close(Ref, true).
+
+-spec close(pid(), boolean()) -> ok;
+          (state(), boolean()) -> state().
+close(Pid, SendTrailer) when is_pid(Pid) ->
+    cast(Pid, {close, SendTrailer});
+close(#{owner := Owner} = State, SendTrailer) when Owner == self() ->
+    if SendTrailer -> send_trailer(State);
+       true -> close_socket(State)
+    end;
+close(_, _) ->
+    erlang:error(badarg).
+
+-spec establish(state()) -> state().
+establish(State) ->
+    process_stream_established(State).
+
+-spec set_timeout(state(), non_neg_integer() | infinity) -> state().
+set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
+    case Timeout of
+       infinity -> State#{stream_timeout => infinity};
+       _ ->
+           Time = p1_time_compat:monotonic_time(milli_seconds),
+           State#{stream_timeout => {Timeout, Time}}
+    end;
+set_timeout(_, _) ->
+    erlang:error(badarg).
+
+get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner})
+  when Owner == self() ->
+    SockMod:get_transport(Socket);
+get_transport(_) ->
+    erlang:error(badarg).
 
 -spec change_shaper(state(), shaper:shaper()) -> ok.
-change_shaper(#{sockmod := SockMod, socket := Socket}, Shaper) ->
-    SockMod:change_shaper(Socket, Shaper).
+change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper)
+  when Owner == self() ->
+    SockMod:change_shaper(Socket, Shaper);
+change_shaper(_, _) ->
+    erlang:error(badarg).
+
+-spec format_error(stop_reason()) ->  binary().
+format_error({socket, Reason}) ->
+    format("Connection failed: ~s", [format_inet_error(Reason)]);
+format_error({stream, reset}) ->
+    <<"Stream reset by peer">>;
+format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
+    format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
+format_error({tls, Reason}) ->
+    format("TLS failed: ~w", [Reason]);
+format_error(Err) ->
+    format("Unrecognized error: ~w", [Err]).
 
 %%%===================================================================
 %%% gen_server callbacks
@@ -114,19 +204,24 @@ init([Module, {SockMod, Socket}, Opts]) ->
                    {_, XS} -> XS;
                    false -> false
                end,
-    TLSEnabled = proplists:get_bool(tls, Opts),
+    Encrypted = proplists:get_bool(tls, Opts),
     SocketMonitor = SockMod:monitor(Socket),
     case peername(SockMod, Socket) of
        {ok, IP} ->
-           State = #{mod => Module,
+           Time = p1_time_compat:monotonic_time(milli_seconds),
+           State = #{owner => self(),
+                     mod => Module,
                      socket => Socket,
                      sockmod => SockMod,
                      socket_monitor => SocketMonitor,
+                     stream_timeout => {timer:seconds(30), Time},
+                     stream_direction => in,
                      stream_id => new_id(),
                      stream_state => wait_for_stream,
+                     stream_header_sent => false,
                      stream_restarted => false,
                      stream_compressed => false,
-                     stream_tlsed => TLSEnabled,
+                     stream_encrypted => Encrypted,
                      stream_version => {1,0},
                      stream_authenticated => false,
                      xml_socket => XMLSocket,
@@ -137,97 +232,133 @@ init([Module, {SockMod, Socket}, Opts]) ->
                      resource => <<"">>,
                      lserver => <<"">>,
                      ip => IP},
-           try Module:init([State, Opts])
-           catch _:undef -> {ok, State}
+           case try Module:init([State, Opts])
+                catch _:undef -> {ok, State}
+                end of
+               {ok, State1} ->
+                   {_, State2, Timeout} = noreply(State1),
+                   {ok, State2, Timeout};
+               Err ->
+                   Err
            end;
        {error, Reason} ->
            {stop, Reason}
     end.
 
+handle_cast({send, Pkt}, State) ->
+    noreply(send_element(State, Pkt));
+handle_cast(stop, State) ->
+    {stop, normal, State};
 handle_cast(Cast, #{mod := Mod} = State) ->
-    try Mod:handle_cast(Cast, State)
-    catch _:undef -> {noreply, State}
-    end.
+    noreply(try Mod:handle_cast(Cast, State)
+             catch _:undef -> State
+             end).
 
 handle_call(Call, From, #{mod := Mod} = State) ->
-    try Mod:handle_call(Call, From, State)
-    catch _:undef -> {reply, ok, State}
-    end.
+    noreply(try Mod:handle_call(Call, From, State)
+           catch _:undef -> State
+           end).
 
 handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
-           #{stream_state := wait_for_stream, xmlns := XMLNS} = State) ->
-    try xmpp:decode(#xmlel{name = Name, attrs = Attrs}, XMLNS, []) of
+           #{stream_state := wait_for_stream,
+             xmlns := XMLNS, lang := MyLang} = State) ->
+    El = #xmlel{name = Name, attrs = Attrs},
+    try xmpp:decode(El, XMLNS, []) of
        #stream_start{} = Pkt ->
-           case send_header(State, Pkt) of
-               {noreply, State1} ->
-                   process_stream(Pkt, State1);
-               Err ->
-                   Err
+           State1 = send_header(State, Pkt),
+           case is_disconnected(State1) of
+               true -> State1;
+               false -> noreply(process_stream(Pkt, State1))
            end;
        _ ->
-           case send_header(State) of
-               {noreply, State1} ->
-                   send_element(State1, xmpp:serr_invalid_xml());
-               Err ->
-                   Err
+           State1 = send_header(State),
+           case is_disconnected(State1) of
+               true -> State1;
+               false -> noreply(send_element(State1, xmpp:serr_invalid_xml()))
            end
     catch _:{xmpp_codec, Why} ->
-           case send_header(State) of
-               {noreply, State1} -> process_invalid_xml(Why, State1);
-               Err -> Err
+           State1 = send_header(State),
+           case is_disconnected(State1) of
+               true -> State1;
+               false ->
+                   Txt = xmpp:io_format_error(Why),
+                   Lang = select_lang(MyLang, xmpp:get_lang(El)),
+                   Err = xmpp:serr_invalid_xml(Txt, Lang),
+                   noreply(send_element(State1, Err))
            end
     end;
-handle_info({'$gen_event', {xmlstreamend, _}}, #{mod := Mod} = State) ->
-    try Mod:handle_stream_end(State)
-    catch _:undef -> {stop, normal, State}
-    end;
 handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
-    case send_header(State) of
-       {noreply, State1} ->
+    State1 = send_header(State),
+    case is_disconnected(State1) of
+       true -> State1;
+       false ->
            Err = case Reason of
                      <<"XML stanza is too big">> ->
                          xmpp:serr_policy_violation(Reason, Lang);
                      _ ->
                          xmpp:serr_not_well_formed()
                  end,
-           send_element(State1, Err);
-       Err ->
-           Err
+           noreply(send_element(State1, Err))
     end;
 handle_info({'$gen_event', {xmlstreamelement, El}},
-           #{xmlns := NS} = State) ->
+           #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
     try xmpp:decode(El, NS, [ignore_els]) of
        Pkt ->
-           process_element(Pkt, State)
+           State1 = try Mod:handle_recv(El, Pkt, State)
+                    catch _:undef -> State
+                    end,
+           case is_disconnected(State1) of
+               true -> State1;
+               false -> noreply(process_element(Pkt, State1))
+           end
     catch _:{xmpp_codec, Why} ->
-           process_invalid_xml(Why, State)
+           State1 = try Mod:handle_recv(El, {error, Why}, State)
+                    catch _:undef -> State
+                    end,
+           case is_disconnected(State1) of
+               true -> State1;
+               false ->
+                   Txt = xmpp:io_format_error(Why),
+                   Lang = select_lang(MyLang, xmpp:get_lang(El)),
+                   noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
+           end
     end;
 handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
            #{mod := Mod} = State) ->
-    try Mod:handle_cdata(Data, State)
-    catch _:undef -> {noreply, State}
-    end;
-handle_info(closed, #{mod := Mod} = State) ->
-    try Mod:handle_stream_close(State)
-    catch _:undef -> {stop, normal, State}
-    end;
+    noreply(try Mod:handle_cdata(Data, State)
+           catch _:undef -> State
+           end);
+handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
+    noreply(process_stream_end({error, {stream, reset}}, State));
+handle_info({'$gen_event', closed}, State) ->
+    noreply(process_stream_close({error, {socket, closed}}, State));
+handle_info(timeout, #{mod := Mod} = State) ->
+    Disconnected = is_disconnected(State),
+    noreply(try Mod:handle_timeout(State)
+           catch _:undef when not Disconnected ->
+                   send_element(State, xmpp:serr_connection_timeout());
+                 _:undef ->
+                   stop(State)
+           end);
 handle_info({'DOWN', MRef, _Type, _Object, _Info},
-           #{socket_monitor := MRef, mod := Mod} = State) ->
-    try Mod:handle_stream_close(State)
-    catch _:undef -> {stop, normal, State}
-    end;
+           #{socket_monitor := MRef} = State) ->
+    noreply(process_stream_close({error, {socket, closed}}, State));
 handle_info(Info, #{mod := Mod} = State) ->
-    try Mod:handle_info(Info, State)
-    catch _:undef -> {noreply, State}
-    end.
+    noreply(try Mod:handle_info(Info, State)
+           catch _:undef -> State
+           end).
 
-terminate(Reason, #{mod := Mod, socket := Socket,
-                   sockmod := SockMod} = State) ->
-    try Mod:terminate(Reason, State)
-    catch _:undef -> ok
-    end,
-    send_text(State, <<"</stream:stream>">>),
-    SockMod:close(Socket).
+terminate(Reason, #{mod := Mod} = State) ->
+    case get(already_terminated) of
+       true ->
+           State;
+       _ ->
+           put(already_terminated, true),
+           try Mod:terminate(Reason, State)
+           catch _:undef -> ok
+           end,
+           send_trailer(State)
+    end.
 
 code_change(OldVsn, #{mod := Mod} = State, Extra) ->
     Mod:code_change(OldVsn, State, Extra).
@@ -235,20 +366,49 @@ code_change(OldVsn, #{mod := Mod} = State, Extra) ->
 %%%===================================================================
 %%% Internal functions
 %%%===================================================================
+-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
+noreply(#{stream_timeout := infinity} = State) ->
+    {noreply, State, infinity};
+noreply(#{stream_timeout := {MSecs, StartTime}} = State) ->
+    CurrentTime = p1_time_compat:monotonic_time(milli_seconds),
+    Timeout = max(0, MSecs - CurrentTime + StartTime),
+    {noreply, State, Timeout}.
+
 -spec new_id() -> binary().
 new_id() ->
     randoms:get_string().
 
+-spec is_disconnected(state()) -> boolean().
+is_disconnected(#{stream_state := StreamState}) ->
+    StreamState == disconnected.
+
+-spec peername(term(), term()) -> {ok, {inet:ip_address(), inet:port_number()}}|
+                                 {error, inet:posix()}.
 peername(SockMod, Socket) ->
     case SockMod of
        gen_tcp -> inet:peername(Socket);
        _ -> SockMod:peername(Socket)
     end.
 
-process_invalid_xml(Reason, #{lang := Lang} = State) ->
-    Txt = xmpp:io_format_error(Reason),
-    send_element(State, xmpp:serr_invalid_xml(Txt, Lang)).
+-spec process_stream_close(stop_reason(), state()) -> state().
+process_stream_close(_, #{stream_state := disconnected} = State) ->
+    State;
+process_stream_close(Reason, #{mod := Mod} = State) ->
+    State1 = send_trailer(State),
+    try Mod:handle_stream_close(Reason, State1)
+    catch _:undef -> stop(State1)
+    end.
+
+-spec process_stream_end(stop_reason(), state()) -> state().
+process_stream_end(_, #{stream_state := disconnected} = State) ->
+    State;
+process_stream_end(Reason, #{mod := Mod} = State) ->
+    State1 = send_trailer(State),
+    try Mod:handle_stream_end(Reason, State1)
+    catch _:undef -> stop(State1)
+    end.
 
+-spec process_stream(stream_start(), state()) -> state().
 process_stream(#stream_start{xmlns = XML_NS,
                             stream_xmlns = STREAM_NS},
               #{xmlns := NS} = State)
@@ -268,73 +428,67 @@ process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) ->
     send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
 process_stream(#stream_start{from = undefined, version = {1,0}},
               #{lang := Lang, xmlns := ?NS_SERVER,
-                stream_tlsed := true} = State) ->
+                stream_encrypted := true} = State) ->
     Txt = <<"Missing 'from' attribute">>,
     send_element(State, xmpp:serr_invalid_from(Txt, Lang));
 process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
               #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
     Txt = <<"Improper 'to' attribute">>,
     send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
-process_stream(#stream_start{to = #jid{lserver = RemoteServer}},
+process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
               #{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
-    State1 = State#{remote_server => RemoteServer},
-    case try Mod:handle_stream_start(State1)
-        catch _:undef -> {noreply, State1}
-        end of
-       {noreply, State2} ->
-           {noreply, State2#{stream_state => wait_for_handshake}};
-       Err ->
-           Err
+    State1 = State#{remote_server => RemoteServer,
+                   stream_state => wait_for_handshake},
+    try Mod:handle_stream_start(StreamStart, State1)
+    catch _:undef -> State1
     end;
 process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
-                            from = From},
+                            from = From} = StreamStart,
               #{stream_authenticated := Authenticated,
                 stream_restarted := StreamWasRestarted,
                 mod := Mod, xmlns := NS, resource := Resource,
-                stream_tlsed := TLSEnabled} = State) ->
-    case if not StreamWasRestarted ->
-                State1 = State#{server => Server, lserver => LServer},
-                try Mod:handle_stream_start(State1)
-                catch _:undef -> {noreply, State1}
-                end;
-           true ->
-                {noreply, State}
-        end of
-       {noreply, State2} ->
-           State3 = if NS == ?NS_SERVER andalso TLSEnabled ->
-                            State2#{remote_server => From#jid.lserver};
-                       true ->
-                            State2
-                    end,
-           case send_features(State3) of
-               {noreply, State4} ->
+                stream_encrypted := Encrypted} = State) ->
+    State1 = if not StreamWasRestarted ->
+                    State#{server => Server, lserver => LServer};
+               true ->
+                    State
+            end,
+    State2 = if NS == ?NS_SERVER andalso Encrypted ->
+                    State1#{remote_server => From#jid.lserver};
+               true ->
+                    State1
+            end,
+    State3 = try Mod:handle_stream_start(StreamStart, State2)
+            catch _:undef -> State2
+            end,
+    case is_disconnected(State3) of
+       true -> State3;
+       false ->
+           State4 = send_features(State3),
+           case is_disconnected(State4) of
+               true -> State4;
+               false ->
                    TLSRequired = is_starttls_required(State4),
-                   NewStreamState =
-                       if not Authenticated and
-                          (not TLSEnabled and TLSRequired) ->
-                               wait_for_starttls;
-                          not Authenticated ->
-                               wait_for_sasl_request;
-                          (NS == ?NS_CLIENT) and (Resource == <<"">>) ->
-                               wait_for_bind;
-                          true ->
-                               session_established
-                       end,
-                   {noreply, State4#{stream_state => NewStreamState}};
-               Err ->
-                   Err
-           end;
-       Err ->
-           Err
+                   if not Authenticated and (TLSRequired and not Encrypted) ->
+                           State4#{stream_state => wait_for_starttls};
+                      not Authenticated ->
+                           State4#{stream_state => wait_for_sasl_request};
+                      (NS == ?NS_CLIENT) and (Resource == <<"">>) ->
+                           State4#{stream_state => wait_for_bind};
+                      true ->
+                           process_stream_established(State4)
+                   end
+           end
     end.
 
+-spec process_element(xmpp_element(), state()) -> state().
 process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
     case Pkt of
        #starttls{} when StateName == wait_for_starttls;
                         StateName == wait_for_sasl_request ->
            process_starttls(State);
        #starttls{} ->
-           send_element(State, #starttls_failure{});
+           process_starttls_failure(unexpected_starttls_request, State);
        #sasl_auth{} when StateName == wait_for_starttls ->
            send_element(State, #sasl_failure{reason = 'encryption-required'});
        #sasl_auth{} when StateName == wait_for_sasl_request ->
@@ -356,7 +510,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
        #sasl_abort{} ->
            send_element(State, #sasl_failure{reason = 'aborted'});
        #sasl_success{} ->
-           {noreply, State};
+           State;
        #compress{} when StateName == wait_for_sasl_response ->
            send_element(State, #compress_failure{reason = 'setup-failed'});
        #compress{} ->
@@ -364,7 +518,9 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
        #handshake{} when StateName == wait_for_handshake ->
            process_handshake(Pkt, State);
        #handshake{} ->
-           {noreply, State};
+           State;
+       #stream_error{} ->
+           process_stream_end({error, {stream, Pkt}}, State);
        _ when StateName == wait_for_sasl_request;
               StateName == wait_for_handshake;
               StateName == wait_for_sasl_response ->
@@ -375,10 +531,11 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
            send_error(State, Pkt, Err);
        _ when StateName == wait_for_bind ->
            process_bind(Pkt, State);
-       _ when StateName == session_established ->
+       _ when StateName == established ->
            process_authenticated_packet(Pkt, State)
     end.
 
+-spec process_unauthenticated_packet(xmpp_element(), state()) -> state().
 process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
     NewPkt = set_lang(Pkt, State),
     try Mod:handle_unauthenticated_packet(NewPkt, State)
@@ -387,6 +544,7 @@ process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
            send_error(State, Pkt, Err)
     end.
 
+-spec process_authenticated_packet(xmpp_element(), state()) -> state().
 process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
     Pkt1 = set_lang(Pkt, State),
     case set_from_to(Pkt1, State) of
@@ -411,6 +569,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
            send_element(State, Err)
     end.
 
+-spec process_bind(xmpp_element(), state()) -> state().
 process_bind(#iq{type = set, sub_els = [_]} = Pkt,
             #{xmlns := ?NS_CLIENT, mod := Mod, lang := Lang} = State) ->
     case xmpp:get_subtag(Pkt, #bind{}) of
@@ -426,8 +585,8 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt,
                               server := S,
                               resource := NewR} = State1} when NewR /= <<"">> ->
                            Reply = #bind{jid = jid:make(U, S, NewR)},
-                           State2 = State1#{stream_state => session_established},
-                           send_element(State2, xmpp:make_iq_result(Pkt, Reply));
+                           State2 = send_element(State1, xmpp:make_iq_result(Pkt, Reply)),
+                           process_stream_established(State2);
                        {error, #stanza_error{}, State1} = Err ->
                            send_error(State1, Pkt, Err)
                    end
@@ -446,16 +605,55 @@ process_bind(Pkt, #{mod := Mod} = State) ->
            send_error(State, Pkt, Err)
     end.
 
-process_handshake(#handshake{data = Data}, #{mod := Mod} = State) ->
-    case Mod:handshake(Data, State) of
-       {ok, State1} ->
-           State2 = State1#{stream_state => session_established,
-                            stream_authenticated => true},
-           send_element(State2, #handshake{});
-       {error, #stream_error{} = Err, State1} ->
-           send_element(State1, Err)
+-spec process_handshake(handshake(), state()) -> state().
+process_handshake(#handshake{data = Digest},
+                 #{mod := Mod, stream_id := StreamID,
+                   remote_server := RemoteServer} = State) ->
+    GetPW = try Mod:get_password_fun(State)
+           catch _:undef -> fun(_) -> {false, undefined} end
+           end,
+    AuthRes = case GetPW(<<"">>) of
+                 {false, _} ->
+                     false;
+                 {Password, _} ->
+                     p1_sha:sha(<<StreamID/binary, Password/binary>>) == Digest
+             end,
+    case AuthRes of
+       true ->
+           State1 = try Mod:handle_auth_success(
+                          RemoteServer, <<"handshake">>, undefined, State)
+                    catch _:undef -> State
+                    end,
+           case is_disconnected(State1) of
+               true -> State1;
+               false ->
+                   State2 = send_element(State1, #handshake{}),
+                   process_stream_established(State2)
+           end;
+       false ->
+           State1 = try Mod:handle_auth_failure(
+                          RemoteServer, <<"handshake">>, 'not-authorized', State)
+                    catch _:undef -> State
+                    end,
+           case is_disconnected(State1) of
+               true -> State1;
+               false -> send_element(State1, xmpp:serr_not_authorized())
+           end
+    end.
+
+-spec process_stream_established(state()) -> state().
+process_stream_established(#{stream_state := StateName} = State)
+  when StateName == disconnected; StateName == established ->
+    State;
+process_stream_established(#{mod := Mod} = State) ->
+    State1 = State#{stream_authenticated := true,
+                   stream_state => established,
+                   stream_timeout => infinity},
+    try Mod:handle_stream_established(State1)
+    catch _:undef -> State1
     end.
 
+-spec process_compress(compress(), state()) -> state().
 process_compress(#compress{}, #{stream_compressed := true} = State) ->
     send_element(State, #compress_failure{reason = 'setup-failed'});
 process_compress(#compress{methods = HisMethods},
@@ -468,16 +666,17 @@ process_compress(#compress{methods = HisMethods},
        true ->
            BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})),
            ZlibSocket = SockMod:compress(Socket, BCompressed),
-           State1 = State#{socket => ZlibSocket,
-                           stream_id => new_id(),
-                           stream_restarted => true,
-                           stream_state => wait_for_stream,
-                           stream_compressed => true},
-           {noreply, State1};
+           State#{socket => ZlibSocket,
+                  stream_id => new_id(),
+                  stream_header_sent => false,
+                  stream_restarted => true,
+                  stream_state => wait_for_stream,
+                  stream_compressed => true};
        false ->
            send_element(State, #compress_failure{reason = 'unsupported-method'})
     end.
 
+-spec process_starttls(state()) -> state().
 process_starttls(#{socket := Socket,
                   sockmod := SockMod, mod := Mod} = State) ->
     TLSOpts = try Mod:tls_options(State)
@@ -485,38 +684,69 @@ process_starttls(#{socket := Socket,
              end,
     case SockMod:starttls(Socket, TLSOpts) of
        {ok, TLSSocket} ->
-           case send_element(State, #starttls_proceed{}) of
-               {noreply, State1} ->
-                   {noreply, State1#{socket => TLSSocket,
-                                     stream_id => new_id(),
-                                     stream_restarted => true,
-                                     stream_state => wait_for_stream,
-                                     stream_tlsed => true}};
-               Err ->
-                   Err
+           State1 = send_element(State, #starttls_proceed{}),
+           case is_disconnected(State1) of
+               true -> State1;
+               false ->
+                   State1#{socket => TLSSocket,
+                           stream_id => new_id(),
+                           stream_header_sent => false,
+                           stream_restarted => true,
+                           stream_state => wait_for_stream,
+                           stream_encrypted => true}
            end;
-       {error, _Reason} ->
-           send_element(State, #starttls_failure{})
+       {error, Reason} ->
+           process_starttls_failure(Reason, State)
     end.
 
-process_sasl_request(#sasl_auth{mechanism = <<"EXTERNAL">>},
-                    #{stream_tlsed := false} = State) ->
-    process_sasl_failure('encryption-required', <<"">>, State);
-process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
-                    #{mod := Mod} = State) ->
-    try Mod:init_sasl(State) of
-       SASLState ->
-           SASLResult = cyrsasl:server_start(SASLState, Mech, ClientIn),
-           process_sasl_result(SASLResult, State)
-    catch _:undef ->
-           process_sasl_failure('temporary-auth-failure', <<"">>, State)
+-spec process_starttls_failure(term(), state()) -> state().
+process_starttls_failure(Why, State) ->
+    State1 = send_element(State, #starttls_failure{}),
+    case is_disconnected(State1) of
+       true -> State1;
+       false -> process_stream_end({error, {tls, Why}}, State1)
     end.
 
+-spec process_sasl_request(sasl_auth(), state()) -> state().
+process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
+                    #{mod := Mod, lserver := LServer} = State) ->
+    GetPW = try Mod:get_password_fun(State)
+           catch _:undef -> fun(_) -> false end
+           end,
+    CheckPW = try Mod:check_password_fun(State)
+             catch _:undef -> fun(_, _, _) -> false end
+             end,
+    CheckPWDigest = try Mod:check_password_digest_fun(State)
+                   catch _:undef -> fun(_, _, _, _, _) -> false end
+                   end,
+    SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
+                                  GetPW, CheckPW, CheckPWDigest),
+    State1 = State#{sasl_state => SASLState, sasl_mech => Mech},
+    Mechs = get_sasl_mechanisms(State1),
+    SASLResult = case lists:member(Mech, Mechs) of
+                    true when Mech == <<"EXTERNAL">> ->
+                        case xmpp_stream_pkix:authenticate(State1, ClientIn) of
+                            {ok, Peer} ->
+                                {ok, [{auth_module, pkix},
+                                      {username, Peer}]};
+                            {error, _Reason, Peer} ->
+                                %% TODO: return meaningful error
+                                {error, 'not-authorized', Peer}
+                        end;
+                    true ->
+                        cyrsasl:server_start(SASLState, Mech, ClientIn);
+                    false ->
+                        {error, 'invalid-mechanism'}
+                end,
+    process_sasl_result(SASLResult, State1).
+
+-spec process_sasl_response(sasl_response(), state()) -> state().
 process_sasl_response(#sasl_response{text = ClientIn},
                      #{sasl_state := SASLState} = State) ->
     SASLResult = cyrsasl:server_step(SASLState, ClientIn),
     process_sasl_result(SASLResult, State).
 
+-spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state().
 process_sasl_result({ok, Props}, State) ->
     process_sasl_success(Props, <<"">>, State);
 process_sasl_result({ok, Props, ServerOut}, State) ->
@@ -528,58 +758,59 @@ process_sasl_result({error, Reason, User}, State) ->
 process_sasl_result({error, Reason}, State) ->
     process_sasl_failure(Reason, <<"">>, State).
 
+-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
 process_sasl_success(Props, ServerOut,
                     #{socket := Socket, sockmod := SockMod,
-                      mod := Mod, sasl_state := SASLState} = State) ->
-    Mech = cyrsasl:get_mech(SASLState),
+                      mod := Mod, sasl_mech := Mech} = State) ->
     User = identity(Props),
     AuthModule = proplists:get_value(auth_module, Props),
-    case try Mod:handle_auth_success(User, Mech, AuthModule, State)
-        catch _:undef -> {noreply, State}
-        end of
-       {noreply, State1} ->
+    State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State)
+            catch _:undef -> State
+            end,
+    case is_disconnected(State1) of
+       true -> State1;
+       false ->
            SockMod:reset_stream(Socket),
-           case send_element(State1, #sasl_success{text = ServerOut}) of
-               {noreply, State2} ->
-                   State3 = maps:remove(sasl_state, State2),
-                   {noreply, State3#{stream_id => new_id(),
-                                     stream_authenticated => true,
-                                     stream_restarted => true,
-                                     stream_state => wait_for_stream,
-                                     user => User}};
-               Err ->
-                   Err
-           end;
-       Err ->
-           Err
+           State2 = send_element(State1, #sasl_success{text = ServerOut}),
+           case is_disconnected(State2) of
+               true -> State2;
+               false ->
+                   State3 = maps:remove(sasl_state,
+                                        maps:remove(sasl_mech, State2)),
+                   State3#{stream_id => new_id(),
+                           stream_authenticated => true,
+                           stream_header_sent => false,
+                           stream_restarted => true,
+                           stream_state => wait_for_stream,
+                           user => User}
+           end
     end.
 
+-spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state().
 process_sasl_continue(ServerOut, NewSASLState, State) ->
-    send_element(State, #sasl_challenge{text = ServerOut}),
-    {noreply, State#{sasl_state => NewSASLState,
-                    stream_state => wait_for_sasl_response}}.
+    State1 = State#{sasl_state => NewSASLState,
+                   stream_state => wait_for_sasl_response},
+    send_element(State1, #sasl_challenge{text = ServerOut}).
 
+-spec process_sasl_failure(atom(), binary(), state()) -> state().
 process_sasl_failure(Reason, User,
-                    #{mod := Mod, sasl_state := SASLState} = State) ->
-    Mech = cyrsasl:get_mech(SASLState),
-    case try Mod:handle_auth_failure(User, Mech, Reason, State)
-        catch _:undef -> {noreply, State}
-        end of
-       {noreply, State1} ->
-           State2 = maps:remove(sasl_state, State1),
-           State3 = State2#{stream_state => wait_for_sasl_request},
-           send_element(State3, #sasl_failure{reason = Reason});
-       Err ->
-           Err
-    end.
+                    #{mod := Mod, sasl_mech := Mech} = State) ->
+    State1 = try Mod:handle_auth_failure(User, Mech, Reason, State)
+            catch _:undef -> State
+            end,
+    State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)),
+    State3 = State2#{stream_state => wait_for_sasl_request},
+    send_element(State3, #sasl_failure{reason = Reason}).
 
+-spec process_sasl_abort(state()) -> state().
 process_sasl_abort(State) ->
     process_sasl_failure('aborted', <<"">>, State).
 
+-spec send_features(state()) -> state().
 send_features(#{stream_version := {1,0},
-               stream_tlsed := TLSEnabled} = State) ->
+               stream_encrypted := Encrypted} = State) ->
     TLSRequired = is_starttls_required(State),
-    Features = if TLSRequired and not TLSEnabled ->
+    Features = if TLSRequired and not Encrypted ->
                       get_tls_feature(State);
                  true ->
                       get_sasl_feature(State) ++ get_compress_feature(State)
@@ -588,26 +819,38 @@ send_features(#{stream_version := {1,0},
               end,
     send_element(State, #stream_features{sub_els = Features});
 send_features(State) ->
-    %% clients from stone age
-    {noreply, State}.
+    %% clients and servers from stone age
+    State.
 
+-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()].
+get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod,
+                     xmlns := NS, lserver := LServer} = State) ->
+    Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer);
+              true -> []
+           end,
+    TLSVerify = try Mod:tls_verify(State)
+               catch _:undef -> false
+               end,
+    if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
+           [<<"EXTERNAL">>|Mechs];
+       true ->
+           Mechs
+    end.
+
+-spec get_sasl_feature(state()) -> [sasl_mechanisms()].
 get_sasl_feature(#{stream_authenticated := false,
-                  mod := Mod,
-                  stream_tlsed := TLSEnabled} = State) ->
+                  stream_encrypted := Encrypted} = State) ->
     TLSRequired = is_starttls_required(State),
-    if TLSEnabled or not TLSRequired ->
-           try Mod:sasl_mechanisms(State) of
-               [] -> [];
-               List -> [#sasl_mechanisms{list = List}]
-           catch _:undef ->
-                   []
-           end;
+    if Encrypted or not TLSRequired ->
+           Mechs = get_sasl_mechanisms(State),
+           [#sasl_mechanisms{list = Mechs}];
        true ->
            []
     end;
 get_sasl_feature(_) ->
     [].
 
+-spec get_compress_feature(state()) -> [compression()].
 get_compress_feature(#{stream_compressed := false, mod := Mod} = State) ->
     try Mod:compress_methods(State) of
        [] -> [];
@@ -618,23 +861,31 @@ get_compress_feature(#{stream_compressed := false, mod := Mod} = State) ->
 get_compress_feature(_) ->
     [].
 
+-spec get_tls_feature(state()) -> [starttls()].
 get_tls_feature(#{stream_authenticated := false,
-                 stream_tlsed := false} = State) ->
+                 stream_encrypted := false} = State) ->
     TLSRequired = is_starttls_required(State),
     [#starttls{required = TLSRequired}];
 get_tls_feature(_) ->
     [].
 
-get_bind_feature(#{stream_authenticated := true, resource := <<"">>}) ->
+-spec get_bind_feature(state()) -> [bind()].
+get_bind_feature(#{xmlns := ?NS_CLIENT,
+                  stream_authenticated := true,
+                  resource := <<"">>}) ->
     [#bind{}];
 get_bind_feature(_) ->
     [].
 
-get_session_feature(#{stream_authenticated := true, resource := <<"">>}) ->
+-spec get_session_feature(state()) -> [xmpp_session()].
+get_session_feature(#{xmlns := ?NS_CLIENT,
+                     stream_authenticated := true,
+                     resource := <<"">>}) ->
     [#xmpp_session{optional = true}];
 get_session_feature(_) ->
     [].
 
+-spec get_other_features(state()) -> [xmpp_element()].
 get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
     try
        if Auth -> Mod:authenticated_stream_features(State);
@@ -644,15 +895,18 @@ get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
            []
     end.
 
+-spec is_starttls_required(state()) -> boolean().
 is_starttls_required(#{mod := Mod} = State) ->
     try Mod:tls_required(State)
     catch _:undef -> false
     end.
 
+-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} |
+                                             {error, stream_error()}.
 set_from_to(Pkt, _State) when not ?is_stanza(Pkt) ->
     {ok, Pkt};
 set_from_to(Pkt, #{user := U, server := S, resource := R,
-                    xmlns := ?NS_CLIENT}) ->
+                  lang := Lang, xmlns := ?NS_CLIENT}) ->
     JID = jid:make(U, S, R),
     From = case xmpp:get_from(Pkt) of
               undefined -> JID;
@@ -668,7 +922,8 @@ set_from_to(Pkt, #{user := U, server := S, resource := R,
                 end,
            {ok, xmpp:set_from_to(Pkt, JID, To)};
        true ->
-           {error, xmpp:serr_invalid_from()}
+           Txt = <<"Improper 'from' attribute">>,
+           {error, xmpp:serr_invalid_from(Txt, Lang)}
     end;
 set_from_to(Pkt, #{lang := Lang}) ->
     From = xmpp:get_from(Pkt),
@@ -683,17 +938,22 @@ set_from_to(Pkt, #{lang := Lang}) ->
            {ok, Pkt}
     end.
 
+-spec send_header(state()) -> state().
 send_header(State) ->
     send_header(State, #stream_start{}).
 
-send_header(#{stream_state := wait_for_stream,
-             stream_id := StreamID,
+-spec send_header(state(), stream_start()) -> state().
+send_header(#{stream_id := StreamID,
              stream_version := MyVersion,
+             stream_header_sent := false,
              lang := MyLang,
              xmlns := NS,
              server := DefaultServer} = State,
            #stream_start{to = To, lang = HisLang, version = HisVersion}) ->
-    Lang = choose_lang(MyLang, HisLang),
+    Lang = select_lang(MyLang, HisLang),
+    NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
+              true -> <<"">>
+           end,
     From = case To of
               #jid{} -> To;
               undefined -> jid:make(DefaultServer)
@@ -706,63 +966,114 @@ send_header(#{stream_state := wait_for_stream,
                                       lang = Lang,
                                       xmlns = NS,
                                       stream_xmlns = ?NS_STREAM,
+                                      db_xmlns = NS_DB,
                                       id = StreamID,
                                       from = From}),
-    State1 = State#{lang => Lang},
+    State1 = State#{lang => Lang, stream_header_sent => true},
     case send_text(State1, fxml:element_to_header(Header)) of
-       ok -> {noreply, State1};
-       {error, _} -> {stop, normal, State1}
+       ok -> State1;
+       {error, Why} -> process_stream_close({error, {socket, Why}}, State1)
     end;
 send_header(State, _) ->
-    {noreply, State}.
+    State.
 
+-spec send_element(state(), xmpp_element()) -> state().
 send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
     El = xmpp:encode(Pkt, NS),
     Data = fxml:element_to_binary(El),
-    case send_text(State, Data) of
-       ok when is_record(Pkt, stream_error) ->
-           {stop, normal, State};
-       ok when is_record(Pkt, starttls_failure) ->
-           {stop, normal, State};
-       Res ->
-           try Mod:handle_send(Res, Pkt, El, Data, State)
-           catch _:undef when Res == ok ->
-                   {noreply, State};
-                 _:undef ->
-                   {stop, normal, State}
-           end
+    Result = send_text(State, Data),
+    State1 = try Mod:handle_send(Pkt, Result, State)
+            catch _:undef -> State
+            end,
+    case Result of
+       _ when is_record(Pkt, stream_error) ->
+           process_stream_end({error, {stream, Pkt}}, State1);
+       ok ->
+           State1;
+       {error, Why} ->
+           process_stream_close({error, {socket, Why}}, State1)
     end.
 
-send_error(State, Pkt, Err) when ?is_stanza(Pkt) ->
-    case xmpp:get_type(Pkt) of
-       result -> {noreply, State};
-       error -> {noreply, State};
-       _ ->
-           ErrPkt = xmpp:make_error(Pkt, Err),
-           send_element(State, ErrPkt)
-    end;
-send_error(State, _, _) ->
-    {noreply, State}.
+-spec send_error(state(), xmpp_element(), stanza_error()) -> state().
+send_error(State, Pkt, Err) ->
+    case xmpp:is_stanza(Pkt) of
+       true ->
+           case xmpp:get_type(Pkt) of
+               result -> State;
+               error -> State;
+               <<"result">> -> State;
+               <<"error">> -> State;
+               _ ->
+                   ErrPkt = xmpp:make_error(Pkt, Err),
+                   send_element(State, ErrPkt)
+           end;
+       false ->
+           State
+    end.
+
+-spec send_trailer(state()) -> state().
+send_trailer(State) ->
+    send_text(State, <<"</stream:stream>">>),
+    close_socket(State).
 
-send_text(#{socket := Sock, sockmod := SockMod}, Data) ->
-    SockMod:send(Sock, Data).
+-spec send_text(state(), binary()) -> ok | {error, inet:posix()}.
+send_text(#{socket := Sock, sockmod := SockMod,
+           stream_state := StateName,
+           stream_header_sent := true}, Data) when StateName /= disconnected ->
+    SockMod:send(Sock, Data);
+send_text(_, _) ->
+    {error, einval}.
 
-choose_lang(Lang, <<"">>) -> Lang;
-choose_lang(_, Lang) -> Lang.
+-spec close_socket(state()) -> state().
+close_socket(#{sockmod := SockMod, socket := Socket} = State) ->
+    SockMod:close(Socket),
+    State#{stream_timeout => infinity,
+          stream_state => disconnected}.
 
+-spec select_lang(binary(), binary()) -> binary().
+select_lang(Lang, <<"">>) -> Lang;
+select_lang(_, Lang) -> Lang.
+
+-spec set_lang(xmpp_element(), state()) -> xmpp_element().
 set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) ->
     HisLang = xmpp:get_lang(Pkt),
-    Lang = choose_lang(MyLang, HisLang),
+    Lang = select_lang(MyLang, HisLang),
     xmpp:set_lang(Pkt, Lang);
 set_lang(Pkt, _) ->
     Pkt.
 
+-spec format_inet_error(atom()) -> string().
+format_inet_error(Reason) ->
+    case inet:format_error(Reason) of
+       "unknown POSIX error" -> atom_to_list(Reason);
+       Txt -> Txt
+    end.
+
+-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
+format_stream_error(Reason, Txt) ->
+    Slogan = case Reason of
+                #'see-other-host'{} -> "see-other-host";
+                _ -> atom_to_list(Reason)
+            end,
+    case Txt of
+       undefined -> Slogan;
+       #text{data = <<"">>} -> Slogan;
+       #text{data = Data} ->
+           binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
+    end.
+
+-spec format(io:format(), list()) -> binary().
+format(Fmt, Args) ->
+    iolist_to_binary(io_lib:format(Fmt, Args)).
+
+-spec lists_intersection(list(), list()) -> list().
 lists_intersection(L1, L2) ->
     lists:filter(
       fun(E) ->
              lists:member(E, L2)
       end, L1).
 
+-spec identity([cyrsasl:sasl_property()]) -> binary().
 identity(Props) ->
     case proplists:get_value(authzid, Props, <<>>) of
        <<>> -> proplists:get_value(username, Props, <<>>);
diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl
new file mode 100644 (file)
index 0000000..fc373ff
--- /dev/null
@@ -0,0 +1,856 @@
+%%%-------------------------------------------------------------------
+%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%% @copyright (C) 2016, Evgeny Khramtsov
+%%% @doc
+%%%
+%%% @end
+%%% Created : 14 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%-------------------------------------------------------------------
+-module(xmpp_stream_out).
+-behaviour(gen_server).
+
+-protocol({rfc, 6120}).
+
+%% API
+-export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1,
+        stop/1, send/2, close/1, close/2, establish/1, format_error/1,
+        set_timeout/2, get_transport/1, change_shaper/2]).
+%% gen_server callbacks
+-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
+        terminate/2, code_change/3]).
+
+%%-define(DBGFSM, true).
+-ifdef(DBGFSM).
+-define(FSMOPTS, [{debug, [trace]}]).
+-else.
+-define(FSMOPTS, []).
+-endif.
+
+-define(TCP_SEND_TIMEOUT, 15000).
+
+-include("xmpp.hrl").
+-include("logger.hrl").
+-include_lib("kernel/include/inet.hrl").
+
+-type state() :: map().
+-type host_port() :: {inet:hostname(), inet:port_number()}.
+-type ip_port() :: {inet:ip_address(), inet:port_number()}.
+-type network_error() :: {error, inet:posix() | inet_res:res_error()}.
+-type stop_reason() :: {idna, bad_string} |
+                      {dns, inet:posix() | inet_res:res_error()} |
+                      {stream, reset | stream_error()} |
+                      {tls, term()} |
+                      {pkix, binary()} |
+                      {auth, atom() | binary() | string()} |
+                      {socket, inet:posix() | closed | timeout}.
+
+-callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+start(Mod, Args, Opts) ->
+    gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+
+start_link(Mod, Args, Opts) ->
+    gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+
+call(Ref, Msg, Timeout) ->
+    gen_server:call(Ref, Msg, Timeout).
+
+cast(Ref, Msg) ->
+    gen_server:cast(Ref, Msg).
+
+reply(Ref, Reply) ->
+    gen_server:reply(Ref, Reply).
+
+-spec connect(pid()) -> ok.
+connect(Ref) ->
+    cast(Ref, connect).
+
+-spec stop(pid()) -> ok;
+         (state()) -> no_return().
+stop(Pid) when is_pid(Pid) ->
+    cast(Pid, stop);
+stop(#{owner := Owner} = State) when Owner == self() ->
+    terminate(normal, State),
+    exit(normal);
+stop(_) ->
+    erlang:error(badarg).
+
+-spec send(pid(), xmpp_element()) -> ok;
+         (state(), xmpp_element()) -> state().
+send(Pid, Pkt) when is_pid(Pid) ->
+    cast(Pid, {send, Pkt});
+send(#{owner := Owner} = State, Pkt) when Owner == self() ->
+    send_element(State, Pkt);
+send(_, _) ->
+    erlang:error(badarg).
+
+-spec close(pid()) -> ok;
+          (state()) -> state().
+close(Ref) ->
+    close(Ref, true).
+
+-spec close(pid(), boolean()) -> ok;
+          (state(), boolean()) -> state().
+close(Pid, SendTrailer) when is_pid(Pid) ->
+    cast(Pid, {close, SendTrailer});
+close(#{owner := Owner} = State, SendTrailer) when Owner == self() ->
+    if SendTrailer -> send_trailer(State);
+       true -> close_socket(State)
+    end;
+close(_, _) ->
+    erlang:error(badarg).
+
+-spec establish(state()) -> state().
+establish(State) ->
+    process_stream_established(State).
+
+-spec set_timeout(state(), non_neg_integer() | infinity) -> state().
+set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
+    case Timeout of
+       infinity -> State#{stream_timeout => infinity};
+       _ ->
+           Time = p1_time_compat:monotonic_time(milli_seconds),
+           State#{stream_timeout => {Timeout, Time}}
+    end;
+set_timeout(_, _) ->
+    erlang:error(badarg).
+
+get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner})
+  when Owner == self() ->
+    SockMod:get_transport(Socket);
+get_transport(_) ->
+    erlang:error(badarg).
+
+-spec change_shaper(state(), shaper:shaper()) -> ok.
+change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper)
+  when Owner == self() ->
+    SockMod:change_shaper(Socket, Shaper);
+change_shaper(_, _) ->
+    erlang:error(badarg).
+
+-spec format_error(stop_reason()) ->  binary().
+format_error({idna, _}) ->
+    <<"Not an IDN hostname">>;
+format_error({dns, Reason}) ->
+    format("DNS lookup failed: ~s", [format_inet_error(Reason)]);
+format_error({socket, Reason}) ->
+    format("Connection failed: ~s", [format_inet_error(Reason)]);
+format_error({pkix, Reason}) ->
+    format("Peer certificate rejected: ~s", [Reason]);
+format_error({stream, reset}) ->
+    <<"Stream reset by peer">>;
+format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
+    format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
+format_error({tls, Reason}) ->
+    format("TLS failed: ~w", [Reason]);
+format_error({auth, Reason}) ->
+    format("Authentication failed: ~s", [Reason]);
+format_error(Err) ->
+    format("Unrecognized error: ~w", [Err]).
+
+%%%===================================================================
+%%% gen_server callbacks
+%%%===================================================================
+init([Mod, SockMod, From, To, Opts]) ->
+    Time = p1_time_compat:monotonic_time(milli_seconds),
+    State = #{owner => self(),
+             mod => Mod,
+             sockmod => SockMod,
+             server => From,
+             user => <<"">>,
+             resource => <<"">>,
+             lang => <<"">>,
+             remote_server => To,
+             xmlns => ?NS_SERVER,
+             stream_direction => out,
+             stream_timeout => {timer:seconds(30), Time},
+             stream_id => new_id(),
+             stream_encrypted => false,
+             stream_verified => false,
+             stream_authenticated => false,
+             stream_restarted => false,
+             stream_state => connecting},
+    case try Mod:init([State, Opts])
+        catch _:undef -> {ok, State}
+        end of
+       {ok, State1} ->
+           {_, State2, Timeout} = noreply(State1),
+           {ok, State2, Timeout};
+       Err ->
+           Err
+    end.
+
+handle_call(Call, From, #{mod := Mod} = State) ->
+    noreply(try Mod:handle_call(Call, From, State)
+           catch _:undef -> State
+           end).
+
+handle_cast(connect, #{remote_server := RemoteServer,
+                      sockmod := SockMod,
+                      stream_state := connecting} = State) ->
+    case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of
+       false ->
+           noreply(process_stream_close({error, {idna, bad_string}}, State));
+       ASCIIName ->
+           case resolve(binary_to_list(ASCIIName), State) of
+               {ok, AddrPorts} ->
+                   case connect(AddrPorts, State) of
+                       {ok, Socket, AddrPort} ->
+                           SocketMonitor = SockMod:monitor(Socket),
+                           State1 = State#{ip => AddrPort,
+                                           socket => Socket,
+                                           socket_monitor => SocketMonitor},
+                           State2 = State1#{stream_state => wait_for_stream},
+                           noreply(send_header(State2));
+                       {error, Why} ->
+                           Err = {error, {socket, Why}},
+                           noreply(process_stream_close(Err, State))
+                   end;
+               {error, Why} ->
+                   noreply(process_stream_close({error, {dns, Why}}, State))
+           end
+    end;
+handle_cast(connect, State) ->
+    %% Ignoring connection attempts in other states
+    noreply(State);
+handle_cast({send, Pkt}, State) ->
+    noreply(send_element(State, Pkt));
+handle_cast(stop, State) ->
+    {stop, normal, State};
+handle_cast(Cast, #{mod := Mod} = State) ->
+    noreply(try Mod:handle_cast(Cast, State)
+           catch _:undef -> State
+           end).
+
+handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
+           #{stream_state := wait_for_stream,
+             xmlns := XMLNS, lang := MyLang} = State) ->
+    El = #xmlel{name = Name, attrs = Attrs},
+    try xmpp:decode(El, XMLNS, []) of
+       #stream_start{} = Pkt ->
+           noreply(process_stream(Pkt, State));
+       _ ->
+           noreply(send_element(State, xmpp:serr_invalid_xml()))
+    catch _:{xmpp_codec, Why} ->
+           Txt = xmpp:io_format_error(Why),
+           Lang = select_lang(MyLang, xmpp:get_lang(El)),
+           Err = xmpp:serr_invalid_xml(Txt, Lang),
+           noreply(send_element(State, Err))
+    end;
+handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
+    State1 = send_header(State),
+    case is_disconnected(State1) of
+       true -> State1;
+       false ->
+           Err = case Reason of
+                     <<"XML stanza is too big">> ->
+                         xmpp:serr_policy_violation(Reason, Lang);
+                     _ ->
+                         xmpp:serr_not_well_formed()
+                 end,
+           noreply(send_element(State1, Err))
+    end;
+handle_info({'$gen_event', {xmlstreamelement, El}},
+           #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
+    try xmpp:decode(El, NS, [ignore_els]) of
+       Pkt ->
+           State1 = try Mod:handle_recv(El, Pkt, State)
+                    catch _:undef -> State
+                    end,
+           case is_disconnected(State1) of
+               true -> State1;
+               false -> noreply(process_element(Pkt, State1))
+           end
+    catch _:{xmpp_codec, Why} ->
+           State1 = try Mod:handle_recv(El, undefined, State)
+                    catch _:undef -> State
+                    end,
+           case is_disconnected(State1) of
+               true -> State1;
+               false ->
+                   Txt = xmpp:io_format_error(Why),
+                   Lang = select_lang(MyLang, xmpp:get_lang(El)),
+                   noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
+           end
+    end;
+handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
+           #{mod := Mod} = State) ->
+    noreply(try Mod:handle_cdata(Data, State)
+           catch _:undef -> State
+           end);
+handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
+    noreply(process_stream_end({error, {stream, reset}}, State));
+handle_info({'$gen_event', closed}, State) ->
+    noreply(process_stream_close({error, {socket, closed}}, State));
+handle_info(timeout, #{mod := Mod} = State) ->
+    Disconnected = is_disconnected(State),
+    noreply(try Mod:handle_timeout(State)
+           catch _:undef when not Disconnected ->
+                   send_element(State, xmpp:serr_connection_timeout());
+                 _:undef ->
+                   stop(State)
+           end);
+handle_info({'DOWN', MRef, _Type, _Object, _Info},
+           #{socket_monitor := MRef} = State) ->
+    noreply(process_stream_close({error, {socket, closed}}, State));
+handle_info(Info, #{mod := Mod} = State) ->
+    noreply(try Mod:handle_info(Info, State)
+           catch _:undef -> State
+           end).
+
+terminate(Reason, #{mod := Mod} = State) ->
+    case get(already_terminated) of
+       true ->
+           State;
+       _ ->
+           put(already_terminated, true),
+           try Mod:terminate(Reason, State)
+           catch _:undef -> ok
+           end,
+           send_trailer(State)
+    end.
+
+code_change(OldVsn, #{mod := Mod} = State, Extra) ->
+    Mod:code_change(OldVsn, State, Extra).
+
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
+-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
+noreply(#{stream_timeout := infinity} = State) ->
+    {noreply, State, infinity};
+noreply(#{stream_timeout := {MSecs, OldTime}} = State) ->
+    NewTime = p1_time_compat:monotonic_time(milli_seconds),
+    Timeout = max(0, MSecs - NewTime + OldTime),
+    {noreply, State, Timeout}.
+
+-spec new_id() -> binary().
+new_id() ->
+    randoms:get_string().
+
+-spec is_disconnected(state()) -> boolean().
+is_disconnected(#{stream_state := StreamState}) ->
+    StreamState == disconnected.
+
+-spec process_stream_close(stop_reason(), state()) -> state().
+process_stream_close(_, #{stream_state := disconnected} = State) ->
+    State;
+process_stream_close(Reason, #{mod := Mod} = State) ->
+    State1 = send_trailer(State),
+    try Mod:handle_stream_close(Reason, State1)
+    catch _:undef -> stop(State1)
+    end.
+
+-spec process_stream_end(stop_reason(), state()) -> state().
+process_stream_end(_, #{stream_state := disconnected} = State) ->
+    State;
+process_stream_end(Reason, #{mod := Mod} = State) ->
+    State1 = send_trailer(State),
+    try Mod:handle_stream_end(Reason, State1)
+    catch _:undef -> stop(State1)
+    end.
+
+-spec process_stream(stream_start(), state()) -> state().
+process_stream(#stream_start{xmlns = XML_NS,
+                            stream_xmlns = STREAM_NS},
+              #{xmlns := NS} = State)
+  when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
+    send_element(State, xmpp:serr_invalid_namespace());
+process_stream(#stream_start{lang = Lang, id = ID,
+                            version = Version} = StreamStart,
+              #{mod := Mod} = State) ->
+    State1 = State#{stream_remote_id => ID, lang => Lang},
+    State2 = try Mod:handle_stream_start(StreamStart, State1)
+            catch _:undef -> State1
+            end,
+    case is_disconnected(State2) of
+       true -> State2;
+       false ->
+           case Version of
+               {1,0} -> State2#{stream_state => wait_for_features};
+               _ -> process_stream_downgrade(StreamStart, State)
+           end
+    end.
+
+-spec process_element(xmpp_element(), state()) -> state().
+process_element(Pkt, #{stream_state := StateName} = State) ->
+    case Pkt of
+       #stream_features{} when StateName == wait_for_features ->
+           process_features(Pkt, State);
+       #starttls_proceed{} when StateName == wait_for_starttls_response ->
+           process_starttls(State);
+       #sasl_success{} when StateName == wait_for_sasl_response ->
+           process_sasl_success(State);
+       #sasl_failure{} when StateName == wait_for_sasl_response ->
+           process_sasl_failure(Pkt, State);
+       #stream_error{} ->
+           process_stream_end({error, {stream, Pkt}}, State);
+       _ when is_record(Pkt, stream_features);
+              is_record(Pkt, starttls_proceed);
+              is_record(Pkt, starttls);
+              is_record(Pkt, sasl_auth);
+              is_record(Pkt, sasl_success);
+              is_record(Pkt, sasl_failure);
+              is_record(Pkt, sasl_response);
+              is_record(Pkt, sasl_abort);
+              is_record(Pkt, compress);
+              is_record(Pkt, handshake) ->
+           %% Do not pass this crap upstream
+           State;
+       _ ->
+           process_packet(Pkt, State)
+    end.
+
+-spec process_features(stream_features(), state()) -> state().
+process_features(StreamFeatures,
+                #{stream_authenticated := true, mod := Mod} = State) ->
+    State1 = try Mod:handle_authenticated_features(StreamFeatures, State)
+            catch _:undef -> State
+            end,
+    process_stream_established(State1);
+process_features(#stream_features{sub_els = Els} = StreamFeatures,
+                #{stream_encrypted := Encrypted,
+                  mod := Mod, lang := Lang} = State) ->
+    State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State)
+            catch _:undef -> State
+            end,
+    case is_disconnected(State1) of
+       true -> State1;
+       false ->
+           TLSRequired = is_starttls_required(State1),
+           %% TODO: improve xmpp.erl
+           Msg = #message{sub_els = Els},
+           case xmpp:get_subtag(Msg, #starttls{}) of
+               false when TLSRequired and not Encrypted ->
+                   Txt = <<"Use of STARTTLS required">>,
+                   send_element(State1, xmpp:err_policy_violation(Txt, Lang));
+               #starttls{} when not Encrypted ->
+                   State2 = State1#{stream_state => wait_for_starttls_response},
+                   send_element(State2, #starttls{});
+               _ ->
+                   State2 = process_cert_verification(State1),
+                   case is_disconnected(State2) of
+                       true -> State2;
+                       false ->
+                           case xmpp:get_subtag(Msg, #sasl_mechanisms{}) of
+                               #sasl_mechanisms{list = Mechs} ->
+                                   process_sasl_mechanisms(Mechs, State2);
+                               false ->
+                                   process_sasl_failure(
+                                     #sasl_failure{reason = 'invalid-mechanism'},
+                                     State2)
+                           end
+                   end
+           end
+    end.
+
+-spec process_stream_established(state()) -> state().
+process_stream_established(#{stream_state := StateName} = State)
+  when StateName == disconnected; StateName == established ->
+    State;
+process_stream_established(#{mod := Mod} = State) ->
+    State1 = State#{stream_authenticated := true,
+                   stream_state => established,
+                   stream_timeout => infinity},
+    try Mod:handle_stream_established(State1)
+    catch _:undef -> State1
+    end.
+
+-spec process_sasl_mechanisms([binary()], state()) -> state().
+process_sasl_mechanisms(Mechs, #{user := User, server := Server} = State) ->
+    %% TODO: support other mechanisms
+    Mech = <<"EXTERNAL">>,
+    case lists:member(<<"EXTERNAL">>, Mechs) of
+       true ->
+           State1 = State#{stream_state => wait_for_sasl_response},
+           Authzid = jid:to_string(jid:make(User, Server)),
+           send_element(State1, #sasl_auth{mechanism = Mech, text = Authzid});
+       false ->
+           process_sasl_failure(
+             #sasl_failure{reason = 'invalid-mechanism'}, State)
+    end.
+
+-spec process_starttls(state()) -> state().
+process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) ->
+    TLSOpts = try Mod:tls_options(State)
+             catch _:undef -> []
+             end,
+    case SockMod:starttls(Socket, [connect|TLSOpts]) of
+       {ok, TLSSocket} ->
+           State1 = State#{socket => TLSSocket,
+                           stream_id => new_id(),
+                           stream_restarted => true,
+                           stream_state => wait_for_stream,
+                           stream_encrypted => true},
+           send_header(State1);
+       {error, Why} ->
+           process_stream_close({error, {tls, Why}}, State)
+    end.
+
+-spec process_stream_downgrade(stream_start(), state()) -> state().
+process_stream_downgrade(StreamStart, #{mod := Mod} = State) ->
+    try Mod:downgrade_stream(StreamStart, State)
+    catch _:undef ->
+           send_element(State, xmpp:serr_unsupported_version())
+    end.
+
+-spec process_cert_verification(state()) -> state().
+process_cert_verification(#{stream_encrypted := true,
+                           stream_verified := false,
+                           mod := Mod} = State) ->
+    case try Mod:tls_verify(State)
+        catch _:undef -> true
+        end of
+       true ->
+           case xmpp_stream_pkix:authenticate(State) of
+               {ok, _} ->
+                   State#{stream_verified => true};
+               {error, Why, _Peer} ->
+                   process_stream_close({error, {pkix, Why}}, State)
+           end;
+       false ->
+           State#{stream_verified => true}
+    end;
+process_cert_verification(State) ->
+    State.
+
+-spec process_sasl_success(state()) -> state().
+process_sasl_success(#{mod := Mod,
+                      sockmod := SockMod,
+                      socket := Socket} = State) ->
+    State1 = try Mod:handle_auth_success(<<"EXTERNAL">>, State)
+            catch _:undef -> State
+            end,
+    case is_disconnected(State1) of
+       true -> State1;
+       false ->
+           SockMod:reset_stream(Socket),
+           State2 = State1#{stream_id => new_id(),
+                            stream_restarted => true,
+                            stream_state => wait_for_stream,
+                            stream_authenticated => true},
+           send_header(State2)
+    end.
+
+-spec process_sasl_failure(sasl_failure(), state()) -> state().
+process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) ->
+    try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State)
+    catch _:undef -> process_stream_close({error, {auth, Reason}}, State)
+    end.
+
+-spec process_packet(xmpp_element(), state()) -> state().
+process_packet(Pkt, #{mod := Mod} = State) ->
+    try Mod:handle_packet(Pkt, State)
+    catch _:undef -> State
+    end.
+
+-spec is_starttls_required(state()) -> boolean().
+is_starttls_required(#{mod := Mod} = State) ->
+    try Mod:tls_required(State)
+    catch _:undef -> false
+    end.
+
+-spec send_header(state()) -> state().
+send_header(#{remote_server := RemoteServer,
+             stream_encrypted := Encrypted,
+             lang := Lang,
+             xmlns := NS,
+             user := User,
+             resource := Resource,
+             server := Server} = State) ->
+    NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
+              true -> <<"">>
+           end,
+    From = if Encrypted ->
+                  jid:make(User, Server, Resource);
+             NS == ?NS_SERVER ->
+                  jid:make(Server);
+             true ->
+                  undefined
+          end,
+    Header = xmpp:encode(
+              #stream_start{xmlns = NS,
+                            lang = Lang,
+                            stream_xmlns = ?NS_STREAM,
+                            db_xmlns = NS_DB,
+                            from = From,
+                            to = jid:make(RemoteServer),
+                            version = {1,0}}),
+    case send_text(State, fxml:element_to_header(Header)) of
+       ok -> State;
+       {error, Why} -> process_stream_close({error, {socket, Why}}, State)
+    end.
+
+-spec send_element(state(), xmpp_element()) -> state().
+send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
+    El = xmpp:encode(Pkt, NS),
+    Data = fxml:element_to_binary(El),
+    State1 = try Mod:handle_send(Pkt, El, Data, State)
+            catch _:undef -> State
+            end,
+    case is_disconnected(State1) of
+       true -> State1;
+       false ->
+           case send_text(State1, Data) of
+               _ when is_record(Pkt, stream_error) ->
+                   process_stream_end({error, {stream, Pkt}}, State1);
+               ok ->
+                   State1;
+               {error, Why} ->
+                   process_stream_close({error, {socket, Why}}, State1)
+           end
+    end.
+
+-spec send_error(state(), xmpp_element(), stanza_error()) -> state().
+send_error(State, Pkt, Err) ->
+    case xmpp:is_stanza(Pkt) of
+       true ->
+           case xmpp:get_type(Pkt) of
+               result -> State;
+               error -> State;
+               <<"result">> -> State;
+               <<"error">> -> State;
+               _ ->
+                   ErrPkt = xmpp:make_error(Pkt, Err),
+                   send_element(State, ErrPkt)
+           end;
+       false ->
+           State
+    end.
+
+-spec send_text(state(), binary()) -> ok | {error, inet:posix()}.
+send_text(#{sockmod := SockMod, socket := Socket,
+           stream_state := StateName}, Data) when StateName /= disconnected ->
+    SockMod:send(Socket, Data);
+send_text(_, _) ->
+    {error, einval}.
+
+-spec send_trailer(state()) -> state().
+send_trailer(State) ->
+    send_text(State, <<"</stream:stream>">>),
+    close_socket(State).
+
+-spec close_socket(state()) -> state().
+close_socket(State) ->
+    case State of
+       #{sockmod := SockMod, socket := Socket} ->
+           SockMod:close(Socket);
+       _ ->
+           ok
+    end,
+    State#{stream_timeout => infinity,
+          stream_state => disconnected}.
+
+-spec select_lang(binary(), binary()) -> binary().
+select_lang(Lang, <<"">>) -> Lang;
+select_lang(_, Lang) -> Lang.
+
+-spec format_inet_error(atom()) -> string().
+format_inet_error(Reason) ->
+    case inet:format_error(Reason) of
+       "unknown POSIX error" -> atom_to_list(Reason);
+       Txt -> Txt
+    end.
+
+-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
+format_stream_error(Reason, Txt) ->
+    Slogan = case Reason of
+                #'see-other-host'{} -> "see-other-host";
+                _ -> atom_to_list(Reason)
+            end,
+    case Txt of
+       undefined -> Slogan;
+       #text{data = <<"">>} -> Slogan;
+       #text{data = Data} ->
+           binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
+    end.
+
+-spec format(io:format(), list()) -> binary().
+format(Fmt, Args) ->
+    iolist_to_binary(io_lib:format(Fmt, Args)).
+
+%%%===================================================================
+%%% Connection stuff
+%%%===================================================================
+-spec resolve(string(), state()) -> {ok, [host_port()]} | network_error().
+resolve(Host, State) ->
+    case srv_lookup(Host, State) of
+       {error, _Reason} ->
+           DefaultPort = get_default_port(State),
+           a_lookup([{Host, DefaultPort}], State);
+       {ok, HostPorts} ->
+           a_lookup(HostPorts, State)
+    end.
+
+-spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error().
+srv_lookup(Host, State) ->
+    %% Only perform SRV lookups for FQDN names
+    case string:chr(Host, $.) of
+       0 ->
+           {error, nxdomain};
+       _ ->
+           case inet_parse:address(Host) of
+               {ok, _} ->
+                   {error, nxdomain};
+               {error, _} ->
+                   Timeout = get_dns_timeout(State),
+                   Retries = get_dns_retries(State),
+                   srv_lookup(Host, Timeout, Retries)
+           end
+    end.
+
+-spec srv_lookup(string(), non_neg_integer(), integer()) ->
+                       {ok, [host_port()]} | network_error().
+srv_lookup(_Host, _Timeout, Retries) when Retries < 1 ->
+    {error, timeout};
+srv_lookup(Host, Timeout, Retries) ->
+    SRVName = "_xmpp-server._tcp." ++ Host,
+    case inet_res:getbyname(SRVName, srv, Timeout) of
+       {ok, HostEntry} ->
+           host_entry_to_host_ports(HostEntry);
+       {error, _} ->
+           LegacySRVName = "_jabber._tcp." ++ Host,
+           case inet_res:getbyname(LegacySRVName, srv, Timeout) of
+               {error, timeout} ->
+                   srv_lookup(Host, Timeout, Retries - 1);
+               {error, _} = Err ->
+                   Err;
+               {ok, HostEntry} ->
+                   host_entry_to_host_ports(HostEntry)
+           end
+    end.
+
+-spec a_lookup([{inet:hostname(), inet:port_number()}], state()) ->
+                     {ok, [ip_port()]} | network_error().
+a_lookup(HostPorts, State) ->
+    HostPortFamilies = [{Host, Port, Family}
+                       || {Host, Port} <- HostPorts,
+                          Family <- get_address_families(State)],
+    a_lookup(HostPortFamilies, State, {error, nxdomain}).
+
+-spec a_lookup([{inet:hostname(), inet:port_number(), inet:address_family()}],
+              state(), network_error()) -> {ok, [ip_port()]} | network_error().
+a_lookup([{Host, Port, Family}|HostPortFamilies], State, _) ->
+    Timeout = get_dns_timeout(State),
+    Retries = get_dns_retries(State),
+    case a_lookup(Host, Port, Family, Timeout, Retries) of
+       {error, _} = Err ->
+           a_lookup(HostPortFamilies, State, Err);
+       {ok, AddrPorts} ->
+           {ok, AddrPorts}
+    end;
+a_lookup([], _State, Err) ->
+    Err.
+
+-spec a_lookup(inet:hostname(), inet:port_number(), inet:address_family(),
+              non_neg_integer(), integer()) -> {ok, [ip_port()]} | network_error().
+a_lookup(_Host, _Port, _Family, _Timeout, Retries) when Retries < 1 ->
+    {error, timeout};
+a_lookup(Host, Port, Family, Timeout, Retries) ->
+    case inet:gethostbyname(Host, Family, Timeout) of
+       {error, timeout} ->
+           a_lookup(Host, Port, Family, Timeout, Retries - 1);
+       {error, _} = Err ->
+           Err;
+       {ok, HostEntry} ->
+           host_entry_to_addr_ports(HostEntry, Port)
+    end.
+
+-spec host_entry_to_host_ports(inet:hostent()) -> {ok, [host_port()]} |
+                                                 {error, nxdomain}.
+host_entry_to_host_ports(#hostent{h_addr_list = AddrList}) ->
+    PrioHostPorts = lists:flatmap(
+                     fun({Priority, Weight, Port, Host}) ->
+                             N = case Weight of
+                                     0 -> 0;
+                                     _ -> (Weight + 1) * randoms:uniform()
+                                 end,
+                             [{Priority * 65536 - N, Host, Port}];
+                        (_) ->
+                             []
+                     end, AddrList),
+    HostPorts = [{Host, Port}
+                || {_Priority, Host, Port} <- lists:usort(PrioHostPorts)],
+    case HostPorts of
+       [] -> {error, nxdomain};
+       _ -> {ok, HostPorts}
+    end.
+
+-spec host_entry_to_addr_ports(inet:hostent(), inet:port_number()) ->
+                                     {ok, [ip_port()]} | {error, nxdomain}.
+host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port) ->
+    AddrPorts = lists:flatmap(
+                 fun(Addr) ->
+                         try get_addr_type(Addr) of
+                             _ -> [{Addr, Port}]
+                         catch _:_ ->
+                                 []
+                         end
+                 end, AddrList),
+    case AddrPorts of
+       [] -> {error, nxdomain};
+       _ -> {ok, AddrPorts}
+    end.
+
+-spec connect([ip_port()], state()) -> {ok, term(), ip_port()} | network_error().
+connect(AddrPorts, #{sockmod := SockMod} = State) ->
+    Timeout = get_connect_timeout(State),
+    connect(AddrPorts, SockMod, Timeout, {error, nxdomain}).
+
+-spec connect([ip_port()], module(), non_neg_integer(), network_error()) ->
+                    {ok, term(), ip_port()} | network_error().
+connect([{Addr, Port}|AddrPorts], SockMod, Timeout, _) ->
+    Type = get_addr_type(Addr),
+    case SockMod:connect(Addr, Port,
+                        [binary, {packet, 0},
+                         {send_timeout, ?TCP_SEND_TIMEOUT},
+                         {send_timeout_close, true},
+                         {active, false}, Type],
+                        Timeout) of
+       {ok, Socket} ->
+           {ok, Socket, {Addr, Port}};
+       Err ->
+           connect(AddrPorts, SockMod, Timeout, Err)
+    end;
+connect([], _SockMod, _Timeout, Err) ->
+    Err.
+
+-spec get_addr_type(inet:ip_address()) -> inet:address_family().
+get_addr_type({_, _, _, _}) -> inet;
+get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
+
+-spec get_dns_timeout(state()) -> non_neg_integer().
+get_dns_timeout(#{mod := Mod} = State) ->
+    timer:seconds(
+      try Mod:dns_timeout(State)
+      catch _:undef -> 10
+      end).
+
+-spec get_dns_retries(state()) -> non_neg_integer().
+get_dns_retries(#{mod := Mod} = State) ->
+    try Mod:dns_retries(State)
+    catch _:undef -> 2
+    end.
+
+-spec get_default_port(state()) -> inet:port_number().
+get_default_port(#{mod := Mod, xmlns := NS} = State) ->
+    try Mod:default_port(State)
+    catch _:undef when NS == ?NS_SERVER -> 5269;
+         _:undef when NS == ?NS_CLIENT -> 5222
+    end.
+
+-spec get_address_families(state()) -> [inet:address_family()].
+get_address_families(#{mod := Mod} = State) ->
+    try Mod:address_families(State)
+    catch _:undef -> [inet, inet6]
+    end.
+
+-spec get_connect_timeout(state()) -> non_neg_integer().
+get_connect_timeout(#{mod := Mod} = State) ->
+    timer:seconds(
+      try Mod:connect_timeout(State)
+      catch _:undef -> 10
+      end).
diff --git a/src/xmpp_stream_pkix.erl b/src/xmpp_stream_pkix.erl
new file mode 100644 (file)
index 0000000..59f5d82
--- /dev/null
@@ -0,0 +1,159 @@
+%%%-------------------------------------------------------------------
+%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%% @copyright (C) 2016, Evgeny Khramtsov
+%%% @doc
+%%%
+%%% @end
+%%% Created : 13 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%-------------------------------------------------------------------
+-module(xmpp_stream_pkix).
+
+%% API
+-export([authenticate/1, authenticate/2]).
+
+-include("xmpp.hrl").
+-include_lib("public_key/include/public_key.hrl").
+-include("XmppAddr.hrl").
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state())
+      -> {ok, binary()} | {error, binary(), binary()}.
+authenticate(State) ->
+    authenticate(State, <<"">>).
+
+-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state(), binary())
+      -> {ok, binary()} | {error, binary(), binary()}.
+authenticate(#{xmlns := ?NS_SERVER, remote_server := Peer,
+              sockmod := SockMod, socket := Socket}, _Authzid) ->
+    case SockMod:get_peer_certificate(Socket) of
+       {ok, Cert} ->
+           case SockMod:get_verify_result(Socket) of
+               0 ->
+                   case ejabberd_idna:domain_utf8_to_ascii(Peer) of
+                       false ->
+                           {error, <<"Cannot decode remote server name">>, Peer};
+                       AsciiPeer ->
+                           case lists:any(
+                                  fun(D) -> match_domain(AsciiPeer, D) end,
+                                  get_cert_domains(Cert)) of
+                               true ->
+                                   {ok, Peer};
+                               false ->
+                                   {error, <<"Certificate host name mismatch">>, Peer}
+                           end
+                   end;
+               VerifyRes ->
+                   {error, fast_tls:get_cert_verify_string(VerifyRes, Cert), Peer}
+           end;
+       {error, _Reason} ->
+           {error, <<"Cannot get peer certificate">>, Peer};
+       error ->
+           {error, <<"Cannot get peer certificate">>, Peer}
+    end;
+authenticate(_State, _Authzid) ->
+    %% TODO: client PKIX authentication
+    {error, <<"Client certificate verification not implemented">>, <<"">>}.
+
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
+get_cert_domains(Cert) ->
+    TBSCert = Cert#'Certificate'.tbsCertificate,
+    Subject = case TBSCert#'TBSCertificate'.subject of
+                 {rdnSequence, Subj} -> lists:flatten(Subj);
+                 _ -> []
+             end,
+    Extensions = case TBSCert#'TBSCertificate'.extensions of
+                    Exts when is_list(Exts) -> Exts;
+                    _ -> []
+                end,
+    lists:flatmap(
+      fun(#'AttributeTypeAndValue'{type = ?'id-at-commonName',value = Val}) ->
+             case 'OTP-PUB-KEY':decode('X520CommonName', Val) of
+                 {ok, {_, D1}} ->
+                     D = if is_binary(D1) -> D1;
+                            is_list(D1) -> list_to_binary(D1);
+                            true -> error
+                         end,
+                     if D /= error ->
+                             case jid:from_string(D) of
+                                 #jid{luser = <<"">>, lserver = LD,
+                                      lresource = <<"">>} ->
+                                     [LD];
+                                 _ -> []
+                             end;
+                        true -> []
+                     end;
+                 _ -> []
+             end;
+        (_) -> []
+      end, Subject) ++
+       lists:flatmap(
+         fun(#'Extension'{extnID = ?'id-ce-subjectAltName',
+                          extnValue = Val}) ->
+                 BVal = if is_list(Val) -> list_to_binary(Val);
+                           true -> Val
+                        end,
+                 case 'OTP-PUB-KEY':decode('SubjectAltName', BVal) of
+                     {ok, SANs} ->
+                         lists:flatmap(
+                           fun({otherName, #'AnotherName'{'type-id' = ?'id-on-xmppAddr',
+                                                          value = XmppAddr}}) ->
+                                   case 'XmppAddr':decode('XmppAddr', XmppAddr) of
+                                       {ok, D} when is_binary(D) ->
+                                           case jid:from_string(D) of
+                                               #jid{luser = <<"">>,
+                                                    lserver = LD,
+                                                    lresource = <<"">>} ->
+                                                   case ejabberd_idna:domain_utf8_to_ascii(LD) of
+                                                       false ->
+                                                           [];
+                                                       PCLD ->
+                                                           [PCLD]
+                                                   end;
+                                               _ -> []
+                                           end;
+                                       _ -> []
+                                   end;
+                              ({dNSName, D}) when is_list(D) ->
+                                   case jid:from_string(list_to_binary(D)) of
+                                       #jid{luser = <<"">>,
+                                            lserver = LD,
+                                            lresource = <<"">>} ->
+                                           [LD];
+                                       _ -> []
+                                   end;
+                              (_) -> []
+                           end, SANs);
+                     _ -> []
+                 end;
+            (_) -> []
+         end, Extensions).
+
+match_domain(Domain, Domain) -> true;
+match_domain(Domain, Pattern) ->
+    DLabels = str:tokens(Domain, <<".">>),
+    PLabels = str:tokens(Pattern, <<".">>),
+    match_labels(DLabels, PLabels).
+
+match_labels([], []) -> true;
+match_labels([], [_ | _]) -> false;
+match_labels([_ | _], []) -> false;
+match_labels([DL | DLabels], [PL | PLabels]) ->
+    case lists:all(fun (C) ->
+                          $a =< C andalso C =< $z orelse
+                            $0 =< C andalso C =< $9 orelse
+                              C == $- orelse C == $*
+                  end,
+                  binary_to_list(PL))
+       of
+      true ->
+         Regexp = ejabberd_regexp:sh_to_awk(PL),
+         case ejabberd_regexp:run(DL, Regexp) of
+           match -> match_labels(DLabels, PLabels);
+           nomatch -> false
+         end;
+      false -> false
+    end.