]> granicus.if.org Git - ejabberd/commitdiff
Rewrite ejabberd_service to use new XMPP stream API
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Sun, 11 Dec 2016 15:24:51 +0000 (18:24 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Sun, 11 Dec 2016 15:24:51 +0000 (18:24 +0300)
src/cyrsasl.erl
src/ejabberd_c2s.erl
src/ejabberd_service.erl
src/xmpp_stream_in.erl

index e2319647515a7b9a6c8cb31b0af671cefb5359cf..c49f8a3cbc51066137bce34d0b06f913f1b5f370 100644 (file)
@@ -71,6 +71,7 @@
     mech_state
 }).
 -type sasl_state() :: #sasl_state{}.
+-export_type([sasl_state/0]).
 
 -callback mech_new(binary(), fun(), fun(), fun()) -> any().
 -callback mech_step(any(), binary()) -> {ok, props()} |
index 1568d5db69de308e8e24b3035cc7f4abccb3f71f..b5113c34bf943659b1e7d6ebcd04b08920bedba2 100644 (file)
 %%%-------------------------------------------------------------------
 -module(ejabberd_c2s).
 -behaviour(xmpp_stream_in).
+-behaviour(ejabberd_config).
 
 -protocol({rfc, 6121}).
 
 %% ejabberd_socket callbacks
 -export([start/2, socket_type/0]).
+%% ejabberd_config callbacks
+-export([opt_type/1, transform_listen_option/2]).
 %% xmpp_stream_in callbacks
 -export([init/1, handle_call/3, handle_cast/2,
         handle_info/2, terminate/2, code_change/3]).
--export([tls_options/1, tls_required/1, sasl_mechanisms/1, init_sasl/1, bind/2,
+-export([tls_options/1, tls_required/1, compress_methods/1,
+        sasl_mechanisms/1, init_sasl/1, bind/2, handshake/2,
         unauthenticated_stream_features/1, authenticated_stream_features/1,
         handle_stream_start/1, handle_stream_end/1, handle_stream_close/1,
         handle_unauthenticated_packet/2, handle_authenticated_packet/2,
-        handle_auth_success/4, handle_auth_failure/4,
-        handle_unbinded_packet/2]).
+        handle_auth_success/4, handle_auth_failure/4, handle_send/5,
+        handle_unbinded_packet/2, handle_cdata/2]).
 %% API
 -export([get_presence/1, get_subscription/2, get_subscribed/1,
         send/2, close/1]).
@@ -99,8 +103,7 @@ send(State, Pkt) ->
 %%%===================================================================
 %%% xmpp_stream_in callbacks
 %%%===================================================================
-tls_options(#{server := Server, tls_options := TLSOpts}) ->
-    LServer = jid:nameprep(Server),
+tls_options(#{lserver := LServer, tls_options := TLSOpts}) ->
     case ejabberd_config:get_option({domain_certfile, LServer},
                                    fun iolist_to_binary/1) of
        undefined ->
@@ -112,19 +115,21 @@ tls_options(#{server := Server, tls_options := TLSOpts}) ->
 tls_required(#{tls_required := TLSRequired}) ->
     TLSRequired.
 
-unauthenticated_stream_features(#{server := Server}) ->
-    LServer = jid:nameprep(Server),
+compress_methods(#{zlib := true}) ->
+    [<<"zlib">>];
+compress_methods(_) ->
+    [].
+
+sasl_mechanisms(#{lserver := LServer}) ->
+    cyrsasl:listmech(LServer).
+
+unauthenticated_stream_features(#{lserver := LServer}) ->
     ejabberd_hooks:run_fold(c2s_pre_auth_features, LServer, [], [LServer]).
 
-authenticated_stream_features(#{server := Server}) ->
-    LServer = jid:nameprep(Server),
+authenticated_stream_features(#{lserver := LServer}) ->
     ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]).
 
-sasl_mechanisms(#{server := Server}) ->
-    cyrsasl:listmech(jid:nameprep(Server)).
-
-init_sasl(#{server := Server}) ->
-    LServer = jid:nameprep(Server),
+init_sasl(#{lserver := LServer}) ->
     cyrsasl:server_new(
       <<"jabber">>, LServer, <<"">>, [],
       fun(U) ->
@@ -147,8 +152,11 @@ bind(R, #{user := U, server := S} = State) ->
            open_session(State, Resource)
     end.
 
-handle_stream_start(#{server := Server, ip := IP, lang := Lang} = State) ->
-    LServer = jid:nameprep(Server),
+handshake(_Data, State) ->
+    %% This is only for jabber component
+    {ok, State}.
+
+handle_stream_start(#{lserver := LServer, ip := IP, lang := Lang} = State) ->
     case lists:member(LServer, ?MYHOSTS) of
        false ->
            xmpp_stream_in:send(State, xmpp:serr_host_unknown());
@@ -172,8 +180,7 @@ handle_stream_close(State) ->
     {stop, normal, State}.
 
 handle_auth_success(User, Mech, AuthModule,
-                   #{socket := Socket, ip := IP, server := Server} = State) ->
-    LServer = jid:nameprep(Server),
+                   #{socket := Socket, ip := IP, lserver := LServer} = State) ->
     ?INFO_MSG("(~w) Accepted ~s authentication for ~s@~s by ~p from ~s",
              [Socket, Mech, User, LServer, AuthModule,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
@@ -182,8 +189,7 @@ handle_auth_success(User, Mech, AuthModule,
                            {noreply, State1}, [true, User]).
 
 handle_auth_failure(User, Mech, Reason,
-                   #{socket := Socket, ip := IP, server := Server} = State) ->
-    LServer = jid:nameprep(Server),
+                   #{socket := Socket, ip := IP, lserver := LServer} = State) ->
     ?INFO_MSG("(~w) Failed ~s authentication ~sfrom ~s: ~s",
              [Socket, Mech,
               if User /= <<"">> -> ["for ", User, "@", LServer, " "];
@@ -193,22 +199,18 @@ handle_auth_failure(User, Mech, Reason,
     ejabberd_hooks:run_fold(c2s_auth_result, LServer,
                            {noreply, State}, [false, User]).
 
-handle_unbinded_packet(Pkt, #{server := Server} = State) ->
-    LServer = jid:nameprep(Server),
+handle_unbinded_packet(Pkt, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer,
                            {noreply, State}, [Pkt]).
 
-handle_unauthenticated_packet(Pkt, #{server := Server} = State) ->
-    LServer = jid:nameprep(Server),
+handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_unauthenticated_packet,
                            LServer, {noreply, State}, [Pkt]).
 
-handle_authenticated_packet(Pkt, #{server := Server} = State) when not ?is_stanza(Pkt) ->
-    LServer = jid:nameprep(Server),
+handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) ->
     ejabberd_hooks:run_fold(c2s_authenticated_packet,
                            LServer, {noreply, State}, [Pkt]);
-handle_authenticated_packet(Pkt, #{server := Server} = State) ->
-    LServer = jid:nameprep(Server),
+handle_authenticated_packet(Pkt, #{lserver := LServer} = State) ->
     case ejabberd_hooks:run_fold(c2s_authenticated_packet,
                                 LServer, {noreply, State}, [Pkt]) of
        {noreply, State1} ->
@@ -228,6 +230,14 @@ handle_authenticated_packet(Pkt, #{server := Server} = State) ->
            Err
     end.
 
+handle_cdata(Data, #{lserver := LServer} = State) ->
+    ejabberd_hooks:run_fold(c2s_handle_cdata, LServer,
+                           {noreply, State}, [Data]).
+
+handle_send(Reason, Pkt, El, Data, #{lserver := LServer} = State) ->
+    ejabberd_hooks:run_fold(c2s_handle_send, LServer,
+                           {noreply, State}, [Reason, Pkt, El, Data]).
+
 init([State, Opts]) ->
     Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all),
     Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none),
@@ -239,6 +249,7 @@ init([State, Opts]) ->
                end, Opts),
     TLSRequired = proplists:get_bool(starttls_required, Opts),
     TLSVerify = proplists:get_bool(tls_verify, Opts),
+    Zlib = proplists:get_bool(zlib, Opts),
     State1 = State#{tls_options => TLSOpts,
                    tls_required => TLSRequired,
                    tls_verify => TLSVerify,
@@ -246,19 +257,18 @@ init([State, Opts]) ->
                    pres_f => ?SETS:new(),
                    pres_t => ?SETS:new(),
                    sid => ejabberd_sm:make_sid(),
+                   zlib => Zlib,
                    lang => ?MYLANG,
                    server => ?MYNAME,
                    access => Access,
                    shaper => Shaper},
     ejabberd_hooks:run_fold(c2s_init, {ok, State1}, []).
 
-handle_call(get_presence, _From,
-           #{user := U, server := S, resource := R} = State) ->
+handle_call(get_presence, _From, #{jid := JID} = State) ->
     Pres = case maps:get(pres_last, State, undefined) of
               undefined ->
-                  From = jid:make(U, S, R),
-                  To = jid:remove_resource(From),
-                  #presence{from = From, to = To, type = unavailable};
+                  BareJID = jid:remove_resource(JID),
+                  #presence{from = JID, to = BareJID, type = unavailable};
               P ->
                   P
           end,
@@ -266,20 +276,17 @@ handle_call(get_presence, _From,
 handle_call(get_subscribed, _From, #{pres_f := PresF} = State) ->
     Subscribed = ?SETS:to_list(PresF),
     {reply, Subscribed, State};
-handle_call(Request, From, #{server := Server} = State) ->
-    LServer = jid:nameprep(Server),
+handle_call(Request, From, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(
       c2s_handle_call, LServer, {noreply, State}, [Request, From]).
 
 handle_cast(closed, State) ->
     handle_stream_close(State);
-handle_cast(Msg, #{server := Server} = State) ->
-    LServer = jid:nameprep(Server),
+handle_cast(Msg, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_handle_cast, LServer, {noreply, State}, [Msg]).
 
-handle_info({route, From, To, Packet0}, #{server := Server} = State) ->
+handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) ->
     Packet = xmpp:set_from_to(Packet0, From, To),
-    LServer = jid:nameprep(Server),
     {Pass, NewState} = case Packet of
                           #presence{} ->
                               process_presence_in(State, Packet);
@@ -289,7 +296,6 @@ handle_info({route, From, To, Packet0}, #{server := Server} = State) ->
                               process_iq_in(State, Packet)
                       end,
     if Pass ->
-           LServer = jid:nameprep(Server),
            Packet1 = ejabberd_hooks:run_fold(
                        user_receive_packet, LServer, Packet, [NewState]),
            ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
@@ -300,8 +306,7 @@ handle_info({route, From, To, Packet0}, #{server := Server} = State) ->
     end;
 handle_info(system_shutdown, State) ->
     xmpp_stream_in:send(State, xmpp:serr_system_shutdown());
-handle_info(Info, #{server := Server} = State) ->
-    LServer = jid:nameprep(Server),
+handle_info(Info, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_handle_info, LServer, {noreply, State}, [Info]).
 
 terminate(_Reason, _State) ->
@@ -319,11 +324,10 @@ check_bl_c2s({IP, _Port}, Lang) ->
     ejabberd_hooks:run_fold(check_bl_c2s, false, [IP, Lang]).
 
 -spec open_session(state(), binary()) -> {ok, state()} | {error, stanza_error(), state()}.
-open_session(#{user := U, server := S, sid := SID,
+open_session(#{user := U, server := S, lserver := LServer, sid := SID,
               socket := Socket, ip := IP, auth_module := AuthMod,
               access := Access, lang := Lang} = State, R) ->
     JID = jid:make(U, S, R),
-    LServer = JID#jid.lserver,
     case acl:access_matches(Access,
                            #{usr => jid:split(JID), ip => IP},
                            LServer) of
@@ -374,9 +378,8 @@ process_message_in(State, #message{type = T} = Msg) ->
     end.
 
 -spec process_presence_in(state(), presence()) -> {boolean(), state()}.
-process_presence_in(#{server := Server, pres_a := PresA} = State0,
+process_presence_in(#{lserver := LServer, pres_a := PresA} = State0,
                    #presence{from = From, to = To, type = T} = Pres) ->
-    LServer = jid:nameprep(Server),
     State = ejabberd_hooks:run_fold(c2s_presence_in, LServer, State0, [Pres]),
     case T of
        probe ->
@@ -399,7 +402,7 @@ process_presence_in(#{server := Server, pres_a := PresA} = State0,
     end.
 
 -spec route_probe_reply(jid(), jid(), state()) -> ok.
-route_probe_reply(From, To, #{server := Server, pres_f := PresF,
+route_probe_reply(From, To, #{lserver := LServer, pres_f := PresF,
                              pres_last := LastPres,
                              pres_timestamp := TS} = State) ->
     LFrom = jid:tolower(From),
@@ -413,7 +416,6 @@ route_probe_reply(From, To, #{server := Server, pres_f := PresF,
                deny ->
                    ok;
                allow ->
-                   LServer = jid:nameprep(Server),
                    ejabberd_hooks:run(presence_probe_hook,
                                       LServer,
                                       [From, To, self()]),
@@ -432,10 +434,9 @@ route_probe_reply(_, _, _) ->
     ok.
 
 -spec process_presence_out(state(), presence()) -> next_state().
-process_presence_out(#{user := User, server := Server,
-                      lang := Lang, pres_a := PresA} = State,
+process_presence_out(#{user := User, server := Server, lserver := LServer,
+                      jid := JID, lang := Lang, pres_a := PresA} = State,
                     #presence{from = From, to = To, type = Type} = Pres) ->
-    LServer = jid:nameprep(Server),
     LTo = jid:tolower(To),
     case privacy_check_packet(State, Pres, out) of
        deny ->
@@ -448,7 +449,7 @@ process_presence_out(#{user := User, server := Server,
            Access = gen_mod:get_module_opt(LServer, mod_roster, access,
                                            fun(A) when is_atom(A) -> A end,
                                            all),
-           MyBareJID = jid:make(User, Server, <<"">>),
+           MyBareJID = jid:remove_resource(JID),
            case acl:match_rule(LServer, Access, MyBareJID) of
                deny ->
                    ErrText = <<"Denied by ACL">>,
@@ -485,9 +486,8 @@ process_self_presence(#{ip := IP, conn := Conn,
     State1 = broadcast_presence_unavailable(State, Pres),
     State2 = maps:remove(pres_last, maps:remove(pres_timestamp, State1)),
     {noreply, State2};
-process_self_presence(#{server := Server} = State,
+process_self_presence(#{lserver := LServer} = State,
                      #presence{type = available} = Pres) ->
-    LServer = jid:nameprep(Server),
     PreviousPres = maps:get(pres_last, State, undefined),
     update_priority(State, Pres),
     State1 = ejabberd_hooks:run_fold(user_available_hook, LServer, State, [Pres]),
@@ -543,8 +543,7 @@ check_privacy_then_route(#{lang := Lang} = State, Pkt) ->
     end.
 
 -spec privacy_check_packet(state(), stanza(), in | out) -> allow | deny.
-privacy_check_packet(#{server := Server} = State, Pkt, Dir) ->
-    LServer = jid:nameprep(Server),
+privacy_check_packet(#{lserver := LServer} = State, Pkt, Dir) ->
     ejabberd_hooks:run_fold(privacy_check_packet, LServer, allow, [State, Pkt, Dir]).
 
 -spec get_priority_from_presence(presence()) -> integer().
@@ -555,9 +554,7 @@ get_priority_from_presence(#presence{priority = Prio}) ->
     end.
 
 -spec filter_blocked(state(), presence(), ?SETS:set()) -> [jid()].
-filter_blocked(#{user := U, server := S, resource := R} = State,
-              Pres, LJIDSet) ->
-    From = jid:make(U, S, R),
+filter_blocked(#{jid := From} = State, Pres, LJIDSet) ->
     ?SETS:fold(
        fun(LJID, Acc) ->
               To = jid:make(LJID),
@@ -581,8 +578,7 @@ route_error(Pkt, Err) ->
     ejabberd_router:route_error(To, From, Pkt, Err).
 
 -spec route_multiple(state(), [jid()], stanza()) -> ok.
-route_multiple(#{server := Server}, JIDs, Pkt) ->
-    LServer = jid:nameprep(Server),
+route_multiple(#{lserver := LServer}, JIDs, Pkt) ->
     From = xmpp:get_from(Pkt),
     ejabberd_router_multicast:route_multicast(From, LServer, JIDs, Pkt).
 
@@ -636,9 +632,9 @@ get_conn_type(State) ->
     end.
 
 -spec change_shaper(state()) -> ok.
-change_shaper(#{shaper := ShaperName, ip := IP,
+change_shaper(#{shaper := ShaperName, ip := IP, lserver := LServer,
                user := U, server := S, resource := R} = State) ->
-    #jid{lserver = LServer} = JID = jid:make(U, S, R),
+    JID = jid:make(U, S, R),
     Shaper = acl:access_matches(ShaperName,
                                #{usr => jid:split(JID), ip => IP},
                                LServer),
@@ -680,3 +676,18 @@ fsm_limit_opts(Opts) ->
                N -> [{max_queue, N}]
            end
     end.
+
+transform_listen_option(Opt, Opts) ->
+    [Opt|Opts].
+
+opt_type(domain_certfile) -> fun iolist_to_binary/1;
+opt_type(max_fsm_queue) ->
+    fun (I) when is_integer(I), I > 0 -> I end;
+opt_type(resource_conflict) ->
+    fun (setresource) -> setresource;
+       (closeold) -> closeold;
+       (closenew) -> closenew;
+       (acceptnew) -> acceptnew
+    end;
+opt_type(_) ->
+    [domain_certfile, max_fsm_queue, resource_conflict].
index 35cfe15af522558c152fdb34a61ced7f46746eb6..c48cd536c6a43a15e4c3124163972ec9a4b1093e 100644 (file)
@@ -1,8 +1,5 @@
-%%%----------------------------------------------------------------------
-%%% File    : ejabberd_service.erl
-%%% Author  : Alexey Shchepin <alexey@process-one.net>
-%%% Purpose : External component management (XEP-0114)
-%%% Created :  6 Dec 2002 by Alexey Shchepin <alexey@process-one.net>
+%%%-------------------------------------------------------------------
+%%% Created : 11 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
 %%%
 %%%
 %%% ejabberd, Copyright (C) 2002-2016   ProcessOne
 %%% with this program; if not, write to the Free Software Foundation, Inc.,
 %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 %%%
-%%%----------------------------------------------------------------------
-
+%%%-------------------------------------------------------------------
 -module(ejabberd_service).
-
+-behaviour(xmpp_stream_in).
 -behaviour(ejabberd_config).
 
--author('alexey@process-one.net').
-
 -protocol({xep, 114, '1.6'}).
 
--define(GEN_FSM, p1_fsm).
-
--behaviour(?GEN_FSM).
-
-%% External exports
--export([start/2, start_link/2, send_text/2,
-        send_element/2, socket_type/0, transform_listen_option/2]).
-
--export([init/1, wait_for_stream/2,
-        wait_for_handshake/2, stream_established/2,
-        handle_event/3, handle_sync_event/4, code_change/4,
-        handle_info/3, terminate/3, print_state/1, opt_type/1]).
+%% ejabberd_socket callbacks
+-export([start/2, socket_type/0]).
+%% ejabberd_config callbacks
+-export([opt_type/1, transform_listen_option/2]).
+%% xmpp_stream_in callbacks
+-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
+        terminate/2, code_change/3]).
+-export([handshake/2, handle_stream_start/1, handle_authenticated_packet/2]).
+%% API
+-export([send/2]).
 
 -include("ejabberd.hrl").
--include("logger.hrl").
 -include("xmpp.hrl").
+-include("logger.hrl").
 
--record(state,
-       {socket                    :: ejabberd_socket:socket_state(),
-         sockmod = ejabberd_socket :: ejabberd_socket | ejabberd_frontend_socket,
-         streamid = <<"">>         :: binary(),
-         host_opts = dict:new()    :: ?TDICT,
-         host = <<"">>             :: binary(),
-         access                    :: atom(),
-        check_from = true         :: boolean()}).
-
--type state_name() :: wait_for_stream | wait_for_handshake | stream_established.
--type state() :: #state{}.
--type fsm_next() :: {next_state, state_name(), state()}.
--type fsm_stop() :: {stop, normal, state()}.
--type fsm_transition() :: fsm_stop() | fsm_next().
-
-%-define(DBGFSM, true).
+%%-define(DBGFSM, true).
 -ifdef(DBGFSM).
 -define(FSMOPTS, [{debug, [trace]}]).
 -else.
 -define(FSMOPTS, []).
 -endif.
 
-%%%----------------------------------------------------------------------
+-type state() :: map().
+-type next_state() :: {noreply, state()} | {stop, term(), state()}.
+-export_type([state/0, next_state/0]).
+
+%%%===================================================================
 %%% API
-%%%----------------------------------------------------------------------
+%%%===================================================================
 start(SockData, Opts) ->
-    supervisor:start_child(ejabberd_service_sup,
-                          [SockData, Opts]).
+    xmpp_stream_in:start(?MODULE, [SockData, Opts],
+                        fsm_limit_opts(Opts) ++ ?FSMOPTS).
 
-start_link(SockData, Opts) ->
-    (?GEN_FSM):start_link(ejabberd_service,
-                         [SockData, Opts], fsm_limit_opts(Opts) ++ (?FSMOPTS)).
+socket_type() ->
+    xml_stream.
 
-socket_type() -> xml_stream.
+-spec send(state(), xmpp_element()) -> next_state().
+send(State, Pkt) ->
+    xmpp_stream_in:send(State, Pkt).
 
-%%%----------------------------------------------------------------------
-%%% Callback functions from gen_fsm
-%%%----------------------------------------------------------------------
-init([{SockMod, Socket}, Opts]) ->
+%%%===================================================================
+%%% xmpp_stream_in callbacks
+%%%===================================================================
+init([#{socket := Socket} = State, Opts]) ->
     ?INFO_MSG("(~w) External service connected", [Socket]),
-    Access = case lists:keysearch(access, 1, Opts) of
-              {value, {_, A}} -> A;
-              _ -> all
-            end,
+    Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all),
+    Shaper = gen_mod:get_opt(shaper_rule, Opts, fun acl:shaper_rules_validator/1, none),
     HostOpts = case lists:keyfind(hosts, 1, Opts) of
                   {hosts, HOpts} ->
                       lists:foldl(
@@ -107,252 +87,120 @@ init([{SockMod, Socket}, Opts]) ->
                                p1_sha:sha(randoms:bytes(20))),
                       dict:from_list([{global, Pass}])
               end,
-    Shaper = case lists:keysearch(shaper_rule, 1, Opts) of
-              {value, {_, S}} -> S;
-              _ -> none
-            end,
-    CheckFrom = case lists:keysearch(service_check_from, 1,
-                                    Opts)
-                   of
-                 {value, {_, CF}} -> CF;
-                 _ -> true
-               end,
-    SockMod:change_shaper(Socket, Shaper),
-    {ok, wait_for_stream,
-     #state{socket = Socket, sockmod = SockMod,
-           streamid = new_id(), host_opts = HostOpts,
-           access = Access, check_from = CheckFrom}}.
-
-wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
-    try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
-       #stream_start{xmlns = NS_COMPONENT, stream_xmlns = NS_STREAM}
-          when NS_COMPONENT /= ?NS_COMPONENT; NS_STREAM /= ?NS_STREAM ->
-            send_header(StateData, ?MYNAME),
-            send_element(StateData, xmpp:serr_invalid_namespace()),
-            {stop, normal, StateData};
-       #stream_start{to = To} when is_record(To, jid) ->
-           Host = To#jid.lserver,
-           send_header(StateData, Host),
-           HostOpts = case dict:is_key(Host, StateData#state.host_opts) of
-                          true ->
-                              StateData#state.host_opts;
-                          false ->
-                              case dict:find(global, StateData#state.host_opts) of
-                                  {ok, GlobalPass} ->
-                                      dict:from_list([{Host, GlobalPass}]);
-                                  error ->
-                                      StateData#state.host_opts
-                              end
-                      end,
-           {next_state, wait_for_handshake,
-            StateData#state{host = Host, host_opts = HostOpts}};
-       #stream_start{} ->
-           send_header(StateData, ?MYNAME),
-           send_element(StateData, xmpp:serr_improper_addressing()),
-           {stop, normal, StateData};
-       _ ->
-           send_header(StateData, ?MYNAME),
-           send_element(StateData, xmpp:serr_invalid_xml()),
-           {stop, normal, StateData}
-    catch _:{xmpp_codec, Why} ->
-           Txt = xmpp:format_error(Why),
-           send_header(StateData, ?MYNAME),
-           send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)),
-           {stop, normal, StateData}
-    end;
-wait_for_stream({xmlstreamerror, _}, StateData) ->
-    send_header(StateData, ?MYNAME),
-    send_element(StateData, xmpp:serr_not_well_formed()),
-    {stop, normal, StateData};
-wait_for_stream(closed, StateData) ->
-    {stop, normal, StateData}.
-
-wait_for_handshake({xmlstreamelement, El}, StateData) ->
-    decode_element(El, wait_for_handshake, StateData);
-wait_for_handshake(#handshake{data = Digest}, StateData) ->
-    case dict:find(StateData#state.host, StateData#state.host_opts) of
+    CheckFrom = gen_mod:get_opt(check_from, Opts,
+                               fun(Flag) when is_boolean(Flag) -> Flag end),
+    xmpp_stream_in:change_shaper(State, Shaper),
+    State1 = State#{access => Access,
+                   xmlns => ?NS_COMPONENT,
+                   lang => ?MYLANG,
+                   server => ?MYNAME,
+                   host_opts => HostOpts,
+                   check_from => CheckFrom},
+    ejabberd_hooks:run_fold(component_init, {ok, State1}, []).
+
+handle_stream_start(#{remote_server := RemoteServer,
+                     host_opts := HostOpts} = State) ->
+    NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of
+                     true ->
+                         HostOpts;
+                     false ->
+                         case dict:find(global, HostOpts) of
+                             {ok, GlobalPass} ->
+                                 dict:from_list([{RemoteServer, GlobalPass}]);
+                             error ->
+                                 HostOpts
+                         end
+                 end,
+    {noreply, State#{host_opts => NewHostOpts}}.
+
+handshake(Digest, #{remote_server := RemoteServer,
+                   stream_id := StreamID,
+                   host_opts := HostOpts} = State) ->
+    case dict:find(RemoteServer, HostOpts) of
        {ok, Password} ->
-           case p1_sha:sha(<<(StateData#state.streamid)/binary,
-                             Password/binary>>) of
+           case p1_sha:sha(<<StreamID/binary, Password/binary>>) of
                Digest ->
-                   send_element(StateData, #handshake{}),
                    lists:foreach(
                      fun (H) ->
                              ejabberd_router:register_route(H, ?MYNAME),
-                             ?INFO_MSG("Route registered for service ~p~n",
-                                       [H]),
+                             ?INFO_MSG("Route registered for service ~p~n", [H]),
                              ejabberd_hooks:run(component_connected, [H])
-                     end, dict:fetch_keys(StateData#state.host_opts)),
-                   {next_state, stream_established, StateData};
-               _ ->
-                   send_element(StateData, xmpp:serr_not_authorized()),
-                   {stop, normal, StateData}
-           end;
-       _ ->
-           send_element(StateData, xmpp:serr_not_authorized()),
-           {stop, normal, StateData}
-    end;
-wait_for_handshake({xmlstreamend, _Name}, StateData) ->
-    {stop, normal, StateData};
-wait_for_handshake({xmlstreamerror, _}, StateData) ->
-    send_element(StateData, xmpp:serr_not_well_formed()),
-    {stop, normal, StateData};
-wait_for_handshake(closed, StateData) ->
-    {stop, normal, StateData};
-wait_for_handshake(_Pkt, StateData) ->
-    {next_state, wait_for_handshake, StateData}.
-
-stream_established({xmlstreamelement, El}, StateData) ->
-    decode_element(El, stream_established, StateData);
-stream_established(El, StateData) when ?is_stanza(El) ->
-    From = xmpp:get_from(El),
-    To = xmpp:get_to(El),
-    Lang = xmpp:get_lang(El),
-    if From == undefined orelse To == undefined ->
-           Txt = <<"Missing 'from' or 'to' attribute">>,
-           send_error(StateData, El, xmpp:err_jid_malformed(Txt, Lang));
-       true ->
-           case check_from(From, StateData) of
-               true ->
-                   ejabberd_router:route(From, To, El);
-               false ->
-                   Txt = <<"Improper domain part of 'from' attribute">>,
-                   send_error(StateData, El, xmpp:err_not_allowed(Txt, Lang))
-           end
-    end,
-    {next_state, stream_established, StateData};
-stream_established({xmlstreamend, _Name}, StateData) ->
-    {stop, normal, StateData};
-stream_established({xmlstreamerror, _}, StateData) ->
-    send_element(StateData, xmpp:serr_not_well_formed()),
-    {stop, normal, StateData};
-stream_established(closed, StateData) ->
-    {stop, normal, StateData};
-stream_established(_Event, StateData) ->
-    {next_state, stream_established, StateData}.
+                     end, dict:fetch_keys(HostOpts)),
+                   {ok, State};
+               _ ->
+                   ?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]),
+                   {error, xmpp:serr_not_authorized(), State}
+           end;
+       _ ->
+           ?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]),
+           {error, xmpp:serr_not_authorized(), State}
+    end.
 
-handle_event(_Event, StateName, StateData) ->
-    {next_state, StateName, StateData}.
+handle_authenticated_packet(Pkt, #{lang := Lang} = State) ->
+    From = xmpp:get_from(Pkt),
+    case check_from(From, State) of
+       true ->
+           To = xmpp:get_to(Pkt),
+           ejabberd_router:route(From, To, Pkt),
+           {noreply, State};
+       false ->
+           Txt = <<"Improper domain part of 'from' attribute">>,
+           Err = xmpp:serr_invalid_from(Txt, Lang),
+           xmpp_stream_in:send(State, Err)
+    end.
 
-handle_sync_event(_Event, _From, StateName,
-                 StateData) ->
-    Reply = ok, {reply, Reply, StateName, StateData}.
+handle_call(_Request, _From, State) ->
+    Reply = ok,
+    {reply, Reply, State}.
 
-code_change(_OldVsn, StateName, StateData, _Extra) ->
-    {ok, StateName, StateData}.
+handle_cast(_Msg, State) ->
+    {noreply, State}.
 
-handle_info({send_text, Text}, StateName, StateData) ->
-    send_text(StateData, Text),
-    {next_state, StateName, StateData};
-handle_info({send_element, El}, StateName, StateData) ->
-    send_element(StateData, El),
-    {next_state, StateName, StateData};
-handle_info({route, From, To, Packet}, StateName,
-           StateData) ->
-    case acl:match_rule(global, StateData#state.access, From) of
-      allow ->
+handle_info({route, From, To, Packet}, #{access := Access} = State) ->
+    case acl:match_rule(global, Access, From) of
+       allow ->
            Pkt = xmpp:set_from_to(Packet, From, To),
-           send_element(StateData, Pkt);
+           xmpp_stream_in:send(State, Pkt);
        deny ->
            Lang = xmpp:get_lang(Packet),
            Err = xmpp:err_not_allowed(<<"Denied by ACL">>, Lang),
-           ejabberd_router:route_error(To, From, Packet, Err)
-    end,
-    {next_state, StateName, StateData};
-handle_info(Info, StateName, StateData) ->
+           ejabberd_router:route_error(To, From, Packet, Err),
+           {noreply, State}
+    end;
+handle_info(Info, State) ->
     ?ERROR_MSG("Unexpected info: ~p", [Info]),
-    {next_state, StateName, StateData}.
-
-terminate(Reason, StateName, StateData) ->
-    ?INFO_MSG("terminated: ~p", [Reason]),
-    case StateName of
-      stream_established ->
-         lists:foreach(fun (H) ->
-                               ejabberd_router:unregister_route(H),
-                               ejabberd_hooks:run(component_disconnected,
-                                                  [H, Reason])
-                       end,
-                       dict:fetch_keys(StateData#state.host_opts));
-      _ -> ok
-    end,
-    catch send_trailer(StateData),
-    (StateData#state.sockmod):close(StateData#state.socket),
-    ok.
-
-%%----------------------------------------------------------------------
-%% Func: print_state/1
-%% Purpose: Prepare the state to be printed on error log
-%% Returns: State to print
-%%----------------------------------------------------------------------
-print_state(State) -> State.
-
-%%%----------------------------------------------------------------------
-%%% Internal functions
-%%%----------------------------------------------------------------------
-
--spec send_text(state(), iodata()) -> ok.
-send_text(StateData, Text) ->
-    (StateData#state.sockmod):send(StateData#state.socket,
-                                  Text).
-
--spec send_element(state(), xmpp_element()) -> ok.
-send_element(StateData, El) ->
-    El1 = xmpp:encode(El, ?NS_COMPONENT),
-    send_text(StateData, fxml:element_to_binary(El1)).
-
--spec send_error(state(), xmlel() | stanza(), stanza_error()) -> ok.
-send_error(StateData, Stanza, Error) ->
-    Type = xmpp:get_type(Stanza),
-    if Type == error; Type == result;
-       Type == <<"error">>; Type == <<"result">> ->
-           ok;
-       true ->
-           send_element(StateData, xmpp:make_error(Stanza, Error))
+    {noreply, State}.
+
+terminate(Reason, #{stream_state := StreamState, host_opts := HostOpts}) ->
+    ?INFO_MSG("External service disconnected: ~p", [Reason]),
+    case StreamState of
+       session_established ->
+           lists:foreach(
+             fun(H) ->
+                     ejabberd_router:unregister_route(H),
+                     ejabberd_hooks:run(component_disconnected, [H, Reason])
+             end, dict:fetch_keys(HostOpts));
+       _ ->
+           ok
     end.
 
--spec send_header(state(), binary()) -> ok.
-send_header(StateData, Host) ->
-    Header = xmpp:encode(
-              #stream_start{xmlns = ?NS_COMPONENT,
-                            stream_xmlns = ?NS_STREAM,
-                            from = jid:make(Host),
-                            id = StateData#state.streamid}),
-    send_text(StateData, fxml:element_to_header(Header)).
-
--spec send_trailer(state()) -> ok.
-send_trailer(StateData) ->
-    send_text(StateData, <<"</stream:stream>">>).
-
--spec decode_element(xmlel(), state_name(), state()) -> fsm_transition().
-decode_element(#xmlel{} = El, StateName, StateData) ->
-    try xmpp:decode(El, ?NS_COMPONENT, [ignore_els]) of
-       Pkt -> ?MODULE:StateName(Pkt, StateData)
-    catch error:{xmpp_codec, Why} ->
-            case xmpp:is_stanza(El) of
-                true ->
-                    Lang = xmpp:get_lang(El),
-                    Txt = xmpp:format_error(Why),
-                    send_error(StateData, El, xmpp:err_bad_request(Txt, Lang));
-                false ->
-                    ok
-            end,
-            {next_state, StateName, StateData}
-    end.
+code_change(_OldVsn, State, _Extra) ->
+    {ok, State}.
 
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
 -spec check_from(jid(), state()) -> boolean().
-check_from(_From, #state{check_from = false}) ->
+check_from(_From, #{check_from := false}) ->
     %% If the admin does not want to check the from field
     %% when accept packets from any address.
     %% In this case, the component can send packet of
     %% behalf of the server users.
     true;
-check_from(From, StateData) ->
+check_from(From, #{host_opts := HostOpts}) ->
     %% The default is the standard behaviour in XEP-0114
     Server = From#jid.lserver,
-    dict:is_key(Server, StateData#state.host_opts).
-
--spec new_id() -> binary().
-new_id() -> randoms:get_string().
+    dict:is_key(Server, HostOpts).
 
 transform_listen_option({hosts, Hosts, O}, Opts) ->
     case lists:keyfind(hosts, 1, Opts) of
index 6294a7893ee3683b99c92d77f876729e95b36095..1307f9da42ac780ebd7686ff9bed03d880b6f6e5 100644 (file)
 -type next_state() :: {noreply, state()} | {stop, term(), state()}.
 
 -callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
+-callback handle_stream_start(state()) -> next_state().
+-callback handle_stream_end(state()) -> next_state().
+-callback handle_stream_close(state()) -> next_state().
+-callback handle_cdata(binary(), state()) -> next_state().
+-callback handle_unauthenticated_packet(xmpp_element(), state()) -> next_state().
 -callback handle_authenticated_packet(xmpp_element(), state()) -> next_state().
+-callback handle_unbinded_packet(xmpp_element(), state()) -> next_state().
+-callback handle_auth_success(binary(), binary(), module(), state()) -> next_state().
+-callback handle_auth_failure(binary(), binary(), atom(), state()) -> next_state().
+-callback handle_send(ok | {error, atom()},
+                     xmpp_element(), fxml:xmlel(), binary(), state()) -> next_state().
+-callback init_sasl(state()) -> cyrsasl:sasl_state().
+-callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}.
+-callback handshake(binary(), state()) -> {ok, state()} | {error, stream_error(), state()}.
+-callback compress_methods(state()) -> [binary()].
+-callback tls_options(state()) -> [proplists:property()].
+-callback tls_required(state()) -> boolean().
+-callback sasl_mechanisms(state()) -> [binary()].
+-callback unauthenticated_stream_features(state()) -> [xmpp_element()].
+-callback authenticated_stream_features(state()) -> [xmpp_element()].
+
+%% All callbacks are optional
+-optional_callbacks([init/1,
+                    handle_stream_start/1,
+                    handle_stream_end/1,
+                    handle_stream_close/1,
+                    handle_cdata/2,
+                    handle_authenticated_packet/2,
+                    handle_unauthenticated_packet/2,
+                    handle_unbinded_packet/2,
+                    handle_auth_success/4,
+                    handle_auth_failure/4,
+                    handle_send/5,
+                    init_sasl/1,
+                    bind/2,
+                    handshake/2,
+                    compress_methods/1,
+                    tls_options/1,
+                    tls_required/1,
+                    sasl_mechanisms/1,
+                    unauthenticated_stream_features/1,
+                    authenticated_stream_features/1]).
 
 %%%===================================================================
 %%% API
@@ -94,21 +135,28 @@ init([Module, {SockMod, Socket}, Opts]) ->
                      user => <<"">>,
                      server => <<"">>,
                      resource => <<"">>,
+                     lserver => <<"">>,
                      ip => IP},
-           Module:init([State, Opts]);
+           try Module:init([State, Opts])
+           catch _:undef -> {ok, State}
+           end;
        {error, Reason} ->
            {stop, Reason}
     end.
 
 handle_cast(Cast, #{mod := Mod} = State) ->
-    Mod:handle_cast(Cast, State).
+    try Mod:handle_cast(Cast, State)
+    catch _:undef -> {noreply, State}
+    end.
 
 handle_call(Call, From, #{mod := Mod} = State) ->
-    Mod:handle_call(Call, From, State).
+    try Mod:handle_call(Call, From, State)
+    catch _:undef -> {reply, ok, State}
+    end.
 
 handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
-           #{stream_state := wait_for_stream} = State) ->
-    try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
+           #{stream_state := wait_for_stream, xmlns := XMLNS} = State) ->
+    try xmpp:decode(#xmlel{name = Name, attrs = Attrs}, XMLNS, []) of
        #stream_start{} = Pkt ->
            case send_header(State, Pkt) of
                {noreply, State1} ->
@@ -169,11 +217,15 @@ handle_info({'DOWN', MRef, _Type, _Object, _Info},
     catch _:undef -> {stop, normal, State}
     end;
 handle_info(Info, #{mod := Mod} = State) ->
-    Mod:handle_info(Info, State).
+    try Mod:handle_info(Info, State)
+    catch _:undef -> {noreply, State}
+    end.
 
 terminate(Reason, #{mod := Mod, socket := Socket,
                    sockmod := SockMod} = State) ->
-    Mod:terminate(Reason, State),
+    try Mod:terminate(Reason, State)
+    catch _:undef -> ok
+    end,
     send_text(State, <<"</stream:stream>">>),
     SockMod:close(Socket).
 
@@ -234,13 +286,14 @@ process_stream(#stream_start{to = #jid{lserver = RemoteServer}},
        Err ->
            Err
     end;
-process_stream(#stream_start{to = #jid{server = Server}, from = From},
+process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
+                            from = From},
               #{stream_authenticated := Authenticated,
                 stream_restarted := StreamWasRestarted,
                 mod := Mod, xmlns := NS, resource := Resource,
                 stream_tlsed := TLSEnabled} = State) ->
     case if not StreamWasRestarted ->
-                State1 = State#{server => Server},
+                State1 = State#{server => Server, lserver => LServer},
                 try Mod:handle_stream_start(State1)
                 catch _:undef -> {noreply, State1}
                 end;
@@ -342,10 +395,18 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
                #xmpp_session{} ->
                    send_element(State, xmpp:make_iq_result(Pkt2));
                _ ->
-                   Mod:handle_authenticated_packet(Pkt2, State)
+                   try Mod:handle_authenticated_packet(Pkt2, State)
+                   catch _:undef ->
+                           Err = xmpp:err_service_unavailable(),
+                           send_error(State, Pkt, Err)
+                   end
            end;
        {ok, Pkt2} ->
-           Mod:handle_authenticated_packet(Pkt2, State);
+           try Mod:handle_authenticated_packet(Pkt2, State)
+           catch _:undef ->
+                   Err = xmpp:err_service_unavailable(),
+                   send_error(State, Pkt, Err)
+           end;
        {error, Err} ->
            send_element(State, Err)
     end.
@@ -385,8 +446,15 @@ process_bind(Pkt, #{mod := Mod} = State) ->
            send_error(State, Pkt, Err)
     end.
 
-process_handshake(#handshake{} = Pkt, #{mod := Mod} = State) ->
-    Mod:handle_handshake(Pkt, State).
+process_handshake(#handshake{data = Data}, #{mod := Mod} = State) ->
+    case Mod:handshake(Data, State) of
+       {ok, State1} ->
+           State2 = State1#{stream_state => session_established,
+                            stream_authenticated => true},
+           send_element(State2, #handshake{});
+       {error, #stream_error{} = Err, State1} ->
+           send_element(State1, Err)
+    end.
 
 process_compress(#compress{}, #{stream_compressed := true} = State) ->
     send_element(State, #compress_failure{reason = 'setup-failed'});
@@ -436,9 +504,13 @@ process_sasl_request(#sasl_auth{mechanism = <<"EXTERNAL">>},
     process_sasl_failure('encryption-required', <<"">>, State);
 process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
                     #{mod := Mod} = State) ->
-    SASLState = Mod:init_sasl(State),
-    SASLResult = cyrsasl:server_start(SASLState, Mech, ClientIn),
-    process_sasl_result(SASLResult, State).
+    try Mod:init_sasl(State) of
+       SASLState ->
+           SASLResult = cyrsasl:server_start(SASLState, Mech, ClientIn),
+           process_sasl_result(SASLResult, State)
+    catch _:undef ->
+           process_sasl_failure('temporary-auth-failure', <<"">>, State)
+    end.
 
 process_sasl_response(#sasl_response{text = ClientIn},
                      #{sasl_state := SASLState} = State) ->