]> granicus.if.org Git - ejabberd/commitdiff
Improve type specs for ejabberd_s2s
authorEvgeny Khramtsov <ekhramtsov@process-one.net>
Tue, 9 Jul 2019 13:42:24 +0000 (16:42 +0300)
committerEvgeny Khramtsov <ekhramtsov@process-one.net>
Tue, 9 Jul 2019 13:42:24 +0000 (16:42 +0300)
Also minor code cleanup

src/ejabberd_s2s.erl

index 2128d6b6a66c89a89d1173b914573450e43ede0b..a1937da7e4f1652c7e286d332767d9ae26638692 100644 (file)
 -include("logger.hrl").
 -include("xmpp.hrl").
 -include("ejabberd_commands.hrl").
--include_lib("public_key/include/public_key.hrl").
+-include_lib("stdlib/include/ms_transform.hrl").
 -include("ejabberd_stacktrace.hrl").
 -include("translate.hrl").
 
--define(PKIXEXPLICIT, 'OTP-PUB-KEY').
-
--define(PKIXIMPLICIT, 'OTP-PUB-KEY').
-
--include("XmppAddr.hrl").
-
 -define(DEFAULT_MAX_S2S_CONNECTIONS_NUMBER, 1).
-
 -define(DEFAULT_MAX_S2S_CONNECTIONS_NUMBER_PER_NODE, 1).
-
 -define(S2S_OVERLOAD_BLOCK_PERIOD, 60).
 
 %% once a server is temporarly blocked, it stay blocked for 60 seconds
 
--record(s2s, {fromto = {<<"">>, <<"">>} :: {binary(), binary()} | '_',
-              pid = self()              :: pid() | '_' | '$1'}).
+-record(s2s, {fromto :: {binary(), binary()},
+              pid    :: pid()}).
 
 -record(state, {}).
 
--record(temporarily_blocked, {host = <<"">>     :: binary(),
-                              timestamp         :: integer()}).
+-record(temporarily_blocked, {host      :: binary(),
+                              timestamp :: integer()}).
 
 -type temporarily_blocked() :: #temporarily_blocked{}.
-
 start_link() ->
-    gen_server:start_link({local, ?MODULE}, ?MODULE, [],
-                         []).
+    gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
 
 clean_temporarily_blocked_table() ->
     mnesia:clear_table(temporarily_blocked).
 
 -spec list_temporarily_blocked_hosts() -> [temporarily_blocked()].
-
 list_temporarily_blocked_hosts() ->
     ets:tab2list(temporarily_blocked).
 
 -spec external_host_overloaded(binary()) -> {aborted, any()} | {atomic, ok}.
-
 external_host_overloaded(Host) ->
-    ?INFO_MSG("Disabling connections from ~s for ~p "
-             "seconds",
+    ?INFO_MSG("Disabling s2s connections to ~s for ~p seconds",
              [Host, ?S2S_OVERLOAD_BLOCK_PERIOD]),
     mnesia:transaction(fun () ->
                                Time = erlang:monotonic_time(),
@@ -107,21 +94,20 @@ external_host_overloaded(Host) ->
                       end).
 
 -spec is_temporarly_blocked(binary()) -> boolean().
-
 is_temporarly_blocked(Host) ->
     case mnesia:dirty_read(temporarily_blocked, Host) of
-      [] -> false;
-      [#temporarily_blocked{timestamp = T} = Entry] ->
-          Diff = erlang:monotonic_time() - T,
-         case erlang:convert_time_unit(Diff, native, microsecond) of
-           N when N > (?S2S_OVERLOAD_BLOCK_PERIOD) * 1000 * 1000 ->
-               mnesia:dirty_delete_object(Entry), false;
-           _ -> true
-         end
+       [] -> false;
+       [#temporarily_blocked{timestamp = T} = Entry] ->
+           Diff = erlang:monotonic_time() - T,
+           case erlang:convert_time_unit(Diff, native, microsecond) of
+               N when N > (?S2S_OVERLOAD_BLOCK_PERIOD) * 1000 * 1000 ->
+                   mnesia:dirty_delete_object(Entry), false;
+               _ -> true
+           end
     end.
 
 -spec remove_connection({binary(), binary()}, pid()) -> ok.
-remove_connection(FromTo, Pid) ->
+remove_connection({From, To} = FromTo, Pid) ->
     case mnesia:dirty_match_object(s2s, #s2s{fromto = FromTo, pid = Pid}) of
        [#s2s{pid = Pid}] ->
            F = fun() ->
@@ -130,25 +116,24 @@ remove_connection(FromTo, Pid) ->
            case mnesia:transaction(F) of
                {atomic, _} -> ok;
                {aborted, Reason} ->
-                   ?ERROR_MSG("Failed to unregister s2s connection: "
-                              "Mnesia failure: ~p", [Reason])
+                   ?ERROR_MSG("Failed to unregister s2s connection ~s -> ~s: "
+                              "Mnesia failure: ~p",
+                              [From, To, Reason])
            end;
        _ ->
            ok
     end.
 
 -spec have_connection({binary(), binary()}) -> boolean().
-
 have_connection(FromTo) ->
     case catch mnesia:dirty_read(s2s, FromTo) of
-       [_] ->
+       [_] ->
             true;
         _ ->
             false
     end.
 
 -spec get_connections_pids({binary(), binary()}) -> [pid()].
-
 get_connections_pids(FromTo) ->
     case catch mnesia:dirty_read(s2s, FromTo) of
        L when is_list(L) ->
@@ -158,8 +143,7 @@ get_connections_pids(FromTo) ->
     end.
 
 -spec try_register({binary(), binary()}) -> boolean().
-
-try_register(FromTo) ->
+try_register({From, To} = FromTo) ->
     MaxS2SConnectionsNumber = max_s2s_connections_number(FromTo),
     MaxS2SConnectionsNumberPerNode =
        max_s2s_connections_number_per_node(FromTo),
@@ -169,18 +153,21 @@ try_register(FromTo) ->
                                                              MaxS2SConnectionsNumber,
                                                              MaxS2SConnectionsNumberPerNode),
                if NeededConnections > 0 ->
-                      mnesia:write(#s2s{fromto = FromTo, pid = self()}),
-                      true;
+                       mnesia:write(#s2s{fromto = FromTo, pid = self()}),
+                       true;
                   true -> false
                end
        end,
     case mnesia:transaction(F) of
-      {atomic, Res} -> Res;
-      _ -> false
+       {atomic, Res} -> Res;
+       {aborted, Reason} ->
+           ?ERROR_MSG("Failed to register s2s connection ~s -> ~s: "
+                      "Mnesia failure: ~p",
+                      [From, To, Reason]),
+           false
     end.
 
 -spec dirty_get_connections() -> [{binary(), binary()}].
-
 dirty_get_connections() ->
     mnesia:dirty_all_keys(s2s).
 
@@ -276,10 +263,12 @@ init([]) ->
            {stop, Reason}
     end.
 
-handle_call(_Request, _From, State) ->
-    {reply, ok, State}.
+handle_call(Request, From, State) ->
+    ?WARNING_MSG("Unexpected call from ~p: ~p", [From, Request]),
+    {noreply, State}.
 
-handle_cast(_Msg, State) ->
+handle_cast(Msg, State) ->
+    ?WARNING_MSG("Unexpected cast: ~p", [Msg]),
     {noreply, State}.
 
 handle_info({mnesia_system_event, {mnesia_down, Node}}, State) ->
@@ -294,14 +283,15 @@ handle_info({route, Packet}, State) ->
                        misc:format_exception(2, Class, Reason, StackTrace)])
     end,
     {noreply, State};
-handle_info(_Info, State) -> {noreply, State}.
+handle_info(Info, State) ->
+    ?WARNING_MSG("Unexpected info: ~p", [Info]),
+    {noreply, State}.
 
 terminate(_Reason, _State) ->
     ejabberd_commands:unregister_commands(get_commands_spec()),
     lists:foreach(fun host_down/1, ejabberd_option:hosts()),
     ejabberd_hooks:delete(host_up, ?MODULE, host_up, 50),
-    ejabberd_hooks:delete(host_down, ?MODULE, host_down, 60),
-    ok.
+    ejabberd_hooks:delete(host_down, ?MODULE, host_down, 60).
 
 code_change(_OldVsn, State, _Extra) ->
     {ok, State}.
@@ -309,10 +299,12 @@ code_change(_OldVsn, State, _Extra) ->
 %%--------------------------------------------------------------------
 %%% Internal functions
 %%--------------------------------------------------------------------
+-spec host_up(binary()) -> ok.
 host_up(Host) ->
     ejabberd_s2s_in:host_up(Host),
     ejabberd_s2s_out:host_up(Host).
 
+-spec host_down(binary()) -> ok.
 host_down(Host) ->
     lists:foreach(
       fun(#s2s{fromto = {From, _}, pid = Pid}) when node(Pid) == node() ->
@@ -334,12 +326,11 @@ clean_table_from_bad_node(Node) ->
     F = fun() ->
                Es = mnesia:select(
                       s2s,
-                      [{#s2s{pid = '$1', _ = '_'},
-                        [{'==', {node, '$1'}, Node}],
-                        ['$_']}]),
-               lists:foreach(fun(E) ->
-                                     mnesia:delete_object(E)
-                             end, Es)
+                      ets:fun2ms(
+                        fun(#s2s{pid = Pid} = E) when node(Pid) == Node ->
+                                E
+                        end)),
+               lists:foreach(fun mnesia:delete_object/1, Es)
        end,
     mnesia:async_dirty(F).
 
@@ -350,12 +341,12 @@ route(Packet) ->
     To = xmpp:get_to(Packet),
     case start_connection(From, To) of
        {ok, Pid} when is_pid(Pid) ->
-         ?DEBUG("Sending to process ~p~n", [Pid]),
-         #jid{lserver = MyServer} = From,
+           ?DEBUG("Sending to process ~p~n", [Pid]),
+           #jid{lserver = MyServer} = From,
            ejabberd_hooks:run(s2s_send_packet, MyServer, [Packet]),
            ejabberd_s2s_out:route(Pid, Packet);
        {error, Reason} ->
-         Lang = xmpp:get_lang(Packet),
+           Lang = xmpp:get_lang(Packet),
            Err = case Reason of
                      forbidden ->
                          xmpp:err_forbidden(?T("Access denied by service policy"), Lang);
@@ -366,12 +357,12 @@ route(Packet) ->
     end.
 
 -spec start_connection(jid(), jid())
-      -> {ok, pid()} | {error, forbidden | internal_server_error}.
+                     -> {ok, pid()} | {error, forbidden | internal_server_error}.
 start_connection(From, To) ->
     start_connection(From, To, []).
 
 -spec start_connection(jid(), jid(), [proplists:property()])
-      -> {ok, pid()} | {error, forbidden | internal_server_error}.
+                     -> {ok, pid()} | {error, forbidden | internal_server_error}.
 start_connection(From, To, Opts) ->
     #jid{lserver = MyServer} = From,
     #jid{lserver = Server} = To,
@@ -382,11 +373,11 @@ start_connection(From, To, Opts) ->
        max_s2s_connections_number_per_node(FromTo),
     ?DEBUG("Finding connection for ~p~n", [FromTo]),
     case mnesia:dirty_read(s2s, FromTo) of
-      [] ->
-         %% We try to establish all the connections if the host is not a
-         %% service and if the s2s host is not blacklisted or
-         %% is in whitelist:
-         LServer = ejabberd_router:host_of_route(MyServer),
+       [] ->
+           %% We try to establish all the connections if the host is not a
+           %% service and if the s2s host is not blacklisted or
+           %% is in whitelist:
+           LServer = ejabberd_router:host_of_route(MyServer),
            case allow_host(LServer, Server) of
                true ->
                    NeededConnections = needed_connections_number(
@@ -400,20 +391,20 @@ start_connection(From, To, Opts) ->
                false ->
                    {error, forbidden}
            end;
-      L when is_list(L) ->
-         NeededConnections = needed_connections_number(L,
-                                                       MaxS2SConnectionsNumber,
-                                                       MaxS2SConnectionsNumberPerNode),
-         if NeededConnections > 0 ->
-                %% We establish the missing connections for this pair.
-                open_several_connections(NeededConnections, MyServer,
-                                         Server, From, FromTo,
-                                         MaxS2SConnectionsNumber,
-                                         MaxS2SConnectionsNumberPerNode, Opts);
-            true ->
-                %% We choose a connexion from the pool of opened ones.
-                {ok, choose_connection(From, L)}
-         end
+       L when is_list(L) ->
+           NeededConnections = needed_connections_number(L,
+                                                         MaxS2SConnectionsNumber,
+                                                         MaxS2SConnectionsNumberPerNode),
+           if NeededConnections > 0 ->
+                   %% We establish the missing connections for this pair.
+                   open_several_connections(NeededConnections, MyServer,
+                                            Server, From, FromTo,
+                                            MaxS2SConnectionsNumber,
+                                            MaxS2SConnectionsNumberPerNode, Opts);
+              true ->
+                   %% We choose a connexion from the pool of opened ones.
+                   {ok, choose_connection(From, L)}
+           end
     end.
 
 -spec choose_connection(jid(), [#s2s{}]) -> pid().
@@ -423,8 +414,8 @@ choose_connection(From, Connections) ->
 -spec choose_pid(jid(), [pid()]) -> pid().
 choose_pid(From, Pids) ->
     Pids1 = case [P || P <- Pids, node(P) == node()] of
-             [] -> Pids;
-             Ps -> Ps
+               [] -> Pids;
+               Ps -> Ps
            end,
     Pid =
        lists:nth(erlang:phash(jid:remove_resource(From),
@@ -433,13 +424,17 @@ choose_pid(From, Pids) ->
     ?DEBUG("Using ejabberd_s2s_out ~p~n", [Pid]),
     Pid.
 
+-spec open_several_connections(pos_integer(), binary(), binary(),
+                              jid(), {binary(), binary()},
+                              integer(), integer(), [proplists:property()]) ->
+                                     {ok, pid()} | {error, internal_server_error}.
 open_several_connections(N, MyServer, Server, From,
                         FromTo, MaxS2SConnectionsNumber,
                         MaxS2SConnectionsNumberPerNode, Opts) ->
     case lists:flatmap(
           fun(_) ->
                   new_connection(MyServer, Server,
-                                       From, FromTo, MaxS2SConnectionsNumber,
+                                 From, FromTo, MaxS2SConnectionsNumber,
                                  MaxS2SConnectionsNumberPerNode, Opts)
           end, lists:seq(1, N)) of
        [] ->
@@ -448,6 +443,8 @@ open_several_connections(N, MyServer, Server, From,
            {ok, choose_pid(From, PIDs)}
     end.
 
+-spec new_connection(binary(), binary(), jid(), {binary(), binary()},
+                    integer(), integer(), [proplists:property()]) -> [pid()].
 new_connection(MyServer, Server, From, FromTo,
               MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) ->
     {ok, Pid} = ejabberd_s2s_out:start(MyServer, Server, Opts),
@@ -457,22 +454,23 @@ new_connection(MyServer, Server, From, FromTo,
                                                              MaxS2SConnectionsNumber,
                                                              MaxS2SConnectionsNumberPerNode),
                if NeededConnections > 0 ->
-                      mnesia:write(#s2s{fromto = FromTo, pid = Pid}),
-                      Pid;
+                       mnesia:write(#s2s{fromto = FromTo, pid = Pid}),
+                       Pid;
                   true -> choose_connection(From, L)
                end
        end,
     TRes = mnesia:transaction(F),
     case TRes of
-      {atomic, Pid1} ->
+       {atomic, Pid1} ->
            if Pid1 == Pid ->
                    ejabberd_s2s_out:connect(Pid);
               true ->
                    ejabberd_s2s_out:stop(Pid)
            end,
            [Pid1];
-      {aborted, Reason} ->
-           ?ERROR_MSG("Failed to register connection ~s -> ~s: ~p",
+       {aborted, Reason} ->
+           ?ERROR_MSG("Failed to register s2s connection ~s -> ~s: "
+                      "Mnesia failure: ~p",
                       [MyServer, Server, Reason]),
            ejabberd_s2s_out:stop(Pid),
            []
@@ -529,11 +527,13 @@ incoming_s2s_number() ->
 outgoing_s2s_number() ->
     supervisor_count(ejabberd_s2s_out_sup).
 
+-spec supervisor_count(atom()) -> non_neg_integer().
 supervisor_count(Supervisor) ->
-    case catch supervisor:which_children(Supervisor) of
-        {'EXIT', _} -> 0;
-        Result ->
-            length(Result)
+    try supervisor:count_children(Supervisor) of
+       Props ->
+           proplists:get_value(workers, Props, 0)
+    catch _:_ ->
+           0
     end.
 
 -spec stop_s2s_connections() -> ok.
@@ -557,10 +557,12 @@ update_tables() ->
     ok.
 
 %% Check if host is in blacklist or white list
+-spec allow_host(binary(), binary()) -> boolean().
 allow_host(MyServer, S2SHost) ->
     allow_host1(MyServer, S2SHost) andalso
-      not is_temporarly_blocked(S2SHost).
+       not is_temporarly_blocked(S2SHost).
 
+-spec allow_host1(binary(), binary()) -> boolean().
 allow_host1(MyHost, S2SHost) ->
     Rule = ejabberd_option:s2s_access(MyHost),
     JID = jid:make(S2SHost),
@@ -570,8 +572,7 @@ allow_host1(MyHost, S2SHost) ->
             case ejabberd_hooks:run_fold(s2s_allow_host, MyHost,
                                          allow, [MyHost, S2SHost]) of
                 deny -> false;
-                allow -> true;
-                _ -> true
+                allow -> true
             end
     end.
 
@@ -581,8 +582,8 @@ allow_host1(MyHost, S2SHost) ->
 %%       Info = [{InfoName::atom(), InfoValue::any()}]
 get_info_s2s_connections(Type) ->
     ChildType = case Type of
-                 in -> ejabberd_s2s_in_sup;
-                 out -> ejabberd_s2s_out_sup
+                   in -> ejabberd_s2s_in_sup;
+                   out -> ejabberd_s2s_out_sup
                end,
     Connections = supervisor:which_children(ChildType),
     get_s2s_info(Connections, Type).
@@ -597,13 +598,12 @@ complete_s2s_info([Connection | T], Type, Result) ->
     complete_s2s_info(T, Type, [State | Result]).
 
 -spec get_s2s_state(pid()) -> [{status, open | closed | error} | {s2s_pid, pid()}].
-
 get_s2s_state(S2sPid) ->
     Infos = case p1_fsm:sync_send_all_state_event(S2sPid,
-                                                  get_state_infos)
-               of
-             {state_infos, Is} -> [{status, open} | Is];
-             {noproc, _} -> [{status, closed}]; %% Connection closed
-             {badrpc, _} -> [{status, error}]
+                                                 get_state_infos)
+           of
+               {state_infos, Is} -> [{status, open} | Is];
+               {noproc, _} -> [{status, closed}]; %% Connection closed
+               {badrpc, _} -> [{status, error}]
            end,
     [{s2s_pid, S2sPid} | Infos].