]> granicus.if.org Git - ejabberd/commitdiff
Adopt remaining code to support new hooks
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Mon, 9 Jan 2017 14:02:17 +0000 (17:02 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Mon, 9 Jan 2017 14:02:17 +0000 (17:02 +0300)
42 files changed:
src/cyrsasl.erl
src/cyrsasl_digest.erl
src/ejabberd_c2s.erl
src/ejabberd_hooks.erl
src/ejabberd_http_ws.erl
src/ejabberd_listener.erl
src/ejabberd_local.erl
src/ejabberd_piefxis.erl
src/ejabberd_receiver.erl
src/ejabberd_router.erl
src/ejabberd_s2s.erl
src/ejabberd_s2s_in.erl
src/ejabberd_s2s_out.erl
src/ejabberd_service.erl
src/ejabberd_sm.erl
src/ejabberd_socket.erl
src/mod_announce.erl
src/mod_blocking.erl
src/mod_caps.erl
src/mod_carboncopy.erl
src/mod_client_state.erl
src/mod_disco.erl
src/mod_fail2ban.erl
src/mod_http_fileserver.erl
src/mod_last.erl
src/mod_mam.erl
src/mod_metrics.erl
src/mod_offline.erl
src/mod_ping.erl
src/mod_pres_counter.erl
src/mod_privacy.erl
src/mod_privilege.erl
src/mod_pubsub.erl
src/mod_roster.erl
src/mod_s2s_dialback.erl
src/mod_service_log.erl
src/mod_shared_roster.erl
src/mod_shared_roster_ldap.erl
src/mod_sm.erl
src/mod_vcard_xupdate.erl
src/xmpp_stream_in.erl
src/xmpp_stream_out.erl

index 1edf44678fc90fc1ffd448f36eb04ffe9e1cd670..5c7eb7edb51146c3860061e15bd1f22b6a0c733b 100644 (file)
 
 -module(cyrsasl).
 
--behaviour(ejabberd_config).
-
 -author('alexey@process-one.net').
 
 -export([start/0, register_mechanism/3, listmech/1,
         server_new/7, server_start/3, server_step/2,
-        get_mech/1, format_error/2, opt_type/1]).
+        get_mech/1, format_error/2]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
@@ -113,15 +111,9 @@ format_error(Mech, Reason) ->
                         PasswordType :: password_type()) -> any().
 
 register_mechanism(Mechanism, Module, PasswordType) ->
-    case is_disabled(Mechanism) of
-      false ->
-         ets:insert(sasl_mechanism,
-                    #sasl_mechanism{mechanism = Mechanism, module = Module,
-                                    password_type = PasswordType});
-      true ->
-         ?DEBUG("SASL mechanism ~p is disabled", [Mechanism]),
-         true
-    end.
+    ets:insert(sasl_mechanism,
+              #sasl_mechanism{mechanism = Mechanism, module = Module,
+                              password_type = PasswordType}).
 
 check_credentials(_State, Props) ->
     User = proplists:get_value(authzid, Props, <<>>),
@@ -134,20 +126,19 @@ check_credentials(_State, Props) ->
 -spec listmech(Host ::binary()) -> Mechanisms::mechanisms().
 
 listmech(Host) ->
-    Mechs = ets:select(sasl_mechanism,
-                      [{#sasl_mechanism{mechanism = '$1',
-                                        password_type = '$2', _ = '_'},
-                        case catch ejabberd_auth:store_type(Host) of
-                          external -> [{'==', '$2', plain}];
-                          scram -> [{'/=', '$2', digest}];
-                          {'EXIT', {undef, [{Module, store_type, []} | _]}} ->
-                              ?WARNING_MSG("~p doesn't implement the function store_type/0",
-                                           [Module]),
-                              [];
-                          _Else -> []
-                        end,
-                        ['$1']}]),
-    filter_anonymous(Host, Mechs).
+    ets:select(sasl_mechanism,
+              [{#sasl_mechanism{mechanism = '$1',
+                                password_type = '$2', _ = '_'},
+                case catch ejabberd_auth:store_type(Host) of
+                    external -> [{'==', '$2', plain}];
+                    scram -> [{'/=', '$2', digest}];
+                    {'EXIT', {undef, [{Module, store_type, []} | _]}} ->
+                        ?WARNING_MSG("~p doesn't implement the function store_type/0",
+                                     [Module]),
+                        [];
+                    _Else -> []
+                end,
+                ['$1']}]).
 
 -spec server_new(binary(), binary(), binary(), term(),
                 fun(), fun(), fun()) -> sasl_state().
@@ -206,33 +197,3 @@ server_step(State, ClientIn) ->
 -spec get_mech(sasl_state()) -> binary().
 get_mech(#sasl_state{mech_name = Mech}) ->
     Mech.
-
-%% Remove the anonymous mechanism from the list if not enabled for the given
-%% host
-%%
--spec filter_anonymous(Host :: binary(), Mechs :: mechanisms()) -> mechanisms().
-
-filter_anonymous(Host, Mechs) ->
-    case ejabberd_auth_anonymous:is_sasl_anonymous_enabled(Host) of
-      true  -> Mechs;
-      false -> Mechs -- [<<"ANONYMOUS">>]
-    end.
-
--spec is_disabled(Mechanism :: mechanism()) -> boolean().
-
-is_disabled(Mechanism) ->
-    Disabled = ejabberd_config:get_option(
-                disable_sasl_mechanisms,
-                fun(V) when is_list(V) ->
-                        lists:map(fun(M) -> str:to_upper(M) end, V);
-                   (V) ->
-                        [str:to_upper(V)]
-                end, []),
-    lists:member(Mechanism, Disabled).
-
-opt_type(disable_sasl_mechanisms) ->
-    fun (V) when is_list(V) ->
-           lists:map(fun (M) -> str:to_upper(M) end, V);
-       (V) -> [str:to_upper(V)]
-    end;
-opt_type(_) -> [disable_sasl_mechanisms].
index 9b4faca204ef87e2c216a28a73ae2e568d450d9e..39055f2b177d4a0eb5ce3a67d9ed9a7c542eda5b 100644 (file)
@@ -59,7 +59,7 @@
 
 start(_Opts) ->
     Fqdn = get_local_fqdn(),
-    ?INFO_MSG("FQDN used to check DIGEST-MD5 SASL authentication: ~p",
+    ?INFO_MSG("FQDN used to check DIGEST-MD5 SASL authentication: ~s",
              [Fqdn]),
     cyrsasl:register_mechanism(<<"DIGEST-MD5">>, ?MODULE,
                               digest).
index a10ee59a5f73bd2588ddc6478d5716e4c8a23306..007a94dc944e13123ece41e7e089175dc5994ac8 100644 (file)
@@ -33,9 +33,9 @@
 %% xmpp_stream_in callbacks
 -export([init/1, handle_call/3, handle_cast/2,
         handle_info/2, terminate/2, code_change/3]).
--export([tls_options/1, tls_required/1, tls_verify/1,
-        compress_methods/1, bind/2, get_password_fun/1,
-        check_password_fun/1, check_password_digest_fun/1,
+-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
+        compress_methods/1, bind/2, sasl_mechanisms/2,
+        get_password_fun/1, check_password_fun/1, check_password_digest_fun/1,
         unauthenticated_stream_features/1, authenticated_stream_features/1,
         handle_stream_start/2, handle_stream_end/2,
         handle_unauthenticated_packet/2, handle_authenticated_packet/2,
@@ -47,7 +47,7 @@
         process_terminated/2, process_info/2]).
 %% API
 -export([get_presence/1, get_subscription/2, get_subscribed/1,
-        open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1,
+        open_session/1, call/3, send/2, close/1, close/2, stop/1,
         reply/2, copy_state/2, set_timeout/2, add_hooks/1]).
 
 -include("ejabberd.hrl").
@@ -73,6 +73,9 @@ start_link(SockData, Opts) ->
 socket_type() ->
     xml_stream.
 
+%%%===================================================================
+%%% Common API
+%%%===================================================================
 -spec call(pid(), term(), non_neg_integer() | infinity) -> term().
 call(Ref, Msg, Timeout) ->
     xmpp_stream_in:call(Ref, Msg, Timeout).
@@ -116,19 +119,16 @@ stop(Ref) ->
 send(Pid, Pkt) when is_pid(Pid) ->
     xmpp_stream_in:send(Pid, Pkt);
 send(#{lserver := LServer} = State, Pkt) ->
-    case ejabberd_hooks:run_fold(c2s_filter_send, LServer, {Pkt, State}, []) of
+    Pkt1 = fix_from_to(Pkt, State),
+    case ejabberd_hooks:run_fold(c2s_filter_send, LServer, {Pkt1, State}, []) of
        {drop, State1} -> State1;
-       {Pkt1, State1} -> xmpp_stream_in:send(State1, Pkt1)
+       {Pkt2, State1} -> xmpp_stream_in:send(State1, Pkt2)
     end.
 
 -spec set_timeout(state(), timeout()) -> state().
 set_timeout(State, Timeout) ->
     xmpp_stream_in:set_timeout(State, Timeout).
 
--spec establish(state()) -> state().
-establish(State) ->
-    xmpp_stream_in:establish(State).
-
 -spec add_hooks(binary()) -> ok.
 add_hooks(Host) ->
     ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100),
@@ -162,7 +162,7 @@ copy_state(#{owner := Owner} = NewState,
                     auth_module => AuthModule,
                     pres_t => PresT, pres_a => PresA,
                     pres_f => PresF},
-    ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]).
+    ejabberd_hooks:run_fold(c2s_copy_session, LServer, State2, [OldState]).
 
 -spec open_session(state()) -> {ok, state()} | state().
 open_session(#{user := U, server := S, resource := R,
@@ -195,14 +195,22 @@ process_info(#{lserver := LServer} = State,
                             process_iq_in(State, Packet)
                     end,
     if Pass ->
-           Packet1 = ejabberd_hooks:run_fold(
-                       user_receive_packet, LServer, Packet, [State1]),
-           ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
-           send(State1, Packet1);
+           {Packet1, State2} = ejabberd_hooks:run_fold(
+                                 user_receive_packet, LServer,
+                                 {Packet, State1}, []),
+           case Packet1 of
+               drop -> State2;
+               _ -> send(State2, Packet1)
+           end;
        true ->
-           ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
            State1
     end;
+process_info(State, force_update_presence) ->
+    try maps:get(pres_last, State) of
+       Pres -> process_self_presence(State, Pres)
+    catch _:{badkey, _} ->
+           State
+    end;
 process_info(State, Info) ->
     ?WARNING_MSG("got unexpected info: ~p", [Info]),
     State.
@@ -218,15 +226,21 @@ reject_unauthenticated_packet(State, Pkt) ->
 process_closed(State, Reason) ->
     stop(State#{stop_reason => Reason}).
 
-process_terminated(#{socket := Socket, jid := JID} = State,
+process_terminated(#{sockmod := SockMod, socket := Socket, jid := JID} = State,
                   Reason) ->
     Status = format_reason(State, Reason),
     ?INFO_MSG("(~s) Closing c2s session for ~s: ~s",
-             [ejabberd_socket:pp(Socket), jid:to_string(JID), Status]),
-    Pres = #presence{type = unavailable,
-                    status = xmpp:mk_text(Status),
-                    from = JID, to = jid:remove_resource(JID)},
-    State1 = broadcast_presence_unavailable(State, Pres),
+             [SockMod:pp(Socket), jid:to_string(JID), Status]),
+    State1 = case maps:is_key(pres_last, State) of
+                true ->
+                    Pres = #presence{type = unavailable,
+                                     status = xmpp:mk_text(Status),
+                                     from = JID,
+                                     to = jid:remove_resource(JID)},
+                    broadcast_presence_unavailable(State, Pres);
+                false ->
+                    State
+            end,
     bounce_message_queue(),
     State1;
 process_terminated(State, _Reason) ->
@@ -235,13 +249,51 @@ process_terminated(State, _Reason) ->
 %%%===================================================================
 %%% xmpp_stream_in callbacks
 %%%===================================================================
-tls_options(#{lserver := LServer, tls_options := TLSOpts}) ->
-    case ejabberd_config:get_option({domain_certfile, LServer},
-                                   fun iolist_to_binary/1) of
-       undefined ->
-           TLSOpts;
-       CertFile ->
-           lists:keystore(certfile, 1, TLSOpts, {certfile, CertFile})
+tls_options(#{lserver := LServer, tls_options := DefaultOpts}) ->
+    TLSOpts1 = case ejabberd_config:get_option(
+                     {c2s_certfile, LServer},
+                     fun iolist_to_binary/1,
+                     ejabberd_config:get_option(
+                       {domain_certfile, LServer},
+                       fun iolist_to_binary/1)) of
+                  undefined -> [];
+                  CertFile -> lists:keystore(certfile, 1, DefaultOpts,
+                                             {certfile, CertFile})
+              end,
+    TLSOpts2 = case ejabberd_config:get_option(
+                      {c2s_ciphers, LServer},
+                     fun iolist_to_binary/1) of
+                   undefined -> TLSOpts1;
+                   Ciphers -> lists:keystore(ciphers, 1, TLSOpts1,
+                                            {ciphers, Ciphers})
+               end,
+    TLSOpts3 = case ejabberd_config:get_option(
+                      {c2s_protocol_options, LServer},
+                      fun (Options) -> str:join(Options, <<$|>>) end) of
+                   undefined -> TLSOpts2;
+                   ProtoOpts -> lists:keystore(protocol_options, 1, TLSOpts2,
+                                              {protocol_options, ProtoOpts})
+               end,
+    TLSOpts4 = case ejabberd_config:get_option(
+                      {c2s_dhfile, LServer},
+                     fun iolist_to_binary/1) of
+                   undefined -> TLSOpts3;
+                   DHFile -> lists:keystore(dhfile, 1, TLSOpts3,
+                                           {dhfile, DHFile})
+               end,
+    TLSOpts5 = case ejabberd_config:get_option(
+                     {c2s_cafile, LServer},
+                     fun iolist_to_binary/1) of
+                  undefined -> TLSOpts4;
+                  CAFile -> lists:keystore(cafile, 1, TLSOpts4,
+                                           {cafile, CAFile})
+              end,
+    case ejabberd_config:get_option(
+          {c2s_tls_compression, LServer},
+          fun(B) when is_boolean(B) -> B end) of
+       undefined -> TLSOpts5;
+       false -> [compression_none | TLSOpts5];
+       true -> lists:delete(compression_none, TLSOpts5)
     end.
 
 tls_required(#{tls_required := TLSRequired}) ->
@@ -250,6 +302,11 @@ tls_required(#{tls_required := TLSRequired}) ->
 tls_verify(#{tls_verify := TLSVerify}) ->
     TLSVerify.
 
+tls_enabled(#{tls_enabled := TLSEnabled,
+             tls_required := TLSRequired,
+             tls_verify := TLSVerify}) ->
+    TLSEnabled or TLSRequired or TLSVerify.
+
 compress_methods(#{zlib := true}) ->
     [<<"zlib">>];
 compress_methods(_) ->
@@ -261,6 +318,20 @@ unauthenticated_stream_features(#{lserver := LServer}) ->
 authenticated_stream_features(#{lserver := LServer}) ->
     ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]).
 
+sasl_mechanisms(Mechs, #{lserver := LServer}) ->
+    Mechs1 = ejabberd_config:get_option(
+              {disable_sasl_mechanisms, LServer},
+              fun(V) when is_list(V) ->
+                      lists:map(fun(M) -> str:to_upper(M) end, V);
+                 (V) ->
+                      [str:to_upper(V)]
+              end, []),
+    Mechs2 = case ejabberd_auth_anonymous:is_sasl_anonymous_enabled(LServer) of
+                true -> Mechs1;
+                false -> [<<"ANONYMOUS">>|Mechs1]
+            end,
+    Mechs -- Mechs2.
+
 get_password_fun(#{lserver := LServer}) ->
     fun(U) ->
            ejabberd_auth:get_password_with_authmodule(U, LServer)
@@ -279,7 +350,8 @@ check_password_digest_fun(#{lserver := LServer}) ->
 bind(<<"">>, State) ->
     bind(new_uniq_id(), State);
 bind(R, #{user := U, server := S, access := Access, lang := Lang,
-         lserver := LServer, socket := Socket, ip := IP} = State) ->
+         lserver := LServer, sockmod := SockMod, socket := Socket,
+         ip := IP} = State) ->
     case resource_conflict_action(U, S, R) of
        closenew ->
            {error, xmpp:err_conflict(), State};
@@ -289,38 +361,30 @@ bind(R, #{user := U, server := S, access := Access, lang := Lang,
                                    #{usr => jid:split(JID), ip => IP},
                                    LServer) of
                allow ->
-                   State1 = open_session(State#{resource => Resource}),
+                   State1 = open_session(State#{resource => Resource,
+                                                sid => ejabberd_sm:make_sid()}),
                    State2 = ejabberd_hooks:run_fold(
                               c2s_session_opened, LServer, State1, []),
                    ?INFO_MSG("(~s) Opened c2s session for ~s",
-                             [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
+                             [SockMod:pp(Socket), jid:to_string(JID)]),
                    {ok, State2};
                deny ->
                    ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]),
                    ?INFO_MSG("(~s) Forbidden c2s session for ~s",
-                             [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
+                             [SockMod:pp(Socket), jid:to_string(JID)]),
                    Txt = <<"Denied by ACL">>,
                    {error, xmpp:err_not_allowed(Txt, Lang), State}
            end
     end.
 
-handle_stream_start(StreamStart,
-                   #{lserver := LServer, ip := IP, lang := Lang} = State) ->
+handle_stream_start(StreamStart, #{lserver := LServer} = State) ->
     case ejabberd_router:is_my_host(LServer) of
        false ->
            send(State, xmpp:serr_host_unknown());
        true ->
-           case check_bl_c2s(IP, Lang) of
-               false ->
-                   change_shaper(State),
-                   ejabberd_hooks:run_fold(
-                     c2s_stream_started, LServer, State, [StreamStart]);
-               {true, LogReason, ReasonT} ->
-                   ?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s",
-                             [jlib:ip_to_list(IP), LogReason]),
-                   Err = xmpp:serr_policy_violation(ReasonT, Lang),
-                   send(State, Err)
-           end
+           change_shaper(State),
+           ejabberd_hooks:run_fold(
+             c2s_stream_started, LServer, State, [StreamStart])
     end.
 
 handle_stream_end(Reason, #{lserver := LServer} = State) ->
@@ -328,18 +392,20 @@ handle_stream_end(Reason, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_closed, LServer, State1, [Reason]).
 
 handle_auth_success(User, Mech, AuthModule,
-                   #{socket := Socket, ip := IP, lserver := LServer} = State) ->
+                   #{socket := Socket, sockmod := SockMod,
+                     ip := IP, lserver := LServer} = State) ->
     ?INFO_MSG("(~s) Accepted c2s ~s authentication for ~s@~s by ~s backend from ~s",
-             [ejabberd_socket:pp(Socket), Mech, User, LServer,
+             [SockMod:pp(Socket), Mech, User, LServer,
               ejabberd_auth:backend_type(AuthModule),
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
     State1 = State#{auth_module => AuthModule},
     ejabberd_hooks:run_fold(c2s_auth_result, LServer, State1, [true, User]).
 
 handle_auth_failure(User, Mech, Reason,
-                   #{socket := Socket, ip := IP, lserver := LServer} = State) ->
+                   #{socket := Socket, sockmod := SockMod,
+                     ip := IP, lserver := LServer} = State) ->
     ?INFO_MSG("(~s) Failed c2s ~s authentication ~sfrom ~s: ~s",
-             [ejabberd_socket:pp(Socket), Mech,
+             [SockMod:pp(Socket), Mech,
               if User /= <<"">> -> ["for ", User, "@", LServer, " "];
                  true -> ""
               end,
@@ -355,17 +421,22 @@ handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) ->
 handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) ->
     ejabberd_hooks:run_fold(c2s_authenticated_packet,
                            LServer, State, [Pkt]);
-handle_authenticated_packet(Pkt, #{lserver := LServer} = State) ->
+handle_authenticated_packet(Pkt, #{lserver := LServer, jid := JID} = State) ->
     State1 = ejabberd_hooks:run_fold(c2s_authenticated_packet,
                                     LServer, State, [Pkt]),
-    Pkt1 = ejabberd_hooks:run_fold(user_send_packet, LServer, Pkt, [State1]),
+    #jid{luser = LUser} = JID,
+    {Pkt1, State2} = ejabberd_hooks:run_fold(
+                      user_send_packet, LServer, {Pkt, State1}, []),
     case Pkt1 of
-       #presence{to = #jid{lresource = <<"">>}} ->
-           process_self_presence(State1, Pkt1);
+       drop ->
+           State2;
+       #presence{to = #jid{luser = LUser, lserver = LServer,
+                           lresource = <<"">>}} ->
+           process_self_presence(State2, Pkt1);
        #presence{} ->
-           process_presence_out(State1, Pkt1);
+           process_presence_out(State2, Pkt1);
        _ ->
-           check_privacy_then_route(State1, Pkt1)
+           check_privacy_then_route(State2, Pkt1)
     end.
 
 handle_cdata(Data, #{lserver := LServer} = State) ->
@@ -381,22 +452,34 @@ handle_send(Pkt, Result, #{lserver := LServer} = State) ->
 init([State, Opts]) ->
     Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all),
     Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none),
-    TLSOpts = lists:filter(
-               fun({certfile, _}) -> true;
-                  ({ciphers, _}) -> true;
-                  ({dhfile, _}) -> true;
-                  (_) -> false
-               end, Opts),
+    TLSOpts1 = lists:filter(
+                fun({certfile, _}) -> true;
+                   ({ciphers, _}) -> true;
+                   ({dhfile, _}) -> true;
+                   ({cafile, _}) -> true;
+                   (_) -> false
+                end, Opts),
+    TLSOpts2 = case lists:keyfind(protocol_options, 1, Opts) of
+                  false -> TLSOpts1;
+                  {_, OptString} ->
+                      ProtoOpts = str:join(OptString, <<$|>>),
+                      [{protocol_options, ProtoOpts}|TLSOpts1]
+              end,
+    TLSOpts3 = case proplists:get_bool(tls_compression, Opts) of
+                   false -> [compression_none | TLSOpts2];
+                   true -> TLSOpts2
+               end,
+    TLSEnabled = proplists:get_bool(starttls, Opts),
     TLSRequired = proplists:get_bool(starttls_required, Opts),
     TLSVerify = proplists:get_bool(tls_verify, Opts),
     Zlib = proplists:get_bool(zlib, Opts),
-    State1 = State#{tls_options => TLSOpts,
+    State1 = State#{tls_options => TLSOpts3,
                    tls_required => TLSRequired,
+                   tls_enabled => TLSEnabled,
                    tls_verify => TLSVerify,
                    pres_a => ?SETS:new(),
                    pres_f => ?SETS:new(),
                    pres_t => ?SETS:new(),
-                   sid => ejabberd_sm:make_sid(),
                    zlib => Zlib,
                    lang => ?MYLANG,
                    server => ?MYNAME,
@@ -426,12 +509,12 @@ handle_cast(Msg, #{lserver := LServer} = State) ->
 handle_info(Info, #{lserver := LServer} = State) ->
     ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]).
 
-terminate(Reason, #{sid := SID, jid := _,
+terminate(Reason, #{sid := SID,
                    user := U, server := S, resource := R,
                    lserver := LServer} = State) ->
-    Status = format_reason(State, Reason),
     case maps:is_key(pres_last, State) of
        true ->
+           Status = format_reason(State, Reason),
            ejabberd_sm:close_session_unset_presence(SID, U, S, R, Status);
        false ->
            ejabberd_sm:close_session(SID, U, S, R)
@@ -446,11 +529,6 @@ code_change(_OldVsn, State, _Extra) ->
 %%%===================================================================
 %%% Internal functions
 %%%===================================================================
--spec check_bl_c2s({inet:ip_address(), non_neg_integer()}, binary())
-      -> false | {true, binary(), binary()}.
-check_bl_c2s({IP, _Port}, Lang) ->
-    ejabberd_hooks:run_fold(check_bl_c2s, false, [IP, Lang]).
-
 -spec process_iq_in(state(), iq()) -> {boolean(), state()}.
 process_iq_in(State, #iq{} = IQ) ->
     case privacy_check_packet(State, IQ, in) of
@@ -484,7 +562,7 @@ process_presence_in(#{lserver := LServer, pres_a := PresA} = State0,
     State = ejabberd_hooks:run_fold(c2s_presence_in, LServer, State0, [Pres]),
     case T of
        probe ->
-           NewState = do_some_magic(State, From),
+           NewState = add_to_pres_a(State, From),
            route_probe_reply(From, To, NewState),
            {false, NewState};
        error ->
@@ -495,7 +573,7 @@ process_presence_in(#{lserver := LServer, pres_a := PresA} = State0,
                allow when T == error ->
                    {true, State};
                allow ->
-                   NewState = do_some_magic(State, From),
+                   NewState = add_to_pres_a(State, From),
                    {true, NewState};
                deny ->
                    {false, State}
@@ -577,24 +655,27 @@ process_presence_out(#{user := User, server := Server, lserver := LServer,
     end.
 
 -spec process_self_presence(state(), presence()) -> state().
-process_self_presence(#{ip := IP, conn := Conn,
+process_self_presence(#{ip := IP, conn := Conn, lserver := LServer,
                        auth_module := AuthMod, sid := SID,
                        user := U, server := S, resource := R} = State,
                      #presence{type = unavailable} = Pres) ->
     Status = xmpp:get_text(Pres#presence.status),
     Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}],
     ejabberd_sm:unset_presence(SID, U, S, R, Status, Info),
-    State1 = broadcast_presence_unavailable(State, Pres),
-    maps:remove(pres_last, maps:remove(pres_timestamp, State1));
+    {Pres1, State1} = ejabberd_hooks:run_fold(
+                       c2s_self_presence, LServer, {Pres, State}, []),
+    State2 = broadcast_presence_unavailable(State1, Pres1),
+    maps:remove(pres_last, maps:remove(pres_timestamp, State2));
 process_self_presence(#{lserver := LServer} = State,
                      #presence{type = available} = Pres) ->
     PreviousPres = maps:get(pres_last, State, undefined),
     update_priority(State, Pres),
-    State1 = ejabberd_hooks:run_fold(user_available_hook, LServer, State, [Pres]),
-    State2 = State1#{pres_last => Pres,
+    {Pres1, State1} = ejabberd_hooks:run_fold(
+                       c2s_self_presence, LServer, {Pres, State}, []),
+    State2 = State1#{pres_last => Pres1,
                     pres_timestamp => p1_time_compat:timestamp()},
     FromUnavailable = PreviousPres == undefined,
-    broadcast_presence_available(State2, Pres, FromUnavailable);
+    broadcast_presence_available(State2, Pres1, FromUnavailable);
 process_self_presence(State, _Pres) ->
     State.
 
@@ -614,9 +695,9 @@ broadcast_presence_unavailable(#{pres_a := PresA} = State, Pres) ->
 
 -spec broadcast_presence_available(state(), presence(), boolean()) -> state().
 broadcast_presence_available(#{pres_a := PresA, pres_f := PresF,
-                              pres_t := PresT} = State,
+                              pres_t := PresT, jid := JID} = State,
                             Pres, _FromUnavailable = true) ->
-    Probe = #presence{type = probe},
+    Probe = #presence{from = JID, type = probe},
     TJIDs = filter_blocked(State, Probe, PresT),
     FJIDs = filter_blocked(State, Pres, PresF),
     route_multiple(State, TJIDs, Probe),
@@ -739,6 +820,19 @@ get_conn_type(State) ->
        websocket -> websocket
     end.
 
+-spec fix_from_to(xmpp_element(), state()) -> stanza().
+fix_from_to(Pkt, #{jid := JID}) when ?is_stanza(Pkt) ->
+    #jid{luser = U, lserver = S, lresource = R} = JID,
+    From = xmpp:get_from(Pkt),
+    From1 = case jid:tolower(From) of
+               {U, S, R} -> JID;
+               {U, S, _} -> jid:replace_resource(JID, From#jid.resource);
+               _ -> From
+           end,
+    xmpp:set_from_to(Pkt, From1, JID);
+fix_from_to(Pkt, _State) ->
+    Pkt.
+
 -spec change_shaper(state()) -> ok.
 change_shaper(#{shaper := ShaperName, ip := IP, lserver := LServer,
                user := U, server := S, resource := R} = State) ->
@@ -748,8 +842,8 @@ change_shaper(#{shaper := ShaperName, ip := IP, lserver := LServer,
                                LServer),
     xmpp_stream_in:change_shaper(State, Shaper).
 
--spec do_some_magic(state(), jid()) -> state().
-do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) ->
+-spec add_to_pres_a(state(), jid()) -> state().
+add_to_pres_a(#{pres_a := PresA, pres_f := PresF} = State, From) ->
     LFrom = jid:tolower(From),
     LBFrom = jid:remove_resource(LFrom),
     case (?SETS):is_element(LFrom, PresA) orelse
@@ -775,20 +869,41 @@ do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) ->
 -spec format_reason(state(), term()) -> binary().
 format_reason(#{stop_reason := Reason}, _) ->
     xmpp_stream_in:format_error(Reason);
-format_reason(_, Reason) when Reason /= normal ->
-    <<"internal server error">>;
+format_reason(_, normal) ->
+    <<"unknown reason">>;
+format_reason(_, shutdown) ->
+    <<"stopped by supervisor">>;
+format_reason(_, {shutdown, _}) ->
+    <<"stopped by supervisor">>;
 format_reason(_, _) ->
-    <<"">>.
+    <<"internal server error">>.
 
 transform_listen_option(Opt, Opts) ->
     [Opt|Opts].
 
 opt_type(domain_certfile) -> fun iolist_to_binary/1;
+opt_type(c2s_certfile) -> fun iolist_to_binary/1;
+opt_type(c2s_ciphers) -> fun iolist_to_binary/1;
+opt_type(c2s_dhfile) -> fun iolist_to_binary/1;
+opt_type(c2s_cafile) -> fun iolist_to_binary/1;
+opt_type(c2s_protocol_options) ->
+    fun (Options) -> str:join(Options, <<"|">>) end;
+opt_type(c2s_tls_compression) ->
+    fun (true) -> true;
+       (false) -> false
+    end;
 opt_type(resource_conflict) ->
     fun (setresource) -> setresource;
        (closeold) -> closeold;
        (closenew) -> closenew;
        (acceptnew) -> acceptnew
     end;
+opt_type(disable_sasl_mechanisms) ->
+    fun (V) when is_list(V) ->
+           lists:map(fun (M) -> str:to_upper(M) end, V);
+       (V) -> [str:to_upper(V)]
+    end;
 opt_type(_) ->
-    [domain_certfile, resource_conflict].
+    [domain_certfile, c2s_certfile, c2s_ciphers, c2s_cafile,
+     c2s_protocol_options, c2s_tls_compression, resource_conflict,
+     disable_sasl_mechanisms].
index 612d5afe5e7c1a88d930a36f1a4bf90a49b2b38f..f63d1d75cb9890f89fa1c7c5ee408d65f1726c61 100644 (file)
@@ -326,10 +326,9 @@ run1([{_Seq, Node, Module, Function} | Ls], Hook, Args) ->
            run1(Ls, Hook, Args)
     end;
 run1([{_Seq, Module, Function} | Ls], Hook, Args) ->
-    Res = safe_apply(Module, Function, Args),
+    Res = safe_apply(Hook, Module, Function, Args),
     case Res of
-       {'EXIT', Reason} ->
-           ?ERROR_MSG("~p~nrunning hook: ~p", [Reason, {Hook, Args}]),
+       'EXIT' ->
            run1(Ls, Hook, Args);
        stop ->
            ok;
@@ -362,10 +361,9 @@ run_fold1([{_Seq, Node, Module, Function} | Ls], Hook, Val, Args) ->
            run_fold1(Ls, Hook, NewVal, Args)
     end;
 run_fold1([{_Seq, Module, Function} | Ls], Hook, Val, Args) ->
-    Res = safe_apply(Module, Function, [Val | Args]),
+    Res = safe_apply(Hook, Module, Function, [Val | Args]),
     case Res of
-       {'EXIT', Reason} ->
-           ?ERROR_MSG("~p~nrunning hook: ~p", [Reason, {Hook, Args}]),
+       'EXIT' ->
            run_fold1(Ls, Hook, Val, Args);
        stop ->
            stopped;
@@ -375,12 +373,20 @@ run_fold1([{_Seq, Module, Function} | Ls], Hook, Val, Args) ->
            run_fold1(Ls, Hook, NewVal, Args)
     end.
 
-safe_apply(Module, Function, Args) ->
+safe_apply(Hook, Module, Function, Args) ->
     try if is_function(Function) ->
                apply(Function, Args);
           true ->
                apply(Module, Function, Args)
        end
     catch E:R when E /= exit, R /= normal ->
-           {'EXIT', {E, {R, erlang:get_stacktrace()}}}
+           ?ERROR_MSG("Hook ~p crashed when running ~p:~p/~p:~n"
+                      "** Reason = ~p~n"
+                      "** Arguments = ~p",
+                      [Hook, Module, Function, length(Args),
+                       {E, R, get_stacktrace()}, Args]),
+           'EXIT'
     end.
+
+get_stacktrace() ->
+    [{Mod, Fun, Loc, Args} || {Mod, Fun, Args, Loc} <- erlang:get_stacktrace()].
index b92345dd447a82c877da50a7390aeee1fe286cc6..6d90dba4b5724e818f82adf115900269d792b140 100644 (file)
@@ -120,7 +120,7 @@ init([{#ws{ip = IP, http_opts = HOpts}, _} = WS]) ->
                                ({resend_on_timeout, _}) -> true;
                                (_) -> false
                             end, HOpts),
-    Opts = [{xml_socket, true} | ejabberd_c2s_config:get_c2s_limits() ++ SOpts],
+    Opts = ejabberd_c2s_config:get_c2s_limits() ++ SOpts,
     PingInterval = ejabberd_config:get_option(
                      {websocket_ping_interval, ?MYNAME},
                      fun(I) when is_integer(I), I>=0 -> I end,
index f720fc5850c0198215d6dd241a9c12b6e3c7d570..4191b19585fc54f84677622190b6678757ab4628 100644 (file)
@@ -186,7 +186,9 @@ init_tcp(PortIP, Module, Opts, SockOpts, Port, IPS) ->
 listen_tcp(PortIP, Module, SockOpts, Port, IPS) ->
     case ets:lookup(listen_sockets, PortIP) of
        [{PortIP, ListenSocket}] ->
-           ?INFO_MSG("Reusing listening port for ~p", [PortIP]),
+           {_, _, Transport} = PortIP,
+           ?INFO_MSG("Reusing listening ~s port ~p at ~s",
+                     [Transport, Port, IPS]),
            ets:delete(listen_sockets, PortIP),
            ListenSocket;
        _ ->
@@ -330,21 +332,26 @@ accept(ListenSocket, Module, Opts, Interval) ->
        {ok, Socket} ->
            case {inet:sockname(Socket), inet:peername(Socket)} of
                {{ok, {Addr, Port}}, {ok, {PAddr, PPort}}} ->
-                   ?INFO_MSG("Accepted connection ~s:~p -> ~s:~p",
-                             [ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)),
+                   CallMod = case is_frontend(Module) of
+                                 true -> ejabberd_frontend_socket;
+                                 false -> ejabberd_socket
+                             end,
+                   Receiver = case CallMod:start(strip_frontend(Module),
+                                                 gen_tcp, Socket, Opts) of
+                                  {ok, RecvPid} -> RecvPid;
+                                  _ -> none
+                              end,
+                   ?INFO_MSG("(~p) Accepted connection ~s:~p -> ~s:~p",
+                             [Receiver,
+                              ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)),
                               PPort, inet_parse:ntoa(Addr), Port]);
                _ ->
                    ok
            end,
-           CallMod = case is_frontend(Module) of
-                         true -> ejabberd_frontend_socket;
-                         false -> ejabberd_socket
-                     end,
-           CallMod:start(strip_frontend(Module), gen_tcp, Socket, Opts),
            accept(ListenSocket, Module, Opts, NewInterval);
        {error, Reason} ->
-           ?ERROR_MSG("(~w) Failed TCP accept: ~w",
-                       [ListenSocket, Reason]),
+           ?ERROR_MSG("(~w) Failed TCP accept: ~s",
+                       [ListenSocket, inet:format_error(Reason)]),
            accept(ListenSocket, Module, Opts, NewInterval)
     end.
 
index a5ee6a2428123ba00185c3ccb294a5833869e6e8..48c4e863c8cfc402cb0fd73a0c2074244851c75a 100644 (file)
@@ -36,8 +36,7 @@
         process_iq_reply/3, register_iq_handler/4,
         register_iq_handler/5, register_iq_response_handler/4,
         register_iq_response_handler/5, unregister_iq_handler/2,
-        unregister_iq_response_handler/2, refresh_iq_handlers/0,
-        bounce_resource_packet/3]).
+        unregister_iq_response_handler/2, bounce_resource_packet/3]).
 
 %% gen_server callbacks
 -export([init/1, handle_call/3, handle_cast/2,
@@ -90,8 +89,13 @@ process_iq(From, To, #iq{type = T, lang = Lang, sub_els = [El]} = Packet)
            Err = xmpp:err_service_unavailable(Txt, Lang),
            ejabberd_router:route_error(To, From, Packet, Err)
     end;
-process_iq(From, To, #iq{type = T} = Packet) when T == get; T == set ->
-    Err = xmpp:err_bad_request(),
+process_iq(From, To, #iq{type = T, lang = Lang, sub_els = SubEls} = Packet)
+  when T == get; T == set ->
+    Txt = case SubEls of
+             [] -> <<"No child elements found">>;
+             _ -> <<"Too many child elements">>
+         end,
+    Err = xmpp:err_bad_request(Txt, Lang),
     ejabberd_router:route_error(To, From, Packet, Err);
 process_iq(From, To, #iq{type = T} = Packet) when T == result; T == error ->
     process_iq_reply(From, To, Packet).
@@ -171,10 +175,6 @@ unregister_iq_response_handler(_Host, ID) ->
 unregister_iq_handler(Host, XMLNS) ->
     ejabberd_local ! {unregister_iq_handler, Host, XMLNS}.
 
--spec refresh_iq_handlers() -> any().
-refresh_iq_handlers() ->
-    ejabberd_local ! refresh_iq_handlers.
-
 -spec bounce_resource_packet(jid(), jid(), stanza()) -> stop.
 bounce_resource_packet(_From, #jid{lresource = <<"">>}, #presence{}) ->
     ok;
@@ -228,14 +228,12 @@ handle_info({register_iq_handler, Host, XMLNS, Module,
             Function},
            State) ->
     ets:insert(?IQTABLE, {{XMLNS, Host}, Module, Function}),
-    catch mod_disco:register_feature(Host, XMLNS),
     {noreply, State};
 handle_info({register_iq_handler, Host, XMLNS, Module,
             Function, Opts},
            State) ->
     ets:insert(?IQTABLE,
               {{XMLNS, Host}, Module, Function, Opts}),
-    catch mod_disco:register_feature(Host, XMLNS),
     {noreply, State};
 handle_info({unregister_iq_handler, Host, XMLNS},
            State) ->
@@ -245,19 +243,6 @@ handle_info({unregister_iq_handler, Host, XMLNS},
       _ -> ok
     end,
     ets:delete(?IQTABLE, {XMLNS, Host}),
-    catch mod_disco:unregister_feature(Host, XMLNS),
-    {noreply, State};
-handle_info(refresh_iq_handlers, State) ->
-    lists:foreach(fun (T) ->
-                         case T of
-                           {{XMLNS, Host}, _Module, _Function, _Opts} ->
-                               catch mod_disco:register_feature(Host, XMLNS);
-                           {{XMLNS, Host}, _Module, _Function} ->
-                               catch mod_disco:register_feature(Host, XMLNS);
-                           _ -> ok
-                         end
-                 end,
-                 ets:tab2list(?IQTABLE)),
     {noreply, State};
 handle_info({timeout, _TRef, ID}, State) ->
     process_iq_timeout(ID),
index 36d734004264ca982c4bbaf125a78a32c117fd5d..9e6cbd71583089d6dcdbdded9e2e402b41d2058e 100644 (file)
@@ -484,18 +484,17 @@ process_privacy(#privacy_query{lists = Lists,
     JID = jid:make(U, S),
     IQ = #iq{type = set, id = randoms:get_string(),
             from = JID, to = JID, sub_els = [PrivacyQuery]},
-    Txt = <<"No module is handling this query">>,
-    Error = {error, xmpp:err_feature_not_implemented(Txt, ?MYLANG)},
-    case mod_privacy:process_iq_set(Error, IQ, #userlist{}) of
-        {error, #stanza_error{reason = Reason}} = Err ->
+    case mod_privacy:process_iq(IQ) of
+       #iq{type = error} = ResIQ ->
+           #stanza_error{reason = Reason} = xmpp:get_error(ResIQ),
            if Reason == 'item-not-found', Lists == [],
               Active == undefined, Default /= undefined ->
                    %% Failed to set default list because there is no
                    %% list with such name. We shouldn't stop here.
                    {ok, State};
               true ->
-                   stop("Failed to write privacy: ~p", [Err])
-            end;
+                   stop("Failed to write privacy: ~p", [Reason])
+           end;
         _ ->
             {ok, State}
     end.
index 0a33e30ec26210475a8a41f1c1f3efb37b3e5e3a..ffa55806f63662886c64084adef4bd9ca9552614 100644 (file)
@@ -135,8 +135,8 @@ handle_call({starttls, TLSSocket}, _From, State) ->
        {ok, TLSData} ->
            {reply, ok,
                process_data(TLSData, NewState), ?HIBERNATE_TIMEOUT};
-       {error, _Reason} ->
-           {stop, normal, ok, NewState}
+       {error, _} = Err ->
+           {stop, normal, Err, NewState}
     end;
 handle_call({compress, Data}, _From,
            #state{socket = Socket, sock_mod = SockMod} =
index 5ce8a8afb89dacd103359ef8d7f91e072f010762..b1c9c9e48f3268b833ee0982f94fdea8cdab5a73 100644 (file)
@@ -76,8 +76,17 @@ start_link() ->
 
 -spec route(jid(), jid(), xmlel() | stanza()) -> ok.
 
-route(From, To, Packet) ->
-    case catch do_route(From, To, Packet) of
+route(#jid{} = From, #jid{} = To, #xmlel{} = El) ->
+    try xmpp:decode(El, ?NS_CLIENT, [ignore_els]) of
+       Pkt -> route(From, To, xmpp:set_from_to(Pkt, From, To))
+    catch _:{xmpp_codec, Why} ->
+           ?ERROR_MSG("failed to decode xml element ~p when "
+                      "routing from ~s to ~s: ~s",
+                      [El, jid:to_string(From), jid:to_string(To),
+                       xmpp:format_error(Why)])
+    end;
+route(#jid{} = From, #jid{} = To, Packet) ->
+    case catch do_route(From, To, xmpp:set_from_to(Packet, From, To)) of
        {'EXIT', Reason} ->
            ?ERROR_MSG("~p~nwhen processing: ~p",
                       [Reason, {From, To, Packet}]);
@@ -169,7 +178,7 @@ register_route(Domain, ServerHost, LocalHint) ->
                mnesia:transaction(F)
          end,
          if LocalHint == undefined ->
-                 ?INFO_MSG("Route registered: ~s", [LDomain]);
+                 ?DEBUG("Route registered: ~s", [LDomain]);
             true ->
                  ok
          end
@@ -218,7 +227,7 @@ unregister_route(Domain) ->
                    end,
                mnesia:transaction(F)
          end,
-         ?INFO_MSG("Route unregistered: ~s", [LDomain])
+         ?DEBUG("Route unregistered: ~s", [LDomain])
     end.
 
 -spec unregister_routes([binary()]) -> ok.
@@ -283,9 +292,9 @@ process_iq(From, To, #iq{} = IQ) ->
        true ->
            ejabberd_sm:process_iq(From, To, IQ)
     end;
-process_iq(From, To, El) ->
+process_iq(From, To, #xmlel{} = El) ->
     try xmpp:decode(El, ?NS_CLIENT, [ignore_els]) of
-       IQ -> process_iq(From, To, IQ)
+       IQ -> process_iq(From, To, xmpp:set_from_to(IQ, From, To))
     catch _:{xmpp_codec, Why} ->
            Type = xmpp:get_type(El),
            if Type == <<"get">>; Type == <<"set">> ->
@@ -409,70 +418,56 @@ code_change(_OldVsn, State, _Extra) ->
 %%--------------------------------------------------------------------
 %%% Internal functions
 %%--------------------------------------------------------------------
--spec do_route(jid(), jid(), xmlel() | xmpp_element()) -> any().
+-spec do_route(jid(), jid(), stanza()) -> any().
 do_route(OrigFrom, OrigTo, OrigPacket) ->
-    ?DEBUG("route~n\tfrom ~p~n\tto ~p~n\tpacket "
-          "~p~n",
-          [OrigFrom, OrigTo, OrigPacket]),
+    ?DEBUG("route:~n~s", [xmpp:pp(OrigPacket)]),
     case ejabberd_hooks:run_fold(filter_packet,
-                                {OrigFrom, OrigTo, OrigPacket}, [])
-       of
-      {From, To, Packet} ->
-         LDstDomain = To#jid.lserver,
-         case mnesia:dirty_read(route, LDstDomain) of
-           [] ->
-                 try xmpp:decode(Packet, ?NS_CLIENT, [ignore_els]) of
-                     Pkt ->
-                         ejabberd_s2s:route(From, To, Pkt)
-                 catch _:{xmpp_codec, Why} ->
-                         log_decoding_error(From, To, Packet, Why)
-                 end;
-           [R] ->
-               do_route(From, To, Packet, R);
-           Rs ->
-               Value = get_domain_balancing(From, To, LDstDomain),
-               case get_component_number(LDstDomain) of
-                 undefined ->
-                     case [R || R <- Rs, node(R#route.pid) == node()] of
-                       [] ->
-                           R = lists:nth(erlang:phash(Value, length(Rs)), Rs),
-                           do_route(From, To, Packet, R);
-                       LRs ->
-                           R = lists:nth(erlang:phash(Value, length(LRs)), LRs),
-                           do_route(From, To, Packet, R)
-                     end;
-                 _ ->
-                     SRs = lists:ukeysort(#route.local_hint, Rs),
-                     R = lists:nth(erlang:phash(Value, length(SRs)), SRs),
-                     do_route(From, To, Packet, R)
-               end
-         end;
-      drop -> ok
+                                {OrigFrom, OrigTo, OrigPacket}, []) of
+       {From, To, Packet} ->
+           LDstDomain = To#jid.lserver,
+           case mnesia:dirty_read(route, LDstDomain) of
+               [] ->
+                   ejabberd_s2s:route(From, To, Packet);
+               [Route] ->
+                   do_route(From, To, Packet, Route);
+               Routes ->
+                   balancing_route(From, To, Packet, Routes)
+           end;
+       drop ->
+           ok
     end.
 
--spec do_route(jid(), jid(), xmlel() | xmpp_element(), #route{}) -> any().
-do_route(From, To, Packet, #route{local_hint = LocalHint,
-                                 pid = Pid}) when is_pid(Pid) ->
-    try xmpp:decode(Packet, ?NS_CLIENT, [ignore_els]) of
-       Pkt ->
-           case LocalHint of
-               {apply, Module, Function} when node(Pid) == node() ->
-                   Module:Function(From, To, Pkt);
-               _ ->
-                   Pid ! {route, From, To, Pkt}
-           end
-    catch error:{xmpp_codec, Why} ->
-           log_decoding_error(From, To, Packet, Why)
+-spec do_route(jid(), jid(), stanza(), #route{}) -> any().
+do_route(From, To, Pkt, #route{local_hint = LocalHint,
+                              pid = Pid}) when is_pid(Pid) ->
+    case LocalHint of
+       {apply, Module, Function} when node(Pid) == node() ->
+           Module:Function(From, To, Pkt);
+       _ ->
+           Pid ! {route, From, To, Pkt}
     end;
-do_route(_From, _To, _Packet, _Route) ->
+do_route(_From, _To, _Pkt, _Route) ->
     drop.
 
--spec log_decoding_error(jid(), jid(), xmlel() | xmpp_element(), term()) -> ok.
-log_decoding_error(From, To, Packet, Reason) ->
-    ?ERROR_MSG("failed to decode xml element ~p when "
-              "routing from ~s to ~s: ~s",
-              [Packet, jid:to_string(From), jid:to_string(To),
-               xmpp:format_error(Reason)]).
+-spec balancing_route(jid(), jid(), stanza(), [#route{}]) -> any().
+balancing_route(From, To, Packet, Rs) ->
+    LDstDomain = To#jid.lserver,
+    Value = get_domain_balancing(From, To, LDstDomain),
+    case get_component_number(LDstDomain) of
+       undefined ->
+           case [R || R <- Rs, node(R#route.pid) == node()] of
+               [] ->
+                   R = lists:nth(erlang:phash(Value, length(Rs)), Rs),
+                   do_route(From, To, Packet, R);
+               LRs ->
+                   R = lists:nth(erlang:phash(Value, length(LRs)), LRs),
+                   do_route(From, To, Packet, R)
+           end;
+       _ ->
+           SRs = lists:ukeysort(#route.local_hint, Rs),
+           R = lists:nth(erlang:phash(Value, length(SRs)), SRs),
+           do_route(From, To, Packet, R)
+    end.
 
 -spec get_component_number(binary()) -> pos_integer() | undefined.
 get_component_number(LDomain) ->
index af4d6a66218058e80b232cdf54ba5211901d7757..d57c91ed22e61f13b68453dafb5eaf12e7f1d7e7 100644 (file)
@@ -257,7 +257,7 @@ tls_verify(LServer) ->
 -spec tls_enabled(binary()) -> boolean().
 tls_enabled(LServer) ->
     TLS = use_starttls(LServer),
-    TLS == true orelse TLS == optional.
+    TLS /= false.
 
 -spec zlib_enabled(binary()) -> boolean().
 zlib_enabled(LServer) ->
index a31af337e6af3f0f93ca1daacaf1bceb188ea33e..cca8438c6b549879a847927bffcd5682f28e8e34 100644 (file)
@@ -120,12 +120,8 @@ process_closed(State, _Reason) ->
 %%%===================================================================
 %%% xmpp_stream_in callbacks
 %%%===================================================================
-tls_options(#{tls_compression := Compression, server_host := LServer}) ->
-    Opts = case Compression of
-              false -> [compression_none];
-              true -> []
-          end,
-    ejabberd_s2s:tls_options(LServer, Opts).
+tls_options(#{tls_options := TLSOpts, server_host := LServer}) ->
+    ejabberd_s2s:tls_options(LServer, TLSOpts).
 
 tls_required(#{server_host := LServer}) ->
     ejabberd_s2s:tls_required(LServer).
@@ -164,16 +160,18 @@ handle_stream_established(State) ->
     set_idle_timeout(State#{established => true}).
 
 handle_auth_success(RServer, Mech, _AuthModule,
-                   #{socket := Socket, ip := IP,
+                   #{sockmod := SockMod,
+                     socket := Socket, ip := IP,
                      auth_domains := AuthDomains,
                      server_host := ServerHost,
                      lserver := LServer} = State) ->
     ?INFO_MSG("(~s) Accepted inbound s2s ~s authentication ~s -> ~s (~s)",
-             [ejabberd_socket:pp(Socket), Mech, RServer, LServer,
+             [SockMod:pp(Socket), Mech, RServer, LServer,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
     State1 = case ejabberd_s2s:allow_host(ServerHost, RServer) of
                 true ->
                     AuthDomains1 = sets:add_element(RServer, AuthDomains),
+                    change_shaper(State, RServer),
                     State#{auth_domains => AuthDomains1};
                 false ->
                     State
@@ -181,11 +179,12 @@ handle_auth_success(RServer, Mech, _AuthModule,
     ejabberd_hooks:run_fold(s2s_in_auth_result, ServerHost, State1, [true, RServer]).
 
 handle_auth_failure(RServer, Mech, Reason,
-                   #{socket := Socket, ip := IP,
+                   #{sockmod := SockMod,
+                     socket := Socket, ip := IP,
                      server_host := ServerHost,
                      lserver := LServer} = State) ->
     ?INFO_MSG("(~s) Failed inbound s2s ~s authentication ~s -> ~s (~s): ~s",
-             [ejabberd_socket:pp(Socket), Mech, RServer, LServer,
+             [SockMod:pp(Socket), Mech, RServer, LServer,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
     ejabberd_hooks:run_fold(s2s_in_auth_result,
                            ServerHost, State, [false, RServer]).
@@ -204,10 +203,13 @@ handle_authenticated_packet(Pkt, State) ->
            LServer = ejabberd_router:host_of_route(To#jid.lserver),
            State1 = ejabberd_hooks:run_fold(s2s_in_authenticated_packet,
                                             LServer, State, [Pkt]),
-           Pkt1 = ejabberd_hooks:run_fold(s2s_receive_packet, LServer,
-                                          Pkt, [State1]),
-           ejabberd_router:route(From, To, Pkt1),
-           State1;
+           {Pkt1, State2} = ejabberd_hooks:run_fold(s2s_receive_packet, LServer,
+                                                    {Pkt, State1}, []),
+           case Pkt1 of
+               drop -> ok;
+               _ -> ejabberd_router:route(From, To, Pkt1)
+           end,
+           State2;
        {error, Err} ->
            send(State, Err)
     end.
@@ -225,8 +227,24 @@ handle_send(Pkt, Result, #{server_host := LServer} = State) ->
 
 init([State, Opts]) ->
     Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none),
-    TLSCompression = proplists:get_bool(tls_compression, Opts),
-    State1 = State#{tls_compression => TLSCompression,
+    TLSOpts1 = lists:filter(
+                fun({certfile, _}) -> true;
+                   ({ciphers, _}) -> true;
+                   ({dhfile, _}) -> true;
+                   ({cafile, _}) -> true;
+                   (_) -> false
+                end, Opts),
+    TLSOpts2 = case lists:keyfind(protocol_options, 1, Opts) of
+                  false -> TLSOpts1;
+                  {_, OptString} ->
+                      ProtoOpts = str:join(OptString, <<$|>>),
+                      [{protocol_options, ProtoOpts}|TLSOpts1]
+              end,
+    TLSOpts3 = case proplists:get_bool(tls_compression, Opts) of
+                   false -> [compression_none | TLSOpts2];
+                   true -> TLSOpts2
+               end,
+    State1 = State#{tls_options => TLSOpts3,
                    auth_domains => sets:new(),
                    xmlns => ?NS_SERVER,
                    lang => ?MYLANG,
@@ -251,8 +269,16 @@ handle_cast(Msg, #{server_host := LServer} = State) ->
 handle_info(Info, #{server_host := LServer} = State) ->
     ejabberd_hooks:run_fold(s2s_in_handle_info, LServer, State, [Info]).
 
-terminate(_Reason, _State) ->
-    ok.
+terminate(Reason, #{auth_domains := AuthDomains}) ->
+    case Reason of
+       {process_limit, _} ->
+           sets:fold(
+             fun(Host, _) ->
+                     ejabberd_s2s:external_host_overloaded(Host)
+             end, ok, AuthDomains);
+       _ ->
+           ok
+    end.
 
 code_change(_OldVsn, State, _Extra) ->
     {ok, State}.
@@ -290,5 +316,11 @@ set_idle_timeout(#{server_host := LServer,
 set_idle_timeout(State) ->
     State.
 
+-spec change_shaper(state(), binary()) -> ok.
+change_shaper(#{shaper := ShaperName, server_host := ServerHost} = State,
+             RServer) ->
+    Shaper = acl:match_rule(ServerHost, ShaperName, jid:make(RServer)),
+    xmpp_stream_in:change_shaper(State, Shaper).
+
 opt_type(_) ->
     [].
index 6069c786ca465c5a71427d64d003d6a5cf08351f..5188d269b198d30cb2c8822006293153d2545084 100644 (file)
@@ -1,10 +1,23 @@
 %%%-------------------------------------------------------------------
-%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
-%%% @copyright (C) 2016, Evgeny Khramtsov
-%%% @doc
-%%%
-%%% @end
 %%% Created : 16 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2016   ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
 %%%-------------------------------------------------------------------
 -module(ejabberd_s2s_out).
 -behaviour(xmpp_stream_out).
 -export([opt_type/1, transform_options/1]).
 %% xmpp_stream_out callbacks
 -export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
+        connect_timeout/1, address_families/1, default_port/1,
+        dns_retries/1, dns_timeout/1,
         handle_auth_success/2, handle_auth_failure/3, handle_packet/2,
         handle_stream_end/2, handle_stream_downgraded/2,
-        handle_recv/3, handle_send/4, handle_cdata/2,
+        handle_recv/3, handle_send/3, handle_cdata/2,
         handle_stream_established/1, handle_timeout/1]).
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
         terminate/2, code_change/3]).
@@ -92,12 +107,12 @@ add_hooks() ->
 %%% Hooks
 %%%===================================================================
 process_auth_result(#{server := LServer, remote_server := RServer} = State,
-                   false) ->
+                   {false, Reason}) ->
     Delay = get_delay(),
-    ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: authentication failed;"
-             " bouncing for ~p seconds",
+    ?INFO_MSG("Failed to establish outbound s2s connection ~s -> ~s: "
+             "authentication failed; bouncing for ~p seconds",
              [LServer, RServer, Delay]),
-    State1 = State#{on_route => bounce},
+    State1 = State#{on_route => bounce, stop_reason => Reason},
     State2 = close(State1),
     State3 = bounce_queue(State2),
     xmpp_stream_out:set_timeout(State3, timer:seconds(Delay));
@@ -113,7 +128,7 @@ process_closed(#{server := LServer, remote_server := RServer,
 process_closed(#{server := LServer, remote_server := RServer} = State,
               Reason) ->
     Delay = get_delay(),
-    ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s; "
+    ?INFO_MSG("Failed to establish outbound s2s connection ~s -> ~s: ~s; "
              "bouncing for ~p seconds",
              [LServer, RServer, xmpp_stream_out:format_error(Reason), Delay]),
     State1 = State#{on_route => bounce},
@@ -146,23 +161,65 @@ tls_verify(#{server := LServer}) ->
 tls_enabled(#{server := LServer}) ->
     ejabberd_s2s:tls_enabled(LServer).
 
-handle_auth_success(Mech, #{socket := Socket, ip := IP,
+connect_timeout(#{server := LServer}) ->
+    ejabberd_config:get_option(
+      {outgoing_s2s_timeout, LServer},
+      fun(TimeOut) when is_integer(TimeOut), TimeOut > 0 ->
+              timer:seconds(TimeOut);
+         (infinity) ->
+              infinity
+      end, timer:seconds(10)).
+
+default_port(#{server := LServer}) ->
+    ejabberd_config:get_option(
+      {outgoing_s2s_port, LServer},
+      fun(I) when is_integer(I), I > 0, I =< 65536 -> I end,
+      5269).
+
+address_families(#{server := LServer}) ->
+    ejabberd_config:get_option(
+      {outgoing_s2s_families, LServer},
+      fun(Families) ->
+             lists:map(
+               fun(ipv4) -> inet;
+                  (ipv6) -> inet6
+               end, Families)
+      end, [inet, inet6]).
+
+dns_retries(#{server := LServer}) ->
+    ejabberd_config:get_option(
+      {s2s_dns_retries, LServer},
+      fun(I) when is_integer(I), I>=0 -> I end,
+      2).
+
+dns_timeout(#{server := LServer}) ->
+    ejabberd_config:get_option(
+      {s2s_dns_timeout, LServer},
+      fun(I) when is_integer(I), I>=0 ->
+             timer:seconds(I);
+        (infinity) ->
+             infinity
+      end, timer:seconds(10)).
+
+handle_auth_success(Mech, #{sockmod := SockMod,
+                           socket := Socket, ip := IP,
                            remote_server := RServer,
                            server := LServer} = State) ->
     ?INFO_MSG("(~s) Accepted outbound s2s ~s authentication ~s -> ~s (~s)",
-             [ejabberd_socket:pp(Socket), Mech, LServer, RServer,
+             [SockMod:pp(Socket), Mech, LServer, RServer,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
     ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State, [true]).
 
 handle_auth_failure(Mech, Reason,
-                   #{socket := Socket, ip := IP,
+                   #{sockmod := SockMod,
+                     socket := Socket, ip := IP,
                      remote_server := RServer,
                      server := LServer} = State) ->
     ?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s",
-             [ejabberd_socket:pp(Socket), Mech, LServer, RServer,
-              ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
-    State1 = State#{stop_reason => {auth, Reason}},
-    ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State1, [false]).
+             [SockMod:pp(Socket), Mech, LServer, RServer,
+              ejabberd_config:may_hide_data(jlib:ip_to_list(IP)),
+              xmpp_stream_out:format_error(Reason)]),
+    ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State, [{false, Reason}]).
 
 handle_packet(Pkt, #{server := LServer} = State) ->
     ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]).
@@ -185,9 +242,8 @@ handle_cdata(Data, #{server := LServer} = State) ->
 handle_recv(El, Pkt, #{server := LServer} = State) ->
     ejabberd_hooks:run_fold(s2s_out_handle_recv, LServer, State, [El, Pkt]).
 
-handle_send(Pkt, El, Data, #{server := LServer} = State) ->
-    ejabberd_hooks:run_fold(s2s_out_handle_send, LServer,
-                           State, [Pkt, El, Data]).
+handle_send(El, Pkt, #{server := LServer} = State) ->
+    ejabberd_hooks:run_fold(s2s_out_handle_send, LServer, State, [El, Pkt]).
 
 handle_timeout(#{on_route := Action} = State) ->
     case Action of
@@ -298,7 +354,7 @@ get_delay() ->
                 s2s_max_retry_delay,
                 fun(I) when is_integer(I), I > 0 -> I end,
                 300),
-    crypto:rand_uniform(0, MaxDelay).
+    crypto:rand_uniform(1, MaxDelay).
 
 -spec set_idle_timeout(state()) -> state().
 set_idle_timeout(#{on_route := send, server := LServer} = State) ->
@@ -316,6 +372,7 @@ transform_options({outgoing_s2s_options, Families, Timeout}, Opts) ->
                  "but it is better to fix your config: "
                  "use 'outgoing_s2s_timeout' and "
                  "'outgoing_s2s_families' instead.", []),
+    maybe_report_huge_timeout(outgoing_s2s_timeout, Timeout),
     [{outgoing_s2s_families, Families},
      {outgoing_s2s_timeout, Timeout}
      | Opts];
@@ -327,15 +384,27 @@ transform_options({s2s_dns_options, S2SDNSOpts}, AllOpts) ->
                  "'s2s_dns_retries' instead", []),
     lists:foldr(
       fun({timeout, T}, AccOpts) ->
+             maybe_report_huge_timeout(s2s_dns_timeout, T),
               [{s2s_dns_timeout, T}|AccOpts];
          ({retries, R}, AccOpts) ->
               [{s2s_dns_retries, R}|AccOpts];
          (_, AccOpts) ->
               AccOpts
       end, AllOpts, S2SDNSOpts);
+transform_options({Opt, T}, Opts)
+  when Opt == outgoing_s2s_timeout; Opt == s2s_dns_timeout ->
+    maybe_report_huge_timeout(Opt, T),
+    [{outgoing_s2s_timeout, T}|Opts];
 transform_options(Opt, Opts) ->
     [Opt|Opts].
 
+maybe_report_huge_timeout(Opt, T) when is_integer(T), T >= 1000 ->
+    ?WARNING_MSG("value '~p' of option '~p' is too big, "
+                "are you sure you have set seconds?",
+                [T, Opt]);
+maybe_report_huge_timeout(_, _) ->
+    ok.
+
 opt_type(outgoing_s2s_families) ->
     fun (Families) ->
            true = lists:all(fun (ipv4) -> true;
@@ -354,7 +423,10 @@ opt_type(outgoing_s2s_timeout) ->
 opt_type(s2s_dns_retries) ->
     fun (I) when is_integer(I), I >= 0 -> I end;
 opt_type(s2s_dns_timeout) ->
-    fun (I) when is_integer(I), I >= 0 -> I end;
+    fun (TimeOut) when is_integer(TimeOut), TimeOut > 0 ->
+           TimeOut;
+       (infinity) -> infinity
+    end;
 opt_type(s2s_max_retry_delay) ->
     fun (I) when is_integer(I), I > 0 -> I end;
 opt_type(_) ->
index 6ecd03a4c2008672ac88004c0e6fc88d22038d0c..d84de3db419b85e1cfe9889f60c6ecda2beea9c1 100644 (file)
@@ -85,7 +85,8 @@ init([State, Opts]) ->
                       dict:from_list([{global, Pass}])
               end,
     CheckFrom = gen_mod:get_opt(check_from, Opts,
-                               fun(Flag) when is_boolean(Flag) -> Flag end),
+                               fun(Flag) when is_boolean(Flag) -> Flag end,
+                               true),
     xmpp_stream_in:change_shaper(State, Shaper),
     State1 = State#{access => Access,
                    xmlns => ?NS_COMPONENT,
@@ -119,7 +120,7 @@ handle_stream_start(_StreamStart,
     end.
 
 get_password_fun(#{remote_server := RemoteServer,
-                  socket := Socket,
+                  socket := Socket, sockmod := SockMod,
                   ip := IP,
                   host_opts := HostOpts}) ->
     fun(_) ->
@@ -129,7 +130,7 @@ get_password_fun(#{remote_server := RemoteServer,
                error ->
                    ?ERROR_MSG("(~s) Domain ~s is unconfigured for "
                               "external component from ~s",
-                              [ejabberd_socket:pp(Socket), RemoteServer,
+                              [SockMod:pp(Socket), RemoteServer,
                                ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
                    {false, undefined}
            end
@@ -137,10 +138,11 @@ get_password_fun(#{remote_server := RemoteServer,
 
 handle_auth_success(_, Mech, _,
                    #{remote_server := RemoteServer, host_opts := HostOpts,
-                     socket := Socket, ip := IP} = State) ->
+                     socket := Socket, sockmod := SockMod,
+                     ip := IP} = State) ->
     ?INFO_MSG("(~s) Accepted external component ~s authentication "
              "for ~s from ~s",
-             [ejabberd_socket:pp(Socket), Mech, RemoteServer,
+             [SockMod:pp(Socket), Mech, RemoteServer,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
     lists:foreach(
       fun (H) ->
@@ -151,10 +153,11 @@ handle_auth_success(_, Mech, _,
 
 handle_auth_failure(_, Mech, Reason,
                    #{remote_server := RemoteServer,
+                     sockmod := SockMod,
                      socket := Socket, ip := IP} = State) ->
     ?ERROR_MSG("(~s) Failed external component ~s authentication "
               "for ~s from ~s: ~s",
-              [ejabberd_socket:pp(Socket), Mech, RemoteServer,
+              [SockMod:pp(Socket), Mech, RemoteServer,
                ejabberd_config:may_hide_data(jlib:ip_to_list(IP)),
                Reason]),
     State.
index a15d788d0abdf33aeba77a047ca19521d559e7f3..38b7ed15f6443878ad8a7a4dc8180c775159a67b 100644 (file)
@@ -83,7 +83,6 @@
 -include("xmpp.hrl").
 
 -include("ejabberd_commands.hrl").
--include("mod_privacy.hrl").
 -include("ejabberd_sm.hrl").
 
 -callback init() -> ok | {error, any()}.
@@ -576,24 +575,10 @@ do_route(From, To, Packet) ->
 %% or if there are no current sessions for the user.
 -spec is_privacy_allow(jid(), jid(), stanza()) -> boolean().
 is_privacy_allow(From, To, Packet) ->
-    User = To#jid.user,
-    Server = To#jid.server,
-    PrivacyList =
-       ejabberd_hooks:run_fold(privacy_get_user_list, Server,
-                               #userlist{}, [User, Server]),
-    is_privacy_allow(From, To, Packet, PrivacyList).
-
-%% Check if privacy rules allow this delivery
-%% Function copied from ejabberd_c2s.erl
--spec is_privacy_allow(jid(), jid(), stanza(), #userlist{}) -> boolean().
-is_privacy_allow(From, To, Packet, PrivacyList) ->
-    User = To#jid.user,
-    Server = To#jid.server,
-    allow ==
-      ejabberd_hooks:run_fold(privacy_check_packet, Server,
-                             allow,
-                             [User, Server, PrivacyList, {From, To, Packet},
-                              in]).
+    LServer = To#jid.server,
+    allow == ejabberd_hooks:run_fold(
+              privacy_check_packet, LServer, allow,
+              [To, xmpp:set_from_to(Packet, From, To), in]).
 
 -spec route_message(jid(), jid(), message(), message_type()) -> any().
 route_message(From, To, Packet, Type) ->
@@ -757,10 +742,14 @@ process_iq(From, To, #iq{type = T, lang = Lang, sub_els = [El]} = Packet)
            Err = xmpp:err_service_unavailable(Txt, Lang),
            ejabberd_router:route_error(To, From, Packet, Err)
     end;
-process_iq(From, To, #iq{type = T} = Packet) when T == get; T == set ->
-    Err = xmpp:err_bad_request(),
-    ejabberd_router:route_error(To, From, Packet, Err),
-    ok;
+process_iq(From, To, #iq{type = T, lang = Lang, sub_els = SubEls} = Packet)
+  when T == get; T == set ->
+    Txt = case SubEls of
+             [] -> <<"No child elements found">>;
+             _ -> <<"Too many child elements">>
+         end,
+    Err = xmpp:err_bad_request(Txt, Lang),
+    ejabberd_router:route_error(To, From, Packet, Err);
 process_iq(_From, _To, #iq{}) ->
     ok.
 
@@ -770,7 +759,7 @@ force_update_presence({LUser, LServer}) ->
     Mod = get_sm_backend(LServer),
     Ss = online(Mod:get_sessions(LUser, LServer)),
     lists:foreach(fun (#session{sid = {_, Pid}}) ->
-                         Pid ! {force_update_presence, LUser, LServer}
+                         Pid ! force_update_presence
                  end,
                  Ss).
 
index 4e523a7e525745531a1bd52b2b20be2cc1370cd8..83b7ae9b90f17d0da7001149d4c50f262b7a5189 100644 (file)
         connect/4,
         connect/5,
         starttls/2,
-        starttls/3,
         compress/1,
         compress/2,
         reset_stream/1,
+        send_element/2,
+        send_header/2,
+        send_trailer/1,
         send/2,
         send_xml/2,
         change_shaper/2,
                     [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore.
 -callback socket_type() -> xml_stream | independent | raw.
 
+-define(is_http_socket(S),
+       (S#socket_state.sockmod == ejabberd_bosh orelse
+        S#socket_state.sockmod == ejabberd_http_ws)).
+
 %%====================================================================
 %% API
 %%====================================================================
--spec start(atom(), sockmod(), socket(), [{atom(), any()}]) -> any().
-
+-spec start(atom(), sockmod(), socket(), [proplists:propery()])
+      -> {ok, pid() | independent} | {error, inet:posix() | any()}.
 start(Module, SockMod, Socket, Opts) ->
     case Module:socket_type() of
-      xml_stream ->
-         MaxStanzaSize = case lists:keysearch(max_stanza_size, 1,
-                                              Opts)
-                             of
-                           {value, {_, Size}} -> Size;
-                           _ -> infinity
-                         end,
-         {ReceiverMod, Receiver, RecRef} = case catch
-                                                  SockMod:custom_receiver(Socket)
-                                               of
-                                             {receiver, RecMod, RecPid} ->
-                                                 {RecMod, RecPid, RecMod};
-                                             _ ->
-                                                 RecPid =
-                                                     ejabberd_receiver:start(Socket,
-                                                                             SockMod,
-                                                                             none,
-                                                                             MaxStanzaSize),
-                                                 {ejabberd_receiver, RecPid,
-                                                  RecPid}
-                                           end,
-         SocketData = #socket_state{sockmod = SockMod,
-                                    socket = Socket, receiver = RecRef},
-         case Module:start({?MODULE, SocketData}, Opts) of
-           {ok, Pid} ->
-               case SockMod:controlling_process(Socket, Receiver) of
-                 ok -> ok;
-                 {error, _Reason} -> SockMod:close(Socket)
+       independent -> {ok, independent};
+       xml_stream ->
+           MaxStanzaSize = proplists:get_value(max_stanza_size, Opts, infinity),
+           {ReceiverMod, Receiver, RecRef} =
+               try SockMod:custom_receiver(Socket) of
+                   {receiver, RecMod, RecPid} ->
+                       {RecMod, RecPid, RecMod}
+               catch _:_ ->
+                       RecPid = ejabberd_receiver:start(
+                                  Socket, SockMod, none, MaxStanzaSize),
+                       {ejabberd_receiver, RecPid, RecPid}
                end,
-               ReceiverMod:become_controller(Receiver, Pid);
-           _ ->
-               SockMod:close(Socket),
-               case ReceiverMod of
-                 ejabberd_receiver -> ReceiverMod:close(Receiver);
-                 _ -> ok
-               end
-         end;
-      independent -> ok;
-      raw ->
-         case Module:start({SockMod, Socket}, Opts) of
-           {ok, Pid} ->
-               case SockMod:controlling_process(Socket, Pid) of
-                 ok -> ok;
-                 {error, _Reason} -> SockMod:close(Socket)
-               end;
-           {error, _Reason} -> SockMod:close(Socket)
-         end
+           SocketData = #socket_state{sockmod = SockMod,
+                                      socket = Socket, receiver = RecRef},
+           case Module:start({?MODULE, SocketData}, Opts) of
+               {ok, Pid} ->
+                   case SockMod:controlling_process(Socket, Receiver) of
+                       ok ->
+                           ReceiverMod:become_controller(Receiver, Pid),
+                           {ok, Receiver};
+                       Err ->
+                           SockMod:close(Socket),
+                           Err
+                   end;
+               Err ->
+                   SockMod:close(Socket),
+                   case ReceiverMod of
+                       ejabberd_receiver -> ReceiverMod:close(Receiver);
+                       _ -> ok
+                   end,
+                   Err
+           end;
+       raw ->
+           case Module:start({SockMod, Socket}, Opts) of
+               {ok, Pid} ->
+                   case SockMod:controlling_process(Socket, Pid) of
+                       ok ->
+                           {ok, Pid};
+                       {error, _} = Err ->
+                           SockMod:close(Socket),
+                           Err
+                   end;
+               Err ->
+                   SockMod:close(Socket),
+                   Err
+           end
     end.
 
 connect(Addr, Port, Opts) ->
@@ -156,35 +161,31 @@ connect(Addr, Port, Opts, Timeout, Owner) ->
       {error, _Reason} = Error -> Error
     end.
 
-starttls(SocketData, TLSOpts) ->
-    case fast_tls:tcp_to_tls(SocketData#socket_state.socket, TLSOpts) of
+starttls(#socket_state{socket = Socket,
+                      receiver = Receiver} = SocketData, TLSOpts) ->
+    case fast_tls:tcp_to_tls(Socket, TLSOpts) of
        {ok, TLSSocket} ->
-           ejabberd_receiver:starttls(SocketData#socket_state.receiver, TLSSocket),
-           {ok, SocketData#socket_state{socket = TLSSocket, sockmod = fast_tls}};
-       Err ->
-           ?ERROR_MSG("starttls failed: ~p", [Err]),
-           Err
-    end.
-
-starttls(SocketData, TLSOpts, Data) ->
-    case fast_tls:tcp_to_tls(SocketData#socket_state.socket, TLSOpts) of
-       {ok, TLSSocket} ->
-           ejabberd_receiver:starttls(SocketData#socket_state.receiver, TLSSocket),
-           send(SocketData, Data),
-           {ok, SocketData#socket_state{socket = TLSSocket, sockmod = fast_tls}};
-       Err ->
-           ?ERROR_MSG("starttls failed: ~p", [Err]),
+           case ejabberd_receiver:starttls(Receiver, TLSSocket) of
+               ok ->
+                   {ok, SocketData#socket_state{socket = TLSSocket,
+                                                sockmod = fast_tls}};
+               {error, _} = Err ->
+                   Err
+           end;
+       {error, _} = Err ->
            Err
     end.
 
 compress(SocketData) -> compress(SocketData, undefined).
 
 compress(SocketData, Data) ->
-    {ok, ZlibSocket} =
-       ejabberd_receiver:compress(SocketData#socket_state.receiver,
-                                  Data),
-    SocketData#socket_state{socket = ZlibSocket,
-                           sockmod = ezlib}.
+    case ejabberd_receiver:compress(SocketData#socket_state.receiver, Data) of
+       {ok, ZlibSocket} ->
+           {ok, SocketData#socket_state{socket = ZlibSocket, sockmod = ezlib}};
+       Err ->
+           ?ERROR_MSG("compress failed: ~p", [Err]),
+           Err
+    end.
 
 reset_stream(SocketData)
     when is_pid(SocketData#socket_state.receiver) ->
@@ -193,30 +194,41 @@ reset_stream(SocketData)
     when is_atom(SocketData#socket_state.receiver) ->
     (SocketData#socket_state.receiver):reset_stream(SocketData#socket_state.socket).
 
--spec send(socket_state(), iodata()) -> ok.
-
-send(SocketData, Data) ->
-    ?DEBUG("Send XML on stream = ~p", [Data]),
-    case catch (SocketData#socket_state.sockmod):send(
-            SocketData#socket_state.socket, Data) of
-        ok -> ok;
-       {error, timeout} ->
-           ?INFO_MSG("Timeout on ~p:send",[SocketData#socket_state.sockmod]),
-           {error, timeout};
-        Error ->
-           ?DEBUG("Error in ~p:send: ~p",[SocketData#socket_state.sockmod, Error]),
-           Error
+-spec send_element(socket_state(), fxml:xmlel()) -> ok | {error, inet:posix()}.
+send_element(SocketData, El) when ?is_http_socket(SocketData) ->
+    send_xml(SocketData, {xmlstreamelement, El});
+send_element(SocketData, El) ->
+    send(SocketData, fxml:element_to_binary(El)).
+
+-spec send_header(socket_state(), fxml:xmlel()) -> ok | {error, inet:posix()}.
+send_header(SocketData, El) when ?is_http_socket(SocketData) ->
+    send_xml(SocketData, {xmlstreamstart, El#xmlel.name, El#xmlel.attrs});
+send_header(SocketData, El) ->
+    send(SocketData, fxml:element_to_header(El)).
+
+-spec send_trailer(socket_state()) -> ok | {error, inet:posix()}.
+send_trailer(SocketData) when ?is_http_socket(SocketData) ->
+    send_xml(SocketData, {xmlstreamend, <<"stream:stream">>});
+send_trailer(SocketData) ->
+    send(SocketData, <<"</stream:stream>">>).
+
+-spec send(socket_state(), iodata()) -> ok | {error, inet:posix()}.
+send(#socket_state{sockmod = SockMod, socket = Socket} = SocketData, Data) ->
+    ?DEBUG("(~s) Send XML on stream = ~p", [pp(SocketData), Data]),
+    try SockMod:send(Socket, Data)
+    catch _:badarg ->
+           %% Some modules throw badarg exceptions on closed sockets
+           %% TODO: their code should be improved
+           {error, einval}
     end.
 
-%% Can only be called when in c2s StateData#state.xml_socket is true
-%% This function is used for HTTP bind
-%% sockmod=ejabberd_http_ws|ejabberd_http_bind or any custom module
--spec send_xml(socket_state(), fxml:xmlel()) -> any().
-
-send_xml(SocketData, Data) ->
-    catch
-      (SocketData#socket_state.sockmod):send_xml(SocketData#socket_state.socket,
-                                                Data).
+-spec send_xml(socket_state(),
+              {xmlstreamelement, fxml:xmlel()} |
+              {xmlstreamstart, binary(), [{binary(), binary()}]} |
+              {xmlstreamend, binary()} |
+              {xmlstreamraw, iodata()}) -> term().
+send_xml(SocketData, El) ->
+    (SocketData#socket_state.sockmod):send_xml(SocketData#socket_state.socket, El).
 
 change_shaper(SocketData, Shaper)
     when is_pid(SocketData#socket_state.receiver) ->
index 2e182ed1e4d3acdbf5aec270ec28733436076a26..15524ce1ef397b4befca5d4f292a14d711862dbe 100644 (file)
@@ -68,7 +68,7 @@ start(Host, Opts) ->
     ejabberd_hooks:add(disco_local_items, Host, ?MODULE, disco_items, 50),
     ejabberd_hooks:add(adhoc_local_items, Host, ?MODULE, announce_items, 50),
     ejabberd_hooks:add(adhoc_local_commands, Host, ?MODULE, announce_commands, 50),
-    ejabberd_hooks:add(user_available_hook, Host,
+    ejabberd_hooks:add(c2s_self_presence, Host,
                       ?MODULE, send_motd, 50),
     register(gen_mod:get_module_proc(Host, ?PROCNAME),
             proc_lib:spawn(?MODULE, init, [])).
@@ -123,7 +123,7 @@ stop(Host) ->
     ejabberd_hooks:delete(disco_local_items, Host, ?MODULE, disco_items, 50),
     ejabberd_hooks:delete(local_send_to_resource_hook, Host,
                          ?MODULE, announce, 50),
-    ejabberd_hooks:delete(user_available_hook, Host,
+    ejabberd_hooks:delete(c2s_self_presence, Host,
                          ?MODULE, send_motd, 50),
     Proc = gen_mod:get_module_proc(Host, ?PROCNAME),
     exit(whereis(Proc), stop),
@@ -733,8 +733,13 @@ announce_motd_delete(LServer) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     Mod:delete_motd(LServer).
 
--spec send_motd(jid()) -> ok | {atomic, any()}.
-send_motd(#jid{luser = LUser, lserver = LServer} = JID) when LUser /= <<>> ->
+-spec send_motd({presence(), ejabberd_c2s:state()}) -> {presence(), ejabberd_c2s:state()}.
+send_motd({_, #{pres_last := _}} = Acc) ->
+    %% This is just a presence update, nothing to do
+    Acc;
+send_motd({#presence{type = available},
+          #{jid := #jid{luser = LUser, lserver = LServer} = JID}} = Acc)
+  when LUser /= <<>> ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     case Mod:get_motd(LServer) of
        {ok, Packet} ->
@@ -754,9 +759,10 @@ send_motd(#jid{luser = LUser, lserver = LServer} = JID) when LUser /= <<>> ->
            end;
        error ->
            ok
-    end;
-send_motd(_) ->
-    ok.
+    end,
+    Acc;
+send_motd(Acc) ->
+    Acc.
 
 get_stored_motd(LServer) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
index 45564daf4ced02aa47b675b08b86ab14ad055996..5195bfb733a1cc12d4bb7259b0773a560ccf39ec 100644 (file)
@@ -29,8 +29,8 @@
 
 -protocol({xep, 191, '1.2'}).
 
--export([start/2, stop/1, process_iq/1, c2s_handle_info/2,
-        process_iq_set/3, process_iq_get/3, mod_opt_type/1, depends/2]).
+-export([start/2, stop/1, process_iq/1, mod_opt_type/1, depends/2,
+        disco_features/5]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
 start(Host, Opts) ->
     IQDisc = gen_mod:get_opt(iqdisc, Opts, fun gen_iq_handler:check_type/1,
                              one_queue),
-    ejabberd_hooks:add(privacy_iq_get, Host, ?MODULE,
-                      process_iq_get, 40),
-    ejabberd_hooks:add(privacy_iq_set, Host, ?MODULE,
-                      process_iq_set, 40),
-    ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
-                      c2s_handle_info, 40),
-    mod_disco:register_feature(Host, ?NS_BLOCKING),
+    ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 50),
     gen_iq_handler:add_iq_handler(ejabberd_sm, Host,
                                  ?NS_BLOCKING, ?MODULE, process_iq, IQDisc).
 
 stop(Host) ->
-    ejabberd_hooks:delete(privacy_iq_get, Host, ?MODULE,
-                         process_iq_get, 40),
-    ejabberd_hooks:delete(privacy_iq_set, Host, ?MODULE,
-                         process_iq_set, 40),
-    ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
-                         c2s_handle_info, 40),
-    mod_disco:unregister_feature(Host, ?NS_BLOCKING),
-    gen_iq_handler:remove_iq_handler(ejabberd_sm, Host,
-                                    ?NS_BLOCKING).
+    ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, disco_features, 50),
+    gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_BLOCKING).
 
 depends(_Host, _Opts) ->
     [{mod_privacy, hard}].
 
+-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty,
+                    jid(), jid(), binary(), binary()) ->
+                           {error, stanza_error()} | {result, [binary()]}.
+disco_features({error, Err}, _From, _To, _Node, _Lang) ->
+    {error, Err};
+disco_features(empty, _From, _To, <<"">>, _Lang) ->
+    {result, [?NS_BLOCKING]};
+disco_features({result, Feats}, _From, _To, <<"">>, _Lang) ->
+    {result, [?NS_BLOCKING|Feats]};
+disco_features(Acc, _From, _To, _Node, _Lang) ->
+    Acc.
+
 -spec process_iq(iq()) -> iq().
-process_iq(IQ) ->
-    xmpp:make_error(IQ, xmpp:err_not_allowed()).
+process_iq(#iq{type = Type,
+              from = #jid{luser = U, lserver = S},
+              to = #jid{luser = U, lserver = S}} = IQ) ->
+    case Type of
+       get -> process_iq_get(IQ);
+       set -> process_iq_set(IQ)
+    end;
+process_iq(#iq{lang = Lang} = IQ) ->
+    Txt = <<"Query to another users is forbidden">>,
+    xmpp:make_error(IQ, xmpp:err_forbidden(Txt, Lang)).
 
--spec process_iq_get({error, stanza_error()} | {result, xmpp_element() | undefined},
-                    iq(), userlist()) ->
-                           {error, stanza_error()} |
-                           {result, xmpp_element() | undefined}.
-process_iq_get(_, #iq{lang = Lang, from = From,
-                     sub_els = [#block_list{}]}, _) ->
-    #jid{luser = LUser, lserver = LServer} = From,
-    process_blocklist_get(LUser, LServer, Lang);
-process_iq_get(Acc, _, _) -> Acc.
+-spec process_iq_get(iq()) -> iq().
+process_iq_get(#iq{sub_els = [#block_list{}]} = IQ) ->
+    process_get(IQ);
+process_iq_get(#iq{lang = Lang} = IQ) ->
+    Txt = <<"No module is handling this query">>,
+    xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)).
 
--spec process_iq_set({error, stanza_error()} |
-                    {result, xmpp_element() | undefined} |
-                    {result, xmpp_element() | undefined, userlist()},
-                    iq(), userlist()) ->
-                           {error, stanza_error()} |
-                           {result, xmpp_element() | undefined} |
-                           {result, xmpp_element() | undefined, userlist()}.
-process_iq_set(Acc, #iq{from = From, lang = Lang, sub_els = [SubEl]}, _) ->
-    #jid{luser = LUser, lserver = LServer} = From,
+-spec process_iq_set(iq()) -> iq().
+process_iq_set(#iq{lang = Lang, sub_els = [SubEl]} = IQ) ->
     case SubEl of
        #block{items = []} ->
            Txt = <<"No items found in this query">>,
-           {error, xmpp:err_bad_request(Txt, Lang)};
+           xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang));
        #block{items = Items} ->
            JIDs = [jid:tolower(Item) || Item <- Items],
-           process_blocklist_block(LUser, LServer, JIDs, Lang);
+           process_block(IQ, JIDs);
        #unblock{items = []} ->
-           process_blocklist_unblock_all(LUser, LServer, Lang);
+           process_unblock_all(IQ);
        #unblock{items = Items} ->
            JIDs = [jid:tolower(Item) || Item <- Items],
-           process_blocklist_unblock(LUser, LServer, JIDs, Lang);
+           process_unblock(IQ, JIDs);
        _ ->
-           Acc
-    end;
-process_iq_set(Acc, _, _) -> Acc.
+           Txt = <<"No module is handling this query">>,
+           xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang))
+    end.
 
--spec list_to_blocklist_jids([listitem()], [ljid()]) -> [ljid()].
-list_to_blocklist_jids([], JIDs) -> JIDs;
-list_to_blocklist_jids([#listitem{type = jid,
-                                 action = deny, value = JID} =
-                           Item
-                       | Items],
-                      JIDs) ->
+-spec listitems_to_jids([listitem()], [ljid()]) -> [ljid()].
+listitems_to_jids([], JIDs) ->
+    JIDs;
+listitems_to_jids([#listitem{type = jid,
+                            action = deny, value = JID} = Item | Items],
+                 JIDs) ->
     Match = case Item of
                #listitem{match_all = true} ->
                    true;
@@ -130,20 +126,18 @@ list_to_blocklist_jids([#listitem{type = jid,
                _ ->
                    false
            end,
-    if Match -> list_to_blocklist_jids(Items, [JID | JIDs]);
-       true -> list_to_blocklist_jids(Items, JIDs)
+    if Match -> listitems_to_jids(Items, [JID | JIDs]);
+       true -> listitems_to_jids(Items, JIDs)
     end;
 % Skip Privacy List items than cannot be mapped to Blocking items
-list_to_blocklist_jids([_ | Items], JIDs) ->
-    list_to_blocklist_jids(Items, JIDs).
+listitems_to_jids([_ | Items], JIDs) ->
+    listitems_to_jids(Items, JIDs).
 
--spec process_blocklist_block(binary(), binary(), [ljid()],
-                             binary()) ->
-                                    {error, stanza_error()} |
-                                    {result, undefined, userlist()}.
-process_blocklist_block(LUser, LServer, JIDs, Lang) ->
+-spec process_block(iq(), [ljid()]) -> iq().
+process_block(#iq{from = #jid{luser = LUser, lserver = LServer},
+                 lang = Lang} = IQ, JIDs) ->
     Filter = fun (List) ->
-                    AlreadyBlocked = list_to_blocklist_jids(List, []),
+                    AlreadyBlocked = listitems_to_jids(List, []),
                     lists:foldr(fun (JID, List1) ->
                                         case lists:member(JID, AlreadyBlocked)
                                             of
@@ -161,23 +155,21 @@ process_blocklist_block(LUser, LServer, JIDs, Lang) ->
             end,
     Mod = db_mod(LServer),
     case Mod:process_blocklist_block(LUser, LServer, Filter) of
-      {atomic, {ok, Default, List}} ->
-         UserList = make_userlist(Default, List),
-         broadcast_list_update(LUser, LServer, Default,
-                               UserList),
-         broadcast_blocklist_event(LUser, LServer,
-                                   {block, [jid:make(J) || J <- JIDs]}),
-         {result, undefined, UserList};
-      _Err ->
+       {atomic, {ok, Default, List}} ->
+           UserList = make_userlist(Default, List),
+           broadcast_list_update(LUser, LServer, UserList, Default),
+           broadcast_event(LUser, LServer,
+                           #block{items = [jid:make(J) || J <- JIDs]}),
+           xmpp:make_iq_result(IQ);
+       _Err ->
            ?ERROR_MSG("Error processing ~p: ~p", [{LUser, LServer, JIDs}, _Err]),
-           {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)}
+           Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang),
+           xmpp:make_error(IQ, Err)
     end.
 
--spec process_blocklist_unblock_all(binary(), binary(), binary()) ->
-                                          {error, stanza_error()} |
-                                          {result, undefined} |
-                                          {result, undefined, userlist()}.
-process_blocklist_unblock_all(LUser, LServer, Lang) ->
+-spec process_unblock_all(iq()) -> iq().
+process_unblock_all(#iq{from = #jid{luser = LUser, lserver = LServer},
+                       lang = Lang} = IQ) ->
     Filter = fun (List) ->
                     lists:filter(fun (#listitem{action = A}) -> A =/= deny
                                  end,
@@ -185,23 +177,22 @@ process_blocklist_unblock_all(LUser, LServer, Lang) ->
             end,
     Mod = db_mod(LServer),
     case Mod:unblock_by_filter(LUser, LServer, Filter) of
-      {atomic, ok} -> {result, undefined};
-      {atomic, {ok, Default, List}} ->
-         UserList = make_userlist(Default, List),
-         broadcast_list_update(LUser, LServer, Default,
-                               UserList),
-         broadcast_blocklist_event(LUser, LServer, unblock_all),
-         {result, undefined, UserList};
-      _Err ->
+       {atomic, ok} ->
+           xmpp:make_iq_result(IQ);
+       {atomic, {ok, Default, List}} ->
+           UserList = make_userlist(Default, List),
+           broadcast_list_update(LUser, LServer, UserList, Default),
+           broadcast_event(LUser, LServer, #unblock{}),
+           xmpp:make_iq_result(IQ);
+       _Err ->
            ?ERROR_MSG("Error processing ~p: ~p", [{LUser, LServer}, _Err]),
-           {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)}
+           Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang),
+           xmpp:make_error(IQ, Err)
     end.
 
--spec process_blocklist_unblock(binary(), binary(), [ljid()], binary()) ->
-                                      {error, stanza_error()} |
-                                      {result, undefined} |
-                                      {result, undefined, userlist()}.
-process_blocklist_unblock(LUser, LServer, JIDs, Lang) ->
+-spec process_unblock(iq(), [ljid()]) -> iq().
+process_unblock(#iq{from = #jid{luser = LUser, lserver = LServer},
+                   lang = Lang} = IQ, JIDs) ->
     Filter = fun (List) ->
                     lists:filter(fun (#listitem{action = deny, type = jid,
                                                 value = JID}) ->
@@ -212,17 +203,18 @@ process_blocklist_unblock(LUser, LServer, JIDs, Lang) ->
             end,
     Mod = db_mod(LServer),
     case Mod:unblock_by_filter(LUser, LServer, Filter) of
-      {atomic, ok} -> {result, undefined};
-      {atomic, {ok, Default, List}} ->
-         UserList = make_userlist(Default, List),
-         broadcast_list_update(LUser, LServer, Default,
-                               UserList),
-         broadcast_blocklist_event(LUser, LServer,
-                                   {unblock, [jid:make(J) || J <- JIDs]}),
-         {result, undefined, UserList};
-      _Err ->
+       {atomic, ok} ->
+           xmpp:make_iq_result(IQ);
+       {atomic, {ok, Default, List}} ->
+           UserList = make_userlist(Default, List),
+           broadcast_list_update(LUser, LServer, UserList, Default),
+           broadcast_event(LUser, LServer,
+                           #unblock{items = [jid:make(J) || J <- JIDs]}),
+           xmpp:make_iq_result(IQ);
+       _Err ->
            ?ERROR_MSG("Error processing ~p: ~p", [{LUser, LServer, JIDs}, _Err]),
-           {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)}
+           Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang),
+           xmpp:make_error(IQ, Err)
     end.
 
 -spec make_userlist(binary(), [listitem()]) -> userlist().
@@ -230,52 +222,36 @@ make_userlist(Name, List) ->
     NeedDb = mod_privacy:is_list_needdb(List),
     #userlist{name = Name, list = List, needdb = NeedDb}.
 
--spec broadcast_list_update(binary(), binary(), binary(), userlist()) -> ok.
-broadcast_list_update(LUser, LServer, Name, UserList) ->
-    ejabberd_sm:route(jid:make(LUser, LServer, <<"">>),
-                      {privacy_list, UserList, Name}).
+-spec broadcast_list_update(binary(), binary(), userlist(), binary()) -> ok.
+broadcast_list_update(LUser, LServer, UserList, Name) ->
+    mod_privacy:push_list_update(jid:make(LUser, LServer), UserList, Name).
 
--spec broadcast_blocklist_event(binary(), binary(), block_event()) -> ok.
-broadcast_blocklist_event(LUser, LServer, Event) ->
-    JID = jid:make(LUser, LServer, <<"">>),
-    ejabberd_sm:route(JID, {blocking, Event}).
+-spec broadcast_event(binary(), binary(), block_event()) -> ok.
+broadcast_event(LUser, LServer, Event) ->
+    From = jid:make(LUser, LServer),
+    lists:foreach(
+      fun(R) ->
+             To = jid:replace_resource(From, R),
+             IQ = #iq{type = set, from = From, to = To,
+                      id = <<"push", (randoms:get_string())/binary>>,
+                      sub_els = [Event]},
+             ejabberd_router:route(From, To, IQ)
+      end, ejabberd_sm:get_user_resources(LUser, LServer)).
 
--spec process_blocklist_get(binary(), binary(), binary()) ->
-                                  {error, stanza_error()} | {result, block_list()}.
-process_blocklist_get(LUser, LServer, Lang) ->
+-spec process_get(iq()) -> iq().
+process_get(#iq{from = #jid{luser = LUser, lserver = LServer},
+               lang = Lang} = IQ) ->
     Mod = db_mod(LServer),
     case Mod:process_blocklist_get(LUser, LServer) of
-      error ->
-         {error, xmpp:err_internal_server_error(<<"Database failure">>, Lang)};
-      List ->
-         LJIDs = list_to_blocklist_jids(List, []),
-         Items = [jid:make(J) || J <- LJIDs],
-         {result, #block_list{items = Items}}
+       error ->
+           Err = xmpp:err_internal_server_error(<<"Database failure">>, Lang),
+           xmpp:make_error(IQ, Err);
+       List ->
+           LJIDs = listitems_to_jids(List, []),
+           Items = [jid:make(J) || J <- LJIDs],
+           xmpp:make_iq_result(IQ, #block_list{items = Items})
     end.
 
--spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
-c2s_handle_info(#{user := U, server := S, resource := R} = State,
-               {blocking, Action}) ->
-    SubEl = case Action of
-               {block, JIDs} ->
-                   #block{items = JIDs};
-               {unblock, JIDs} ->
-                   #unblock{items = JIDs};
-               unblock_all ->
-                   #unblock{}
-           end,
-    PushIQ = #iq{type = set,
-                from = jid:make(U, S),
-                to = jid:make(U, S, R),
-                id = <<"push", (randoms:get_string())/binary>>,
-                sub_els = [SubEl]},
-    %% No need to replace active privacy list here,
-    %% blocking pushes are always accompanied by
-    %% Privacy List pushes
-    {stop, ejabberd_c2s:send(State, PushIQ)};
-c2s_handle_info(State, _) ->
-    State.
-
 -spec db_mod(binary()) -> module().
 db_mod(LServer) ->
     DBType = gen_mod:db_type(LServer, mod_privacy),
index e2ec303051098bb2bd2200dce74d6d06ff683eb3..d5a623669970258362496204b6ba209fe8c5237e 100644 (file)
@@ -47,7 +47,7 @@
 -export([init/1, handle_info/2, handle_call/3,
         handle_cast/2, terminate/2, code_change/3]).
 
--export([user_send_packet/4, user_receive_packet/5,
+-export([user_send_packet/1, user_receive_packet/1,
         c2s_presence_in/2, mod_opt_type/1]).
 
 -include("ejabberd.hrl").
@@ -126,47 +126,51 @@ read_caps(Presence) ->
        Caps -> Caps
     end.
 
--spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza().
-user_send_packet(#presence{type = available} = Pkt,
-                _C2SState,
-                #jid{luser = User, lserver = Server} = From,
-                #jid{luser = User, lserver = Server,
-                     lresource = <<"">>}) ->
+-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+user_send_packet({#presence{type = available,
+                           from = #jid{luser = U, lserver = LServer} = From,
+                           to = #jid{luser = U, lserver = LServer,
+                                     lresource = <<"">>}} = Pkt,
+                 State}) ->
     case read_caps(Pkt) of
        nothing -> ok;
        #caps{version = Version, exts = Exts} = Caps ->
-           feature_request(Server, From, Caps, [Version | Exts])
+           feature_request(LServer, From, Caps, [Version | Exts])
     end,
-    Pkt;
-user_send_packet(Pkt, _C2SState, _From, _To) ->
-    Pkt.
-
--spec user_receive_packet(stanza(), ejabberd_c2s:state(),
-                         jid(), jid(), jid()) -> stanza().
-user_receive_packet(#presence{type = available} = Pkt,
-                   _C2SState,
-                   #jid{lserver = Server},
-                   From, _To) ->
-    IsRemote = not lists:member(From#jid.lserver, ?MYHOSTS),
+    {Pkt, State};
+user_send_packet(Acc) ->
+    Acc.
+
+-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+user_receive_packet({#presence{from = From, type = available} = Pkt,
+                    #{lserver := LServer} = State}) ->
+    IsRemote = not ejabberd_router:is_my_host(From#jid.lserver),
     if IsRemote ->
-          case read_caps(Pkt) of
-            nothing -> ok;
-            #caps{version = Version, exts = Exts} = Caps ->
-                feature_request(Server, From, Caps, [Version | Exts])
-          end;
+           case read_caps(Pkt) of
+               nothing -> ok;
+               #caps{version = Version, exts = Exts} = Caps ->
+                   feature_request(LServer, From, Caps, [Version | Exts])
+           end;
        true -> ok
     end,
-    Pkt;
-user_receive_packet(Pkt, _C2SState, _JID, _From, _To) ->
-    Pkt.
+    {Pkt, State};
+user_receive_packet(Acc) ->
+    Acc.
 
 -spec caps_stream_features([xmpp_element()], binary()) -> [xmpp_element()].
 
 caps_stream_features(Acc, MyHost) ->
-    case make_my_disco_hash(MyHost) of
-      <<"">> -> Acc;
-      Hash ->
-         [#caps{hash = <<"sha-1">>, node = ?EJABBERD_URI, version = Hash}|Acc]
+    case gen_mod:is_loaded(MyHost, ?MODULE) of
+       true ->
+           case make_my_disco_hash(MyHost) of
+               <<"">> ->
+                   Acc;
+               Hash ->
+                   [#caps{hash = <<"sha-1">>, node = ?EJABBERD_URI,
+                          version = Hash}|Acc]
+           end;
+       false ->
+           Acc
     end.
 
 -spec disco_features({error, stanza_error()} | {result, [binary()]} | empty,
@@ -238,7 +242,7 @@ c2s_presence_in(C2SState,
                            end;
                        _ -> gb_trees:delete_any(LFrom, Rs)
                    end,
-           C2SState#{caps_resources := NewRs};
+           C2SState#{caps_resources => NewRs};
        true ->
            C2SState
     end.
@@ -266,7 +270,7 @@ init([Host, Opts]) ->
                       user_receive_packet, 75),
     ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE,
                       caps_stream_features, 75),
-    ejabberd_hooks:add(s2s_stream_features, Host, ?MODULE,
+    ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE,
                       caps_stream_features, 75),
     ejabberd_hooks:add(disco_local_features, Host, ?MODULE,
                       disco_features, 75),
@@ -295,7 +299,7 @@ terminate(_Reason, State) ->
                          ?MODULE, user_receive_packet, 75),
     ejabberd_hooks:delete(c2s_post_auth_features, Host,
                          ?MODULE, caps_stream_features, 75),
-    ejabberd_hooks:delete(s2s_stream_features, Host,
+    ejabberd_hooks:delete(s2s_in_post_auth_features, Host,
                          ?MODULE, caps_stream_features, 75),
     ejabberd_hooks:delete(disco_local_features, Host,
                          ?MODULE, disco_features, 75),
index 5839a65b29c08e28a335c6be015ae66a4e12acc7..ea44aed956c07b545986835d35642bb900b1aa96 100644 (file)
@@ -35,8 +35,8 @@
 -export([start/2,
          stop/1]).
 
--export([user_send_packet/4, user_receive_packet/5,
-        iq_handler/1, remove_connection/4,
+-export([user_send_packet/1, user_receive_packet/1,
+        iq_handler/1, remove_connection/4, disco_features/5,
         is_carbon_copy/1, mod_opt_type/1, depends/2]).
 
 -include("ejabberd.hrl").
@@ -59,7 +59,7 @@ is_carbon_copy(_) ->
 
 start(Host, Opts) ->
     IQDisc = gen_mod:get_opt(iqdisc, Opts,fun gen_iq_handler:check_type/1, one_queue),
-    mod_disco:register_feature(Host, ?NS_CARBONS_2),
+    ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 50),
     Mod = gen_mod:db_mod(Host, ?MODULE),
     Mod:init(Host, Opts),
     ejabberd_hooks:add(unset_presence_hook,Host, ?MODULE, remove_connection, 10),
@@ -70,12 +70,24 @@ start(Host, Opts) ->
 
 stop(Host) ->
     gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_CARBONS_2),
-    mod_disco:unregister_feature(Host, ?NS_CARBONS_2),
+    ejabberd_hooks:delete(disco_local_features, Host, ?MODULE, disco_features, 50),
     %% why priority 89: to define clearly that we must run BEFORE mod_logdb hook (90)
     ejabberd_hooks:delete(user_send_packet,Host, ?MODULE, user_send_packet, 89),
     ejabberd_hooks:delete(user_receive_packet,Host, ?MODULE, user_receive_packet, 89),
     ejabberd_hooks:delete(unset_presence_hook,Host, ?MODULE, remove_connection, 10).
 
+-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty,
+                    jid(), jid(), binary(), binary()) ->
+                           {error, stanza_error()} | {result, [binary()]}.
+disco_features({error, Err}, _From, _To, _Node, _Lang) ->
+    {error, Err};
+disco_features(empty, _From, _To, <<"">>, _Lang) ->
+    {result, [?NS_CARBONS_2]};
+disco_features({result, Feats}, _From, _To, <<"">>, _Lang) ->
+    {result, [?NS_CARBONS_2|Feats]};
+disco_features(Acc, _From, _To, _Node, _Lang) ->
+    Acc.
+
 -spec iq_handler(iq()) -> iq().
 iq_handler(#iq{type = set, lang = Lang, from = From,
               sub_els = [El]} = IQ) when is_record(El, carbons_enable);
@@ -105,16 +117,24 @@ iq_handler(#iq{type = get, lang = Lang} = IQ)->
     Txt = <<"Value 'get' of 'type' attribute is not allowed">>,
     xmpp:make_error(IQ, xmpp:err_not_allowed(Txt, Lang)).
 
--spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) ->
-                             stanza() | {stop, stanza()}.
-user_send_packet(Packet, _C2SState, From, To) ->
-    check_and_forward(From, To, Packet, sent).
+-spec user_send_packet({stanza(), ejabberd_c2s:state()})
+      -> {stanza(), ejabberd_c2s:state()} | {stop, {stanza(), ejabberd_c2s:state()}}.
+user_send_packet({Packet, C2SState}) ->
+    From = xmpp:get_from(Packet),
+    To = xmpp:get_to(Packet),
+    case check_and_forward(From, To, Packet, sent) of
+       {stop, Pkt} -> {stop, {Pkt, C2SState}};
+       Pkt -> {Pkt, C2SState}
+    end.
 
--spec user_receive_packet(stanza(), ejabberd_c2s:state(),
-                         jid(), jid(), jid()) ->
-                                stanza() | {stop, stanza()}.
-user_receive_packet(Packet, _C2SState, JID, _From, To) ->
-    check_and_forward(JID, To, Packet, received).
+-spec user_receive_packet({stanza(), ejabberd_c2s:state()})
+      -> {stanza(), ejabberd_c2s:state()} | {stop, {stanza(), ejabberd_c2s:state()}}.
+user_receive_packet({Packet, #{jid := JID} = C2SState}) ->
+    To = xmpp:get_to(Packet),
+    case check_and_forward(JID, To, Packet, received) of
+       {stop, Pkt} -> {stop, {Pkt, C2SState}};
+       Pkt -> {Pkt, C2SState}
+    end.
 
 % Modified from original version:
 %    - registered to the user_send_packet hook, to be called only once even for multicast
index a838088fccf7fa8dd50a672cbe64c4d521331926..175929a570ffd1ed2536923152390b482d68d3a3 100644 (file)
 -export([start/2, stop/1, mod_opt_type/1, depends/2]).
 
 %% ejabberd_hooks callbacks.
--export([filter_presence/4, filter_chat_states/4, filter_pep/4, filter_other/4,
-        flush_queue/3, add_stream_feature/2]).
+-export([filter_presence/1, filter_chat_states/1,
+        filter_pep/1, filter_other/1,
+        c2s_stream_started/2, add_stream_feature/2,
+        c2s_copy_session/2, c2s_authenticated_packet/2,
+        c2s_session_resumed/1]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
 -define(CSI_QUEUE_MAX, 100).
 
 -type csi_type() :: presence | chatstate | {pep, binary()}.
--type csi_key() :: {ljid(), csi_type()}.
--type csi_stanza() :: {csi_key(), erlang:timestamp(), xmlel()}.
--type csi_queue() :: [csi_stanza()].
+-type csi_queue() :: {non_neg_integer(), non_neg_integer(), map()}.
+-type csi_timestamp() :: {non_neg_integer(), erlang:timestamp()}.
+-type c2s_state() :: ejabberd_c2s:state().
+-type filter_acc() :: {stanza() | drop, c2s_state()}.
 
 %%--------------------------------------------------------------------
 %% gen_mod callbacks.
@@ -68,27 +72,33 @@ start(Host, Opts) ->
                        fun(B) when is_boolean(B) -> B end,
                        true),
     if QueuePresence; QueueChatStates; QueuePEP ->
+          ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE,
+                             c2s_stream_started, 50),
           ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE,
                              add_stream_feature, 50),
+          ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE,
+                             c2s_authenticated_packet, 50),
+          ejabberd_hooks:add(c2s_copy_session, Host, ?MODULE,
+                             c2s_copy_session, 50),
+          ejabberd_hooks:add(c2s_session_resumed, Host, ?MODULE,
+                             c2s_session_resumed, 50),
           if QueuePresence ->
-                 ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE,
+                 ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE,
                                     filter_presence, 50);
              true -> ok
           end,
           if QueueChatStates ->
-                 ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE,
+                 ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE,
                                     filter_chat_states, 50);
              true -> ok
           end,
           if QueuePEP ->
-                 ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE,
+                 ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE,
                                     filter_pep, 50);
              true -> ok
           end,
-          ejabberd_hooks:add(csi_filter_stanza, Host, ?MODULE,
-                             filter_other, 100),
-          ejabberd_hooks:add(csi_flush_queue, Host, ?MODULE,
-                             flush_queue, 50);
+          ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE,
+                             filter_other, 75);
        true -> ok
     end.
 
@@ -108,27 +118,33 @@ stop(Host) ->
                               fun(B) when is_boolean(B) -> B end,
                               true),
     if QueuePresence; QueueChatStates; QueuePEP ->
+          ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE,
+                                c2s_stream_started, 50),
           ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE,
                                 add_stream_feature, 50),
+          ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE,
+                                c2s_authenticated_packet, 50),
+          ejabberd_hooks:delete(c2s_copy_session, Host, ?MODULE,
+                                c2s_copy_session, 50),
+          ejabberd_hooks:delete(c2s_session_resumed, Host, ?MODULE,
+                                c2s_session_resumed, 50),
           if QueuePresence ->
-                 ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE,
+                 ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE,
                                        filter_presence, 50);
              true -> ok
           end,
           if QueueChatStates ->
-                 ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE,
+                 ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE,
                                        filter_chat_states, 50);
              true -> ok
           end,
           if QueuePEP ->
-                 ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE,
+                 ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE,
                                        filter_pep, 50);
              true -> ok
           end,
-          ejabberd_hooks:delete(csi_filter_stanza, Host, ?MODULE,
-                                filter_other, 100),
-          ejabberd_hooks:delete(csi_flush_queue, Host, ?MODULE,
-                                flush_queue, 50);
+          ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE,
+                                filter_other, 75);
        true -> ok
     end.
 
@@ -150,29 +166,46 @@ depends(_Host, _Opts) ->
 %%--------------------------------------------------------------------
 %% ejabberd_hooks callbacks.
 %%--------------------------------------------------------------------
-
--spec filter_presence({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza())
-      -> {ejabberd_c2s:state(), [stanza()]} |
-        {stop, {ejabberd_c2s:state(), [stanza()]}}.
-
-filter_presence({C2SState, _OutStanzas} = Acc, Host, To,
-               #presence{type = Type} = Stanza) ->
-    if Type == available; Type == unavailable ->
-           ?DEBUG("Got availability presence stanza for ~s",
-                  [jid:to_string(To)]),
-           queue_add(presence, Stanza, Host, C2SState);
-       true ->
-           Acc
-    end;
-filter_presence(Acc, _Host, _To, _Stanza) -> Acc.
-
--spec filter_chat_states({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza())
-      -> {ejabberd_c2s:state(), [stanza()]} |
-        {stop, {ejabberd_c2s:state(), [stanza()]}}.
-
-filter_chat_states({C2SState, _OutStanzas} = Acc, Host, To,
-                  #message{from = From} = Stanza) ->
-    case xmpp_util:is_standalone_chat_state(Stanza) of
+-spec c2s_stream_started(c2s_state(), stream_start()) -> c2s_state().
+c2s_stream_started(State, _) ->
+    State#{csi_state => active, csi_queue => queue_new()}.
+
+-spec c2s_authenticated_packet(c2s_state(), xmpp_element()) -> c2s_state().
+c2s_authenticated_packet(C2SState, #csi{type = active}) ->
+    C2SState1 = C2SState#{csi_state => active},
+    flush_queue(C2SState1);
+c2s_authenticated_packet(C2SState, #csi{type = inactive}) ->
+    C2SState#{csi_state => inactive};
+c2s_authenticated_packet(C2SState, _) ->
+    C2SState.
+
+-spec c2s_copy_session(c2s_state(), c2s_state()) -> c2s_state().
+c2s_copy_session(C2SState, #{csi_state := State, csi_queue := Q}) ->
+    C2SState#{csi_state => State, csi_queue => Q};
+c2s_copy_session(C2SState, _) ->
+    C2SState.
+
+-spec c2s_session_resumed(c2s_state()) -> c2s_state().
+c2s_session_resumed(C2SState) ->
+    flush_queue(C2SState).
+
+-spec filter_presence(filter_acc()) -> filter_acc().
+filter_presence({#presence{meta = #{csi_resend := true}}, _} = Acc) ->
+    Acc;
+filter_presence({#presence{to = To, type = Type} = Pres,
+                #{csi_state := inactive} = C2SState})
+  when Type == available; Type == unavailable ->
+    ?DEBUG("Got availability presence stanza for ~s", [jid:to_string(To)]),
+    enqueue_stanza(presence, Pres, C2SState);
+filter_presence(Acc) ->
+    Acc.
+
+-spec filter_chat_states(filter_acc()) -> filter_acc().
+filter_chat_states({#message{meta = #{csi_resend := true}}, _} = Acc) ->
+    Acc;
+filter_chat_states({#message{from = From, to = To} = Msg,
+                   #{csi_state := inactive} = C2SState} = Acc) ->
+    case xmpp_util:is_standalone_chat_state(Msg) of
        true ->
            case {From, To} of
                {#jid{luser = U, lserver = S}, #jid{luser = U, lserver = S}} ->
@@ -181,105 +214,109 @@ filter_chat_states({C2SState, _OutStanzas} = Acc, Host, To,
                    %% conversations across clients.
                    Acc;
                _ ->
-               ?DEBUG("Got standalone chat state notification for ~s",
-                      [jid:to_string(To)]),
-                   queue_add(chatstate, Stanza, Host, C2SState)
+                   ?DEBUG("Got standalone chat state notification for ~s",
+                          [jid:to_string(To)]),
+                   enqueue_stanza(chatstate, Msg, C2SState)
            end;
        false ->
            Acc
     end;
-filter_chat_states(Acc, _Host, _To, _Stanza) -> Acc.
-
--spec filter_pep({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza())
-      -> {ejabberd_c2s:state(), [stanza()]} |
-        {stop, {ejabberd_c2s:state(), [stanza()]}}.
-
-filter_pep({C2SState, _OutStanzas} = Acc, Host, To, #message{} = Stanza) ->
-    case get_pep_node(Stanza) of
+filter_chat_states(Acc) ->
+    Acc.
+
+-spec filter_pep(filter_acc()) -> filter_acc().
+filter_pep({#message{meta = #{csi_resend := true}}, _} = Acc) ->
+    Acc;
+filter_pep({#message{to = To} = Msg,
+           #{csi_state := inactive} = C2SState} = Acc) ->
+    case get_pep_node(Msg) of
        undefined ->
            Acc;
        Node ->
            ?DEBUG("Got PEP notification for ~s", [jid:to_string(To)]),
-           queue_add({pep, Node}, Stanza, Host, C2SState)
+           enqueue_stanza({pep, Node}, Msg, C2SState)
     end;
-filter_pep(Acc, _Host, _To, _Stanza) -> Acc.
-
--spec filter_other({ejabberd_c2s:state(), [stanza()]}, binary(), jid(), stanza())
-      -> {ejabberd_c2s:state(), [stanza()]}.
-
-filter_other({C2SState, _OutStanzas}, Host, To, Stanza) ->
-    ?DEBUG("Won't add stanza for ~s to CSI queue", [jid:to_string(To)]),
-    queue_take(Stanza, Host, C2SState).
+filter_pep(Acc) ->
+    Acc.
 
--spec flush_queue({ejabberd_c2s:state(), [stanza()]}, binary(), jid())
-      -> {ejabberd_c2s:state(), [stanza()]}.
-
-flush_queue({C2SState, _OutStanzas}, Host, JID) ->
-    ?DEBUG("Going to flush CSI queue of ~s", [jid:to_string(JID)]),
-    Queue = get_queue(C2SState),
-    NewState = set_queue([], C2SState),
-    {NewState, get_stanzas(Queue, Host)}.
-
--spec add_stream_feature([stanza()], binary) -> [stanza()].
+-spec filter_other(filter_acc()) -> filter_acc().
+filter_other({Stanza, #{jid := JID} = C2SState} = Acc) when ?is_stanza(Stanza) ->
+    case xmpp:get_meta(Stanza) of
+       #{csi_resend := true} ->
+           Acc;
+       _ ->
+           ?DEBUG("Won't add stanza for ~s to CSI queue", [jid:to_string(JID)]),
+           From = xmpp:get_from(Stanza),
+           C2SState1 = dequeue_sender(From, C2SState),
+           {Stanza, C2SState1}
+    end;
+filter_other(Acc) ->
+    Acc.
 
-add_stream_feature(Features, _Host) ->
-    [#feature_csi{xmlns = <<"urn:xmpp:csi:0">>} | Features].
+-spec add_stream_feature([xmpp_element()], binary) -> [xmpp_element()].
+add_stream_feature(Features, Host) ->
+    case gen_mod:is_loaded(Host, ?MODULE) of
+       true ->
+           [#feature_csi{xmlns = <<"urn:xmpp:csi:0">>} | Features];
+       false ->
+           Features
+    end.
 
 %%--------------------------------------------------------------------
 %% Internal functions.
 %%--------------------------------------------------------------------
-
--spec queue_add(csi_type(), stanza(), binary(), term())
-      -> {stop, {term(), [stanza()]}}.
-
-queue_add(Type, Stanza, Host, C2SState) ->
-    case get_queue(C2SState) of
-      Queue when length(Queue) >= ?CSI_QUEUE_MAX ->
-         ?DEBUG("CSI queue too large, going to flush it", []),
-         NewState = set_queue([], C2SState),
-         {stop, {NewState, get_stanzas(Queue, Host) ++ [Stanza]}};
-      Queue ->
-         ?DEBUG("Adding stanza to CSI queue", []),
-         From = xmpp:get_from(Stanza),
-         Key = {jid:tolower(From), Type},
-         Entry = {Key, p1_time_compat:timestamp(), Stanza},
-         NewQueue = lists:keystore(Key, 1, Queue, Entry),
-         NewState = set_queue(NewQueue, C2SState),
-         {stop, {NewState, []}}
+-spec enqueue_stanza(csi_type(), stanza(), c2s_state()) -> filter_acc().
+enqueue_stanza(Type, Stanza, #{csi_state := inactive,
+                              csi_queue := Q} = C2SState) ->
+    case queue_len(Q) >= ?CSI_QUEUE_MAX of
+       true ->
+           ?DEBUG("CSI queue too large, going to flush it", []),
+           C2SState1 = flush_queue(C2SState),
+           enqueue_stanza(Type, Stanza, C2SState1);
+       false ->
+           #jid{luser = U, lserver = S} = xmpp:get_from(Stanza),
+           Q1 = queue_in({U, S}, Type, Stanza, Q),
+           {stop, {drop, C2SState#{csi_queue => Q1}}}
+    end;
+enqueue_stanza(_Type, Stanza, State) ->
+    {Stanza, State}.
+
+-spec dequeue_sender(jid(), c2s_state()) -> c2s_state().
+dequeue_sender(#jid{luser = U, lserver = S},
+              #{csi_queue := Q, jid := JID} = C2SState) ->
+    ?DEBUG("Flushing packets of ~s@~s from CSI queue of ~s",
+          [U, S, jid:to_string(JID)]),
+    case queue_take({U, S}, Q) of
+       {Stanzas, Q1} ->
+           C2SState1 = flush_stanzas(C2SState, Stanzas),
+           C2SState1#{csi_queue => Q1};
+       error ->
+           C2SState
     end.
 
--spec queue_take(stanza(), binary(), term()) -> {term(), [stanza()]}.
-
-queue_take(Stanza, Host, C2SState) ->
-    From = xmpp:get_from(Stanza),
-    {LUser, LServer, _LResource} = jid:tolower(From),
-    {Selected, Rest} = lists:partition(
-                        fun({{{U, S, _R}, _Type}, _Time, _Stanza}) ->
-                                U == LUser andalso S == LServer
-                        end, get_queue(C2SState)),
-    NewState = set_queue(Rest, C2SState),
-    {NewState, get_stanzas(Selected, Host) ++ [Stanza]}.
-
--spec set_queue(csi_queue(), ejabberd_c2s:state()) -> ejabberd_c2s:state().
-
-set_queue(Queue, C2SState) ->
-    C2SState#{csi_queue => Queue}.
-
--spec get_queue(ejabberd_c2s:state()) -> csi_queue().
-
-get_queue(C2SState) ->
-    maps:get(csi_queue, C2SState, []).
-
--spec get_stanzas(csi_queue(), binary()) -> [stanza()].
-
-get_stanzas(Queue, Host) ->
-    lists:map(fun({_Key, Time, Stanza}) ->
-                     xmpp_util:add_delay_info(Stanza, jid:make(Host), Time,
-                                              <<"Client Inactive">>)
-             end, Queue).
+-spec flush_queue(c2s_state()) -> c2s_state().
+flush_queue(#{csi_queue := Q, jid := JID} = C2SState) ->
+    ?DEBUG("Flushing CSI queue of ~s", [jid:to_string(JID)]),
+    C2SState1 = flush_stanzas(C2SState, queue_to_list(Q)),
+    C2SState1#{csi_queue => queue_new()}.
+
+-spec flush_stanzas(c2s_state(),
+                   [{csi_type(), csi_timestamp(), stanza()}]) -> c2s_state().
+flush_stanzas(#{lserver := LServer} = C2SState, Elems) ->
+    lists:foldl(
+      fun({_Type, Time, Stanza}, AccState) ->
+             Stanza1 = add_delay_info(Stanza, LServer, Time),
+             ejabberd_c2s:send(AccState, Stanza1)
+      end, C2SState, Elems).
+
+-spec add_delay_info(stanza(), binary(), csi_timestamp()) -> stanza().
+add_delay_info(Stanza, LServer, {_Seq, TimeStamp}) ->
+    Stanza1 = xmpp_util:add_delay_info(
+               Stanza, jid:make(LServer), TimeStamp,
+               <<"Client Inactive">>),
+    xmpp:put_meta(Stanza1, csi_resend, true).
 
 -spec get_pep_node(message()) -> binary() | undefined.
-
 get_pep_node(#message{from = #jid{luser = <<>>}}) ->
     %% It's not PEP.
     undefined;
@@ -290,3 +327,53 @@ get_pep_node(#message{} = Msg) ->
        _ ->
            undefined
     end.
+
+%%--------------------------------------------------------------------
+%% Queue interface
+%%--------------------------------------------------------------------
+-spec queue_new() -> csi_queue().
+queue_new() ->
+    {0, 0, #{}}.
+
+-spec queue_in(term(), term(), term(), csi_queue()) -> csi_queue().
+queue_in(Key, Type, Val, {N, Seq, Q}) ->
+    Seq1 = Seq + 1,
+    Time = {Seq1, p1_time_compat:timestamp()},
+    try maps:get(Key, Q) of
+       TypeVals ->
+           case lists:keymember(Type, 1, TypeVals) of
+               true ->
+                   TypeVals1 = lists:keyreplace(
+                                 Type, 1, TypeVals, {Type, Time, Val}),
+                   Q1 = maps:put(Key, TypeVals1, Q),
+                   {N, Seq1, Q1};
+               false ->
+                   TypeVals1 = [{Type, Time, Val}|TypeVals],
+                   Q1 = maps:put(Key, TypeVals1, Q),
+                   {N + 1, Seq1, Q1}
+           end
+    catch _:{badkey, _} ->
+           Q1 = maps:put(Key, [{Type, Time, Val}], Q),
+           {N + 1, Seq1, Q1}
+    end.
+
+-spec queue_take(term(), csi_queue()) -> {list(), csi_queue()} | error.
+queue_take(Key, {N, Seq, Q}) ->
+    case maps:take(Key, Q) of
+       {TypeVals, Q1} ->
+           {lists:keysort(2, TypeVals), {N-length(TypeVals), Seq, Q1}};
+       error ->
+           error
+    end.
+
+-spec queue_len(csi_queue()) -> non_neg_integer().
+queue_len({N, _, _}) ->
+    N.
+
+-spec queue_to_list(csi_queue()) -> [term()].
+queue_to_list({_, _, Q}) ->
+    TypeVals = maps:fold(
+                fun(_, Vals, Acc) ->
+                        Vals ++ Acc
+                end, [], Q),
+    lists:keysort(2, TypeVals).
index 953d1da10f1a70182fc60e02c8402aea81a8b4f6..54720f7167b50945c93353495e6661e1952f34d5 100644 (file)
@@ -37,9 +37,7 @@
         get_local_features/5, get_local_services/5,
         process_sm_iq_items/1, process_sm_iq_info/1,
         get_sm_identity/5, get_sm_features/5, get_sm_items/5,
-        get_info/5, register_feature/2, unregister_feature/2,
-        register_extra_domain/2, unregister_extra_domain/2,
-        transform_module_options/1, mod_opt_type/1, depends/2]).
+        get_info/5, transform_module_options/1, mod_opt_type/1, depends/2]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
 -include_lib("stdlib/include/ms_transform.hrl").
 -include("mod_roster.hrl").
 
+-type features_acc() :: {error, stanza_error()} | {result, [binary()]} | empty.
+-type items_acc() :: {error, stanza_error()} | {result, [disco_item()]} | empty.
+
 start(Host, Opts) ->
-    ejabberd_local:refresh_iq_handlers(),
     IQDisc = gen_mod:get_opt(iqdisc, Opts, fun gen_iq_handler:check_type/1,
                              one_queue),
     gen_iq_handler:add_iq_handler(ejabberd_local, Host,
@@ -64,12 +64,9 @@ start(Host, Opts) ->
     gen_iq_handler:add_iq_handler(ejabberd_sm, Host,
                                  ?NS_DISCO_INFO, ?MODULE, process_sm_iq_info,
                                  IQDisc),
-    catch ets:new(disco_features,
-                 [named_table, ordered_set, public]),
-    register_feature(Host, <<"iq">>),
-    register_feature(Host, <<"presence">>),
     catch ets:new(disco_extra_domains,
-                 [named_table, ordered_set, public]),
+                 [named_table, ordered_set, public,
+                  {heir, erlang:group_leader(), none}]),
     ExtraDomains = gen_mod:get_opt(extra_domains, Opts,
                                    fun(Hs) ->
                                            [iolist_to_binary(H) || H <- Hs]
@@ -78,10 +75,6 @@ start(Host, Opts) ->
                          register_extra_domain(Host, Domain)
                  end,
                  ExtraDomains),
-    catch ets:new(disco_sm_features,
-                 [named_table, ordered_set, public]),
-    catch ets:new(disco_sm_nodes,
-                 [named_table, ordered_set, public]),
     ejabberd_hooks:add(disco_local_items, Host, ?MODULE,
                       get_local_services, 100),
     ejabberd_hooks:add(disco_local_features, Host, ?MODULE,
@@ -121,35 +114,14 @@ stop(Host) ->
                                     ?NS_DISCO_ITEMS),
     gen_iq_handler:remove_iq_handler(ejabberd_sm, Host,
                                     ?NS_DISCO_INFO),
-    catch ets:match_delete(disco_features, {{'_', Host}}),
     catch ets:match_delete(disco_extra_domains,
                           {{'_', Host}}),
     ok.
 
--spec register_feature(binary(), binary()) -> true.
-register_feature(Host, Feature) ->
-    catch ets:new(disco_features,
-                 [named_table, ordered_set, public]),
-    ets:insert(disco_features, {{Feature, Host}}).
-
--spec unregister_feature(binary(), binary()) -> true.
-unregister_feature(Host, Feature) ->
-    catch ets:new(disco_features,
-                 [named_table, ordered_set, public]),
-    ets:delete(disco_features, {Feature, Host}).
-
 -spec register_extra_domain(binary(), binary()) -> true.
 register_extra_domain(Host, Domain) ->
-    catch ets:new(disco_extra_domains,
-                 [named_table, ordered_set, public]),
     ets:insert(disco_extra_domains, {{Domain, Host}}).
 
--spec unregister_extra_domain(binary(), binary()) -> true.
-unregister_extra_domain(Host, Domain) ->
-    catch ets:new(disco_extra_domains,
-                 [named_table, ordered_set, public]),
-    ets:delete(disco_extra_domains, {Domain, Host}).
-
 -spec process_local_iq_items(iq()) -> iq().
 process_local_iq_items(#iq{type = set, lang = Lang} = IQ) ->
     Txt = <<"Value 'set' of 'type' attribute is not allowed">>,
@@ -198,22 +170,17 @@ get_local_identity(Acc, _From, _To, <<"">>, _Lang) ->
 get_local_identity(Acc, _From, _To, _Node, _Lang) ->
     Acc.
 
--spec get_local_features({error, stanza_error()} | {result, [binary()]} | empty,
-                        jid(), jid(), binary(), binary()) ->
+-spec get_local_features(features_acc(), jid(), jid(), binary(), binary()) ->
                                {error, stanza_error()} | {result, [binary()]}.
 get_local_features({error, _Error} = Acc, _From, _To,
                   _Node, _Lang) ->
     Acc;
-get_local_features(Acc, _From, To, <<"">>, _Lang) ->
+get_local_features(Acc, _From, _To, <<"">>, _Lang) ->
     Feats = case Acc of
                {result, Features} -> Features;
                empty -> []
            end,
-    Host = To#jid.lserver,
-    {result,
-     ets:select(disco_features,
-               ets:fun2ms(fun({{F, H}}) when H == Host -> F end))
-     ++ Feats};
+    {result, [<<"iq">>, <<"presence">>|Feats]};
 get_local_features(Acc, _From, _To, _Node, Lang) ->
     case Acc of
       {result, _Features} -> Acc;
@@ -222,9 +189,7 @@ get_local_features(Acc, _From, _To, _Node, Lang) ->
            {error, xmpp:err_item_not_found(Txt, Lang)}
     end.
 
--spec get_local_services({error, stanza_error()} | {result, [disco_item()]} | empty,
-                        jid(), jid(),
-                        binary(), binary()) ->
+-spec get_local_services(items_acc(), jid(), jid(), binary(), binary()) ->
                                {error, stanza_error()} | {result, [disco_item()]}.
 get_local_services({error, _Error} = Acc, _From, _To,
                   _Node, _Lang) ->
@@ -296,9 +261,7 @@ process_sm_iq_items(#iq{type = get, lang = Lang,
            xmpp:make_error(IQ, xmpp:err_subscription_required(Txt, Lang))
     end.
 
--spec get_sm_items({error, stanza_error()} | {result, [disco_item()]} | empty,
-                  jid(), jid(),
-                  binary(), binary()) ->
+-spec get_sm_items(items_acc(), jid(), jid(), binary(), binary()) ->
                          {error, stanza_error()} | {result, [disco_item()]}.
 get_sm_items({error, _Error} = Acc, _From, _To, _Node,
             _Lang) ->
@@ -383,8 +346,7 @@ get_sm_identity(Acc, _From,
        _ -> []
       end.
 
--spec get_sm_features({error, stanza_error()} | {result, [binary()]} | empty,
-                     jid(), jid(), binary(), binary()) ->
+-spec get_sm_features(features_acc(), jid(), jid(), binary(), binary()) ->
                             {error, stanza_error()} | {result, [binary()]}.
 get_sm_features(empty, From, To, _Node, Lang) ->
     #jid{luser = LFrom, lserver = LSFrom} = From,
index cc3b4bf7f4bac9f9e08d19695464c8b12c6ef734..e8cc298166e3ca87311b05865b134b553681d2b3 100644 (file)
@@ -29,7 +29,8 @@
 -behaviour(gen_server).
 
 %% API
--export([start_link/2, start/2, stop/1, c2s_auth_result/4, check_bl_c2s/3]).
+-export([start_link/2, start/2, stop/1, c2s_auth_result/3,
+        c2s_stream_started/2]).
 
 -export([init/1, handle_call/3, handle_cast/2,
         handle_info/2, terminate/2, code_change/3,
@@ -38,6 +39,7 @@
 -include_lib("stdlib/include/ms_transform.hrl").
 -include("ejabberd.hrl").
 -include("logger.hrl").
+-include("xmpp.hrl").
 
 -define(C2S_AUTH_BAN_LIFETIME, 3600). %% 1 hour
 -define(C2S_MAX_AUTH_FAILURES, 20).
@@ -52,12 +54,12 @@ start_link(Host, Opts) ->
     Proc = gen_mod:get_module_proc(Host, ?MODULE),
     gen_server:start_link({local, Proc}, ?MODULE, [Host, Opts], []).
 
--spec c2s_auth_result(boolean(), binary(), binary(),
-                     {inet:ip_address(), non_neg_integer()}) -> ok.
-c2s_auth_result(false, _User, LServer, {Addr, _Port}) ->
+-spec c2s_auth_result(ejabberd_c2s:state(), boolean(), binary())
+      -> ejabberd_c2s:state() | {stop, ejabberd_c2s:state()}.
+c2s_auth_result(#{ip := {Addr, _}, lserver := LServer} = State, false, _User) ->
     case is_whitelisted(LServer, Addr) of
        true ->
-           ok;
+           State;
        false ->
            BanLifetime = gen_mod:get_module_opt(
                            LServer, ?MODULE, c2s_auth_ban_lifetime,
@@ -68,47 +70,41 @@ c2s_auth_result(false, _User, LServer, {Addr, _Port}) ->
                            fun(I) when is_integer(I), I > 0 -> I end,
                            ?C2S_MAX_AUTH_FAILURES),
            UnbanTS = p1_time_compat:system_time(seconds) + BanLifetime,
-           case ets:lookup(failed_auth, Addr) of
-               [{Addr, N, _, _}] ->
-                   ets:insert(failed_auth, {Addr, N+1, UnbanTS, MaxFailures});
-               [] ->
-                   ets:insert(failed_auth, {Addr, 1, UnbanTS, MaxFailures})
-           end,
-           ok
+           Attempts = case ets:lookup(failed_auth, Addr) of
+                          [{Addr, N, _, _}] ->
+                              ets:insert(failed_auth,
+                                         {Addr, N+1, UnbanTS, MaxFailures}),
+                              N+1;
+                          [] ->
+                              ets:insert(failed_auth,
+                                         {Addr, 1, UnbanTS, MaxFailures}),
+                              1
+                      end,
+           if Attempts >= MaxFailures ->
+                   log_and_disconnect(State, Attempts, UnbanTS);
+              true ->
+                   State
+           end
     end;
-c2s_auth_result(true, _User, _Server, _AddrPort) ->
-    ok.
-
--spec check_bl_c2s({true, binary(), binary()} | false,
-                  {inet:ip_address(), non_neg_integer()},
-                  binary()) -> {stop, {true, binary(), binary()}} | false.
-check_bl_c2s(_Acc, Addr, Lang) ->
+c2s_auth_result(#{ip := {Addr, _}} = State, true, _User) ->
+    ets:delete(failed_auth, Addr),
+    State.
+
+-spec c2s_stream_started(ejabberd_c2s:state(), stream_start())
+      -> ejabberd_c2s:state() | {stop, ejabberd_c2s:state()}.
+c2s_stream_started(#{ip := {Addr, _}} = State, _) ->
+    ets:tab2list(failed_auth),
     case ets:lookup(failed_auth, Addr) of
        [{Addr, N, TS, MaxFailures}] when N >= MaxFailures ->
            case TS > p1_time_compat:system_time(seconds) of
                true ->
-                   IP = jlib:ip_to_list(Addr),
-                   UnbanDate = format_date(
-                                   calendar:now_to_universal_time(seconds_to_now(TS))),
-                   LogReason = io_lib:fwrite(
-                                 "Too many (~p) failed authentications "
-                                 "from this IP address (~s). The address "
-                                 "will be unblocked at ~s UTC",
-                                 [N, IP, UnbanDate]),
-                   ReasonT = io_lib:fwrite(
-                               translate:translate(
-                                 Lang,
-                                 <<"Too many (~p) failed authentications "
-                                   "from this IP address (~s). The address "
-                                   "will be unblocked at ~s UTC">>),
-                               [N, IP, UnbanDate]),
-                   {stop, {true, LogReason, ReasonT}};
+                   log_and_disconnect(State, N, TS);
                false ->
                    ets:delete(failed_auth, Addr),
-                   false
+                   State
            end;
        _ ->
-           false
+           State
     end.
 
 %%====================================================================
@@ -134,7 +130,7 @@ depends(_Host, _Opts) ->
 %%%===================================================================
 init([Host, _Opts]) ->
     ejabberd_hooks:add(c2s_auth_result, Host, ?MODULE, c2s_auth_result, 100),
-    ejabberd_hooks:add(check_bl_c2s, ?MODULE, check_bl_c2s, 100),
+    ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE, c2s_stream_started, 100),
     erlang:send_after(?CLEAN_INTERVAL, self(), clean),
     {ok, #state{host = Host}}.
 
@@ -160,11 +156,11 @@ handle_info(_Info, State) ->
 
 terminate(_Reason, #state{host = Host}) ->
     ejabberd_hooks:delete(c2s_auth_result, Host, ?MODULE, c2s_auth_result, 100),
+    ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE, c2s_stream_started, 100),
     case is_loaded_at_other_hosts(Host) of
        true ->
            ok;
        false ->
-           ejabberd_hooks:delete(check_bl_c2s, ?MODULE, check_bl_c2s, 100),
            ets:delete(failed_auth)
     end.
 
@@ -174,6 +170,21 @@ code_change(_OldVsn, State, _Extra) ->
 %%%===================================================================
 %%% Internal functions
 %%%===================================================================
+-spec log_and_disconnect(ejabberd_c2s:state(), pos_integer(), erlang:timestamp())
+      -> {stop, ejabberd_c2s:state()}.
+log_and_disconnect(#{ip := {Addr, _}, lang := Lang} = State, Attempts, UnbanTS) ->
+    IP = jlib:ip_to_list(Addr),
+    UnbanDate = format_date(
+                 calendar:now_to_universal_time(seconds_to_now(UnbanTS))),
+    Format = <<"Too many (~p) failed authentications "
+              "from this IP address (~s). The address "
+              "will be unblocked at ~s UTC">>,
+    Args = [Attempts, IP, UnbanDate],
+    ?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s",
+             [IP, io_lib:fwrite(Format, Args)]),
+    Err = xmpp:serr_policy_violation({Format, Args}, Lang),
+    {stop, ejabberd_c2s:send(State, Err)}.
+
 is_whitelisted(Host, Addr) ->
     Access = gen_mod:get_module_opt(Host, ?MODULE, access,
                                    fun(A) -> A end,
index a896cb8b42929a1212bb1c58ad6aaf93b9e04eae..734dbb12679be03002a06a613fe84bc52a2aca44 100644 (file)
@@ -46,7 +46,7 @@
 %% utility for other http modules
 -export([content_type/3]).
 
--export([reopen_log/1, mod_opt_type/1, depends/2]).
+-export([reopen_log/0, mod_opt_type/1, depends/2]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
@@ -236,7 +236,7 @@ check_docroot_is_readable(DRInfo, DocRoot) ->
 
 try_open_log(undefined, _Host) ->
     undefined;
-try_open_log(FN, Host) ->
+try_open_log(FN, _Host) ->
     FD = try open_log(FN) of
             FD1 -> FD1
         catch
@@ -244,7 +244,7 @@ try_open_log(FN, Host) ->
                 ?ERROR_MSG("Cannot open access log file: ~p~nReason: ~p", [FN, Reason]),
                 undefined
         end,
-    ejabberd_hooks:add(reopen_log_hook, Host, ?MODULE, reopen_log, 50),
+    ejabberd_hooks:add(reopen_log_hook, ?MODULE, reopen_log, 50),
     FD.
 
 %%--------------------------------------------------------------------
@@ -298,7 +298,8 @@ handle_info(_Info, State) ->
 %%--------------------------------------------------------------------
 terminate(_Reason, State) ->
     close_log(State#state.accesslogfd),
-    ejabberd_hooks:delete(reopen_log_hook, State#state.host, ?MODULE, reopen_log, 50),
+    %% TODO: unregister the hook gracefully
+    %% ejabberd_hooks:delete(reopen_log_hook, State#state.host, ?MODULE, reopen_log, 50),
     ok.
 
 %%--------------------------------------------------------------------
@@ -410,8 +411,11 @@ reopen_log(FN, FD) ->
     close_log(FD),
     open_log(FN).
 
-reopen_log(Host) ->
-    gen_server:cast(get_proc_name(Host), reopen_log).
+reopen_log() ->
+    lists:foreach(
+      fun(Host) ->
+             gen_server:cast(get_proc_name(Host), reopen_log)
+      end, ?MYHOSTS).
 
 add_to_log(FileSize, Code, Request) ->
     gen_server:cast(get_proc_name(Request#request.host),
index 2c17dcda36cdea2872e3e79a57e19e0f03102ec2..7a08d362bbec0353f29cf22c32b71c714c4f78d6 100644 (file)
@@ -130,13 +130,10 @@ process_sm_iq(#iq{from = From, to = To, lang = Lang} = IQ) ->
     if (Subscription == both) or (Subscription == from) or
        (From#jid.luser == To#jid.luser) and
        (From#jid.lserver == To#jid.lserver) ->
-           UserListRecord =
-               ejabberd_hooks:run_fold(privacy_get_user_list, Server,
-                                       #userlist{}, [User, Server]),
+           Pres = xmpp:set_from_to(#presence{}, To, From),
            case ejabberd_hooks:run_fold(privacy_check_packet,
                                         Server, allow,
-                                        [User, Server, UserListRecord,
-                                         {To, From, #presence{}}, out]) of
+                                        [To, Pres, out]) of
                allow -> get_last_iq(IQ, User, Server);
                deny -> xmpp:make_error(IQ, xmpp:err_forbidden())
            end;
index edb0d148529360f51593ec348a3da96a6f0f851e..0e2d77d9667fbcd2c2310efa787682fc8c9a147a 100644 (file)
 %% API
 -export([start/2, stop/1, depends/2]).
 
--export([user_send_packet/4, user_send_packet_strip_tag/4, user_receive_packet/5,
+-export([user_send_packet/1, user_send_packet_strip_tag/1, user_receive_packet/1,
         process_iq_v0_2/1, process_iq_v0_3/1, disco_sm_features/5,
         remove_user/2, remove_room/3, mod_opt_type/1, muc_process_iq/2,
-        muc_filter_message/5, message_is_archived/5, delete_old_messages/2,
+        muc_filter_message/5, message_is_archived/3, delete_old_messages/2,
         get_commands_spec/0, msg_to_el/4, get_room_config/4, set_room_option/3]).
 
 -include("xmpp.hrl").
@@ -200,46 +200,50 @@ set_room_option(_Acc, {mam, Val}, _Lang) ->
 set_room_option(Acc, _Property, _Lang) ->
     Acc.
 
--spec user_receive_packet(stanza(), ejabberd_c2s:state(), jid(), jid(), jid()) -> stanza().
-user_receive_packet(Pkt, C2SState, JID, Peer, _To) ->
+-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+user_receive_packet({Pkt, #{jid := JID} = C2SState}) ->
+    Peer = xmpp:get_from(Pkt),
     LUser = JID#jid.luser,
     LServer = JID#jid.lserver,
-    case should_archive(Pkt, LServer) of
-       true ->
-           NewPkt = strip_my_archived_tag(Pkt, LServer),
-           case store_msg(C2SState, NewPkt, LUser, LServer, Peer, recv) of
-               {ok, ID} ->
-                   set_stanza_id(NewPkt, JID, ID);
-               _ ->
-                   NewPkt
-           end;
-       _ ->
-           Pkt
-    end.
-
--spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza().
-user_send_packet(Pkt, C2SState, JID, Peer) ->
+    Pkt2 = case should_archive(Pkt, LServer) of
+              true ->
+                  Pkt1 = strip_my_archived_tag(Pkt, LServer),
+                  case store_msg(C2SState, Pkt1, LUser, LServer, Peer, recv) of
+                      {ok, ID} ->
+                          set_stanza_id(Pkt1, JID, ID);
+                      _ ->
+                          Pkt1
+                  end;
+              _ ->
+                  Pkt
+          end,
+    {Pkt2, C2SState}.
+
+-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+user_send_packet({Pkt, #{jid := JID} = C2SState}) ->
+    Peer = xmpp:get_to(Pkt),
     LUser = JID#jid.luser,
     LServer = JID#jid.lserver,
-    case should_archive(Pkt, LServer) of
-       true ->
-           NewPkt = strip_my_archived_tag(Pkt, LServer),
-           case store_msg(C2SState, xmpp:set_from_to(NewPkt, JID, Peer),
-                     LUser, LServer, Peer, send) of
-              {ok, ID} ->
-                   set_stanza_id(NewPkt, JID, ID);
-            _ ->
-                NewPkt
-        end;
-       false ->
-           Pkt
-    end.
-
--spec user_send_packet_strip_tag(stanza(), ejabberd_c2s:state(),
-                                jid(), jid()) -> stanza().
-user_send_packet_strip_tag(Pkt, _C2SState, JID, _Peer) ->
+    Pkt2 = case should_archive(Pkt, LServer) of
+              true ->
+                  Pkt1 = strip_my_archived_tag(Pkt, LServer),
+                  case store_msg(C2SState, xmpp:set_from_to(Pkt1, JID, Peer),
+                                 LUser, LServer, Peer, send) of
+                      {ok, ID} ->
+                          set_stanza_id(Pkt1, JID, ID);
+                      _ ->
+                          Pkt1
+                  end;
+              false ->
+                  Pkt
+          end,
+    {Pkt2, C2SState}.
+
+-spec user_send_packet_strip_tag({stanza(), ejabberd_c2s:state()}) ->
+                                       {stanza(), ejabberd_c2s:state()}.
+user_send_packet_strip_tag({Pkt, #{jid := JID} = C2SState}) ->
     LServer = JID#jid.lserver,
-    strip_my_archived_tag(Pkt, LServer).
+    {strip_my_archived_tag(Pkt, LServer), C2SState}.
 
 -spec muc_filter_message(message(), mod_muc_room:state(),
                         jid(), jid(), binary()) -> message().
@@ -338,12 +342,12 @@ disco_sm_features({result, OtherFeatures},
 disco_sm_features(Acc, _From, _To, _Node, _Lang) ->
     Acc.
 
--spec message_is_archived(boolean(), ejabberd_c2s:state(),
-                         jid(), jid(), message()) -> boolean().
-message_is_archived(true, _C2SState, _Peer, _JID, _Pkt) ->
+-spec message_is_archived(boolean(), ejabberd_c2s:state(), message()) -> boolean().
+message_is_archived(true, _C2SState, _Pkt) ->
     true;
-message_is_archived(false, C2SState, Peer,
-                   #jid{luser = LUser, lserver = LServer}, Pkt) ->
+message_is_archived(false, #{jid := JID} = C2SState, Pkt) ->
+    #jid{luser = LUser, lserver = LServer} = JID,
+    Peer = xmpp:get_from(Pkt),
     case gen_mod:get_module_opt(LServer, ?MODULE, assume_mam_usage,
                                fun(B) when is_boolean(B) -> B end, false) of
        true ->
index 7861542c5fd4aa88557eea6b66503b39ffe173c0..1698690d4d2d72db824800d4cfbed87497445161 100644 (file)
@@ -38,8 +38,8 @@
 
 -export([offline_message_hook/3,
          sm_register_connection_hook/3, sm_remove_connection_hook/3,
-         user_send_packet/4, user_receive_packet/5,
-         s2s_send_packet/3, s2s_receive_packet/3,
+         user_send_packet/1, user_receive_packet/1,
+         s2s_send_packet/3, s2s_receive_packet/1,
          remove_user/2, register_user/2]).
 
 %%====================================================================
@@ -86,23 +86,27 @@ sm_register_connection_hook(_SID, #jid{lserver=LServer}, _Info) ->
 sm_remove_connection_hook(_SID, #jid{lserver=LServer}, _Info) ->
     push(LServer, sm_remove_connection).
 
--spec user_send_packet(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza().
-user_send_packet(Packet, _C2SState, #jid{lserver=LServer}, _To) ->
+-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+user_send_packet({Packet, #{jid := #jid{lserver = LServer}} = C2SState}) ->
     push(LServer, user_send_packet),
-    Packet.
+    {Packet, C2SState}.
 
--spec user_receive_packet(stanza(), ejabberd_c2s:state(), jid(), jid(), jid()) -> stanza().
-user_receive_packet(Packet, _C2SState, _JID, _From, #jid{lserver=LServer}) ->
+-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+user_receive_packet({Packet, #{jid := #jid{lserver = LServer}} = C2SState}) ->
     push(LServer, user_receive_packet),
-    Packet.
+    {Packet, C2SState}.
 
 -spec s2s_send_packet(jid(), jid(), stanza()) -> any().
 s2s_send_packet(#jid{lserver=LServer}, _To, _Packet) ->
     push(LServer, s2s_send_packet).
 
--spec s2s_receive_packet(jid(), jid(), stanza()) -> any().
-s2s_receive_packet(_From, #jid{lserver=LServer}, _Packet) ->
-    push(LServer, s2s_receive_packet).
+-spec s2s_receive_packet({stanza(), ejabberd_s2s_in:state()}) ->
+                               {stanza(), ejabberd_s2s_in:state()}.
+s2s_receive_packet({Packet, S2SState}) ->
+    To = xmpp:get_to(Packet),
+    LServer = ejabberd_router:host_of_route(To#jid.lserver),
+    push(LServer, s2s_receive_packet),
+    {Packet, S2SState}.
 
 -spec remove_user(binary(), binary()) -> any().
 remove_user(_User, Server) ->
index 8d58b14c9606028e6c4dbb26f2712c7e76651ced..c1768bf1cb1dd9102a6e52c512f2cb0317c40431 100644 (file)
@@ -44,7 +44,7 @@
         store_packet/3,
         store_offline_msg/5,
         resend_offline_messages/2,
-        pop_offline_messages/3,
+        c2s_self_presence/1,
         get_sm_features/5,
         get_sm_identity/5,
         get_sm_items/5,
@@ -62,6 +62,7 @@
         get_offline_els/2,
         find_x_expire/2,
         c2s_handle_info/2,
+        c2s_copy_session/2,
         webadmin_page/3,
         webadmin_user/4,
         webadmin_user_parse_query/5]).
@@ -91,6 +92,8 @@
 -define(MAX_USER_MESSAGES, infinity).
 
 -type us() :: {binary(), binary()}.
+-type c2s_state() :: ejabberd_c2s:state().
+
 -callback init(binary(), gen_mod:opts()) -> any().
 -callback import(#offline_msg{}) -> ok.
 -callback store_messages(binary(), us(), [#offline_msg{}],
@@ -142,8 +145,7 @@ init([Host, Opts]) ->
                             no_queue),
     ejabberd_hooks:add(offline_message_hook, Host, ?MODULE,
                       store_packet, 50),
-    ejabberd_hooks:add(resend_offline_messages_hook, Host,
-                      ?MODULE, pop_offline_messages, 50),
+    ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE, c2s_self_presence, 50),
     ejabberd_hooks:add(remove_user, Host,
                       ?MODULE, remove_user, 50),
     ejabberd_hooks:add(anonymous_purge_hook, Host,
@@ -158,6 +160,7 @@ init([Host, Opts]) ->
                       ?MODULE, get_sm_items, 50),
     ejabberd_hooks:add(disco_info, Host, ?MODULE, get_info, 50),
     ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50),
+    ejabberd_hooks:add(c2s_copy_session, Host, ?MODULE, c2s_copy_session, 50),
     ejabberd_hooks:add(webadmin_page_host, Host,
                       ?MODULE, webadmin_page, 50),
     ejabberd_hooks:add(webadmin_user, Host,
@@ -202,8 +205,7 @@ terminate(_Reason, State) ->
     Host = State#state.host,
     ejabberd_hooks:delete(offline_message_hook, Host,
                          ?MODULE, store_packet, 50),
-    ejabberd_hooks:delete(resend_offline_messages_hook,
-                         Host, ?MODULE, pop_offline_messages, 50),
+    ejabberd_hooks:delete(c2s_self_presence, Host, ?MODULE, c2s_self_presence, 50),
     ejabberd_hooks:delete(remove_user, Host, ?MODULE,
                          remove_user, 50),
     ejabberd_hooks:delete(anonymous_purge_hook, Host,
@@ -214,6 +216,7 @@ terminate(_Reason, State) ->
     ejabberd_hooks:delete(disco_sm_items, Host, ?MODULE, get_sm_items, 50),
     ejabberd_hooks:delete(disco_info, Host, ?MODULE, get_info, 50),
     ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50),
+    ejabberd_hooks:delete(c2s_copy_session, Host, ?MODULE, c2s_copy_session, 50),
     ejabberd_hooks:delete(webadmin_page_host, Host,
                          ?MODULE, webadmin_page, 50),
     ejabberd_hooks:delete(webadmin_user, Host,
@@ -309,12 +312,18 @@ get_info(_Acc, #jid{luser = U, lserver = S} = JID,
 get_info(Acc, _From, _To, _Node, _Lang) ->
     Acc.
 
--spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
+-spec c2s_handle_info(c2s_state(), term()) -> c2s_state().
 c2s_handle_info(State, {resend_offline, Flag}) ->
     {stop, State#{resend_offline => Flag}};
 c2s_handle_info(State, _) ->
     State.
 
+-spec c2s_copy_session(c2s_state(), c2s_state()) -> c2s_state().
+c2s_copy_session(State, #{resend_offline := Flag}) ->
+    State#{resend_offline => Flag};
+c2s_copy_session(State, _) ->
+    State.
+
 -spec handle_offline_query(iq()) -> iq().
 handle_offline_query(#iq{from = #jid{luser = U1, lserver = S1},
                         to = #jid{luser = U2, lserver = S2},
@@ -398,10 +407,10 @@ handle_offline_fetch(#jid{luser = U, lserver = S} = JID) ->
     ejabberd_sm:route(JID, {resend_offline, false}),
     lists:foreach(
       fun({Node, El}) ->
-             NewEl = set_offline_tag(El, Node),
-             From = xmpp:get_from(El),
-             To = xmpp:get_to(El),
-             ejabberd_router:route(From, To, NewEl)
+             El1 = set_offline_tag(El, Node),
+             From = xmpp:get_from(El1),
+             To = xmpp:get_to(El1),
+             ejabberd_router:route(From, To, El1)
       end, read_messages(U, S)).
 
 -spec fetch_msg_by_node(jid(), binary()) -> error | {ok, #offline_msg{}}.
@@ -557,41 +566,67 @@ resend_offline_messages(User, Server) ->
       _ -> ok
     end.
 
--spec pop_offline_messages([{route, jid(), jid(), message()}],
-                          binary(), binary()) ->
-      [{route, jid(), jid(), message()}].
-pop_offline_messages(Ls, User, Server) ->
-    LUser = jid:nodeprep(User),
-    LServer = jid:nameprep(Server),
+c2s_self_presence({#presence{type = available} = NewPres, State} = Acc) ->
+    NewPrio = get_priority_from_presence(NewPres),
+    LastPrio = try maps:get(pres_last, State) of
+                  LastPres -> get_priority_from_presence(LastPres)
+              catch _:{badkey, _} ->
+                      -1
+              end,
+    if LastPrio < 0 andalso NewPrio >= 0 ->
+           route_offline_messages(State);
+       true ->
+           ok
+    end,
+    Acc;
+c2s_self_presence(Acc) ->
+    Acc.
+
+-spec route_offline_messages(c2s_state()) -> ok.
+route_offline_messages(#{jid := #jid{luser = LUser, lserver = LServer}} = State) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     case Mod:pop_messages(LUser, LServer) of
-       {ok, Rs} ->
-           TS = p1_time_compat:timestamp(),
-           Ls ++
-               lists:flatmap(
-                 fun(R) ->
-                         case offline_msg_to_route(LServer, R) of
-                             error -> [];
-                             RouteMsg -> [RouteMsg]
-                         end
-                 end,
-                 lists:filter(
-                   fun(#offline_msg{packet = Pkt} = R) ->
-                           Expire = case R#offline_msg.expire of
-                                        undefined ->
-                                            find_x_expire(TS, Pkt);
-                                        Exp ->
-                                            Exp
-                                    end,
-                           case Expire of
-                               never -> true;
-                               TimeStamp -> TS < TimeStamp
-                           end
-                   end, Rs));
+       {ok, OffMsgs} ->
+           lists:foreach(
+             fun(OffMsg) ->
+                     route_offline_message(State, OffMsg)
+             end, OffMsgs);
        _ ->
-           Ls
+           ok
     end.
 
+-spec route_offline_message(c2s_state(), #offline_msg{}) -> ok.
+route_offline_message(#{lserver := LServer} = State,
+                     #offline_msg{expire = Expire} = OffMsg) ->
+    case offline_msg_to_route(LServer, OffMsg) of
+       error ->
+           ok;
+       {route, From, To, Msg} ->
+           case is_message_expired(Expire, Msg) of
+               true ->
+                   ok;
+               false ->
+                   case privacy_check_packet(State, Msg, in) of
+                       allow -> ejabberd_router:route(From, To, Msg);
+                       false -> ok
+                   end
+           end
+    end.
+
+-spec is_message_expired(erlang:timestamp() | never, message()) -> boolean().
+is_message_expired(Expire, Msg) ->
+    TS = p1_time_compat:timestamp(),
+    Expire1 = case Expire of
+                 undefined -> find_x_expire(TS, Msg);
+                 _ -> Expire
+             end,
+    Expire1 /= never andalso Expire1 =< TS.
+
+-spec privacy_check_packet(c2s_state(), stanza(), in | out) -> allow | deny.
+privacy_check_packet(#{lserver := LServer} = State, Pkt, Dir) ->
+    ejabberd_hooks:run_fold(privacy_check_packet,
+                           LServer, allow, [State, Pkt, Dir]).
+
 remove_expired_messages(Server) ->
     LServer = jid:nameprep(Server),
     Mod = gen_mod:db_mod(LServer, ?MODULE),
@@ -635,14 +670,15 @@ get_offline_els(LUser, LServer) ->
 
 -spec offline_msg_to_route(binary(), #offline_msg{}) ->
                                  {route, jid(), jid(), message()} | error.
-offline_msg_to_route(LServer, #offline_msg{} = R) ->
+offline_msg_to_route(LServer, #offline_msg{from = From, to = To} = R) ->
     try xmpp:decode(R#offline_msg.packet, ?NS_CLIENT, [ignore_els]) of
        Pkt ->
-           NewPkt = add_delay_info(Pkt, LServer, R#offline_msg.timestamp),
-           {route, R#offline_msg.from, R#offline_msg.to, NewPkt}
+           Pkt1 = xmpp:set_from_to(Pkt, From, To),
+           Pkt2 = add_delay_info(Pkt1, LServer, R#offline_msg.timestamp),
+           {route, From, To, Pkt2}
     catch _:{xmpp_codec, Why} ->
            ?ERROR_MSG("failed to decode packet ~p of user ~s: ~s",
-                      [R#offline_msg.packet, jid:to_string(R#offline_msg.to),
+                      [R#offline_msg.packet, jid:to_string(To),
                        xmpp:format_error(Why)]),
            error
     end.
@@ -840,9 +876,17 @@ count_offline_messages(User, Server) ->
 add_delay_info(Packet, _LServer, undefined) ->
     Packet;
 add_delay_info(Packet, LServer, {_, _, _} = TS) ->
-    xmpp_util:add_delay_info(Packet, jid:make(LServer), TS,
+    Packet1 = xmpp:put_meta(Packet, from_offline, true),
+    xmpp_util:add_delay_info(Packet1, jid:make(LServer), TS,
                             <<"Offline storage">>).
 
+-spec get_priority_from_presence(presence()) -> integer().
+get_priority_from_presence(#presence{priority = Prio}) ->
+    case Prio of
+       undefined -> 0;
+       _ -> Prio
+    end.
+
 export(LServer) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     Mod:export(LServer).
index 5e861b7f77b63b77192de8ff9202d2534fa2d842..09550ee9a9000c955d95be6d491d212fe3aea506 100644 (file)
@@ -54,8 +54,8 @@
 -export([init/1, terminate/2, handle_call/3,
         handle_cast/2, handle_info/2, code_change/3]).
 
--export([iq_ping/1, user_online/3, user_offline/3,
-        user_send/4, mod_opt_type/1, depends/2]).
+-export([iq_ping/1, user_online/3, user_offline/3, disco_features/5,
+        user_send/1, mod_opt_type/1, depends/2]).
 
 -record(state,
        {host = <<"">>,
@@ -116,7 +116,7 @@ init([Host, Opts]) ->
                                     end, none),
     IQDisc = gen_mod:get_opt(iqdisc, Opts, fun gen_iq_handler:check_type/1,
                              no_queue),
-    mod_disco:register_feature(Host, ?NS_PING),
+    ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 50),
     gen_iq_handler:add_iq_handler(ejabberd_sm, Host,
                                  ?NS_PING, ?MODULE, iq_ping, IQDisc),
     gen_iq_handler:add_iq_handler(ejabberd_local, Host,
@@ -145,11 +145,12 @@ terminate(_Reason, #state{host = Host}) ->
                          ?MODULE, user_online, 100),
     ejabberd_hooks:delete(user_send_packet, Host, ?MODULE,
                          user_send, 100),
+    ejabberd_hooks:delete(disco_local_features, Host, ?MODULE,
+                         disco_features, 50),
     gen_iq_handler:remove_iq_handler(ejabberd_local, Host,
                                     ?NS_PING),
     gen_iq_handler:remove_iq_handler(ejabberd_sm, Host,
-                                    ?NS_PING),
-    mod_disco:unregister_feature(Host, ?NS_PING).
+                                    ?NS_PING).
 
 handle_call(stop, _From, State) ->
     {stop, normal, ok, State};
@@ -215,10 +216,22 @@ user_online(_SID, JID, _Info) ->
 user_offline(_SID, JID, _Info) ->
     stop_ping(JID#jid.lserver, JID).
 
--spec user_send(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza().
-user_send(Packet, _C2SState, JID, _From) ->
+-spec user_send({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+user_send({Packet, #{jid := JID} = C2SState}) ->
     start_ping(JID#jid.lserver, JID),
-    Packet.
+    {Packet, C2SState}.
+
+-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty,
+                    jid(), jid(), binary(), binary()) ->
+                           {error, stanza_error()} | {result, [binary()]}.
+disco_features({error, Err}, _From, _To, _Node, _Lang) ->
+    {error, Err};
+disco_features(empty, _From, _To, <<"">>, _Lang) ->
+    {result, [?NS_PING]};
+disco_features({result, Feats}, _From, _To, <<"">>, _Lang) ->
+    {result, [?NS_PING|Feats]};
+disco_features(Acc, _From, _To, _Node, _Lang) ->
+    Acc.
 
 %%====================================================================
 %% Internal functions
index 955e53f6f9fd531b06ff524a4dcb8c4fff1d1a89..8da4f7b29150c35207887fca29e9ad5c3c0968ba 100644 (file)
@@ -27,7 +27,7 @@
 
 -behavior(gen_mod).
 
--export([start/2, stop/1, check_packet/6,
+-export([start/2, stop/1, check_packet/4,
         mod_opt_type/1, depends/2]).
 
 -include("ejabberd.hrl").
@@ -51,10 +51,12 @@ stop(Host) ->
 depends(_Host, _Opts) ->
     [].
 
--spec check_packet(allow | deny, binary(), binary(), _,
-                  {jid(), jid(), stanza()}, in | out) -> allow | deny.
-check_packet(_, _User, Server, _PrivacyList,
-            {From, To, #presence{type = Type}}, Dir) ->
+-spec check_packet(allow | deny, ejabberd_c2s:state() | jid(),
+                  stanza(), in | out) -> allow | deny.
+check_packet(Acc, #{jid := JID}, Packet, Dir) ->
+    check_packet(Acc, JID, Packet, Dir);
+check_packet(_, #jid{lserver = LServer},
+            #presence{from = From, to = To, type = Type}, Dir) ->
     IsSubscription = case Type of
                         subscribe -> true;
                         subscribed -> true;
@@ -67,11 +69,11 @@ check_packet(_, _User, Server, _PrivacyList,
                      in -> To;
                      out -> From
                  end,
-           update(Server, JID, Dir);
+           update(LServer, JID, Dir);
        true -> allow
     end;
-check_packet(_, _User, _Server, _PrivacyList, _Pkt, _Dir) ->
-    allow.
+check_packet(Acc, _, _, _) ->
+    Acc.
 
 update(Server, JID, Dir) ->
     StormCount = gen_mod:get_module_opt(Server, ?MODULE, count,
index b28bbcea2fefcfb15d0bb0deb98061b4e24e68fb..6eb939c3c126d8fe643e297340766a4bdf923167 100644 (file)
 -behaviour(gen_mod).
 
 -export([start/2, stop/1, process_iq/1, export/1, import_info/0,
-        process_iq_set/3, process_iq_get/3, get_user_list/3,
-        check_packet/6, remove_user/2, encode_list_item/1,
-        is_list_needdb/1, updated_list/3,
-        import_start/2, import_stop/2, c2s_handle_info/2,
+        c2s_session_opened/1, c2s_copy_session/2, push_list_update/3,
+        user_send_packet/1, user_receive_packet/1, disco_features/5,
+        check_packet/4, remove_user/2, encode_list_item/1,
+        is_list_needdb/1, import_start/2, import_stop/2,
          item_to_xml/1, get_user_lists/2, import/5,
         set_privacy_list/1, mod_opt_type/1, depends/2]).
 
@@ -64,106 +64,124 @@ start(Host, Opts) ->
                              one_queue),
     Mod = gen_mod:db_mod(Host, Opts, ?MODULE),
     Mod:init(Host, Opts),
-    mod_disco:register_feature(Host, ?NS_PRIVACY),
-    ejabberd_hooks:add(privacy_iq_get, Host, ?MODULE,
-                      process_iq_get, 50),
-    ejabberd_hooks:add(privacy_iq_set, Host, ?MODULE,
-                      process_iq_set, 50),
-    ejabberd_hooks:add(privacy_get_user_list, Host, ?MODULE,
-                      get_user_list, 50),
+    ejabberd_hooks:add(disco_local_features, Host, ?MODULE,
+                      disco_features, 50),
+    ejabberd_hooks:add(c2s_session_opened, Host, ?MODULE,
+                      c2s_session_opened, 50),
+    ejabberd_hooks:add(c2s_copy_session, Host, ?MODULE,
+                      c2s_copy_session, 50),
+    ejabberd_hooks:add(user_send_packet, Host, ?MODULE,
+                      user_send_packet, 50),
+    ejabberd_hooks:add(user_receive_packet, Host, ?MODULE,
+                      user_receive_packet, 50),
     ejabberd_hooks:add(privacy_check_packet, Host, ?MODULE,
                       check_packet, 50),
-    ejabberd_hooks:add(privacy_updated_list, Host, ?MODULE,
-                      updated_list, 50),
     ejabberd_hooks:add(remove_user, Host, ?MODULE,
                       remove_user, 50),
-    ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
-                      c2s_handle_info, 50),
     gen_iq_handler:add_iq_handler(ejabberd_sm, Host,
                                  ?NS_PRIVACY, ?MODULE, process_iq, IQDisc).
 
 stop(Host) ->
-    mod_disco:unregister_feature(Host, ?NS_PRIVACY),
-    ejabberd_hooks:delete(privacy_iq_get, Host, ?MODULE,
-                         process_iq_get, 50),
-    ejabberd_hooks:delete(privacy_iq_set, Host, ?MODULE,
-                         process_iq_set, 50),
-    ejabberd_hooks:delete(privacy_get_user_list, Host,
-                         ?MODULE, get_user_list, 50),
+    ejabberd_hooks:delete(disco_local_features, Host, ?MODULE,
+                         disco_features, 50),
+    ejabberd_hooks:delete(c2s_session_opened, Host, ?MODULE,
+                         c2s_session_opened, 50),
+    ejabberd_hooks:delete(c2s_copy_session, Host, ?MODULE,
+                         c2s_copy_session, 50),
+    ejabberd_hooks:delete(user_send_packet, Host, ?MODULE,
+                         user_send_packet, 50),
+    ejabberd_hooks:delete(user_receive_packet, Host, ?MODULE,
+                         user_receive_packet, 50),
     ejabberd_hooks:delete(privacy_check_packet, Host,
                          ?MODULE, check_packet, 50),
-    ejabberd_hooks:delete(privacy_updated_list, Host,
-                         ?MODULE, updated_list, 50),
     ejabberd_hooks:delete(remove_user, Host, ?MODULE,
                          remove_user, 50),
-    ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
-                         c2s_handle_info, 50),
     gen_iq_handler:remove_iq_handler(ejabberd_sm, Host,
                                     ?NS_PRIVACY).
 
+-spec disco_features({error, stanza_error()} | {result, [binary()]} | empty,
+                    jid(), jid(), binary(), binary()) ->
+                           {error, stanza_error()} | {result, [binary()]}.
+disco_features({error, Err}, _From, _To, _Node, _Lang) ->
+    {error, Err};
+disco_features(empty, _From, _To, <<"">>, _Lang) ->
+    {result, [?NS_PRIVACY]};
+disco_features({result, Feats}, _From, _To, <<"">>, _Lang) ->
+    {result, [?NS_PRIVACY|Feats]};
+disco_features(Acc, _From, _To, _Node, _Lang) ->
+    Acc.
+
 -spec process_iq(iq()) -> iq().
-process_iq(IQ) ->
-    xmpp:make_error(IQ, xmpp:err_not_allowed()).
-
--spec process_iq_get({error, stanza_error()} | {result, xmpp_element() | undefined},
-                    iq(), userlist()) -> {error, stanza_error()} |
-                                         {result, xmpp_element() | undefined}.
-process_iq_get(_, #iq{lang = Lang,
-                     sub_els = [#privacy_query{default = Default,
-                                               active = Active}]},
-              _) when Default /= undefined; Active /= undefined ->
+process_iq(#iq{type = Type,
+              from = #jid{luser = U, lserver = S},
+              to = #jid{luser = U, lserver = S}} = IQ) ->
+    case Type of
+       get -> process_iq_get(IQ);
+       set -> process_iq_set(IQ)
+    end;
+process_iq(#iq{lang = Lang} = IQ) ->
+    Txt = <<"Query to another users is forbidden">>,
+    xmpp:make_error(IQ, xmpp:err_forbidden(Txt, Lang)).
+
+-spec process_iq_get(iq()) -> iq().
+process_iq_get(#iq{lang = Lang,
+                  sub_els = [#privacy_query{default = Default,
+                                            active = Active}]} = IQ)
+  when Default /= undefined; Active /= undefined ->
     Txt = <<"Only <list/> element is allowed in this query">>,
-    {error, xmpp:err_bad_request(Txt, Lang)};
-process_iq_get(_, #iq{from = From, lang = Lang,
-                     sub_els = [#privacy_query{lists = Lists}]},
-              #userlist{name = Active}) ->
-    #jid{luser = LUser, lserver = LServer} = From,
+    xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang));
+process_iq_get(#iq{lang = Lang,
+                  sub_els = [#privacy_query{lists = Lists}]} = IQ) ->
     case Lists of
        [] ->
-           process_lists_get(LUser, LServer, Active, Lang);
+           process_lists_get(IQ);
        [#privacy_list{name = ListName}] ->
-           process_list_get(LUser, LServer, ListName, Lang);
+           process_list_get(IQ, ListName);
        _ ->
            Txt = <<"Too many <list/> elements">>,
-           {error, xmpp:err_bad_request(Txt, Lang)}
+           xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang))
     end;
-process_iq_get(Acc, _, _) ->
-    Acc.
-
--spec process_lists_get(binary(), binary(), binary(), binary()) ->
-                              {error, stanza_error()} | {result, privacy_query()}.
-process_lists_get(LUser, LServer, Active, Lang) ->
+process_iq_get(#iq{lang = Lang} = IQ) ->
+    Txt = <<"No module is handling this query">>,
+    xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)).
+
+-spec process_lists_get(iq()) -> iq().
+process_lists_get(#iq{from = #jid{luser = LUser, lserver = LServer},
+                     lang = Lang,
+                     meta = #{privacy_active_list := Active}} = IQ) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     case Mod:process_lists_get(LUser, LServer) of
        error ->
            Txt = <<"Database failure">>,
-           {error, xmpp:err_internal_server_error(Txt, Lang)};
+           xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang));
        {_Default, []} ->
-           {result, #privacy_query{}};
+           xmpp:make_iq_result(IQ, #privacy_query{});
        {Default, ListNames} ->
-           {result,
-            #privacy_query{active = Active,
-                           default = Default,
-                           lists = [#privacy_list{name = ListName}
-                                    || ListName <- ListNames]}}
+           xmpp:make_iq_result(
+             IQ,
+             #privacy_query{active = Active,
+                            default = Default,
+                            lists = [#privacy_list{name = ListName}
+                                     || ListName <- ListNames]})
     end.
 
--spec process_list_get(binary(), binary(), binary(), binary()) ->
-                             {error, stanza_error()} | {result, privacy_query()}.
-process_list_get(LUser, LServer, Name, Lang) ->
+-spec process_list_get(iq(), binary()) -> iq().
+process_list_get(#iq{from = #jid{luser = LUser, lserver = LServer},
+                    lang = Lang} = IQ, Name) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     case Mod:process_list_get(LUser, LServer, Name) of
        error ->
            Txt = <<"Database failure">>,
-           {error, xmpp:err_internal_server_error(Txt, Lang)};
+           xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang));
        not_found ->
            Txt = <<"No privacy list with this name found">>,
-           {error, xmpp:err_item_not_found(Txt, Lang)};
+           xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang));
        Items ->
            LItems = lists:map(fun encode_list_item/1, Items),
-           {result,
-            #privacy_query{
-               lists = [#privacy_list{name = Name, items = LItems}]}}
+           xmpp:make_iq_result(
+             IQ,
+             #privacy_query{
+                lists = [#privacy_list{name = Name, items = LItems}]})
     end.
 
 -spec item_to_xml(listitem()) -> xmlel().
@@ -228,69 +246,61 @@ decode_value(Type, Value) ->
        undefined -> none
     end.
 
--spec process_iq_set({error, stanza_error()} |
-                    {result, xmpp_element() | undefined} |
-                    {result, xmpp_element() | undefined, userlist()},
-                    iq(), #userlist{}) ->
-                           {error, stanza_error()} |
-                           {result, xmpp_element() | undefined} |
-                           {result, xmpp_element() | undefined, userlist()}.
-process_iq_set(_, #iq{from = From, lang = Lang,
-                     sub_els = [#privacy_query{default = Default,
-                                               active = Active,
-                                               lists = Lists}]},
-             #userlist{} = UserList) ->
-    #jid{luser = LUser, lserver = LServer} = From,
+-spec process_iq_set(iq()) -> iq().
+process_iq_set(#iq{lang = Lang,
+                  sub_els = [#privacy_query{default = Default,
+                                            active = Active,
+                                            lists = Lists}]} = IQ) ->
     case Lists of
        [#privacy_list{items = Items, name = ListName}]
          when Default == undefined, Active == undefined ->
-           process_lists_set(LUser, LServer, ListName, Items, UserList, Lang);
+           process_lists_set(IQ, ListName, Items);
        [] when Default == undefined, Active /= undefined ->
-           process_active_set(LUser, LServer, Active, Lang);
+           process_active_set(IQ, Active);
        [] when Active == undefined, Default /= undefined ->
-           process_default_set(LUser, LServer, Default, Lang);
+           process_default_set(IQ, Default);
        _ ->
            Txt = <<"The stanza MUST contain only one <active/> element, "
                    "one <default/> element, or one <list/> element">>,
-           {error, xmpp:err_bad_request(Txt, Lang)}
+           xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang))
     end;
-process_iq_set(Acc, _, _) ->
-    Acc.
+process_iq_set(#iq{lang = Lang} = IQ) ->
+    Txt = <<"No module is handling this query">>,
+    xmpp:make_error(IQ, xmpp:err_service_unavailable(Txt, Lang)).
 
--spec process_default_set(binary(), binary(), none | binary(),
-                         binary()) -> {error, stanza_error()} | {result, undefined}.
-process_default_set(LUser, LServer, Value, Lang) ->
+-spec process_default_set(iq(), binary()) -> iq().
+process_default_set(#iq{from = #jid{luser = LUser, lserver = LServer},
+                       lang = Lang} = IQ, Value) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     case Mod:process_default_set(LUser, LServer, Value) of
        {atomic, error} ->
            Txt = <<"Database failure">>,
-           {error, xmpp:err_internal_server_error(Txt, Lang)};
+           xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang));
        {atomic, not_found} ->
            Txt = <<"No privacy list with this name found">>,
-           {error, xmpp:err_item_not_found(Txt, Lang)};
+           xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang));
        {atomic, ok} ->
-           {result, undefined};
+           xmpp:make_iq_result(IQ);
        Err ->
            ?ERROR_MSG("failed to set default list '~s' for user ~s@~s: ~p",
                       [Value, LUser, LServer, Err]),
-           {error, xmpp:err_internal_server_error()}
+           xmpp:make_error(IQ, xmpp:err_internal_server_error())
     end.
 
--spec process_active_set(binary(), binary(), none | binary(), binary()) ->
-                               {error, stanza_error()} |
-                               {result, undefined, userlist()}.
-process_active_set(_LUser, _LServer, none, _Lang) ->
-    {result, undefined, #userlist{}};
-process_active_set(LUser, LServer, Name, Lang) ->
+-spec process_active_set(IQ, none | binary()) -> IQ.
+process_active_set(IQ, none) ->
+    xmpp:make_iq_result(xmpp:put_meta(IQ, privacy_list, #userlist{}));
+process_active_set(#iq{from = #jid{luser = LUser, lserver = LServer},
+                      lang = Lang} = IQ, Name) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     case Mod:process_active_set(LUser, LServer, Name) of
        error ->
            Txt = <<"No privacy list with this name found">>,
-           {error, xmpp:err_item_not_found(Txt, Lang)};
+           xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang));
        Items ->
            NeedDb = is_list_needdb(Items),
-           {result, undefined,
-            #userlist{name = Name, list = Items, needdb = NeedDb}}
+           List = #userlist{name = Name, list = Items, needdb = NeedDb},
+           xmpp:make_iq_result(xmpp:put_meta(IQ, privacy_list, List))
     end.
 
 -spec set_privacy_list(privacy()) -> any().
@@ -298,57 +308,100 @@ set_privacy_list(#privacy{us = {_, LServer}} = Privacy) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     Mod:set_privacy_list(Privacy).
 
--spec process_lists_set(binary(), binary(), binary(), [privacy_item()],
-                       #userlist{}, binary()) -> {error, stanza_error()} |
-                                                 {result, undefined}.
-process_lists_set(_LUser, _LServer, Name, [], #userlist{name = Name}, Lang) ->
+-spec process_lists_set(iq(), binary(), [privacy_item()]) -> iq().
+process_lists_set(#iq{meta = #{privacy_active_list := Name},
+                     lang = Lang} = IQ, Name, []) ->
     Txt = <<"Cannot remove active list">>,
-    {error, xmpp:err_conflict(Txt, Lang)};
-process_lists_set(LUser, LServer, Name, [], _UserList, Lang) ->
+    xmpp:make_error(IQ, xmpp:err_conflict(Txt, Lang));
+process_lists_set(#iq{from = #jid{luser = LUser, lserver = LServer} = From,
+                     lang = Lang} = IQ, Name, []) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     case Mod:remove_privacy_list(LUser, LServer, Name) of
        {atomic, conflict} ->
            Txt = <<"Cannot remove default list">>,
-           {error, xmpp:err_conflict(Txt, Lang)};
+           xmpp:make_error(IQ, xmpp:err_conflict(Txt, Lang));
        {atomic, not_found} ->
            Txt = <<"No privacy list with this name found">>,
-           {error, xmpp:err_item_not_found(Txt, Lang)};
+           xmpp:make_error(IQ, xmpp:err_item_not_found(Txt, Lang));
        {atomic, ok} ->
-           ejabberd_sm:route(jid:make(LUser, LServer, <<"">>),
-                             {privacy_list, #userlist{name = Name}, Name}),
-           {result, undefined};
+           push_list_update(From, #userlist{name = Name}, Name),
+           xmpp:make_iq_result(IQ);
        Err ->
            ?ERROR_MSG("failed to remove privacy list '~s' for user ~s@~s: ~p",
                       [Name, LUser, LServer, Err]),
            Txt = <<"Database failure">>,
-           {error, xmpp:err_internal_server_error(Txt, Lang)}
+           xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang))
     end;
-process_lists_set(LUser, LServer, Name, Items, _UserList, Lang) ->
+process_lists_set(#iq{from = #jid{luser = LUser, lserver = LServer} = From,
+                     lang = Lang} = IQ, Name, Items) ->
     case catch lists:map(fun decode_item/1, Items) of
        {error, Why} ->
            Txt = xmpp:format_error(Why),
-           {error, xmpp:err_bad_request(Txt, Lang)};
+           xmpp:make_error(IQ, xmpp:err_bad_request(Txt, Lang));
        List ->
            Mod = gen_mod:db_mod(LServer, ?MODULE),
            case Mod:set_privacy_list(LUser, LServer, Name, List) of
                {atomic, ok} ->
-                   NeedDb = is_list_needdb(List),
-                   ejabberd_sm:route(jid:make(LUser, LServer, <<"">>),
-                                     {privacy_list,
-                                      #userlist{name = Name,
-                                                list = List,
-                                                needdb = NeedDb},
-                                      Name}),
-                   {result, undefined};
+                   UserList = #userlist{name = Name, list = List,
+                                        needdb = is_list_needdb(List)},
+                   push_list_update(From, UserList, Name),
+                   xmpp:make_iq_result(IQ);
                Err ->
                    ?ERROR_MSG("failed to set privacy list '~s' "
                               "for user ~s@~s: ~p",
                               [Name, LUser, LServer, Err]),
                    Txt = <<"Database failure">>,
-                   {error, xmpp:err_internal_server_error(Txt, Lang)}
+                   xmpp:make_error(IQ, xmpp:err_internal_server_error(Txt, Lang))
            end
     end.
 
+-spec push_list_update(jid(), #userlist{}, binary() | none) -> ok.
+push_list_update(From, List, Name) ->
+    BareFrom = jid:remove_resource(From),
+    lists:foreach(
+      fun(R) ->
+             To = jid:replace_resource(From, R),
+             IQ = #iq{type = set, from = BareFrom, to = To,
+                      id = <<"push", (randoms:get_string())/binary>>,
+                      sub_els = [#privacy_query{
+                                    lists = [#privacy_list{name = Name}]}],
+                      meta = #{privacy_updated_list => List}},
+             ejabberd_router:route(BareFrom, To, IQ)
+      end, ejabberd_sm:get_user_resources(From#jid.luser, From#jid.lserver)).
+
+-spec user_send_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+user_send_packet({#iq{type = Type,
+                     to = #jid{luser = U, lserver = S, lresource = <<"">>},
+                     from = #jid{luser = U, lserver = S},
+                     sub_els = [_]} = IQ,
+                 #{privacy_list := #userlist{name = Name}} = State})
+  when Type == get; Type == set ->
+    NewIQ = case xmpp:has_subtag(IQ, #privacy_query{}) of
+               true -> xmpp:put_meta(IQ, privacy_active_list, Name);
+               false -> IQ
+           end,
+    {NewIQ, State};
+user_send_packet(Acc) ->
+    Acc.
+
+-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+user_receive_packet({#iq{type = result, meta = #{privacy_list := List}} = IQ,
+                    State}) ->
+    {IQ, State#{privacy_list => List}};
+user_receive_packet({#iq{type = set, meta = #{privacy_updated_list := New}} = IQ,
+                    #{user := U, server := S, resource := R,
+                      privacy_list := Old} = State}) ->
+    State1 = if Old#userlist.name == New#userlist.name ->
+                    State#{privacy_list => New};
+               true ->
+                    State
+            end,
+    From = jid:make(U, S, <<"">>),
+    To = jid:make(U, S, R),
+    {xmpp:set_from_to(IQ, From, To), State1};
+user_receive_packet(Acc) ->
+    Acc.
+
 -spec decode_item(privacy_item()) -> listitem().
 decode_item(#privacy_item{order = Order,
                          action = Action,
@@ -391,15 +444,20 @@ is_list_needdb(Items) ->
              end,
              Items).
 
--spec get_user_list(userlist(), binary(), binary()) -> userlist().
-get_user_list(_Acc, User, Server) ->
-    LUser = jid:nodeprep(User),
-    LServer = jid:nameprep(Server),
+-spec get_user_list(binary(), binary()) -> #userlist{}.
+get_user_list(LUser, LServer) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     {Default, Items} = Mod:get_user_list(LUser, LServer),
     NeedDb = is_list_needdb(Items),
-    #userlist{name = Default, list = Items,
-             needdb = NeedDb}.
+    #userlist{name = Default, list = Items, needdb = NeedDb}.
+
+-spec c2s_session_opened(ejabberd_c2s:state()) -> ejabberd_c2s:state().
+c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer}} = State) ->
+    State#{privacy_list => get_user_list(LUser, LServer)}.
+
+-spec c2s_copy_session(ejabberd_c2s:state(), ejabberd_c2s:state()) -> ejabberd_c2s:state().
+c2s_copy_session(State, #{privacy_list := List}) ->
+    State#{privacy_list => List}.
 
 -spec get_user_lists(binary(), binary()) -> {ok, privacy()} | error.
 get_user_lists(User, Server) ->
@@ -411,59 +469,66 @@ get_user_lists(User, Server) ->
 %% From is the sender, To is the destination.
 %% If Dir = out, User@Server is the sender account (From).
 %% If Dir = in, User@Server is the destination account (To).
--spec check_packet(allow | deny, binary(), binary(), userlist(),
-                  {jid(), jid(), stanza()}, in | out) -> allow | deny.
-check_packet(_, _User, _Server, _UserList,
-            {#jid{luser = <<"">>, lserver = Server} = _From,
-             #jid{lserver = Server} = _To, _},
-            in) ->
-    allow;
-check_packet(_, _User, _Server, _UserList,
-            {#jid{lserver = Server} = _From,
-             #jid{luser = <<"">>, lserver = Server} = _To, _},
-            out) ->
-    allow;
-check_packet(_, _User, _Server, _UserList,
-            {#jid{luser = User, lserver = Server} = _From,
-             #jid{luser = User, lserver = Server} = _To, _},
-            _Dir) ->
-    allow;
-check_packet(_, User, Server,
-            #userlist{list = List, needdb = NeedDb},
-            {From, To, Packet}, Dir) ->
-    case List of
-      [] -> allow;
-      _ ->
-         PType = case Packet of
-                   #message{} -> message;
-                   #iq{} -> iq;
-                   #presence{type = available} -> presence;
-                   #presence{type = unavailable} -> presence;
-                   _ -> other
-                 end,
-         PType2 = case {PType, Dir} of
-                    {message, in} -> message;
-                    {iq, in} -> iq;
-                    {presence, in} -> presence_in;
-                    {presence, out} -> presence_out;
-                    {_, _} -> other
+-spec check_packet(allow | deny, ejabberd_c2s:state() | jid(),
+                  stanza(), in | out) -> allow | deny.
+check_packet(_, #{jid := #jid{luser = LUser, lserver = LServer},
+                 privacy_list := #userlist{list = List, needdb = NeedDb}},
+            Packet, Dir) ->
+    From = xmpp:get_from(Packet),
+    To = xmpp:get_to(Packet),
+    case {From, To} of
+       {#jid{luser = <<"">>, lserver = LServer},
+        #jid{lserver = LServer}} when Dir == in ->
+           %% Allow any packets from local server
+           allow;
+       {#jid{lserver = LServer},
+        #jid{luser = <<"">>, lserver = LServer}} when Dir == out ->
+           %% Allow any packets to local server
+           allow;
+       {#jid{luser = LUser, lserver = LServer, lresource = <<"">>},
+        #jid{luser = LUser, lserver = LServer}} when Dir == in ->
+           %% Allow incoming packets from user's bare jid to his full jid
+           allow;
+       {#jid{luser = LUser, lserver = LServer},
+        #jid{luser = LUser, lserver = LServer, lresource = <<"">>}} when Dir == out ->
+           %% Allow outgoing packets from user's full jid to his bare JID
+           allow;
+       _ when List == [] ->
+           allow;
+       _ ->
+           PType = case Packet of
+                       #message{} -> message;
+                       #iq{} -> iq;
+                       #presence{type = available} -> presence;
+                       #presence{type = unavailable} -> presence;
+                       _ -> other
+                   end,
+           PType2 = case {PType, Dir} of
+                        {message, in} -> message;
+                        {iq, in} -> iq;
+                        {presence, in} -> presence_in;
+                        {presence, out} -> presence_out;
+                        {_, _} -> other
+                    end,
+           LJID = case Dir of
+                      in -> jid:tolower(From);
+                      out -> jid:tolower(To)
                   end,
-         LJID = case Dir of
-                  in -> jid:tolower(From);
-                  out -> jid:tolower(To)
-                end,
-         {Subscription, Groups} = case NeedDb of
-                                    true ->
-                                        ejabberd_hooks:run_fold(roster_get_jid_info,
-                                                                jid:nameprep(Server),
-                                                                {none, []},
-                                                                [User, Server,
-                                                                 LJID]);
-                                    false -> {[], []}
-                                  end,
-         check_packet_aux(List, PType2, LJID, Subscription,
-                          Groups)
-    end.
+           {Subscription, Groups} =
+               case NeedDb of
+                   true ->
+                       ejabberd_hooks:run_fold(roster_get_jid_info,
+                                               LServer,
+                                               {none, []},
+                                               [LUser, LServer, LJID]);
+                   false ->
+                       {[], []}
+               end,
+           check_packet_aux(List, PType2, LJID, Subscription, Groups)
+    end;
+check_packet(Acc, #jid{luser = LUser, lserver = LServer} = JID, Packet, Dir) ->
+    List = get_user_list(LUser, LServer),
+    check_packet(Acc, #{jid => JID, privacy_list => List}, Packet, Dir).
 
 -spec check_packet_aux([listitem()],
                       message | iq | presence_in | presence_out | other,
@@ -535,30 +600,6 @@ remove_user(User, Server) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     Mod:remove_user(LUser, LServer).
 
-c2s_handle_info(#{privacy_list := Old,
-                 user := U, server := S, resource := R} = State,
-               {privacy_list, New, Name}) ->
-    List = if Old#userlist.name == New#userlist.name -> New;
-             true -> Old
-          end,
-    From = jid:make(U, S),
-    To = jid:make(U, S, R),
-    PushIQ = #iq{type = set, from = From, to = To,
-                id = <<"push", (randoms:get_string())/binary>>,
-                sub_els = [#privacy_query{
-                              lists = [#privacy_list{name = Name}]}]},
-    State1 = State#{privacy_list => List},
-    {stop, ejabberd_c2s:send(State1, PushIQ)};
-c2s_handle_info(State, _) ->
-    State.
-
--spec updated_list(userlist(), userlist(), userlist()) -> userlist().
-updated_list(_, #userlist{name = OldName} = Old,
-            #userlist{name = NewName} = New) ->
-    if OldName == NewName -> New;
-       true -> Old
-    end.
-
 numeric_to_binary(<<0, 0, _/binary>>) ->
     <<"0">>;
 numeric_to_binary(<<0, _, _:6/binary, T/binary>>) ->
index c1ac5a3fc528bac7a55cfa5b4d8b06691224eb40..ae0a67e72b0096fbf4502657590ef87d0f813890 100644 (file)
@@ -38,7 +38,7 @@
         terminate/2, code_change/3]).
 -export([component_connected/1, component_disconnected/2,
         roster_access/2, process_message/3,
-        process_presence_out/4, process_presence_in/5]).
+        process_presence_out/1, process_presence_in/1]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
@@ -133,10 +133,11 @@ roster_access(false, #iq{from = From, to = To, type = Type}) ->
            false
     end.
 
--spec process_presence_out(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza().
-process_presence_out(#presence{type = Type} = Pres, _C2SState,
-                    #jid{luser = LUser, lserver = LServer} = From,
-                    #jid{luser = LUser, lserver = LServer, lresource = <<"">>})
+-spec process_presence_out({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+process_presence_out({#presence{
+                        from = #jid{luser = LUser, lserver = LServer} = From,
+                        to = #jid{luser = LUser, lserver = LServer, lresource = <<"">>},
+                        type = Type} = Pres, C2SState})
   when Type == available; Type == unavailable ->
     %% Self-presence processing
     Permissions = get_permissions(LServer),
@@ -151,15 +152,15 @@ process_presence_out(#presence{type = Type} = Pres, _C2SState,
                      ok
              end
       end, dict:to_list(Permissions)),
-    Pres;
-process_presence_out(Acc, _, _, _) ->
+    {Pres, C2SState};
+process_presence_out(Acc) ->
     Acc.
 
--spec process_presence_in(stanza(), ejabberd_c2s:state(),
-                         jid(), jid(), jid()) -> stanza().
-process_presence_in(#presence{type = Type} = Pres, _C2SState, _,
-                   #jid{luser = U, lserver = S} = From,
-                   #jid{luser = LUser, lserver = LServer})
+-spec process_presence_in({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+process_presence_in({#presence{
+                       from = #jid{luser = U, lserver = S} = From,
+                       to = #jid{luser = LUser, lserver = LServer},
+                       type = Type} = Pres, C2SState})
   when {U, S} /= {LUser, LServer} andalso
        (Type == available orelse Type == unavailable) ->
     Permissions = get_permissions(LServer),
@@ -179,8 +180,8 @@ process_presence_in(#presence{type = Type} = Pres, _C2SState, _,
                      ok
              end
       end, dict:to_list(Permissions)),
-    Pres;
-process_presence_in(Acc, _, _, _, _) ->
+    {Pres, C2SState};
+process_presence_in(Acc) ->
     Acc.
 
 %%%===================================================================
index 8819e3a994d5134daa4b6edc79677ac1207ce1e8..d631b0ad052e14a9e8deb478d5bb368d79eecc53 100644 (file)
@@ -272,7 +272,6 @@ init([ServerHost, Opts]) ->
     ejabberd_mnesia:create(?MODULE, pubsub_last_item,
        [{ram_copies, [node()]},
            {attributes, record_info(fields, pubsub_last_item)}]),
-    mod_disco:register_feature(ServerHost, ?NS_PUBSUB),
     lists:foreach(
       fun(H) ->
              T = gen_mod:get_module_proc(H, config),
@@ -533,7 +532,7 @@ disco_local_features(Acc, _From, To, <<>>, _Lang) ->
        {result, I} -> I;
        _ -> []
     end,
-    {result, Feats ++ [feature(F) || F <- features(Host, <<>>)]};
+    {result, Feats ++ [?NS_PUBSUB|[feature(F) || F <- features(Host, <<>>)]]};
 disco_local_features(Acc, _From, _To, _Node, _Lang) ->
     Acc.
 
@@ -923,7 +922,6 @@ terminate(_Reason,
     gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_PUBSUB_OWNER),
     gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_VCARD),
     gen_iq_handler:remove_iq_handler(ejabberd_local, Host, ?NS_COMMANDS),
-    mod_disco:unregister_feature(ServerHost, ?NS_PUBSUB),
     case whereis(gen_mod:get_module_proc(ServerHost, ?LOOPNAME)) of
        undefined ->
            ?ERROR_MSG("~s process is dead, pubsub was broken", [?LOOPNAME]);
index 5c207f3a484e5dd94f7753246b49a84c5d4d4e53..085f502253619532e8c61f6c0a37151cc6123896 100644 (file)
@@ -43,9 +43,9 @@
 
 -export([start/2, stop/1, process_iq/1, export/1,
         import_info/0, process_local_iq/1, get_user_roster/2,
-        import/5, get_subscription_lists/3, get_roster/2,
-        import_start/2, import_stop/2, c2s_handle_info/2,
-        get_in_pending_subscriptions/3, in_subscription/6,
+        import/5, c2s_session_opened/1, get_roster/2,
+        import_start/2, import_stop/2, user_receive_packet/1,
+        c2s_self_presence/1, in_subscription/6,
         out_subscription/4, set_items/3, remove_user/2,
         get_jid_info/4, encode_item/1, webadmin_page/3,
         webadmin_user/4, get_versioning_feature/2,
@@ -94,24 +94,24 @@ start(Host, Opts) ->
                       ?MODULE, in_subscription, 50),
     ejabberd_hooks:add(roster_out_subscription, Host,
                       ?MODULE, out_subscription, 50),
-    ejabberd_hooks:add(roster_get_subscription_lists, Host,
-                      ?MODULE, get_subscription_lists, 50),
+    ejabberd_hooks:add(c2s_session_opened, Host, ?MODULE,
+                      c2s_session_opened, 50),
     ejabberd_hooks:add(roster_get_jid_info, Host, ?MODULE,
                       get_jid_info, 50),
     ejabberd_hooks:add(remove_user, Host, ?MODULE,
                       remove_user, 50),
     ejabberd_hooks:add(anonymous_purge_hook, Host, ?MODULE,
                       remove_user, 50),
-    ejabberd_hooks:add(resend_subscription_requests_hook,
-                      Host, ?MODULE, get_in_pending_subscriptions, 50),
+    ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE,
+                      c2s_self_presence, 50),
     ejabberd_hooks:add(c2s_post_auth_features, Host,
                       ?MODULE, get_versioning_feature, 50),
     ejabberd_hooks:add(webadmin_page_host, Host, ?MODULE,
                       webadmin_page, 50),
     ejabberd_hooks:add(webadmin_user, Host, ?MODULE,
                       webadmin_user, 50),
-    ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
-                      c2s_handle_info, 50),
+    ejabberd_hooks:add(user_receive_packet, Host, ?MODULE,
+                      user_receive_packet, 50),
     gen_iq_handler:add_iq_handler(ejabberd_sm, Host,
                                  ?NS_ROSTER, ?MODULE, process_iq, IQDisc).
 
@@ -122,24 +122,24 @@ stop(Host) ->
                          ?MODULE, in_subscription, 50),
     ejabberd_hooks:delete(roster_out_subscription, Host,
                          ?MODULE, out_subscription, 50),
-    ejabberd_hooks:delete(roster_get_subscription_lists,
-                         Host, ?MODULE, get_subscription_lists, 50),
+    ejabberd_hooks:delete(c2s_session_opened, Host, ?MODULE,
+                         c2s_session_opened, 50),
     ejabberd_hooks:delete(roster_get_jid_info, Host,
                          ?MODULE, get_jid_info, 50),
     ejabberd_hooks:delete(remove_user, Host, ?MODULE,
                          remove_user, 50),
     ejabberd_hooks:delete(anonymous_purge_hook, Host,
                          ?MODULE, remove_user, 50),
-    ejabberd_hooks:delete(resend_subscription_requests_hook,
-                         Host, ?MODULE, get_in_pending_subscriptions, 50),
+    ejabberd_hooks:delete(c2s_self_presence, Host, ?MODULE,
+                         c2s_self_presence, 50),
     ejabberd_hooks:delete(c2s_post_auth_features,
                          Host, ?MODULE, get_versioning_feature, 50),
     ejabberd_hooks:delete(webadmin_page_host, Host, ?MODULE,
                          webadmin_page, 50),
     ejabberd_hooks:delete(webadmin_user, Host, ?MODULE,
                          webadmin_user, 50),
-    ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
-                         c2s_handle_info, 50),
+    ejabberd_hooks:delete(user_receive_packet, Host, ?MODULE,
+                         user_receive_packet, 50),
     gen_iq_handler:remove_iq_handler(ejabberd_sm, Host,
                                     ?NS_ROSTER).
 
@@ -220,10 +220,16 @@ roster_version_on_db(Host) ->
 %% Returns a list that may contain an xmlelement with the XEP-237 feature if it's enabled.
 -spec get_versioning_feature([xmpp_element()], binary()) -> [xmpp_element()].
 get_versioning_feature(Acc, Host) ->
-    case roster_versioning_enabled(Host) of
-      true ->
-         [#rosterver_feature{}|Acc];
-      false -> []
+    case gen_mod:is_loaded(Host, ?MODULE) of
+       true ->
+           case roster_versioning_enabled(Host) of
+               true ->
+                   [#rosterver_feature{}|Acc];
+               false ->
+                   Acc
+           end;
+       false ->
+           Acc
     end.
 
 roster_version(LServer, LUser) ->
@@ -423,8 +429,6 @@ process_iq_set(#iq{from = From, to = To,
     end.
 
 push_item(User, Server, From, Item) ->
-    ejabberd_sm:route(jid:make(User, Server, <<"">>),
-                      {item, Item#roster.jid, Item#roster.subscription}),
     case roster_versioning_enabled(Server) of
       true ->
          push_item_version(Server, User, From, Item,
@@ -446,15 +450,12 @@ push_item(User, Server, Resource, From, Item,
              not_found -> undefined;
              _ -> RosterVersion
          end,
-    ResIQ = #iq{type = set,
-%% @doc Roster push, calculate and include the version attribute.
-%% TODO: don't push to those who didn't load roster
+    To = jid:make(User, Server, Resource),
+    ResIQ = #iq{type = set, from = From, to = To,
                id = <<"push", (randoms:get_string())/binary>>,
                sub_els = [#roster_query{ver = Ver,
                                         items = [encode_item(Item)]}]},
-    ejabberd_router:route(From,
-                         jid:make(User, Server, Resource),
-                         ResIQ).
+    ejabberd_router:route(From, To, xmpp:put_meta(ResIQ, roster_item, Item)).
 
 push_item_version(Server, User, From, Item,
                  RosterVersion) ->
@@ -464,19 +465,19 @@ push_item_version(Server, User, From, Item,
                  end,
                  ejabberd_sm:get_user_resources(User, Server)).
 
-c2s_handle_info(State, {item, JID, Sub}) ->
-    {stop, roster_change(State, JID, Sub)};
-c2s_handle_info(State, _) ->
-    State.
+-spec user_receive_packet({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+user_receive_packet({#iq{type = set, meta = #{roster_item := Item}} = IQ, State}) ->
+    {IQ, roster_change(State, Item)};
+user_receive_packet(Acc) ->
+    Acc.
 
--spec roster_change(ejabberd_c2s:state(), jid(), subscription()) -> ejabberd_c2s:state().
-roster_change(#{user := U, server := S, resource := R} = State,
-             IJID, ISubscription) ->
+-spec roster_change(ejabberd_c2s:state(), #roster{}) -> ejabberd_c2s:state().
+roster_change(#{user := U, server := S, resource := R,
+               pres_a := PresA, pres_f := PresF, pres_t := PresT} = State,
+             #roster{jid = IJID, subscription = ISubscription}) ->
     LIJID = jid:tolower(IJID),
     IsFrom = (ISubscription == both) or (ISubscription == from),
     IsTo = (ISubscription == both) or (ISubscription == to),
-    PresF = maps:get(pres_f, State, ?SETS:new()),
-    PresT = maps:get(pres_t, State, ?SETS:new()),
     OldIsFrom = ?SETS:is_element(LIJID, PresF),
     FSet = if IsFrom -> ?SETS:add_element(LIJID, PresF);
              true -> ?SETS:del_element(LIJID, PresF)
@@ -490,7 +491,6 @@ roster_change(#{user := U, server := S, resource := R} = State,
            State1;
        LastPres ->
            From = jid:make(U, S, R),
-           PresA = maps:get(pres_a, State1, ?SETS:new()),
            To = jid:make(IJID),
            Cond1 = IsFrom andalso not OldIsFrom,
            Cond2 = not IsFrom andalso OldIsFrom andalso
@@ -507,7 +507,7 @@ roster_change(#{user := U, server := S, resource := R} = State,
                    end,
                    A = ?SETS:add_element(LIJID, PresA),
                    State1#{pres_a => A};
-            Cond2 ->
+              Cond2 ->
                    PU = #presence{from = From, to = To, type = unavailable},
                    case ejabberd_hooks:run_fold(
                           privacy_check_packet, allow,
@@ -524,26 +524,29 @@ roster_change(#{user := U, server := S, resource := R} = State,
            end
     end.
 
--spec get_subscription_lists({[ljid()], [ljid()]}, binary(), binary())
-      -> {[ljid()], [ljid()]}.
-get_subscription_lists(_Acc, User, Server) ->
-    LUser = jid:nodeprep(User),
-    LServer = jid:nameprep(Server),
+-spec c2s_session_opened(ejabberd_c2s:state()) -> ejabberd_c2s:state().
+c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer} = JID,
+                    pres_f := PresF, pres_t := PresT} = State) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     Items = Mod:get_only_items(LUser, LServer),
-    fill_subscription_lists(LServer, Items, [], []).
+    {F, T} = fill_subscription_lists(Items, PresF, PresT),
+    LJID = jid:tolower(jid:remove_resource(JID)),
+    State#{pres_f => ?SETS:add(LJID, F), pres_t => ?SETS:add(LJID, T)}.
 
-fill_subscription_lists(LServer, [I | Is], F, T) ->
+fill_subscription_lists([I | Is], F, T) ->
     J = element(3, I#roster.usj),
-    case I#roster.subscription of
-       both ->
-           fill_subscription_lists(LServer, Is, [J | F], [J | T]);
-       from ->
-           fill_subscription_lists(LServer, Is, [J | F], T);
-       to -> fill_subscription_lists(LServer, Is, F, [J | T]);
-       _ -> fill_subscription_lists(LServer, Is, F, T)
-    end;
-fill_subscription_lists(_LServer, [], F, T) ->
+    {F1, T1} = case I#roster.subscription of
+                  both ->
+                      {?SETS:add_element(J, F), ?SETS:add_element(J, T)};
+                  from ->
+                      {?SETS:add_element(J, F), T};
+                  to ->
+                      {F, ?SETS:add_element(J, T)};
+                  _ ->
+                      {F, T}
+              end,
+    fill_subscription_lists(Is, F1, T1);
+fill_subscription_lists([], F, T) ->
     {F, T}.
 
 ask_to_pending(subscribe) -> out;
@@ -836,27 +839,47 @@ process_item_set_t(LUser, LServer, #roster_item{jid = JID1} = QueryItem) ->
     end;
 process_item_set_t(_LUser, _LServer, _) -> ok.
 
--spec get_in_pending_subscriptions([presence()], binary(), binary()) -> [presence()].
-get_in_pending_subscriptions(Ls, User, Server) ->
-    LServer = jid:nameprep(Server),
-    Mod = gen_mod:db_mod(LServer, ?MODULE),
-    get_in_pending_subscriptions(Ls, User, Server, Mod).
+-spec c2s_self_presence({presence(), ejabberd_c2s:state()})
+      -> {presence(), ejabberd_c2s:state()}.
+c2s_self_presence({_, #{pres_last := _}} = Acc) ->
+    Acc;
+c2s_self_presence({#presence{type = available} = Pkt,
+                  #{lserver := LServer} = State}) ->
+    Prio = get_priority_from_presence(Pkt),
+    if Prio >= 0 ->
+           Mod = gen_mod:db_mod(LServer, ?MODULE),
+           State1 = resend_pending_subscriptions(State, Mod),
+           {Pkt, State1};
+       true ->
+           {Pkt, State}
+    end;
+c2s_self_presence(Acc) ->
+    Acc.
 
-get_in_pending_subscriptions(Ls, User, Server, Mod) ->
-    JID = jid:make(User, Server, <<"">>),
+-spec resend_pending_subscriptions(ejabberd_c2s:state(), module()) -> ejabberd_c2s:state().
+resend_pending_subscriptions(#{jid := JID} = State, Mod) ->
+    BareJID = jid:remove_resource(JID),
     Result = Mod:get_only_items(JID#jid.luser, JID#jid.lserver),
-    Ls ++ lists:flatmap(
-           fun(#roster{ask = Ask} = R) when Ask == in; Ask == both ->
-                   Message = R#roster.askmessage,
-                   Status = if is_binary(Message) -> (Message);
-                               true -> <<"">>
-                            end,
-                   [#presence{from = R#roster.jid, to = JID,
-                              type = subscribe,
-                              status = xmpp:mk_text(Status)}];
-              (_) ->
-                   []
-           end, Result).
+    lists:foldl(
+      fun(#roster{ask = Ask} = R, AccState) when Ask == in; Ask == both ->
+             Message = R#roster.askmessage,
+             Status = if is_binary(Message) -> (Message);
+                         true -> <<"">>
+                      end,
+             Sub = #presence{from = R#roster.jid, to = BareJID,
+                             type = subscribe,
+                             status = xmpp:mk_text(Status)},
+             ejabberd_c2s:send(AccState, Sub);
+        (_, AccState) ->
+             AccState
+      end, State, Result).
+
+-spec get_priority_from_presence(presence()) -> integer().
+get_priority_from_presence(#presence{priority = Prio}) ->
+    case Prio of
+       undefined -> 0;
+       _ -> Prio
+    end.
 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 
index d0d78a30c272b9cf07b5d7a07e4b003e58d026ce..7e952f5766c05eef4820bb20d757ab7282b5ce5c 100644 (file)
@@ -131,13 +131,14 @@ s2s_out_auth_result(#{db_verify := _} = State, _) ->
     %% in section 2.1.2, step 2
     {stop, send_verify_request(State)};
 s2s_out_auth_result(#{db_enabled := true,
+                     sockmod := SockMod,
                      socket := Socket, ip := IP,
                      server := LServer,
-                     remote_server := RServer} = State, false) ->
+                     remote_server := RServer} = State, {false, _}) ->
     %% SASL authentication has failed, retrying with dialback
     %% Sending dialback request, section 2.1.1, step 1
     ?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)",
-             [ejabberd_socket:pp(Socket), LServer, RServer,
+             [SockMod:pp(Socket), LServer, RServer,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
     State1 = maps:remove(stop_reason, State#{on_route => queue}),
     {stop, send_db_request(State1)};
@@ -150,6 +151,7 @@ s2s_out_downgraded(#{db_verify := _} = State, _) ->
     %% section 2.1.2, step 2
     {stop, send_verify_request(State)};
 s2s_out_downgraded(#{db_enabled := true,
+                    sockmod := SockMod,
                     socket := Socket, ip := IP,
                     server := LServer,
                     remote_server := RServer} = State, _) ->
@@ -157,7 +159,7 @@ s2s_out_downgraded(#{db_enabled := true,
     %% section 2.1.1, step 1
     ?INFO_MSG("(~s) Trying s2s dialback authentication with "
              "non-RFC compliant server: ~s -> ~s (~s)",
-             [ejabberd_socket:pp(Socket), LServer, RServer,
+             [SockMod:pp(Socket), LServer, RServer,
               ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
     {stop, send_db_request(State)};
 s2s_out_downgraded(State, _) ->
index ea7768bcae339e1125689f3964e3a96df886b5e6..f27c4d0d804d1072732d6df9935fb0354cf9eddc 100644 (file)
@@ -29,8 +29,8 @@
 
 -behaviour(gen_mod).
 
--export([start/2, stop/1, log_user_send/4,
-        log_user_receive/5, mod_opt_type/1, depends/2]).
+-export([start/2, stop/1, log_user_send/1,
+        log_user_receive/1, mod_opt_type/1, depends/2]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
@@ -54,15 +54,19 @@ stop(Host) ->
 depends(_Host, _Opts) ->
     [].
 
--spec log_user_send(stanza(), ejabberd_c2s:state(), jid(), jid()) -> stanza().
-log_user_send(Packet, _C2SState, From, To) ->
+-spec log_user_send({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+log_user_send({Packet, C2SState}) ->
+    From = xmpp:get_from(Packet),
+    To = xmpp:get_to(Packet),
     log_packet(From, To, Packet, From#jid.lserver),
-    Packet.
+    {Packet, C2SState}.
 
--spec log_user_receive(stanza(), ejabberd_c2s:state(), jid(), jid(), jid()) -> stanza().
-log_user_receive(Packet, _C2SState, _JID, From, To) ->
+-spec log_user_receive({stanza(), ejabberd_c2s:state()}) -> {stanza(), ejabberd_c2s:state()}.
+log_user_receive({Packet, C2SState}) ->
+    From = xmpp:get_from(Packet),
+    To = xmpp:get_to(Packet),
     log_packet(From, To, Packet, To#jid.lserver),
-    Packet.
+    {Packet, C2SState}.
 
 -spec log_packet(jid(), jid(), stanza(), binary()) -> ok.
 log_packet(From, To, Packet, Host) ->
index e91f7481afe2b4c7f9cc90802ac079a6b55a4e1c..e7510936f125c1e58824f69cdd5852b196189a17 100644 (file)
@@ -31,9 +31,9 @@
 
 -export([start/2, stop/1, export/1,
         import_info/0, webadmin_menu/3, webadmin_page/3,
-        get_user_roster/2, get_subscription_lists/3,
+        get_user_roster/2, c2s_session_opened/1,
         get_jid_info/4, import/5, process_item/2, import_start/2,
-        in_subscription/6, out_subscription/4, user_available/1,
+        in_subscription/6, out_subscription/4, c2s_self_presence/1,
         unset_presence/4, register_user/2, remove_user/2,
         list_groups/1, create_group/2, create_group/3,
         delete_group/2, get_group_opts/2, set_group_opts/3,
@@ -54,6 +54,8 @@
 
 -include("mod_shared_roster.hrl").
 
+-define(SETS, gb_sets).
+
 -type group_options() :: [{atom(), any()}].
 -callback init(binary(), gen_mod:opts()) -> any().
 -callback import(binary(), binary(), [binary()]) -> ok.
@@ -84,14 +86,14 @@ start(Host, Opts) ->
                       ?MODULE, in_subscription, 30),
     ejabberd_hooks:add(roster_out_subscription, Host,
                       ?MODULE, out_subscription, 30),
-    ejabberd_hooks:add(roster_get_subscription_lists, Host,
-                      ?MODULE, get_subscription_lists, 70),
+    ejabberd_hooks:add(c2s_session_opened, Host,
+                      ?MODULE, c2s_session_opened, 70),
     ejabberd_hooks:add(roster_get_jid_info, Host, ?MODULE,
                       get_jid_info, 70),
     ejabberd_hooks:add(roster_process_item, Host, ?MODULE,
                       process_item, 50),
-    ejabberd_hooks:add(user_available_hook, Host, ?MODULE,
-                      user_available, 50),
+    ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE,
+                      c2s_self_presence, 50),
     ejabberd_hooks:add(unset_presence_hook, Host, ?MODULE,
                       unset_presence, 50),
     ejabberd_hooks:add(register_user, Host, ?MODULE,
@@ -112,14 +114,14 @@ stop(Host) ->
                          ?MODULE, in_subscription, 30),
     ejabberd_hooks:delete(roster_out_subscription, Host,
                          ?MODULE, out_subscription, 30),
-    ejabberd_hooks:delete(roster_get_subscription_lists,
-                         Host, ?MODULE, get_subscription_lists, 70),
+    ejabberd_hooks:delete(c2s_session_opened,
+                         Host, ?MODULE, c2s_session_opened, 70),
     ejabberd_hooks:delete(roster_get_jid_info, Host,
                          ?MODULE, get_jid_info, 70),
     ejabberd_hooks:delete(roster_process_item, Host,
                          ?MODULE, process_item, 50),
-    ejabberd_hooks:delete(user_available_hook, Host,
-                         ?MODULE, user_available, 50),
+    ejabberd_hooks:delete(c2s_self_presence, Host,
+                         ?MODULE, c2s_self_presence, 50),
     ejabberd_hooks:delete(unset_presence_hook, Host,
                          ?MODULE, unset_presence, 50),
     ejabberd_hooks:delete(register_user, Host, ?MODULE,
@@ -294,19 +296,21 @@ set_item(User, Server, Resource, Item) ->
                          jid:make(Server),
                          ResIQ).
 
--spec get_subscription_lists({[ljid()], [ljid()]}, binary(), binary())
-      -> {[ljid()], [ljid()]}.
-get_subscription_lists({F, T}, User, Server) ->
-    LUser = jid:nodeprep(User),
-    LServer = jid:nameprep(Server),
+c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer} = JID,
+                    pres_f := PresF, pres_t := PresT} = State) ->
     US = {LUser, LServer},
     DisplayedGroups = get_user_displayed_groups(US),
-    SRUsers = lists:usort(lists:flatmap(fun (Group) ->
-                                               get_group_users(LServer, Group)
-                                       end,
-                                       DisplayedGroups)),
-    SRJIDs = [{U1, S1, <<"">>} || {U1, S1} <- SRUsers],
-    {lists:usort(SRJIDs ++ F), lists:usort(SRJIDs ++ T)}.
+    SRUsers = lists:flatmap(fun(Group) ->
+                                   get_group_users(LServer, Group)
+                           end,
+                           DisplayedGroups),
+    BareLJID = jid:tolower(jid:remove_resource(JID)),
+    PresBoth = lists:foldl(
+                fun({U, S}, Acc) ->
+                        ?SETS:add_element({U, S, <<"">>}, Acc)
+                end, ?SETS:new(), [BareLJID|SRUsers]),
+    State#{pres_f => ?SETS:union(PresBoth, PresF),
+          pres_t => ?SETS:union(PresBoth, PresT)}.
 
 -spec get_jid_info({subscription(), [binary()]}, binary(), binary(), jid())
       -> {subscription(), [binary()]}.
@@ -739,12 +743,15 @@ push_roster_item(User, Server, ContactU, ContactS,
                   groups = [GroupName]},
     push_item(User, Server, Item).
 
--spec user_available(jid()) -> ok.
-user_available(New) ->
+-spec c2s_self_presence({presence(), ejabberd_c2s:state()})
+      -> {presence(), ejabberd_c2s:state()}.
+c2s_self_presence({_, #{pres_last := _}} = Acc) ->
+    %% This is just a presence update, nothing to do
+    Acc;
+c2s_self_presence({#presence{type = available}, #{jid := New}} = Acc) ->
     LUser = New#jid.luser,
     LServer = New#jid.lserver,
-    Resources = ejabberd_sm:get_user_resources(LUser,
-                                              LServer),
+    Resources = ejabberd_sm:get_user_resources(LUser, LServer),
     ?DEBUG("user_available for ~p @ ~p (~p resources)",
           [LUser, LServer, length(Resources)]),
     case length(Resources) of
@@ -761,7 +768,10 @@ user_available(New) ->
                        end,
                        UserGroups);
       _ -> ok
-    end.
+    end,
+    Acc;
+c2s_self_presence(Acc) ->
+    Acc.
 
 -spec unset_presence(binary(), binary(), binary(), binary()) -> ok.
 unset_presence(LUser, LServer, Resource, Status) ->
index 97ead9f3db2479b9972ca7bbb0d58a75aacd7dcb..777854b8ef23aef99806ca34429103d283b6b734 100644 (file)
@@ -39,7 +39,7 @@
 -export([init/1, handle_call/3, handle_cast/2,
         handle_info/2, terminate/2, code_change/3]).
 
--export([get_user_roster/2, get_subscription_lists/3,
+-export([get_user_roster/2, c2s_session_opened/1,
         get_jid_info/4, process_item/2, in_subscription/6,
         out_subscription/4, mod_opt_type/1, opt_type/1, depends/2]).
 
@@ -49,6 +49,7 @@
 -include("mod_roster.hrl").
 -include("eldap.hrl").
 
+-define(SETS, gb_sets).
 -define(CACHE_SIZE, 1000).
 -define(USER_CACHE_VALIDITY, 300).  %% in seconds
 -define(GROUP_CACHE_VALIDITY, 300).
@@ -160,19 +161,21 @@ process_item(RosterItem, _Host) ->
       _ -> RosterItem#roster{subscription = both, ask = none}
     end.
 
--spec get_subscription_lists({[ljid()], [ljid()]}, binary(), binary())
-      -> {[ljid()], [ljid()]}.
-get_subscription_lists({F, T}, User, Server) ->
-    LUser = jid:nodeprep(User),
-    LServer = jid:nameprep(Server),
+c2s_session_opened(#{jid := #jid{luser = LUser, lserver = LServer} = JID,
+                    pres_f := PresF, pres_t := PresT} = State) ->
     US = {LUser, LServer},
     DisplayedGroups = get_user_displayed_groups(US),
-    SRUsers = lists:usort(lists:flatmap(fun (Group) ->
-                                               get_group_users(LServer, Group)
-                                       end,
-                                       DisplayedGroups)),
-    SRJIDs = [{U1, S1, <<"">>} || {U1, S1} <- SRUsers],
-    {lists:usort(SRJIDs ++ F), lists:usort(SRJIDs ++ T)}.
+    SRUsers = lists:flatmap(fun(Group) ->
+                                   get_group_users(LServer, Group)
+                           end,
+                           DisplayedGroups),
+    BareLJID = jid:tolower(jid:remove_resource(JID)),
+    PresBoth = lists:foldl(
+                fun({U, S}, Acc) ->
+                        ?SETS:add_element({U, S, <<"">>}, Acc)
+                end, ?SETS:new(), [BareLJID|SRUsers]),
+    State#{pres_f => ?SETS:union(PresBoth, PresF),
+          pres_t => ?SETS:union(PresBoth, PresT)}.
 
 -spec get_jid_info({subscription(), [binary()]}, binary(), binary(), jid())
       -> {subscription(), [binary()]}.
@@ -246,8 +249,8 @@ init([Host, Opts]) ->
                       ?MODULE, in_subscription, 30),
     ejabberd_hooks:add(roster_out_subscription, Host,
                       ?MODULE, out_subscription, 30),
-    ejabberd_hooks:add(roster_get_subscription_lists, Host,
-                      ?MODULE, get_subscription_lists, 70),
+    ejabberd_hooks:add(c2s_session_opened, Host,
+                      ?MODULE, c2s_session_opened, 70),
     ejabberd_hooks:add(roster_get_jid_info, Host, ?MODULE,
                       get_jid_info, 70),
     ejabberd_hooks:add(roster_process_item, Host, ?MODULE,
@@ -275,8 +278,8 @@ terminate(_Reason, State) ->
                          ?MODULE, in_subscription, 30),
     ejabberd_hooks:delete(roster_out_subscription, Host,
                          ?MODULE, out_subscription, 30),
-    ejabberd_hooks:delete(roster_get_subscription_lists,
-                         Host, ?MODULE, get_subscription_lists, 70),
+    ejabberd_hooks:delete(c2s_session_opened,
+                         Host, ?MODULE, c2s_session_opened, 70),
     ejabberd_hooks:delete(roster_get_jid_info, Host,
                          ?MODULE, get_jid_info, 70),
     ejabberd_hooks:delete(roster_process_item, Host,
index 7e64e6a000360b477aa21b23f581aa44abfc2b48..0382c60a967e9d3391492849d2ed5fc108c68671 100644 (file)
@@ -31,8 +31,8 @@
 -export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2,
         c2s_authenticated_packet/2, c2s_unauthenticated_packet/2,
         c2s_unbinded_packet/2, c2s_closed/2, c2s_terminated/2,
-        c2s_handle_send/3, c2s_filter_send/1, c2s_handle_info/2,
-        c2s_handle_call/3, c2s_handle_recv/3]).
+        c2s_handle_send/3, c2s_handle_info/2, c2s_handle_call/3,
+        c2s_handle_recv/3]).
 
 -include("xmpp.hrl").
 -include("logger.hrl").
@@ -63,7 +63,6 @@ start(Host, _Opts) ->
                       c2s_authenticated_packet, 50),
     ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50),
     ejabberd_hooks:add(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50),
-    ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50),
     ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50),
     ejabberd_hooks:add(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50),
     ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50),
@@ -83,7 +82,6 @@ stop(Host) ->
                          c2s_authenticated_packet, 50),
     ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50),
     ejabberd_hooks:delete(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50),
-    ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50),
     ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50),
     ejabberd_hooks:delete(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50),
     ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50),
@@ -179,21 +177,33 @@ c2s_handle_recv(#{lang := Lang} = State, El, {error, Why}) ->
 c2s_handle_recv(State, _, _) ->
     State.
 
-c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, _Result)
+c2s_handle_send(#{mgmt_state := MgmtState,
+                 lang := Lang} = State, Pkt, SendResult)
   when MgmtState == pending; MgmtState == active ->
-    State1 = mgmt_queue_add(State, Pkt),
     case xmpp:is_stanza(Pkt) of
        true ->
-           send_rack(State1);
+           case mgmt_queue_add(State, Pkt) of
+               #{mgmt_max_queue := exceeded} = State1 ->
+                   State2 = State1#{mgmt_resend => false},
+                   case MgmtState of
+                       active ->
+                           Err = xmpp:serr_policy_violation(
+                                   <<"Too many unacked stanzas">>, Lang),
+                           send(State2, Err);
+                       _ ->
+                           ejabberd_c2s:stop(State2)
+                   end;
+               State1 when SendResult == ok ->
+                   send_rack(State1);
+               State1 ->
+                   State1
+           end;
        false ->
-           State1
+           State
     end;
 c2s_handle_send(State, _Pkt, _Result) ->
     State.
 
-c2s_filter_send({Pkt, State}) ->
-    {Pkt, State}.
-
 c2s_handle_call(#{sid := {Time, _}} = State,
                {resume_session, Time}, From) ->
     ejabberd_c2s:reply(From, {resume, State}),
@@ -216,6 +226,13 @@ c2s_handle_info(#{mgmt_state := pending, jid := JID} = State,
     ?DEBUG("Timed out waiting for resumption of stream for ~s",
           [jid:to_string(JID)]),
     ejabberd_c2s:stop(State#{mgmt_state => timeout});
+c2s_handle_info(#{jid := JID} = State, {_Ref, {resume, OldState}}) ->
+    %% This happens if the resume_session/1 request timed out; the new session
+    %% now receives the late response.
+    ?DEBUG("Received old session state for ~s after failed resumption",
+          [jid:to_string(JID)]),
+    route_unacked_stanzas(OldState#{mgmt_resend => false}),
+    State;
 c2s_handle_info(State, _) ->
     State.
 
@@ -325,7 +342,7 @@ handle_a(State, #sm_a{h = H}) ->
     resend_rack(State1).
 
 -spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}.
-handle_resume(#{user := User, lserver := LServer,
+handle_resume(#{user := User, lserver := LServer, sockmod := SockMod,
                lang := Lang, socket := Socket} = State,
              #sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) ->
     R = case inherit_session_state(State, PrevID) of
@@ -354,7 +371,7 @@ handle_resume(#{user := User, lserver := LServer,
            %% csi_flush_queue(State4),
            State5 = ejabberd_hooks:run_fold(c2s_session_resumed, LServer, State4, []),
            ?INFO_MSG("(~s) Resumed session for ~s",
-                     [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
+                     [SockMod:pp(Socket), jid:to_string(JID)]),
            {ok, State5};
        {error, El, Msg} ->
            ?INFO_MSG("Cannot resume session for ~s@~s: ~s",
@@ -363,6 +380,8 @@ handle_resume(#{user := User, lserver := LServer,
     end.
 
 -spec transition_to_pending(state()) -> state().
+transition_to_pending(#{mgmt_state := active, mgmt_timeout := 0} = State) ->
+    ejabberd_c2s:stop(State);
 transition_to_pending(#{mgmt_state := active, jid := JID,
                        lserver := LServer, mgmt_timeout := Timeout} = State) ->
     State1 = cancel_ack_timer(State),
@@ -405,9 +424,9 @@ send_rack(#{mgmt_ack_timer := _} = State) ->
 send_rack(#{mgmt_xmlns := Xmlns,
            mgmt_stanzas_out := NumStanzasOut,
            mgmt_ack_timeout := AckTimeout} = State) ->
-    State1 = send(State, #sm_r{xmlns = Xmlns}),
     TRef = erlang:start_timer(AckTimeout, self(), ack_timeout),
-    State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}.
+    State1 = State#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut},
+    send(State1, #sm_r{xmlns = Xmlns}).
 
 resend_rack(#{mgmt_ack_timer := _,
              mgmt_queue := Queue,
@@ -424,18 +443,13 @@ resend_rack(State) ->
 -spec mgmt_queue_add(state(), xmpp_element()) -> state().
 mgmt_queue_add(#{mgmt_stanzas_out := NumStanzasOut,
                 mgmt_queue := Queue} = State, Pkt) ->
-    case xmpp:is_stanza(Pkt) of
-       true ->
-           NewNum = case NumStanzasOut of
-                        4294967295 -> 0;
-                        Num -> Num + 1
-                    end,
-           Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Pkt}, Queue),
-           State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum},
-           check_queue_length(State1);
-       false ->
-           State
-    end.
+    NewNum = case NumStanzasOut of
+                4294967295 -> 0;
+                Num -> Num + 1
+            end,
+    Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Pkt}, Queue),
+    State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum},
+    check_queue_length(State1).
 
 -spec mgmt_queue_drop(state(), non_neg_integer()) -> state().
 mgmt_queue_drop(#{mgmt_queue := Queue} = State, NumHandled) ->
@@ -510,20 +524,24 @@ route_unacked_stanzas(#{mgmt_state := MgmtState,
              %% easily lead to unexpected results as well.
              ?DEBUG("Dropping forwarded message stanza from ~s",
                     [jid:to_string(From)]);
-        ({_, Time, El}) ->
+        ({_, Time, #message{} = Msg}) ->
              case ejabberd_hooks:run_fold(message_is_archived,
                                           LServer, false,
-                                          [State, El]) of
+                                          [State, Msg]) of
                  true ->
                      ?DEBUG("Dropping archived message stanza from ~s",
-                            [jid:to_string(xmpp:get_from(El))]);
+                            [jid:to_string(xmpp:get_from(Msg))]);
                  false when ResendOnTimeout ->
-                     NewEl = add_resent_delay_info(State, El, Time),
+                     NewEl = add_resent_delay_info(State, Msg, Time),
                      route(NewEl);
                  false ->
                      Txt = <<"User session terminated">>,
-                     route_error(El, xmpp:err_service_unavailable(Txt, Lang))
-             end
+                     route_error(Msg, xmpp:err_service_unavailable(Txt, Lang))
+             end;
+        ({_, _Time, El}) ->
+             %% Raw element of type 'error' resulting from a validation error
+             %% We cannot pass it to the router, it will generate an error
+             ?DEBUG("Do not route raw element from ack queue: ~p", [El])
       end, Queue);
 route_unacked_stanzas(_State) ->
     ok.
@@ -587,11 +605,13 @@ resume_session({Time, Pid}, _State) ->
 make_resume_id(#{sid := {Time, _}, resource := Resource}) ->
     jlib:term_to_base64({Resource, Time}).
 
--spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza().
-add_resent_delay_info(_State, #iq{} = El, _Time) ->
-    El;
-add_resent_delay_info(#{lserver := LServer}, El, Time) ->
-    xmpp_util:add_delay_info(El, jid:make(LServer), Time, <<"Resent">>).
+-spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza();
+                          (state(), xmlel(), erlang:timestamp()) -> xmlel().
+add_resent_delay_info(#{lserver := LServer}, El, Time)
+  when is_record(El, message); is_record(El, presence) ->
+    xmpp_util:add_delay_info(El, jid:make(LServer), Time, <<"Resent">>);
+add_resent_delay_info(_State, El, _Time) ->
+    El.
 
 -spec route(stanza()) -> ok.
 route(Pkt) ->
index 4d1dfa2fc5946b5e64e083b473659d1af94845e3..900758e398732fa3a32f43e061d8bfb4c437a0f0 100644 (file)
@@ -12,7 +12,7 @@
 %% gen_mod callbacks
 -export([start/2, stop/1]).
 
--export([update_presence/3, vcard_set/3, export/1,
+-export([update_presence/1, vcard_set/3, export/1,
         import_info/0, import/5, import_start/2,
         mod_opt_type/1, depends/2]).
 
 start(Host, Opts) ->
     Mod = gen_mod:db_mod(Host, Opts, ?MODULE),
     Mod:init(Host, Opts),
-    ejabberd_hooks:add(c2s_update_presence, Host, ?MODULE,
+    ejabberd_hooks:add(c2s_self_presence, Host, ?MODULE,
                       update_presence, 100),
     ejabberd_hooks:add(vcard_set, Host, ?MODULE, vcard_set,
                       100),
     ok.
 
 stop(Host) ->
-    ejabberd_hooks:delete(c2s_update_presence, Host,
+    ejabberd_hooks:delete(c2s_self_presence, Host,
                          ?MODULE, update_presence, 100),
     ejabberd_hooks:delete(vcard_set, Host, ?MODULE,
                          vcard_set, 100),
@@ -52,10 +52,15 @@ depends(_Host, _Opts) ->
 %%====================================================================
 %% Hooks
 %%====================================================================
--spec update_presence(presence(), binary(), binary()) -> presence().
-update_presence(#presence{type = available} = Packet, User, Host) ->
-    presence_with_xupdate(Packet, User, Host);
-update_presence(Packet, _User, _Host) -> Packet.
+-spec update_presence({presence(), ejabberd_c2s:state()})
+      -> {presence(), ejabberd_c2s:state()}.
+update_presence({#presence{type = available} = Pres,
+                #{jid := #jid{luser = LUser, lserver = LServer}} = State}) ->
+    Hash = get_xupdate(LUser, LServer),
+    Pres1 = xmpp:set_subtag(Pres, #vcard_xupdate{hash = Hash}),
+    {Pres1, State};
+update_presence(Acc) ->
+    Acc.
 
 -spec vcard_set(binary(), binary(), xmlel()) -> ok.
 vcard_set(LUser, LServer, VCARD) ->
@@ -86,15 +91,6 @@ remove_xupdate(LUser, LServer) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     Mod:remove_xupdate(LUser, LServer).
 
-%%%----------------------------------------------------------------------
-%%% Presence stanza rebuilding
-%%%----------------------------------------------------------------------
-
-presence_with_xupdate(Presence, User, Host) ->
-    Hash = get_xupdate(User, Host),
-    Presence1 = xmpp:remove_subtag(Presence, #vcard_xupdate{}),
-    xmpp:set_subtag(Presence1, #vcard_xupdate{hash = Hash}).
-
 import_info() ->
     [{<<"vcard_xupdate">>, 3}].
 
@@ -110,5 +106,8 @@ export(LServer) ->
     Mod = gen_mod:db_mod(LServer, ?MODULE),
     Mod:export(LServer).
 
+%%====================================================================
+%% Options
+%%====================================================================
 mod_opt_type(db_type) -> fun(T) -> ejabberd_config:v_db(?MODULE, T) end;
 mod_opt_type(_) -> [db_type].
index 1ad78d45b5b0501a58813f6078cf51c570f97e75..b2b3b3072575177a3308de47b9add1aef0db2008 100644 (file)
 %%%
 %%%-------------------------------------------------------------------
 -module(xmpp_stream_in).
--behaviour(gen_server).
+-define(GEN_SERVER, gen_server).
+-behaviour(?GEN_SERVER).
 
 -protocol({rfc, 6120}).
+-protocol({xep, 114, '1.6'}).
 
 %% API
 -export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1,
 -include("xmpp.hrl").
 -type state() :: map().
 -type stop_reason() :: {stream, reset | {in | out, stream_error()}} |
-                      {tls, term()} |
+                      {tls, inet:posix() | atom() | binary()} |
                       {socket, inet:posix() | closed | timeout} |
                       internal_failure.
 
--callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
+-callback init(list()) -> {ok, state()} | {error, term()} | ignore.
 -callback handle_cast(term(), state()) -> state().
 -callback handle_call(term(), term(), state()) -> state().
 -callback handle_info(term(), state()) -> state().
 -callback terminate(term(), state()) -> any().
 -callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
--callback handle_stream_start(state()) -> state().
+-callback handle_stream_start(stream_start(), state()) -> state().
+-callback handle_stream_established(state()) -> state().
 -callback handle_stream_end(stop_reason(), state()) -> state().
 -callback handle_cdata(binary(), state()) -> state().
 -callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
@@ -63,6 +66,7 @@
 -callback handle_auth_failure(binary(), binary(), atom(), state()) -> state().
 -callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
 -callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
+-callback handle_timeout(state()) -> state().
 -callback get_password_fun(state()) -> fun().
 -callback check_password_fun(state()) -> fun().
 -callback check_password_digest_fun(state()) -> fun().
@@ -71,6 +75,8 @@
 -callback tls_options(state()) -> [proplists:property()].
 -callback tls_required(state()) -> boolean().
 -callback tls_verify(state()) -> boolean().
+-callback tls_enabled(state()) -> boolean().
+-callback sasl_mechanisms([cyrsasl:mechanism()], state()) -> [cyrsasl:mechanism()].
 -callback unauthenticated_stream_features(state()) -> [xmpp_element()].
 -callback authenticated_stream_features(state()) -> [xmpp_element()].
 
@@ -81,7 +87,8 @@
                     handle_info/2,
                     terminate/2,
                     code_change/3,
-                    handle_stream_start/1,
+                    handle_stream_start/2,
+                    handle_stream_established/1,
                     handle_stream_end/2,
                     handle_cdata/2,
                     handle_authenticated_packet/2,
@@ -91,6 +98,7 @@
                     handle_auth_failure/4,
                     handle_send/3,
                     handle_recv/3,
+                    handle_timeout/1,
                     get_password_fun/1,
                     check_password_fun/1,
                     check_password_digest_fun/1,
                     tls_options/1,
                     tls_required/1,
                     tls_verify/1,
+                    tls_enabled/1,
+                    sasl_mechanisms/2,
                     unauthenticated_stream_features/1,
                     authenticated_stream_features/1]).
 
 %%% API
 %%%===================================================================
 start(Mod, Args, Opts) ->
-    gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+    ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
 
 start_link(Mod, Args, Opts) ->
-    gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+    ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
 
 call(Ref, Msg, Timeout) ->
-    gen_server:call(Ref, Msg, Timeout).
+    ?GEN_SERVER:call(Ref, Msg, Timeout).
 
 cast(Ref, Msg) ->
-    gen_server:cast(Ref, Msg).
+    ?GEN_SERVER:cast(Ref, Msg).
 
 reply(Ref, Reply) ->
-    gen_server:reply(Ref, Reply).
+    ?GEN_SERVER:reply(Ref, Reply).
 
 -spec stop(pid()) -> ok;
          (state()) -> no_return().
@@ -135,7 +145,7 @@ stop(_) ->
 send(Pid, Pkt) when is_pid(Pid) ->
     cast(Pid, {send, Pkt});
 send(#{owner := Owner} = State, Pkt) when Owner == self() ->
-    send_element(State, Pkt);
+    send_pkt(State, Pkt);
 send(_, _) ->
     erlang:error(badarg).
 
@@ -193,7 +203,7 @@ format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
 format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
     format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
 format_error({tls, Reason}) ->
-    format("TLS failed: ~w", [Reason]);
+    format("TLS failed: ~s", [format_tls_error(Reason)]);
 format_error(internal_failure) ->
     <<"Internal server error">>;
 format_error(Err) ->
@@ -203,13 +213,9 @@ format_error(Err) ->
 %%% gen_server callbacks
 %%%===================================================================
 init([Module, {SockMod, Socket}, Opts]) ->
-    XMLSocket = case lists:keyfind(xml_socket, 1, Opts) of
-                   {_, XS} -> XS;
-                   false -> false
-               end,
     Encrypted = proplists:get_bool(tls, Opts),
     SocketMonitor = SockMod:monitor(Socket),
-    case peername(SockMod, Socket) of
+    case SockMod:peername(Socket) of
        {ok, IP} ->
            Time = p1_time_compat:monotonic_time(milli_seconds),
            State = #{owner => self(),
@@ -227,7 +233,6 @@ init([Module, {SockMod, Socket}, Opts]) ->
                      stream_encrypted => Encrypted,
                      stream_version => {1,0},
                      stream_authenticated => false,
-                     xml_socket => XMLSocket,
                      xmlns => ?NS_CLIENT,
                      lang => <<"">>,
                      user => <<"">>,
@@ -238,18 +243,32 @@ init([Module, {SockMod, Socket}, Opts]) ->
            case try Module:init([State, Opts])
                 catch _:undef -> {ok, State}
                 end of
-               {ok, State1} ->
+               {ok, State1} when not Encrypted ->
                    {_, State2, Timeout} = noreply(State1),
                    {ok, State2, Timeout};
-               Err ->
-                   Err
+               {ok, State1} when Encrypted ->
+                   TLSOpts = try Module:tls_options(State1)
+                             catch _:undef -> []
+                             end,
+                   case SockMod:starttls(Socket, TLSOpts) of
+                       {ok, TLSSocket} ->
+                           State2 = State1#{socket => TLSSocket},
+                           {_, State3, Timeout} = noreply(State2),
+                           {ok, State3, Timeout};
+                       {error, Reason} ->
+                           {stop, Reason}
+                   end;
+               {error, Reason} ->
+                   {stop, Reason};
+               ignore ->
+                   ignore
            end;
-       {error, Reason} ->
-           {stop, Reason}
+       {error, _Reason} ->
+           ignore
     end.
 
 handle_cast({send, Pkt}, State) ->
-    noreply(send_element(State, Pkt));
+    noreply(send_pkt(State, Pkt));
 handle_cast(stop, State) ->
     {stop, normal, State};
 handle_cast(Cast, #{mod := Mod} = State) ->
@@ -278,7 +297,7 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
              State1 = send_header(State),
              case is_disconnected(State1) of
                  true -> State1;
-                 false -> send_element(State1, xmpp:serr_invalid_xml())
+                 false -> send_pkt(State1, xmpp:serr_invalid_xml())
              end
       catch _:{xmpp_codec, Why} ->
              State1 = send_header(State),
@@ -288,7 +307,7 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
                      Txt = xmpp:io_format_error(Why),
                      Lang = select_lang(MyLang, xmpp:get_lang(El)),
                      Err = xmpp:serr_invalid_xml(Txt, Lang),
-                     send_element(State1, Err)
+                     send_pkt(State1, Err)
              end
       end);
 handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
@@ -303,7 +322,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
                        _ ->
                            xmpp:serr_not_well_formed()
                    end,
-             send_element(State1, Err)
+             send_pkt(State1, Err)
       end);
 handle_info({'$gen_event', {xmlstreamelement, El}},
            #{xmlns := NS, mod := Mod} = State) ->
@@ -339,7 +358,7 @@ handle_info(timeout, #{mod := Mod} = State) ->
     Disconnected = is_disconnected(State),
     noreply(try Mod:handle_timeout(State)
            catch _:undef when not Disconnected ->
-                   send_element(State, xmpp:serr_connection_timeout());
+                   send_pkt(State, xmpp:serr_connection_timeout());
                  _:undef ->
                    stop(State)
            end);
@@ -385,14 +404,6 @@ new_id() ->
 is_disconnected(#{stream_state := StreamState}) ->
     StreamState == disconnected.
 
--spec peername(term(), term()) -> {ok, {inet:ip_address(), inet:port_number()}}|
-                                 {error, inet:posix()}.
-peername(SockMod, Socket) ->
-    case SockMod of
-       gen_tcp -> inet:peername(Socket);
-       _ -> SockMod:peername(Socket)
-    end.
-
 -spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
 process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
     case xmpp:is_stanza(El) of
@@ -408,12 +419,12 @@ process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
                    Txt = xmpp:io_format_error(Reason),
                    Err = #sasl_failure{reason = 'malformed-request',
                                        text = xmpp:mk_text(Txt, MyLang)},
-                   send_element(State, Err);
+                   send_pkt(State, Err);
                {<<"starttls">>, ?NS_TLS} ->
-                   send_element(State, #starttls_failure{});
+                   send_pkt(State, #starttls_failure{});
                {<<"compress">>, ?NS_COMPRESS} ->
                    Err = #compress_failure{reason = 'setup-failed'},
-                   send_element(State, Err);
+                   send_pkt(State, Err);
                _ ->
                    %% Maybe add something more?
                    State
@@ -434,9 +445,9 @@ process_stream(#stream_start{xmlns = XML_NS,
                             stream_xmlns = STREAM_NS},
               #{xmlns := NS} = State)
   when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
-    send_element(State, xmpp:serr_invalid_namespace());
+    send_pkt(State, xmpp:serr_invalid_namespace());
 process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
-    send_element(State, xmpp:serr_unsupported_version());
+    send_pkt(State, xmpp:serr_unsupported_version());
 process_stream(#stream_start{lang = Lang},
               #{xmlns := ?NS_CLIENT, lang := DefaultLang} = State)
   when size(Lang) > 35 ->
@@ -445,14 +456,14 @@ process_stream(#stream_start{lang = Lang},
     %% language tags MUST allow for language tags of at least 35 characters.
     %% Do not store long language tag to avoid possible DoS/flood attacks
     Txt = <<"Too long value of 'xml:lang' attribute">>,
-    send_element(State, xmpp:serr_policy_violation(Txt, DefaultLang));
+    send_pkt(State, xmpp:serr_policy_violation(Txt, DefaultLang));
 process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) ->
     Txt = <<"Missing 'to' attribute">>,
-    send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
+    send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang));
 process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
               #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
     Txt = <<"Improper 'to' attribute">>,
-    send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
+    send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang));
 process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
               #{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
     State1 = State#{remote_server => RemoteServer,
@@ -509,29 +520,29 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
        #starttls{} ->
            process_starttls_failure(unexpected_starttls_request, State);
        #sasl_auth{} when StateName == wait_for_starttls ->
-           send_element(State, #sasl_failure{reason = 'encryption-required'});
+           send_pkt(State, #sasl_failure{reason = 'encryption-required'});
        #sasl_auth{} when StateName == wait_for_sasl_request ->
            process_sasl_request(Pkt, State);
        #sasl_auth{} ->
            Txt = <<"SASL negotiation is not allowed in this state">>,
-           send_element(State, #sasl_failure{reason = 'not-authorized',
+           send_pkt(State, #sasl_failure{reason = 'not-authorized',
                                              text = xmpp:mk_text(Txt, Lang)});
        #sasl_response{} when StateName == wait_for_starttls ->
-           send_element(State, #sasl_failure{reason = 'encryption-required'});
+           send_pkt(State, #sasl_failure{reason = 'encryption-required'});
        #sasl_response{} when StateName == wait_for_sasl_response ->
            process_sasl_response(Pkt, State);
        #sasl_response{} ->
            Txt = <<"SASL negotiation is not allowed in this state">>,
-           send_element(State, #sasl_failure{reason = 'not-authorized',
+           send_pkt(State, #sasl_failure{reason = 'not-authorized',
                                              text = xmpp:mk_text(Txt, Lang)});
        #sasl_abort{} when StateName == wait_for_sasl_response ->
            process_sasl_abort(State);
        #sasl_abort{} ->
-           send_element(State, #sasl_failure{reason = 'aborted'});
+           send_pkt(State, #sasl_failure{reason = 'aborted'});
        #sasl_success{} ->
            State;
        #compress{} when StateName == wait_for_sasl_response ->
-           send_element(State, #compress_failure{reason = 'setup-failed'});
+           send_pkt(State, #compress_failure{reason = 'setup-failed'});
        #compress{} ->
            process_compress(Pkt, State);
        #handshake{} when StateName == wait_for_handshake ->
@@ -570,7 +581,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
        {ok, #iq{type = set, sub_els = [_]} = Pkt2} when NS == ?NS_CLIENT ->
            case xmpp:get_subtag(Pkt2, #xmpp_session{}) of
                #xmpp_session{} ->
-                   send_element(State, xmpp:make_iq_result(Pkt2));
+                   send_pkt(State, xmpp:make_iq_result(Pkt2));
                _ ->
                    try Mod:handle_authenticated_packet(Pkt2, State)
                    catch _:undef ->
@@ -585,7 +596,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
                    send_error(State, Pkt, Err)
            end;
        {error, Err} ->
-           send_element(State, Err)
+           send_pkt(State, Err)
     end.
 
 -spec process_bind(xmpp_element(), state()) -> state().
@@ -604,7 +615,7 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt,
                               server := S,
                               resource := NewR} = State1} when NewR /= <<"">> ->
                            Reply = #bind{jid = jid:make(U, S, NewR)},
-                           State2 = send_element(State1, xmpp:make_iq_result(Pkt, Reply)),
+                           State2 = send_pkt(State1, xmpp:make_iq_result(Pkt, Reply)),
                            process_stream_established(State2);
                        {error, #stanza_error{}, State1} = Err ->
                            send_error(State1, Pkt, Err)
@@ -646,7 +657,7 @@ process_handshake(#handshake{data = Digest},
            case is_disconnected(State1) of
                true -> State1;
                false ->
-                   State2 = send_element(State1, #handshake{}),
+                   State2 = send_pkt(State1, #handshake{}),
                    process_stream_established(State2)
            end;
        false ->
@@ -656,7 +667,7 @@ process_handshake(#handshake{data = Digest},
                     end,
            case is_disconnected(State1) of
                true -> State1;
-               false -> send_element(State1, xmpp:serr_not_authorized())
+               false -> send_pkt(State1, xmpp:serr_not_authorized())
            end
     end.
 
@@ -674,7 +685,7 @@ process_stream_established(#{mod := Mod} = State) ->
 
 -spec process_compress(compress(), state()) -> state().
 process_compress(#compress{}, #{stream_compressed := true} = State) ->
-    send_element(State, #compress_failure{reason = 'setup-failed'});
+    send_pkt(State, #compress_failure{reason = 'setup-failed'});
 process_compress(#compress{methods = HisMethods},
                 #{socket := Socket, sockmod := SockMod, mod := Mod} = State) ->
     MyMethods = try Mod:compress_methods(State)
@@ -683,44 +694,60 @@ process_compress(#compress{methods = HisMethods},
     CommonMethods = lists_intersection(MyMethods, HisMethods),
     case lists:member(<<"zlib">>, CommonMethods) of
        true ->
-           BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})),
-           ZlibSocket = SockMod:compress(Socket, BCompressed),
-           State#{socket => ZlibSocket,
-                  stream_id => new_id(),
-                  stream_header_sent => false,
-                  stream_restarted => true,
-                  stream_state => wait_for_stream,
-                  stream_compressed => true};
+           State1 = send_pkt(State, #compressed{}),
+           case is_disconnected(State1) of
+               true -> State1;
+               false ->
+                   case SockMod:compress(Socket) of
+                       {ok, ZlibSocket} ->
+                           State1#{socket => ZlibSocket,
+                                   stream_id => new_id(),
+                                   stream_header_sent => false,
+                                   stream_restarted => true,
+                                   stream_state => wait_for_stream,
+                                   stream_compressed => true};
+                       {error, _} ->
+                           Err = #compress_failure{reason = 'setup-failed'},
+                           send_pkt(State1, Err)
+                   end
+           end;
        false ->
-           send_element(State, #compress_failure{reason = 'unsupported-method'})
+           send_pkt(State, #compress_failure{reason = 'unsupported-method'})
     end.
 
 -spec process_starttls(state()) -> state().
+process_starttls(#{stream_encrypted := true} = State) ->
+    process_starttls_failure(already_encrypted, State);
 process_starttls(#{socket := Socket,
                   sockmod := SockMod, mod := Mod} = State) ->
-    TLSOpts = try Mod:tls_options(State)
-             catch _:undef -> []
-             end,
-    case SockMod:starttls(Socket, TLSOpts) of
-       {ok, TLSSocket} ->
-           State1 = send_element(State, #starttls_proceed{}),
-           case is_disconnected(State1) of
-               true -> State1;
-               false ->
-                   State1#{socket => TLSSocket,
-                           stream_id => new_id(),
-                           stream_header_sent => false,
-                           stream_restarted => true,
-                           stream_state => wait_for_stream,
-                           stream_encrypted => true}
+    case is_starttls_available(State) of
+       true ->
+           TLSOpts = try Mod:tls_options(State)
+                     catch _:undef -> []
+                     end,
+           case SockMod:starttls(Socket, TLSOpts) of
+               {ok, TLSSocket} ->
+                   State1 = send_pkt(State, #starttls_proceed{}),
+                   case is_disconnected(State1) of
+                       true -> State1;
+                       false ->
+                           State1#{socket => TLSSocket,
+                                   stream_id => new_id(),
+                                   stream_header_sent => false,
+                                   stream_restarted => true,
+                                   stream_state => wait_for_stream,
+                                   stream_encrypted => true}
+                   end;
+               {error, Reason} ->
+                   process_starttls_failure(Reason, State)
            end;
-       {error, Reason} ->
-           process_starttls_failure(Reason, State)
+       false ->
+           process_starttls_failure(starttls_unsupported, State)
     end.
 
 -spec process_starttls_failure(term(), state()) -> state().
 process_starttls_failure(Why, State) ->
-    State1 = send_element(State, #starttls_failure{}),
+    State1 = send_pkt(State, #starttls_failure{}),
     case is_disconnected(State1) of
        true -> State1;
        false -> process_stream_end({tls, Why}, State1)
@@ -780,17 +807,17 @@ process_sasl_success(Props, ServerOut,
                       mod := Mod, sasl_mech := Mech} = State) ->
     User = identity(Props),
     AuthModule = proplists:get_value(auth_module, Props),
-    State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State)
-            catch _:undef -> State
-            end,
+    State1 = send_pkt(State, #sasl_success{text = ServerOut}),
     case is_disconnected(State1) of
        true -> State1;
        false ->
-           SockMod:reset_stream(Socket),
-           State2 = send_element(State1, #sasl_success{text = ServerOut}),
+           State2 = try Mod:handle_auth_success(User, Mech, AuthModule, State1)
+                    catch _:undef -> State1
+                    end,
            case is_disconnected(State2) of
                true -> State2;
                false ->
+                   SockMod:reset_stream(Socket),
                    State3 = maps:remove(sasl_state,
                                         maps:remove(sasl_mech, State2)),
                    State3#{stream_id => new_id(),
@@ -806,19 +833,23 @@ process_sasl_success(Props, ServerOut,
 process_sasl_continue(ServerOut, NewSASLState, State) ->
     State1 = State#{sasl_state => NewSASLState,
                    stream_state => wait_for_sasl_response},
-    send_element(State1, #sasl_challenge{text = ServerOut}).
+    send_pkt(State1, #sasl_challenge{text = ServerOut}).
 
 -spec process_sasl_failure(atom(), binary(), state()) -> state().
 process_sasl_failure(Err, User,
                     #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) ->
     {Reason, Text} = format_sasl_error(Mech, Err),
-    State1 = try Mod:handle_auth_failure(User, Mech, Text, State)
-            catch _:undef -> State
-            end,
-    State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)),
-    State3 = State2#{stream_state => wait_for_sasl_request},
-    send_element(State3, #sasl_failure{reason = Reason,
-                                      text = xmpp:mk_text(Text, Lang)}).
+    State1 = send_pkt(State, #sasl_failure{reason = Reason,
+                                          text = xmpp:mk_text(Text, Lang)}),
+    case is_disconnected(State1) of
+       true -> State1;
+       false ->
+           State2 = try Mod:handle_auth_failure(User, Mech, Text, State1)
+                    catch _:undef -> State1
+                    end,
+           State3 = maps:remove(sasl_state, maps:remove(sasl_mech, State2)),
+           State3#{stream_state => wait_for_sasl_request}
+    end.
 
 -spec process_sasl_abort(state()) -> state().
 process_sasl_abort(State) ->
@@ -835,7 +866,7 @@ send_features(#{stream_version := {1,0},
                           ++ get_tls_feature(State) ++ get_bind_feature(State)
                           ++ get_session_feature(State) ++ get_other_features(State)
               end,
-    send_element(State, #stream_features{sub_els = Features});
+    send_pkt(State, #stream_features{sub_els = Features});
 send_features(State) ->
     %% clients and servers from stone age
     State.
@@ -849,10 +880,13 @@ get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod,
     TLSVerify = try Mod:tls_verify(State)
                catch _:undef -> false
                end,
-    if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
-           [<<"EXTERNAL">>|Mechs];
-       true ->
-           Mechs
+    Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
+                    [<<"EXTERNAL">>|Mechs];
+               true ->
+                    Mechs
+            end,
+    try Mod:sasl_mechanisms(Mechs1, State)
+    catch _:undef -> Mechs1
     end.
 
 -spec get_sasl_feature(state()) -> [sasl_mechanisms()].
@@ -882,8 +916,13 @@ get_compress_feature(_) ->
 -spec get_tls_feature(state()) -> [starttls()].
 get_tls_feature(#{stream_authenticated := false,
                  stream_encrypted := false} = State) ->
-    TLSRequired = is_starttls_required(State),
-    [#starttls{required = TLSRequired}];
+    case is_starttls_available(State) of
+       true ->
+           TLSRequired = is_starttls_required(State),
+           [#starttls{required = TLSRequired}];
+       false ->
+           []
+    end;
 get_tls_feature(_) ->
     [].
 
@@ -913,6 +952,12 @@ get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
            []
     end.
 
+-spec is_starttls_available(state()) -> boolean().
+is_starttls_available(#{mod := Mod} = State) ->
+    try Mod:tls_enabled(State)
+    catch _:undef -> true
+    end.
+
 -spec is_starttls_required(state()) -> boolean().
 is_starttls_required(#{mod := Mod} = State) ->
     try Mod:tls_required(State)
@@ -967,13 +1012,14 @@ send_header(#{stream_id := StreamID,
              lang := MyLang,
              xmlns := NS,
              server := DefaultServer} = State,
-           #stream_start{to = To, lang = HisLang, version = HisVersion}) ->
+           #stream_start{to = HisTo, from = HisFrom,
+                         lang = HisLang, version = HisVersion}) ->
     Lang = select_lang(MyLang, HisLang),
     NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
               true -> <<"">>
            end,
-    From = case To of
-              #jid{} -> To;
+    From = case HisTo of
+              #jid{} -> HisTo;
               undefined -> jid:make(DefaultServer)
           end,
     Version = case HisVersion of
@@ -981,45 +1027,40 @@ send_header(#{stream_id := StreamID,
                  {0,_} -> HisVersion;
                  _ -> MyVersion
              end,
-    Header = xmpp:encode(#stream_start{version = Version,
-                                      lang = Lang,
-                                      xmlns = NS,
-                                      stream_xmlns = ?NS_STREAM,
-                                      db_xmlns = NS_DB,
-                                      id = StreamID,
-                                      from = From}),
+    StreamStart = #stream_start{version = Version,
+                               lang = Lang,
+                               xmlns = NS,
+                               stream_xmlns = ?NS_STREAM,
+                               db_xmlns = NS_DB,
+                               id = StreamID,
+                               to = HisFrom,
+                               from = From},
     State1 = State#{lang => Lang,
                    stream_version => Version,
                    stream_header_sent => true},
-    case send_text(State1, fxml:element_to_header(Header)) of
+    case socket_send(State1, StreamStart) of
        ok -> State1;
        {error, Why} -> process_stream_end({socket, Why}, State1)
     end;
 send_header(State, _) ->
     State.
 
--spec send_element(state(), xmpp_element()) -> state().
-send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
-    El = xmpp:encode(Pkt, NS),
-    Data = fxml:element_to_binary(El),
-    Result = send_text(State, Data),
+-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
+send_pkt(#{mod := Mod} = State, Pkt) ->
+    Result = socket_send(State, Pkt),
     State1 = try Mod:handle_send(Pkt, Result, State)
             catch _:undef -> State
             end,
-    case is_disconnected(State1) of
-       true -> State1;
-       false ->
-           case Result of
-               _ when is_record(Pkt, stream_error) ->
-                   process_stream_end({stream, {out, Pkt}}, State1);
-               ok ->
-                   State1;
-               {error, Why} ->
-                   process_stream_end({socket, Why}, State1)
-           end
+    case Result of
+       _ when is_record(Pkt, stream_error) ->
+           process_stream_end({stream, {out, Pkt}}, State1);
+       ok ->
+           State1;
+       {error, Why} ->
+           process_stream_end({socket, Why}, State1)
     end.
 
--spec send_error(state(), xmpp_element(), stanza_error()) -> state().
+-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state().
 send_error(State, Pkt, Err) ->
     case xmpp:is_stanza(Pkt) of
        true ->
@@ -1030,7 +1071,7 @@ send_error(State, Pkt, Err) ->
                <<"error">> -> State;
                _ ->
                    ErrPkt = xmpp:make_error(Pkt, Err),
-                   send_element(State, ErrPkt)
+                   send_pkt(State, ErrPkt)
            end;
        false ->
            State
@@ -1038,15 +1079,23 @@ send_error(State, Pkt, Err) ->
 
 -spec send_trailer(state()) -> state().
 send_trailer(State) ->
-    send_text(State, <<"</stream:stream>">>),
+    socket_send(State, trailer),
     close_socket(State).
 
--spec send_text(state(), binary()) -> ok | {error, inet:posix()}.
-send_text(#{socket := Sock, sockmod := SockMod,
-           stream_state := StateName,
-           stream_header_sent := true}, Data) when StateName /= disconnected ->
-    SockMod:send(Sock, Data);
-send_text(_, _) ->
+-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}.
+socket_send(#{socket := Sock, sockmod := SockMod,
+             stream_state := StateName,
+             xmlns := NS,
+             stream_header_sent := true}, Pkt) when StateName /= disconnected ->
+    case Pkt of
+       trailer ->
+           SockMod:send_trailer(Sock);
+       #stream_start{} ->
+           SockMod:send_header(Sock, xmpp:encode(Pkt));
+       _ ->
+           SockMod:send_element(Sock, xmpp:encode(Pkt, NS))
+    end;
+socket_send(_, _) ->
     {error, closed}.
 
 -spec close_socket(state()) -> state().
@@ -1096,6 +1145,12 @@ format_sasl_error(<<"EXTERNAL">>, Err) ->
 format_sasl_error(Mech, Err) ->
     cyrsasl:format_error(Mech, Err).
 
+-spec format_tls_error(atom() | binary()) -> list().
+format_tls_error(Reason) when is_atom(Reason) ->
+    format_inet_error(Reason);
+format_tls_error(Reason) ->
+    Reason.
+
 -spec format(io:format(), list()) -> binary().
 format(Fmt, Args) ->
     iolist_to_binary(io_lib:format(Fmt, Args)).
index adbc6ffba1017db0ca462546db7973acc4742867..3dcecf6f6519735ac3d828f735844267f1fccf78 100644 (file)
 %%%
 %%%-------------------------------------------------------------------
 -module(xmpp_stream_out).
--behaviour(gen_server).
+-define(GEN_SERVER, gen_server).
+-behaviour(?GEN_SERVER).
 
 -protocol({rfc, 6120}).
+-protocol({xep, 114, '1.6'}).
 
 %% API
 -export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1,
@@ -42,7 +44,6 @@
 -define(TCP_SEND_TIMEOUT, 15000).
 
 -include("xmpp.hrl").
--include("logger.hrl").
 -include_lib("kernel/include/inet.hrl").
 
 -type state() :: map().
 -type stop_reason() :: {idna, bad_string} |
                       {dns, inet:posix() | inet_res:res_error()} |
                       {stream, reset | {in | out, stream_error()}} |
-                      {tls, term()} |
+                      {tls, inet:posix() | atom() | binary()} |
                       {pkix, binary()} |
                       {auth, atom() | binary() | string()} |
                       {socket, inet:posix() | closed | timeout} |
                       internal_failure.
 
--callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
+-callback init(list()) -> {ok, state()} | {error, term()} | ignore.
+-callback handle_cast(term(), state()) -> state().
+-callback handle_call(term(), term(), state()) -> state().
+-callback handle_info(term(), state()) -> state().
+-callback terminate(term(), state()) -> any().
+-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
+-callback handle_stream_start(stream_start(), state()) -> state().
+-callback handle_stream_established(state()) -> state().
+-callback handle_stream_downgraded(stream_start(), state()) -> state().
+-callback handle_stream_end(stop_reason(), state()) -> state().
+-callback handle_cdata(binary(), state()) -> state().
+-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
+-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
+-callback handle_timeout(state()) -> state().
+-callback handle_authenticated_features(stream_features(), state()) -> state().
+-callback handle_unauthenticated_features(stream_features(), state()) -> state().
+-callback handle_auth_success(cyrsasl:mechanism(), state()) -> state().
+-callback handle_auth_failure(cyrsasl:mechanism(), binary(), state()) -> state().
+-callback handle_packet(xmpp_element(), state()) -> state().
+-callback tls_options(state()) -> [proplists:property()].
+-callback tls_required(state()) -> boolean().
+-callback tls_verify(state()) -> boolean().
+-callback tls_enabled(state()) -> boolean().
+-callback dns_timeout(state()) -> timeout().
+-callback dns_retries(state()) -> non_neg_integer().
+-callback default_port(state()) -> inet:port_number().
+-callback address_families(state()) -> [inet:address_family()].
+-callback connect_timeout(state()) -> timeout().
+
+-optional_callbacks([init/1,
+                    handle_cast/2,
+                    handle_call/3,
+                    handle_info/2,
+                    terminate/2,
+                    code_change/3,
+                    handle_stream_start/2,
+                    handle_stream_established/1,
+                    handle_stream_downgraded/2,
+                    handle_stream_end/2,
+                    handle_cdata/2,
+                    handle_send/3,
+                    handle_recv/3,
+                    handle_timeout/1,
+                    handle_authenticated_features/2,
+                    handle_unauthenticated_features/2,
+                    handle_auth_success/2,
+                    handle_auth_failure/3,
+                    handle_packet/2,
+                    tls_options/1,
+                    tls_required/1,
+                    tls_verify/1,
+                    tls_enabled/1,
+                    dns_timeout/1,
+                    dns_retries/1,
+                    default_port/1,
+                    address_families/1,
+                    connect_timeout/1]).
 
 %%%===================================================================
 %%% API
 %%%===================================================================
 start(Mod, Args, Opts) ->
-    gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+    ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
 
 start_link(Mod, Args, Opts) ->
-    gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+    ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
 
 call(Ref, Msg, Timeout) ->
-    gen_server:call(Ref, Msg, Timeout).
+    ?GEN_SERVER:call(Ref, Msg, Timeout).
 
 cast(Ref, Msg) ->
-    gen_server:cast(Ref, Msg).
+    ?GEN_SERVER:cast(Ref, Msg).
 
 reply(Ref, Reply) ->
-    gen_server:reply(Ref, Reply).
+    ?GEN_SERVER:reply(Ref, Reply).
 
 -spec connect(pid()) -> ok.
 connect(Ref) ->
@@ -98,7 +155,7 @@ stop(_) ->
 send(Pid, Pkt) when is_pid(Pid) ->
     cast(Pid, {send, Pkt});
 send(#{owner := Owner} = State, Pkt) when Owner == self() ->
-    send_element(State, Pkt);
+    send_pkt(State, Pkt);
 send(_, _) ->
     erlang:error(badarg).
 
@@ -154,7 +211,8 @@ format_error({dns, Reason}) ->
 format_error({socket, Reason}) ->
     format("Connection failed: ~s", [format_inet_error(Reason)]);
 format_error({pkix, Reason}) ->
-    format("Peer certificate rejected: ~s", [Reason]);
+    {_, ErrTxt} = xmpp_stream_pkix:format_error(Reason),
+    format("Peer certificate rejected: ~s", [ErrTxt]);
 format_error({stream, reset}) ->
     <<"Stream reset by peer">>;
 format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
@@ -162,7 +220,7 @@ format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
 format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
     format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
 format_error({tls, Reason}) ->
-    format("TLS failed: ~w", [Reason]);
+    format("TLS failed: ~s", [format_tls_error(Reason)]);
 format_error({auth, Reason}) ->
     format("Authentication failed: ~s", [Reason]);
 format_error(internal_failure) ->
@@ -199,8 +257,10 @@ init([Mod, SockMod, From, To, Opts]) ->
        {ok, State1} ->
            {_, State2, Timeout} = noreply(State1),
            {ok, State2, Timeout};
-       Err ->
-           Err
+       {error, Reason} ->
+           {stop, Reason};
+       ignore ->
+           ignore
     end.
 
 -spec handle_call(term(), term(), state()) -> noreply().
@@ -239,7 +299,7 @@ handle_cast(connect, State) ->
     %% Ignoring connection attempts in other states
     noreply(State);
 handle_cast({send, Pkt}, State) ->
-    noreply(send_element(State, Pkt));
+    noreply(send_pkt(State, Pkt));
 handle_cast(stop, State) ->
     {stop, normal, State};
 handle_cast(Cast, #{mod := Mod} = State) ->
@@ -257,12 +317,12 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
          #stream_start{} = Pkt ->
              process_stream(Pkt, State);
          _ ->
-             send_element(State, xmpp:serr_invalid_xml())
+             send_pkt(State, xmpp:serr_invalid_xml())
       catch _:{xmpp_codec, Why} ->
              Txt = xmpp:io_format_error(Why),
              Lang = select_lang(MyLang, xmpp:get_lang(El)),
              Err = xmpp:serr_invalid_xml(Txt, Lang),
-             send_element(State, Err)
+             send_pkt(State, Err)
       end);
 handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
     State1 = send_header(State),
@@ -276,7 +336,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
                        _ ->
                            xmpp:serr_not_well_formed()
                    end,
-             send_element(State1, Err)
+             send_pkt(State1, Err)
       end);
 handle_info({'$gen_event', {xmlstreamelement, El}},
            #{xmlns := NS, mod := Mod} = State) ->
@@ -291,7 +351,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
                  false -> process_element(Pkt, State1)
              end
       catch _:{xmpp_codec, Why} ->
-             State1 = try Mod:handle_recv(El, undefined, State)
+             State1 = try Mod:handle_recv(El, {error, Why}, State)
                       catch _:undef -> State
                       end,
              case is_disconnected(State1) of
@@ -312,7 +372,7 @@ handle_info(timeout, #{mod := Mod} = State) ->
     Disconnected = is_disconnected(State),
     noreply(try Mod:handle_timeout(State)
            catch _:undef when not Disconnected ->
-                   send_element(State, xmpp:serr_connection_timeout());
+                   send_pkt(State, xmpp:serr_connection_timeout());
                  _:undef ->
                    stop(State)
            end);
@@ -384,9 +444,9 @@ process_stream(#stream_start{xmlns = XML_NS,
                             stream_xmlns = STREAM_NS},
               #{xmlns := NS} = State)
   when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
-    send_element(State, xmpp:serr_invalid_namespace());
+    send_pkt(State, xmpp:serr_invalid_namespace());
 process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
-    send_element(State, xmpp:serr_unsupported_version());
+    send_pkt(State, xmpp:serr_unsupported_version());
 process_stream(#stream_start{lang = Lang, id = ID,
                             version = Version} = StreamStart,
               #{mod := Mod} = State) ->
@@ -451,15 +511,19 @@ process_features(#stream_features{sub_els = Els} = StreamFeatures,
        true -> State1;
        false ->
            TLSRequired = is_starttls_required(State1),
+           TLSAvailable = is_starttls_available(State1),
            %% TODO: improve xmpp.erl
            Msg = #message{sub_els = Els},
            case xmpp:get_subtag(Msg, #starttls{}) of
                false when TLSRequired and not Encrypted ->
                    Txt = <<"Use of STARTTLS required">>,
-                   send_element(State1, xmpp:err_policy_violation(Txt, Lang));
-               #starttls{} when not Encrypted ->
+                   send_pkt(State1, xmpp:serr_policy_violation(Txt, Lang));
+               #starttls{required = true} when not TLSAvailable and not Encrypted ->
+                   Txt = <<"Use of STARTTLS forbidden">>,
+                   send_pkt(State1, xmpp:serr_policy_violation(Txt, Lang));
+               #starttls{} when TLSAvailable and not Encrypted ->
                    State2 = State1#{stream_state => wait_for_starttls_response},
-                   send_element(State2, #starttls{});
+                   send_pkt(State2, #starttls{});
                _ ->
                    State2 = process_cert_verification(State1),
                    case is_disconnected(State2) of
@@ -497,7 +561,7 @@ process_sasl_mechanisms(Mechs, #{user := User, server := Server} = State) ->
        true ->
            State1 = State#{stream_state => wait_for_sasl_response},
            Authzid = jid:to_string(jid:make(User, Server)),
-           send_element(State1, #sasl_auth{mechanism = Mech, text = Authzid});
+           send_pkt(State1, #sasl_auth{mechanism = Mech, text = Authzid});
        false ->
            process_sasl_failure(
              #sasl_failure{reason = 'invalid-mechanism'}, State)
@@ -527,12 +591,12 @@ process_stream_downgrade(StreamStart,
     TLSRequired = is_starttls_required(State),
     if not Encrypted and TLSRequired ->
            Txt = <<"Use of STARTTLS required">>,
-           send_element(State, xmpp:err_policy_violation(Txt, Lang));
+           send_pkt(State, xmpp:serr_policy_violation(Txt, Lang));
        true ->
            State1 = State#{stream_state => downgraded},
            try Mod:handle_stream_downgraded(StreamStart, State1)
            catch _:undef ->
-                   send_element(State1, xmpp:serr_unsupported_version())
+                   send_pkt(State1, xmpp:serr_unsupported_version())
            end
     end.
 
@@ -576,7 +640,7 @@ process_sasl_success(#{mod := Mod,
 
 -spec process_sasl_failure(sasl_failure(), state()) -> state().
 process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) ->
-    try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State)
+    try Mod:handle_auth_failure(<<"EXTERNAL">>, {auth, Reason}, State)
     catch _:undef -> process_stream_end({auth, Reason}, State)
     end.
 
@@ -592,6 +656,12 @@ is_starttls_required(#{mod := Mod} = State) ->
     catch _:undef -> false
     end.
 
+-spec is_starttls_available(state()) -> boolean().
+is_starttls_available(#{mod := Mod} = State) ->
+    try Mod:tls_enabled(State)
+    catch _:undef -> true
+    end.
+
 -spec send_header(state()) -> state().
 send_header(#{remote_server := RemoteServer,
              stream_encrypted := Encrypted,
@@ -610,40 +680,34 @@ send_header(#{remote_server := RemoteServer,
              true ->
                   undefined
           end,
-    Header = xmpp:encode(
-              #stream_start{xmlns = NS,
-                            lang = Lang,
-                            stream_xmlns = ?NS_STREAM,
-                            db_xmlns = NS_DB,
-                            from = From,
-                            to = jid:make(RemoteServer),
-                            version = {1,0}}),
-    case send_text(State, fxml:element_to_header(Header)) of
+    StreamStart = #stream_start{xmlns = NS,
+                               lang = Lang,
+                               stream_xmlns = ?NS_STREAM,
+                               db_xmlns = NS_DB,
+                               from = From,
+                               to = jid:make(RemoteServer),
+                               version = {1,0}},
+    case socket_send(State, StreamStart) of
        ok -> State;
        {error, Why} -> process_stream_end({socket, Why}, State)
     end.
 
--spec send_element(state(), xmpp_element()) -> state().
-send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
-    El = xmpp:encode(Pkt, NS),
-    Data = fxml:element_to_binary(El),
-    State1 = try Mod:handle_send(Pkt, El, Data, State)
+-spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
+send_pkt(#{mod := Mod} = State, Pkt) ->
+    Result = socket_send(State, Pkt),
+    State1 = try Mod:handle_send(Pkt, Result, State)
             catch _:undef -> State
             end,
-    case is_disconnected(State1) of
-       true -> State1;
-       false ->
-           case send_text(State1, Data) of
-               _ when is_record(Pkt, stream_error) ->
-                   process_stream_end({stream, {out, Pkt}}, State1);
-               ok ->
-                   State1;
-               {error, Why} ->
-                   process_stream_end({socket, Why}, State1)
-           end
+    case Result of
+       _ when is_record(Pkt, stream_error) ->
+           process_stream_end({stream, {out, Pkt}}, State1);
+       ok ->
+           State1;
+       {error, Why} ->
+           process_stream_end({socket, Why}, State1)
     end.
 
--spec send_error(state(), xmpp_element(), stanza_error()) -> state().
+-spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state().
 send_error(State, Pkt, Err) ->
     case xmpp:is_stanza(Pkt) of
        true ->
@@ -654,22 +718,29 @@ send_error(State, Pkt, Err) ->
                <<"error">> -> State;
                _ ->
                    ErrPkt = xmpp:make_error(Pkt, Err),
-                   send_element(State, ErrPkt)
+                   send_pkt(State, ErrPkt)
            end;
        false ->
            State
     end.
 
--spec send_text(state(), binary()) -> ok | {error, inet:posix()}.
-send_text(#{sockmod := SockMod, socket := Socket,
-           stream_state := StateName}, Data) when StateName /= disconnected ->
-    SockMod:send(Socket, Data);
-send_text(_, _) ->
+-spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}.
+socket_send(#{sockmod := SockMod, socket := Socket, xmlns := NS,
+             stream_state := StateName}, Pkt) when StateName /= disconnected ->
+    case Pkt of
+       trailer ->
+           SockMod:send_trailer(Socket);
+       #stream_start{} ->
+           SockMod:send_header(Socket, xmpp:encode(Pkt));
+       _ ->
+           SockMod:send_element(Socket, xmpp:encode(Pkt, NS))
+    end;
+socket_send(_, _) ->
     {error, closed}.
 
 -spec send_trailer(state()) -> state().
 send_trailer(State) ->
-    send_text(State, <<"</stream:stream>">>),
+    socket_send(State, trailer),
     close_socket(State).
 
 -spec close_socket(state()) -> state().
@@ -710,6 +781,12 @@ format_stream_error(Reason, Txt) ->
            binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
     end.
 
+-spec format_tls_error(atom() | binary()) -> list().
+format_tls_error(Reason) when is_atom(Reason) ->
+    format_inet_error(Reason);
+format_tls_error(Reason) ->
+    binary_to_list(Reason).
+
 -spec format(io:format(), list()) -> binary().
 format(Fmt, Args) ->
     iolist_to_binary(io_lib:format(Fmt, Args)).
@@ -747,13 +824,16 @@ resolve(Host, State) ->
     end.
 
 -spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error().
+srv_lookup(_Host, #{xmlns := ?NS_COMPONENT}) ->
+    %% Do not attempt to lookup SRV for component connections
+    {error, nxdomain};
 srv_lookup(Host, State) ->
     %% Only perform SRV lookups for FQDN names
     case string:chr(Host, $.) of
        0 ->
            {error, nxdomain};
        _ ->
-           case inet_parse:address(Host) of
+           case inet:parse_address(Host) of
                {ok, _} ->
                    {error, nxdomain};
                {error, _} ->
@@ -763,7 +843,7 @@ srv_lookup(Host, State) ->
            end
     end.
 
--spec srv_lookup(string(), non_neg_integer(), integer()) ->
+-spec srv_lookup(string(), timeout(), integer()) ->
                        {ok, [host_port()]} | network_error().
 srv_lookup(_Host, _Timeout, Retries) when Retries < 1 ->
     {error, timeout};
@@ -807,7 +887,7 @@ a_lookup([], _State, Err) ->
     Err.
 
 -spec a_lookup(inet:hostname(), inet:port_number(), inet:address_family(),
-              non_neg_integer(), integer()) -> {ok, [ip_port()]} | network_error().
+              timeout(), integer()) -> {ok, [ip_port()]} | network_error().
 a_lookup(_Host, _Port, _Family, _Timeout, Retries) when Retries < 1 ->
     {error, timeout};
 a_lookup(Host, Port, Family, Timeout, Retries) ->
@@ -861,7 +941,7 @@ connect(AddrPorts, #{sockmod := SockMod} = State) ->
     Timeout = get_connect_timeout(State),
     connect(AddrPorts, SockMod, Timeout, {error, nxdomain}).
 
--spec connect([ip_port()], module(), non_neg_integer(), network_error()) ->
+-spec connect([ip_port()], module(), timeout(), network_error()) ->
                     {ok, term(), ip_port()} | network_error().
 connect([{Addr, Port}|AddrPorts], SockMod, Timeout, _) ->
     Type = get_addr_type(Addr),
@@ -883,12 +963,11 @@ connect([], _SockMod, _Timeout, Err) ->
 get_addr_type({_, _, _, _}) -> inet;
 get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
 
--spec get_dns_timeout(state()) -> non_neg_integer().
+-spec get_dns_timeout(state()) -> timeout().
 get_dns_timeout(#{mod := Mod} = State) ->
-    timer:seconds(
-      try Mod:dns_timeout(State)
-      catch _:undef -> 10
-      end).
+    try Mod:dns_timeout(State)
+    catch _:undef -> timer:seconds(10)
+    end.
 
 -spec get_dns_retries(state()) -> non_neg_integer().
 get_dns_retries(#{mod := Mod} = State) ->
@@ -909,9 +988,8 @@ get_address_families(#{mod := Mod} = State) ->
     catch _:undef -> [inet, inet6]
     end.
 
--spec get_connect_timeout(state()) -> non_neg_integer().
+-spec get_connect_timeout(state()) -> timeout().
 get_connect_timeout(#{mod := Mod} = State) ->
-    timer:seconds(
-      try Mod:connect_timeout(State)
-      catch _:undef -> 10
-      end).
+    try Mod:connect_timeout(State)
+    catch _:undef -> timer:seconds(10)
+    end.