-define(COPYRIGHT, "Copyright (c) 2002-2016 ProcessOne").
--define(S2STIMEOUT, timer:minutes(10)).
-
%%-define(DBGFSM, true).
-record(scram,
-include("ejabberd.hrl").
-include("logger.hrl").
-%%
--export_type([
- mechanism/0,
- mechanisms/0,
- sasl_mechanism/0
-]).
-
-record(sasl_mechanism,
{mechanism = <<"">> :: mechanism() | '$1',
module :: atom(),
-type(mechanism() :: binary()).
-type(mechanisms() :: [mechanism(),...]).
-type(password_type() :: plain | digest | scram).
--type(props() :: [{username, binary()} |
- {authzid, binary()} |
- {mechanism, binary()} |
- {auth_module, atom()}]).
+-type sasl_property() :: {username, binary()} |
+ {authzid, binary()} |
+ {mechanism, binary()} |
+ {auth_module, atom()}.
+-type sasl_return() :: {ok, [sasl_property()]} |
+ {ok, [sasl_property()], binary()} |
+ {continue, binary(), any()} |
+ {error, atom()} |
+ {error, atom(), binary()}.
-type(sasl_mechanism() :: #sasl_mechanism{}).
mech_state
}).
-type sasl_state() :: #sasl_state{}.
--export_type([sasl_state/0]).
+-export_type([mechanism/0, mechanisms/0, sasl_mechanism/0,
+ sasl_state/0, sasl_return/0, sasl_property/0]).
-callback mech_new(binary(), fun(), fun(), fun()) -> any().
--callback mech_step(any(), binary()) -> {ok, props()} |
- {ok, props(), binary()} |
- {continue, binary(), any()} |
- {error, atom()} |
- {error, atom(), binary()}.
+-callback mech_step(any(), binary()) -> sasl_return().
start() ->
ets:new(sasl_mechanism,
Children = ejabberd_sm:get_all_pids(),
lists:foreach(
fun(C2SPid) when node(C2SPid) == node() ->
- C2SPid ! system_shutdown;
+ ejabberd_c2s:send(C2SPid, xmpp:serr_system_shutdown());
(_) ->
ok
end, Children).
get_password_s/2, get_password_with_authmodule/2,
is_user_exists/2, is_user_exists_in_other_modules/3,
remove_user/2, remove_user/3, plain_password_required/1,
- store_type/1, entropy/1]).
+ store_type/1, entropy/1, backend_type/1]).
-export([auth_modules/1, opt_type/1]).
length(S) * math:log(lists:sum(Set)) / math:log(2)
end.
+-spec backend_type(atom()) -> atom().
+backend_type(Mod) ->
+ case atom_to_list(Mod) of
+ "ejabberd_auth_" ++ T -> list_to_atom(T);
+ _ -> Mod
+ end.
+
%%%----------------------------------------------------------------------
%%% Internal functions
%%%----------------------------------------------------------------------
-module(ejabberd_c2s).
-behaviour(xmpp_stream_in).
-behaviour(ejabberd_config).
+-behaviour(ejabberd_socket).
-protocol({rfc, 6121}).
%% ejabberd_socket callbacks
--export([start/2, socket_type/0]).
+-export([start/2, start_link/2, socket_type/0]).
%% ejabberd_config callbacks
-export([opt_type/1, transform_listen_option/2]).
%% xmpp_stream_in callbacks
-export([init/1, handle_call/3, handle_cast/2,
handle_info/2, terminate/2, code_change/3]).
--export([tls_options/1, tls_required/1, compress_methods/1,
- sasl_mechanisms/1, init_sasl/1, bind/2, handshake/2,
+-export([tls_options/1, tls_required/1, tls_verify/1,
+ compress_methods/1, bind/2, get_password_fun/1,
+ check_password_fun/1, check_password_digest_fun/1,
unauthenticated_stream_features/1, authenticated_stream_features/1,
- handle_stream_start/1, handle_stream_end/1, handle_stream_close/1,
+ handle_stream_start/2, handle_stream_end/2, handle_stream_close/2,
handle_unauthenticated_packet/2, handle_authenticated_packet/2,
- handle_auth_success/4, handle_auth_failure/4, handle_send/5,
- handle_unbinded_packet/2, handle_cdata/2]).
+ handle_auth_success/4, handle_auth_failure/4, handle_send/3,
+ handle_recv/3, handle_cdata/2, handle_unbinded_packet/2]).
+%% Hooks
+-export([handle_unexpected_info/2, handle_unexpected_cast/2,
+ reject_unauthenticated_packet/2, process_closed/2]).
%% API
-export([get_presence/1, get_subscription/2, get_subscribed/1,
- send/2, close/1]).
+ open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1,
+ copy_state/2, add_hooks/0]).
-include("ejabberd.hrl").
-include("xmpp.hrl").
-define(SETS, gb_sets).
-%%-define(DBGFSM, true).
--ifdef(DBGFSM).
--define(FSMOPTS, [{debug, [trace]}]).
--else.
--define(FSMOPTS, []).
--endif.
-
-type state() :: map().
--type next_state() :: {noreply, state()} | {stop, term(), state()}.
--export_type([state/0, next_state/0]).
+-export_type([state/0]).
%%%===================================================================
%%% ejabberd_socket API
%%%===================================================================
start(SockData, Opts) ->
xmpp_stream_in:start(?MODULE, [SockData, Opts],
- fsm_limit_opts(Opts) ++ ?FSMOPTS).
+ ejabberd_config:fsm_limit_opts(Opts)).
+
+start_link(SockData, Opts) ->
+ xmpp_stream_in:start_link(?MODULE, [SockData, Opts],
+ ejabberd_config:fsm_limit_opts(Opts)).
socket_type() ->
xml_stream.
+-spec call(pid(), term(), non_neg_integer() | infinity) -> term().
+call(Ref, Msg, Timeout) ->
+ xmpp_stream_in:call(Ref, Msg, Timeout).
+
-spec get_presence(pid()) -> presence().
get_presence(Ref) ->
- xmpp_stream_in:call(Ref, get_presence, 1000).
+ call(Ref, get_presence, 1000).
-spec get_subscription(jid() | ljid(), state()) -> both | from | to | none.
get_subscription(#jid{} = From, State) ->
-spec get_subscribed(pid()) -> [ljid()].
%% Return list of all available resources of contacts
get_subscribed(Ref) ->
- xmpp_stream_in:call(Ref, get_subscribed, 1000).
+ call(Ref, get_subscribed, 1000).
--spec close(pid()) -> ok.
close(Ref) ->
- xmpp_stream_in:cast(Ref, closed).
+ xmpp_stream_in:close(Ref).
+
+close(Ref, SendTrailer) ->
+ xmpp_stream_in:close(Ref, SendTrailer).
+
+stop(Ref) ->
+ xmpp_stream_in:stop(Ref).
+
+-spec send(pid(), xmpp_element()) -> ok;
+ (state(), xmpp_element()) -> state().
+send(Pid, Pkt) when is_pid(Pid) ->
+ xmpp_stream_in:send(Pid, Pkt);
+send(#{lserver := LServer} = State, Pkt) ->
+ case ejabberd_hooks:run_fold(c2s_filter_send, LServer, Pkt, [State]) of
+ drop -> State;
+ Pkt1 -> xmpp_stream_in:send(State, Pkt1)
+ end.
+
+-spec establish(state()) -> state().
+establish(State) ->
+ xmpp_stream_in:establish(State).
+
+-spec add_hooks() -> ok.
+add_hooks() ->
+ lists:foreach(
+ fun(Host) ->
+ ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100),
+ ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
+ reject_unauthenticated_packet, 100),
+ ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
+ handle_unexpected_info, 100),
+ ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE,
+ handle_unexpected_cast, 100)
+
+ end, ?MYHOSTS).
+
+%% Copies content of one c2s state to another.
+%% This is needed for session migration from one pid to another.
+-spec copy_state(state(), state()) -> state().
+copy_state(#{owner := Owner} = NewState,
+ #{jid := JID, resource := Resource, sid := {Time, _},
+ auth_module := AuthModule, lserver := LServer,
+ pres_t := PresT, pres_a := PresA,
+ pres_f := PresF} = OldState) ->
+ State1 = case OldState of
+ #{pres_last := Pres, pres_timestamp := PresTS} ->
+ NewState#{pres_last => Pres, pres_timestamp => PresTS};
+ _ ->
+ NewState
+ end,
+ Conn = get_conn_type(State1),
+ State2 = State1#{jid => JID, resource => Resource,
+ conn => Conn,
+ sid => {Time, Owner},
+ auth_module => AuthModule,
+ pres_t => PresT, pres_a => PresA,
+ pres_f => PresF},
+ ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]).
+
+%%%===================================================================
+%%% Hooks
+%%%===================================================================
+handle_unexpected_info(State, Info) ->
+ ?WARNING_MSG("got unexpected info: ~p", [Info]),
+ State.
--spec send(state(), xmpp_element()) -> next_state().
-send(State, Pkt) ->
- xmpp_stream_in:send(State, Pkt).
+handle_unexpected_cast(State, Msg) ->
+ ?WARNING_MSG("got unexpected cast: ~p", [Msg]),
+ State.
+
+reject_unauthenticated_packet(State, Pkt) ->
+ Err = xmpp:err_not_authorized(),
+ xmpp_stream_in:send_error(State, Pkt, Err).
+
+process_closed(State, _Reason) ->
+ stop(State).
%%%===================================================================
%%% xmpp_stream_in callbacks
tls_required(#{tls_required := TLSRequired}) ->
TLSRequired.
+tls_verify(#{tls_verify := TLSVerify}) ->
+ TLSVerify.
+
compress_methods(#{zlib := true}) ->
[<<"zlib">>];
compress_methods(_) ->
[].
-sasl_mechanisms(#{lserver := LServer}) ->
- cyrsasl:listmech(LServer).
-
unauthenticated_stream_features(#{lserver := LServer}) ->
ejabberd_hooks:run_fold(c2s_pre_auth_features, LServer, [], [LServer]).
authenticated_stream_features(#{lserver := LServer}) ->
ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]).
-init_sasl(#{lserver := LServer}) ->
- cyrsasl:server_new(
- <<"jabber">>, LServer, <<"">>, [],
- fun(U) ->
- ejabberd_auth:get_password_with_authmodule(U, LServer)
- end,
- fun(U, AuthzId, P) ->
- ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P)
- end,
- fun(U, AuthzId, P, D, DG) ->
- ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG)
- end).
+get_password_fun(#{lserver := LServer}) ->
+ fun(U) ->
+ ejabberd_auth:get_password_with_authmodule(U, LServer)
+ end.
+
+check_password_fun(#{lserver := LServer}) ->
+ fun(U, AuthzId, P) ->
+ ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P)
+ end.
+
+check_password_digest_fun(#{lserver := LServer}) ->
+ fun(U, AuthzId, P, D, DG) ->
+ ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG)
+ end.
bind(<<"">>, State) ->
bind(new_uniq_id(), State);
-bind(R, #{user := U, server := S} = State) ->
+bind(R, #{user := U, server := S, access := Access, lang := Lang,
+ lserver := LServer, socket := Socket, ip := IP} = State) ->
case resource_conflict_action(U, S, R) of
closenew ->
{error, xmpp:err_conflict(), State};
{accept_resource, Resource} ->
- open_session(State, Resource)
+ JID = jid:make(U, S, Resource),
+ case acl:access_matches(Access,
+ #{usr => jid:split(JID), ip => IP},
+ LServer) of
+ allow ->
+ State1 = open_session(State#{resource => Resource}),
+ State2 = ejabberd_hooks:run_fold(
+ c2s_session_opened, LServer, State1, []),
+ ?INFO_MSG("(~s) Opened session for ~s",
+ [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
+ {ok, State2};
+ deny ->
+ ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]),
+ ?INFO_MSG("(~s) Forbidden session for ~s",
+ [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
+ Txt = <<"Denied by ACL">>,
+ {error, xmpp:err_not_allowed(Txt, Lang), State}
+ end
end.
-handshake(_Data, State) ->
- %% This is only for jabber component
- {ok, State}.
+-spec open_session(state()) -> {ok, state()} | state().
+open_session(#{user := U, server := S, resource := R,
+ sid := SID, ip := IP, auth_module := AuthModule} = State) ->
+ JID = jid:make(U, S, R),
+ change_shaper(State),
+ Conn = get_conn_type(State),
+ State1 = State#{conn => Conn, resource => R, jid => JID},
+ Prio = try maps:get(pres_last, State) of
+ Pres -> get_priority_from_presence(Pres)
+ catch _:{badkey, _} ->
+ undefined
+ end,
+ Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}],
+ ejabberd_sm:open_session(SID, U, S, R, Prio, Info),
+ State1.
-handle_stream_start(#{lserver := LServer, ip := IP, lang := Lang} = State) ->
+handle_stream_start(StreamStart,
+ #{lserver := LServer, ip := IP, lang := Lang} = State) ->
case lists:member(LServer, ?MYHOSTS) of
false ->
- xmpp_stream_in:send(State, xmpp:serr_host_unknown());
+ send(State, xmpp:serr_host_unknown());
true ->
case check_bl_c2s(IP, Lang) of
false ->
change_shaper(State),
- {noreply, State};
+ ejabberd_hooks:run_fold(
+ c2s_stream_started, LServer, State, [StreamStart]);
{true, LogReason, ReasonT} ->
?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s",
[jlib:ip_to_list(IP), LogReason]),
Err = xmpp:serr_policy_violation(ReasonT, Lang),
- xmpp_stream_in:send(State, Err)
+ send(State, Err)
end
end.
-handle_stream_end(State) ->
- {stop, normal, State}.
+handle_stream_end(Reason, #{lserver := LServer} = State) ->
+ ejabberd_hooks:run_fold(c2s_closed, LServer, State, [Reason]).
-handle_stream_close(State) ->
- {stop, normal, State}.
+handle_stream_close(_Reason, #{lserver := LServer} = State) ->
+ ejabberd_hooks:run_fold(c2s_closed, LServer, State, [normal]).
handle_auth_success(User, Mech, AuthModule,
#{socket := Socket, ip := IP, lserver := LServer} = State) ->
- ?INFO_MSG("(~w) Accepted ~s authentication for ~s@~s by ~p from ~s",
- [Socket, Mech, User, LServer, AuthModule,
+ ?INFO_MSG("(~s) Accepted c2s ~s authentication for ~s@~s by ~s backend from ~s",
+ [ejabberd_socket:pp(Socket), Mech, User, LServer,
+ ejabberd_auth:backend_type(AuthModule),
ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
State1 = State#{auth_module => AuthModule},
ejabberd_hooks:run_fold(c2s_auth_result, LServer,
- {noreply, State1}, [true, User]).
+ State1, [true, User]).
handle_auth_failure(User, Mech, Reason,
#{socket := Socket, ip := IP, lserver := LServer} = State) ->
- ?INFO_MSG("(~w) Failed ~s authentication ~sfrom ~s: ~s",
- [Socket, Mech,
+ ?INFO_MSG("(~s) Failed c2s ~s authentication ~sfrom ~s: ~s",
+ [ejabberd_socket:pp(Socket), Mech,
if User /= <<"">> -> ["for ", User, "@", LServer, " "];
true -> ""
end,
ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
ejabberd_hooks:run_fold(c2s_auth_result, LServer,
- {noreply, State}, [false, User]).
+ State, [false, User]).
handle_unbinded_packet(Pkt, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer,
- {noreply, State}, [Pkt]).
+ State, [Pkt]).
handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_unauthenticated_packet,
- LServer, {noreply, State}, [Pkt]).
+ LServer, State, [Pkt]).
handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) ->
ejabberd_hooks:run_fold(c2s_authenticated_packet,
- LServer, {noreply, State}, [Pkt]);
+ LServer, State, [Pkt]);
handle_authenticated_packet(Pkt, #{lserver := LServer} = State) ->
- case ejabberd_hooks:run_fold(c2s_authenticated_packet,
- LServer, {noreply, State}, [Pkt]) of
- {noreply, State1} ->
- Pkt1 = ejabberd_hooks:run_fold(user_send_packet, LServer, Pkt, [State1]),
- Res = case Pkt1 of
- #presence{to = #jid{lresource = <<"">>}} ->
- process_self_presence(State1, Pkt1);
- #presence{} ->
- process_presence_out(State1, Pkt1);
- _ ->
- check_privacy_then_route(State1, Pkt1)
- end,
- ejabberd_hooks:run(c2s_loop_debug, [{xmlstreamelement, Pkt}]),
- Res;
- Err ->
- ejabberd_hooks:run(c2s_loop_debug, [{xmlstreamelement, Pkt}]),
- Err
+ State1 = ejabberd_hooks:run_fold(c2s_authenticated_packet,
+ LServer, State, [Pkt]),
+ Pkt1 = ejabberd_hooks:run_fold(user_send_packet, LServer, Pkt, [State1]),
+ case Pkt1 of
+ #presence{to = #jid{lresource = <<"">>}} ->
+ process_self_presence(State1, Pkt1);
+ #presence{} ->
+ process_presence_out(State1, Pkt1);
+ _ ->
+ check_privacy_then_route(State1, Pkt1)
end.
handle_cdata(Data, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(c2s_handle_cdata, LServer,
- {noreply, State}, [Data]).
+ State, [Data]).
+
+handle_recv(El, Pkt, #{lserver := LServer} = State) ->
+ ejabberd_hooks:run_fold(c2s_handle_recv, LServer, State, [El, Pkt]).
-handle_send(Reason, Pkt, El, Data, #{lserver := LServer} = State) ->
- ejabberd_hooks:run_fold(c2s_handle_send, LServer,
- {noreply, State}, [Reason, Pkt, El, Data]).
+handle_send(Pkt, Result, #{lserver := LServer} = State) ->
+ ejabberd_hooks:run_fold(c2s_handle_send, LServer, State, [Pkt, Result]).
init([State, Opts]) ->
Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all),
server => ?MYNAME,
access => Access,
shaper => Shaper},
- ejabberd_hooks:run_fold(c2s_init, {ok, State1}, []).
+ ejabberd_hooks:run_fold(c2s_init, {ok, State1}, [Opts]).
handle_call(get_presence, _From, #{jid := JID} = State) ->
- Pres = case maps:get(pres_last, State, undefined) of
- undefined ->
+ Pres = try maps:get(pres_last, State)
+ catch _:{badkey, _} ->
BareJID = jid:remove_resource(JID),
- #presence{from = JID, to = BareJID, type = unavailable};
- P ->
- P
+ #presence{from = JID, to = BareJID, type = unavailable}
end,
{reply, Pres, State};
handle_call(get_subscribed, _From, #{pres_f := PresF} = State) ->
{reply, Subscribed, State};
handle_call(Request, From, #{lserver := LServer} = State) ->
ejabberd_hooks:run_fold(
- c2s_handle_call, LServer, {noreply, State}, [Request, From]).
+ c2s_handle_call, LServer, State, [Request, From]).
-handle_cast(closed, State) ->
- handle_stream_close(State);
handle_cast(Msg, #{lserver := LServer} = State) ->
- ejabberd_hooks:run_fold(c2s_handle_cast, LServer, {noreply, State}, [Msg]).
+ ejabberd_hooks:run_fold(c2s_handle_cast, LServer, State, [Msg]).
handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) ->
Packet = xmpp:set_from_to(Packet0, From, To),
Packet1 = ejabberd_hooks:run_fold(
user_receive_packet, LServer, Packet, [NewState]),
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
- xmpp_stream_in:send(NewState, Packet1);
+ send(NewState, Packet1);
true ->
ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]),
- {noreply, NewState}
+ NewState
end;
-handle_info(system_shutdown, State) ->
- xmpp_stream_in:send(State, xmpp:serr_system_shutdown());
handle_info(Info, #{lserver := LServer} = State) ->
- ejabberd_hooks:run_fold(c2s_handle_info, LServer, {noreply, State}, [Info]).
+ ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]).
terminate(_Reason, _State) ->
ok.
check_bl_c2s({IP, _Port}, Lang) ->
ejabberd_hooks:run_fold(check_bl_c2s, false, [IP, Lang]).
--spec open_session(state(), binary()) -> {ok, state()} | {error, stanza_error(), state()}.
-open_session(#{user := U, server := S, lserver := LServer, sid := SID,
- socket := Socket, ip := IP, auth_module := AuthMod,
- access := Access, lang := Lang} = State, R) ->
- JID = jid:make(U, S, R),
- case acl:access_matches(Access,
- #{usr => jid:split(JID), ip => IP},
- LServer) of
- allow ->
- ?INFO_MSG("(~w) Opened session for ~s",
- [Socket, jid:to_string(JID)]),
- change_shaper(State),
- Conn = get_conn_type(State),
- Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}],
- ejabberd_sm:open_session(SID, U, LServer, R, Info),
- State1 = State#{conn => Conn, resource => R, jid => JID},
- State2 = ejabberd_hooks:run_fold(
- c2s_session_opened, LServer, State1, []),
- {ok, State2};
- deny ->
- ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]),
- ?INFO_MSG("(~w) Forbidden session for ~s",
- [Socket, jid:to_string(JID)]),
- Txt = <<"Denied by ACL">>,
- {error, xmpp:err_not_allowed(Txt, Lang), State}
- end.
-
-spec process_iq_in(state(), iq()) -> {boolean(), state()}.
process_iq_in(State, #iq{} = IQ) ->
case privacy_check_packet(State, IQ, in) of
route_probe_reply(_, _, _) ->
ok.
--spec process_presence_out(state(), presence()) -> next_state().
+-spec process_presence_out(state(), presence()) -> state().
process_presence_out(#{user := User, server := Server, lserver := LServer,
jid := JID, lang := Lang, pres_a := PresA} = State,
#presence{from = From, to = To, type = Type} = Pres) ->
[User, Server, To, Type]),
BareFrom = jid:remove_resource(From),
route(xmpp:set_from_to(Pres, BareFrom, To)),
- {noreply, State}
+ State
end;
allow when Type == error; Type == probe ->
route(Pres),
- {noreply, State};
+ State;
allow ->
route(Pres),
A = case Type of
available -> ?SETS:add_element(LTo, PresA);
unavailable -> ?SETS:del_element(LTo, PresA)
end,
- {noreply, State#{pres_a => A}}
+ State#{pres_a => A}
end.
--spec process_self_presence(state(), presence()) -> {noreply, state()}.
+-spec process_self_presence(state(), presence()) -> state().
process_self_presence(#{ip := IP, conn := Conn,
auth_module := AuthMod, sid := SID,
user := U, server := S, resource := R} = State,
Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}],
ejabberd_sm:unset_presence(SID, U, S, R, Status, Info),
State1 = broadcast_presence_unavailable(State, Pres),
- State2 = maps:remove(pres_last, maps:remove(pres_timestamp, State1)),
- {noreply, State2};
+ maps:remove(pres_last, maps:remove(pres_timestamp, State1));
process_self_presence(#{lserver := LServer} = State,
#presence{type = available} = Pres) ->
PreviousPres = maps:get(pres_last, State, undefined),
State2 = State1#{pres_last => Pres,
pres_timestamp => p1_time_compat:timestamp()},
FromUnavailable = PreviousPres == undefined,
- State3 = broadcast_presence_available(State2, Pres, FromUnavailable),
- {noreply, State3};
+ broadcast_presence_available(State2, Pres, FromUnavailable);
process_self_presence(State, _Pres) ->
- {noreply, State}.
+ State.
-spec update_priority(state(), presence()) -> ok.
update_priority(#{ip := IP, conn := Conn, auth_module := AuthMod,
route_multiple(State, JIDs, Pres),
State.
--spec check_privacy_then_route(state(), stanza()) -> next_state().
+-spec check_privacy_then_route(state(), stanza()) -> state().
check_privacy_then_route(#{lang := Lang} = State, Pkt) ->
case privacy_check_packet(State, Pkt, out) of
deny ->
xmpp_stream_in:send_error(State, Pkt, Err);
allow ->
route(Pkt),
- {noreply, State}
+ State
end.
-spec privacy_check_packet(state(), stanza(), in | out) -> allow | deny.
end
end.
--spec fsm_limit_opts([proplists:property()]) -> [proplists:property()].
-fsm_limit_opts(Opts) ->
- case lists:keysearch(max_fsm_queue, 1, Opts) of
- {value, {_, N}} when is_integer(N) -> [{max_queue, N}];
- _ ->
- case ejabberd_config:get_option(
- max_fsm_queue,
- fun(I) when is_integer(I), I > 0 -> I end) of
- undefined -> [];
- N -> [{max_queue, N}]
- end
- end.
-
transform_listen_option(Opt, Opts) ->
[Opt|Opts].
opt_type(domain_certfile) -> fun iolist_to_binary/1;
-opt_type(max_fsm_queue) ->
- fun (I) when is_integer(I), I > 0 -> I end;
opt_type(resource_conflict) ->
fun (setresource) -> setresource;
(closeold) -> closeold;
(acceptnew) -> acceptnew
end;
opt_type(_) ->
- [domain_certfile, max_fsm_queue, resource_conflict].
+ [domain_certfile, resource_conflict].
transform_options/1, collect_options/1, default_db/2,
convert_to_yaml/1, convert_to_yaml/2, v_db/2,
env_binary_to_list/2, opt_type/1, may_hide_data/1,
- is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1]).
+ is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1,
+ fsm_limit_opts/1]).
-export([start/2]).
end;
opt_type(language) ->
fun iolist_to_binary/1;
+opt_type(max_fsm_queue) ->
+ fun (I) when is_integer(I), I > 0 -> I end;
opt_type(_) ->
[hide_sensitive_log_data, hosts, language].
true ->
"hidden_by_ejabberd"
end.
+
+-spec fsm_limit_opts([proplists:property()]) -> [{max_queue, pos_integer()}].
+fsm_limit_opts(Opts) ->
+ case lists:keyfind(max_fsm_queue, 1, Opts) of
+ {_, I} when is_integer(I), I>0 ->
+ [{max_queue, I}];
+ false ->
+ case get_option(
+ max_fsm_queue,
+ fun(I) when is_integer(I), I>0 -> I end) of
+ undefined -> [];
+ N -> [{max_queue, N}]
+ end
+ end.
end.
safe_apply(Module, Function, Args) ->
- if is_function(Function) ->
- catch apply(Function, Args);
- true ->
- catch apply(Module, Function, Args)
+ try if is_function(Function) ->
+ apply(Function, Args);
+ true ->
+ apply(Module, Function, Args)
+ end
+ catch E:R when E /= exit, R /= normal ->
+ {'EXIT', {E, {R, erlang:get_stacktrace()}}}
end.
{ok, Socket} ->
case {inet:sockname(Socket), inet:peername(Socket)} of
{{ok, {Addr, Port}}, {ok, {PAddr, PPort}}} ->
- ?INFO_MSG("(~w) Accepted connection ~s:~p -> ~s:~p",
- [Socket, ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)), PPort,
- inet_parse:ntoa(Addr), Port]);
+ ?INFO_MSG("Accepted connection ~s:~p -> ~s:~p",
+ [ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)),
+ PPort, inet_parse:ntoa(Addr), Port]);
_ ->
ok
end,
unregister_route/1,
unregister_routes/1,
dirty_get_all_routes/0,
- dirty_get_all_domains/0
+ dirty_get_all_domains/0,
+ is_my_route/1,
+ is_my_host/1
]).
-export([start_link/0]).
[?MODULE, ?MODULE]),
register_route(Domain, ?MYNAME).
--spec register_route(binary(), binary()) -> term().
+-spec register_route(binary(), binary()) -> ok.
register_route(Domain, ServerHost) ->
register_route(Domain, ServerHost, undefined).
--spec register_route(binary(), binary(), local_hint()) -> term().
+-spec register_route(binary(), binary(), local_hint()) -> ok.
register_route(Domain, ServerHost, LocalHint) ->
case {jid:nameprep(Domain), jid:nameprep(ServerHost)} of
end
end,
mnesia:transaction(F)
+ end,
+ if LocalHint == undefined ->
+ ?INFO_MSG("Route registered: ~s", [LDomain]);
+ true ->
+ ok
end
end.
end,
Domains).
--spec unregister_route(binary()) -> term().
+-spec unregister_route(binary()) -> ok.
unregister_route(Domain) ->
case jid:nameprep(Domain) of
end
end,
mnesia:transaction(F)
- end
+ end,
+ ?INFO_MSG("Route unregistered: ~s", [LDomain])
end.
-spec unregister_routes([binary()]) -> ok.
end
end.
+-spec is_my_route(binary()) -> boolean().
+is_my_route(Domain) ->
+ case jid:nameprep(Domain) of
+ error ->
+ erlang:error({invalid_domain, Domain});
+ LDomain ->
+ mnesia:dirty_read(route, LDomain) /= []
+ end.
+
+-spec is_my_host(binary()) -> boolean().
+is_my_host(Domain) ->
+ case jid:nameprep(Domain) of
+ error ->
+ erlang:error({invalid_domain, Domain});
+ LDomain ->
+ case mnesia:dirty_read(route, LDomain) of
+ [#route{server_host = Host}|_] ->
+ Host == LDomain;
+ [] ->
+ false
+ end
+ end.
+
-spec process_iq(jid(), jid(), iq() | xmlel()) -> any().
process_iq(From, To, #iq{} = IQ) ->
if To#jid.luser == <<"">> ->
%% API
-export([start_link/0, route/3, have_connection/1,
- make_key/2, get_connections_pids/1, try_register/1,
- remove_connection/2, find_connection/2,
+ get_connections_pids/1, try_register/1,
+ remove_connection/2, start_connection/2, start_connection/3,
dirty_get_connections/0, allow_host/2,
incoming_s2s_number/0, outgoing_s2s_number/0,
stop_all_connections/0,
clean_temporarily_blocked_table/0,
list_temporarily_blocked_hosts/0,
external_host_overloaded/1, is_temporarly_blocked/1,
- check_peer_certificate/3,
- get_commands_spec/0]).
+ get_commands_spec/0, zlib_enabled/1, get_idle_timeout/1,
+ tls_required/1, tls_verify/1, tls_enabled/1, tls_options/2]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2,
dirty_get_connections() ->
mnesia:dirty_all_keys(s2s).
-check_peer_certificate(SockMod, Sock, Peer) ->
- case SockMod:get_peer_certificate(Sock) of
- {ok, Cert} ->
- case SockMod:get_verify_result(Sock) of
- 0 ->
- case ejabberd_idna:domain_utf8_to_ascii(Peer) of
- false ->
- {error, <<"Cannot decode remote server name">>};
- AsciiPeer ->
- case
- lists:any(fun(D) -> match_domain(AsciiPeer, D) end,
- get_cert_domains(Cert)) of
- true ->
- {ok, <<"Verification successful">>};
- false ->
- {error, <<"Certificate host name mismatch">>}
- end
- end;
- VerifyRes ->
- {error, fast_tls:get_cert_verify_string(VerifyRes, Cert)}
- end;
- {error, _Reason} ->
- {error, <<"Cannot get peer certificate">>};
- error ->
- {error, <<"Cannot get peer certificate">>}
+-spec tls_options(binary(), [proplists:property()]) -> [proplists:property()].
+tls_options(LServer, DefaultOpts) ->
+ TLSOpts1 = case ejabberd_config:get_option(
+ {s2s_certfile, LServer},
+ fun iolist_to_binary/1,
+ ejabberd_config:get_option(
+ {domain_certfile, LServer},
+ fun iolist_to_binary/1)) of
+ undefined -> [];
+ CertFile -> lists:keystore(certfile, 1, DefaultOpts,
+ {certfile, CertFile})
+ end,
+ TLSOpts2 = case ejabberd_config:get_option(
+ {s2s_ciphers, LServer},
+ fun iolist_to_binary/1) of
+ undefined -> TLSOpts1;
+ Ciphers -> lists:keystore(ciphers, 1, TLSOpts1,
+ {ciphers, Ciphers})
+ end,
+ TLSOpts3 = case ejabberd_config:get_option(
+ {s2s_protocol_options, LServer},
+ fun (Options) -> str:join(Options, <<$|>>) end) of
+ undefined -> TLSOpts2;
+ ProtoOpts -> lists:keystore(protocol_options, 1, TLSOpts2,
+ {protocol_options, ProtoOpts})
+ end,
+ TLSOpts4 = case ejabberd_config:get_option(
+ {s2s_dhfile, LServer},
+ fun iolist_to_binary/1) of
+ undefined -> TLSOpts3;
+ DHFile -> lists:keystore(dhfile, 1, TLSOpts3,
+ {dhfile, DHFile})
+ end,
+ TLSOpts5 = case ejabberd_config:get_option(
+ {s2s_cafile, LServer},
+ fun iolist_to_binary/1) of
+ undefined -> TLSOpts4;
+ CAFile -> lists:keystore(cafile, 1, TLSOpts4,
+ {cafile, CAFile})
+ end,
+ case ejabberd_config:get_option(
+ {s2s_tls_compression, LServer},
+ fun(B) when is_boolean(B) -> B end) of
+ undefined -> TLSOpts5;
+ false -> [compression_none | TLSOpts5];
+ true -> lists:delete(compression_none, TLSOpts5)
end.
--spec make_key({binary(), binary()}, binary()) -> binary().
-make_key({From, To}, StreamID) ->
- Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end),
- p1_sha:to_hexlist(
- crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)),
- [To, " ", From, " ", StreamID])).
+-spec tls_required(binary()) -> boolean().
+tls_required(LServer) ->
+ TLS = use_starttls(LServer),
+ TLS == required orelse TLS == required_trusted.
+
+-spec tls_verify(binary()) -> boolean().
+tls_verify(LServer) ->
+ TLS = use_starttls(LServer),
+ TLS == required_trusted.
+
+-spec tls_enabled(binary()) -> boolean().
+tls_enabled(LServer) ->
+ TLS = use_starttls(LServer),
+ TLS == true orelse TLS == optional.
+
+-spec zlib_enabled(binary()) -> boolean().
+zlib_enabled(LServer) ->
+ ejabberd_config:get_option(
+ {s2s_zlib, LServer},
+ fun(B) when is_boolean(B) -> B end,
+ false).
+
+-spec use_starttls(binary()) -> boolean() | optional | required | required_trusted.
+use_starttls(LServer) ->
+ ejabberd_config:get_option(
+ {s2s_use_starttls, LServer},
+ fun(true) -> true;
+ (false) -> false;
+ (optional) -> optional;
+ (required) -> required;
+ (required_trusted) -> required_trusted
+ end, false).
+
+-spec get_idle_timeout(binary()) -> non_neg_integer() | infinity.
+get_idle_timeout(LServer) ->
+ ejabberd_config:get_option(
+ {s2s_timeout, LServer},
+ fun(I) when is_integer(I), I >= 0 -> timer:seconds(I);
+ (infinity) -> infinity
+ end, timer:minutes(10)).
%%====================================================================
%% gen_server callbacks
ejabberd_mnesia:create(?MODULE, temporarily_blocked,
[{ram_copies, [node()]},
{attributes, record_info(fields, temporarily_blocked)}]),
+ ejabberd_s2s_in:add_hooks(),
+ ejabberd_s2s_out:add_hooks(),
{ok, #state{}}.
handle_call(_Request, _From, State) ->
end,
mnesia:async_dirty(F).
--spec do_route(jid(), jid(), stanza()) -> ok | false.
+-spec do_route(jid(), jid(), stanza()) -> ok.
do_route(From, To, Packet) ->
?DEBUG("s2s manager~n\tfrom ~p~n\tto ~p~n\tpacket "
"~P~n",
[From, To, Packet, 8]),
- case find_connection(From, To) of
- {atomic, Pid} when is_pid(Pid) ->
- ?DEBUG("sending to process ~p~n", [Pid]),
- #jid{lserver = MyServer} = From,
- ejabberd_hooks:run(s2s_send_packet, MyServer,
- [From, To, Packet]),
- send_element(Pid, xmpp:set_from_to(Packet, From, To)),
- ok;
- {aborted, _Reason} ->
- Lang = xmpp:get_lang(Packet),
- Txt = <<"No s2s connection found">>,
- Err = xmpp:err_service_unavailable(Txt, Lang),
- ejabberd_router:route_error(To, From, Packet, Err),
- false
+ case start_connection(From, To) of
+ {ok, Pid} when is_pid(Pid) ->
+ ?DEBUG("sending to process ~p~n", [Pid]),
+ #jid{lserver = MyServer} = From,
+ ejabberd_hooks:run(s2s_send_packet, MyServer, [From, To, Packet]),
+ ejabberd_s2s_out:route(Pid, xmpp:set_from_to(Packet, From, To));
+ {error, Reason} ->
+ Err = case Reason of
+ forbidden ->
+ Lang = xmpp:get_lang(Packet),
+ xmpp:err_forbidden(<<"Denied by ACL">>, Lang);
+ internal_server_error ->
+ xmpp:err_internal_server_error()
+ end,
+ ejabberd_router:route_error(To, From, Packet, Err)
end.
--spec find_connection(jid(), jid()) -> {aborted, any()} | {atomic, pid()}.
+-spec start_connection(jid(), jid()) -> {ok, pid()} |
+ {error, forbidden | internal_server_error}.
+start_connection(From, To) ->
+ start_connection(From, To, []).
-find_connection(From, To) ->
+-spec start_connection(jid(), jid(), [proplists:property()])
+ -> {ok, pid()} | {error, forbidden | internal_server_error}.
+start_connection(From, To, Opts) ->
#jid{lserver = MyServer} = From,
#jid{lserver = Server} = To,
FromTo = {MyServer, Server},
MaxS2SConnectionsNumberPerNode =
max_s2s_connections_number_per_node(FromTo),
?DEBUG("Finding connection for ~p~n", [FromTo]),
- case catch mnesia:dirty_read(s2s, FromTo) of
- {'EXIT', Reason} -> {aborted, Reason};
+ case mnesia:dirty_read(s2s, FromTo) of
[] ->
%% We try to establish all the connections if the host is not a
%% service and if the s2s host is not blacklisted or
%% is in whitelist:
- case not is_service(From, To) andalso
- allow_host(MyServer, Server)
- of
+ LServer = ejabberd_router:host_of_route(MyServer),
+ case not is_service(From, To) andalso allow_host(LServer, Server) of
true ->
NeededConnections = needed_connections_number([],
MaxS2SConnectionsNumber,
open_several_connections(NeededConnections, MyServer,
Server, From, FromTo,
MaxS2SConnectionsNumber,
- MaxS2SConnectionsNumberPerNode);
- false -> {aborted, error}
+ MaxS2SConnectionsNumberPerNode, Opts);
+ false -> {error, forbidden}
end;
L when is_list(L) ->
NeededConnections = needed_connections_number(L,
open_several_connections(NeededConnections, MyServer,
Server, From, FromTo,
MaxS2SConnectionsNumber,
- MaxS2SConnectionsNumberPerNode);
+ MaxS2SConnectionsNumberPerNode, Opts);
true ->
%% We choose a connexion from the pool of opened ones.
- {atomic, choose_connection(From, L)}
+ {ok, choose_connection(From, L)}
end
end.
open_several_connections(N, MyServer, Server, From,
FromTo, MaxS2SConnectionsNumber,
- MaxS2SConnectionsNumberPerNode) ->
- ConnectionsResult = [new_connection(MyServer, Server,
- From, FromTo, MaxS2SConnectionsNumber,
- MaxS2SConnectionsNumberPerNode)
- || _N <- lists:seq(1, N)],
- case [PID || {atomic, PID} <- ConnectionsResult] of
- [] -> hd(ConnectionsResult);
- PIDs -> {atomic, choose_pid(From, PIDs)}
+ MaxS2SConnectionsNumberPerNode, Opts) ->
+ case lists:flatmap(
+ fun(_) ->
+ new_connection(MyServer, Server,
+ From, FromTo, MaxS2SConnectionsNumber,
+ MaxS2SConnectionsNumberPerNode, Opts)
+ end, lists:seq(1, N)) of
+ [] ->
+ {error, internal_server_error};
+ PIDs ->
+ {ok, choose_pid(From, PIDs)}
end.
new_connection(MyServer, Server, From, FromTo,
- MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode) ->
- {ok, Pid} = ejabberd_s2s_out:start(
- MyServer, Server, new),
+ MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) ->
+ {ok, Pid} = ejabberd_s2s_out:start(MyServer, Server, Opts),
F = fun() ->
L = mnesia:read({s2s, FromTo}),
NeededConnections = needed_connections_number(L,
MaxS2SConnectionsNumberPerNode),
if NeededConnections > 0 ->
mnesia:write(#s2s{fromto = FromTo, pid = Pid}),
- ?INFO_MSG("New s2s connection started ~p", [Pid]),
Pid;
true -> choose_connection(From, L)
end
end,
TRes = mnesia:transaction(F),
case TRes of
- {atomic, Pid} -> ejabberd_s2s_out:start_connection(Pid);
- _ -> ejabberd_s2s_out:stop_connection(Pid)
- end,
- TRes.
+ {atomic, Pid} ->
+ ejabberd_s2s_out:connect(Pid),
+ [Pid];
+ {aborted, Reason} ->
+ ?ERROR_MSG("failed to register connection ~s -> ~s: ~p",
+ [MyServer, Server, Reason]),
+ ejabberd_s2s_out:stop(Pid),
+ []
+ end.
-spec max_s2s_connections_number({binary(), binary()}) -> integer().
max_s2s_connections_number({From, To}) ->
end,
[], lists:reverse(str:tokens(Domain, <<".">>))).
-send_element(Pid, El) ->
- Pid ! {send_element, El}.
-
%%%----------------------------------------------------------------------
%%% ejabberd commands
%% Check if host is in blacklist or white list
allow_host(MyServer, S2SHost) ->
- allow_host2(MyServer, S2SHost) andalso
+ allow_host1(MyServer, S2SHost) andalso
not is_temporarly_blocked(S2SHost).
-allow_host2(MyServer, S2SHost) ->
- Hosts = (?MYHOSTS),
- case lists:dropwhile(fun (ParentDomain) ->
- not lists:member(ParentDomain, Hosts)
- end,
- parent_domains(MyServer))
- of
- [MyHost | _] -> allow_host1(MyHost, S2SHost);
- [] -> allow_host1(MyServer, S2SHost)
- end.
-
allow_host1(MyHost, S2SHost) ->
Rule = ejabberd_config:get_option(
- s2s_access,
- fun(A) -> A end,
+ {s2s_access, MyHost},
+ fun acl:access_rules_validator/1,
all),
JID = jid:make(S2SHost),
case acl:match_rule(MyHost, Rule, JID) of
end,
[{s2s_pid, S2sPid} | Infos].
-get_cert_domains(Cert) ->
- TBSCert = Cert#'Certificate'.tbsCertificate,
- Subject = case TBSCert#'TBSCertificate'.subject of
- {rdnSequence, Subj} -> lists:flatten(Subj);
- _ -> []
- end,
- Extensions = case TBSCert#'TBSCertificate'.extensions of
- Exts when is_list(Exts) -> Exts;
- _ -> []
- end,
- lists:flatmap(fun (#'AttributeTypeAndValue'{type =
- ?'id-at-commonName',
- value = Val}) ->
- case 'OTP-PUB-KEY':decode('X520CommonName', Val) of
- {ok, {_, D1}} ->
- D = if is_binary(D1) -> D1;
- is_list(D1) -> list_to_binary(D1);
- true -> error
- end,
- if D /= error ->
- case jid:from_string(D) of
- #jid{luser = <<"">>, lserver = LD,
- lresource = <<"">>} ->
- [LD];
- _ -> []
- end;
- true -> []
- end;
- _ -> []
- end;
- (_) -> []
- end,
- Subject)
- ++
- lists:flatmap(fun (#'Extension'{extnID =
- ?'id-ce-subjectAltName',
- extnValue = Val}) ->
- BVal = if is_list(Val) -> list_to_binary(Val);
- true -> Val
- end,
- case 'OTP-PUB-KEY':decode('SubjectAltName', BVal)
- of
- {ok, SANs} ->
- lists:flatmap(fun ({otherName,
- #'AnotherName'{'type-id' =
- ?'id-on-xmppAddr',
- value =
- XmppAddr}}) ->
- case
- 'XmppAddr':decode('XmppAddr',
- XmppAddr)
- of
- {ok, D}
- when
- is_binary(D) ->
- case
- jid:from_string((D))
- of
- #jid{luser =
- <<"">>,
- lserver =
- LD,
- lresource =
- <<"">>} ->
- case
- ejabberd_idna:domain_utf8_to_ascii(LD)
- of
- false ->
- [];
- PCLD ->
- [PCLD]
- end;
- _ -> []
- end;
- _ -> []
- end;
- ({dNSName, D})
- when is_list(D) ->
- case
- jid:from_string(list_to_binary(D))
- of
- #jid{luser = <<"">>,
- lserver = LD,
- lresource =
- <<"">>} ->
- [LD];
- _ -> []
- end;
- (_) -> []
- end,
- SANs);
- _ -> []
- end;
- (_) -> []
- end,
- Extensions).
-
-match_domain(Domain, Domain) -> true;
-match_domain(Domain, Pattern) ->
- DLabels = str:tokens(Domain, <<".">>),
- PLabels = str:tokens(Pattern, <<".">>),
- match_labels(DLabels, PLabels).
-
-match_labels([], []) -> true;
-match_labels([], [_ | _]) -> false;
-match_labels([_ | _], []) -> false;
-match_labels([DL | DLabels], [PL | PLabels]) ->
- case lists:all(fun (C) ->
- $a =< C andalso C =< $z orelse
- $0 =< C andalso C =< $9 orelse
- C == $- orelse C == $*
- end,
- binary_to_list(PL))
- of
- true ->
- Regexp = ejabberd_regexp:sh_to_awk(PL),
- case ejabberd_regexp:run(DL, Regexp) of
- match -> match_labels(DLabels, PLabels);
- nomatch -> false
- end;
- false -> false
- end.
-
opt_type(route_subdomains) ->
fun (s2s) -> s2s;
(local) -> local
end;
opt_type(s2s_access) ->
fun acl:access_rules_validator/1;
-opt_type(_) -> [route_subdomains, s2s_access].
+opt_type(domain_certfile) -> fun iolist_to_binary/1;
+opt_type(s2s_certfile) -> fun iolist_to_binary/1;
+opt_type(s2s_ciphers) -> fun iolist_to_binary/1;
+opt_type(s2s_dhfile) -> fun iolist_to_binary/1;
+opt_type(s2s_protocol_options) ->
+ fun (Options) -> str:join(Options, <<"|">>) end;
+opt_type(s2s_tls_compression) ->
+ fun (true) -> true;
+ (false) -> false
+ end;
+opt_type(s2s_use_starttls) ->
+ fun (true) -> true;
+ (false) -> false;
+ (optional) -> optional;
+ (required) -> required;
+ (required_trusted) -> required_trusted
+ end;
+opt_type(s2s_timeout) ->
+ fun(I) when is_integer(I), I>=0 -> I;
+ (infinity) -> infinity
+ end;
+opt_type(_) ->
+ [route_subdomains, s2s_access, s2s_certfile,
+ s2s_ciphers, s2s_dhfile, s2s_protocol_options,
+ s2s_tls_compression, s2s_use_starttls, s2s_timeout].
-%%%----------------------------------------------------------------------
-%%% File : ejabberd_s2s_in.erl
-%%% Author : Alexey Shchepin <alexey@process-one.net>
-%%% Purpose : Serve incoming s2s connection
-%%% Created : 6 Dec 2002 by Alexey Shchepin <alexey@process-one.net>
+%%%-------------------------------------------------------------------
+%%% Created : 12 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
%%%
%%%
%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
%%% with this program; if not, write to the Free Software Foundation, Inc.,
%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
%%%
-%%%----------------------------------------------------------------------
-
+%%%-------------------------------------------------------------------
-module(ejabberd_s2s_in).
-
+-behaviour(xmpp_stream_in).
-behaviour(ejabberd_config).
+-behaviour(ejabberd_socket).
--author('alexey@process-one.net').
-
--behaviour(p1_fsm).
-
-%% External exports
+%% ejabberd_socket callbacks
-export([start/2, start_link/2, socket_type/0]).
-
--export([init/1, wait_for_stream/2,
- wait_for_feature_request/2, stream_established/2,
- handle_event/3, handle_sync_event/4, code_change/4,
- handle_info/3, print_state/1, terminate/3, opt_type/1]).
+%% ejabberd_config callbacks
+-export([opt_type/1]).
+%% xmpp_stream_in callbacks
+-export([init/1, handle_call/3, handle_cast/2,
+ handle_info/2, terminate/2, code_change/3]).
+-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
+ compress_methods/1,
+ unauthenticated_stream_features/1, authenticated_stream_features/1,
+ handle_stream_start/2, handle_stream_end/2, handle_stream_close/2,
+ handle_stream_established/1, handle_auth_success/4,
+ handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2,
+ handle_unauthenticated_packet/2, handle_authenticated_packet/2]).
+%% Hooks
+-export([handle_unexpected_info/2, handle_unexpected_cast/2,
+ reject_unauthenticated_packet/2, process_closed/2]).
+%% API
+-export([stop/1, close/1, send/2, update_state/2, establish/1, add_hooks/0]).
-include("ejabberd.hrl").
--include("logger.hrl").
-
-include("xmpp.hrl").
+-include("logger.hrl").
--define(DICT, dict).
-
--record(state,
- {socket :: ejabberd_socket:socket_state(),
- sockmod = ejabberd_socket :: ejabberd_socket | ejabberd_frontend_socket,
- streamid = <<"">> :: binary(),
- shaper = none :: shaper:shaper(),
- tls = false :: boolean(),
- tls_enabled = false :: boolean(),
- tls_required = false :: boolean(),
- tls_certverify = false :: boolean(),
- tls_options = [] :: list(),
- server = <<"">> :: binary(),
- authenticated = false :: boolean(),
- auth_domain = <<"">> :: binary(),
- connections = (?DICT):new() :: ?TDICT,
- timer = make_ref() :: reference()}).
-
--type state_name() :: wait_for_stream | wait_for_feature_request | stream_established.
--type state() :: #state{}.
--type fsm_next() :: {next_state, state_name(), state()}.
--type fsm_stop() :: {stop, normal, state()}.
--type fsm_transition() :: fsm_stop() | fsm_next().
-
-%%-define(DBGFSM, true).
--ifdef(DBGFSM).
--define(FSMOPTS, [{debug, [trace]}]).
--else.
--define(FSMOPTS, []).
--endif.
+-type state() :: map().
+-export_type([state/0]).
+%%%===================================================================
+%%% API
+%%%===================================================================
start(SockData, Opts) ->
- supervisor:start_child(ejabberd_s2s_in_sup,
- [SockData, Opts]).
+ xmpp_stream_in:start(?MODULE, [SockData, Opts],
+ ejabberd_config:fsm_limit_opts(Opts)).
start_link(SockData, Opts) ->
- p1_fsm:start_link(ejabberd_s2s_in, [SockData, Opts],
- ?FSMOPTS ++ fsm_limit_opts(Opts)).
+ xmpp_stream_in:start_link(?MODULE, [SockData, Opts],
+ ejabberd_config:fsm_limit_opts(Opts)).
+
+close(Ref) ->
+ xmpp_stream_in:close(Ref).
+
+stop(Ref) ->
+ xmpp_stream_in:stop(Ref).
+
+socket_type() ->
+ xml_stream.
+
+-spec send(pid(), xmpp_element()) -> ok;
+ (state(), xmpp_element()) -> state().
+send(Stream, Pkt) ->
+ xmpp_stream_in:send(Stream, Pkt).
+
+-spec establish(state()) -> state().
+establish(State) ->
+ xmpp_stream_in:establish(State).
+
+-spec update_state(pid(), fun((state()) -> state()) |
+ {module(), atom(), list()}) -> ok.
+update_state(Ref, Callback) ->
+ xmpp_stream_in:cast(Ref, {update_state, Callback}).
+
+-spec add_hooks() -> ok.
+add_hooks() ->
+ lists:foreach(
+ fun(Host) ->
+ ejabberd_hooks:add(s2s_in_closed, Host, ?MODULE,
+ process_closed, 100),
+ ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE,
+ reject_unauthenticated_packet, 100),
+ ejabberd_hooks:add(s2s_in_handle_info, Host, ?MODULE,
+ handle_unexpected_info, 100),
+ ejabberd_hooks:add(s2s_in_handle_cast, Host, ?MODULE,
+ handle_unexpected_cast, 100)
+ end, ?MYHOSTS).
+
+%%%===================================================================
+%%% Hooks
+%%%===================================================================
+handle_unexpected_info(State, Info) ->
+ ?WARNING_MSG("got unexpected info: ~p", [Info]),
+ State.
+
+handle_unexpected_cast(State, Msg) ->
+ ?WARNING_MSG("got unexpected cast: ~p", [Msg]),
+ State.
+
+reject_unauthenticated_packet(State, Pkt) ->
+ Err = xmpp:err_not_authorized(),
+ xmpp_stream_in:send_error(State, Pkt, Err).
+
+process_closed(State, _Reason) ->
+ stop(State).
+
+%%%===================================================================
+%%% xmpp_stream_in callbacks
+%%%===================================================================
+tls_options(#{tls_compression := Compression, server_host := LServer}) ->
+ Opts = case Compression of
+ false -> [compression_none];
+ true -> []
+ end,
+ ejabberd_s2s:tls_options(LServer, Opts).
+
+tls_required(#{server_host := LServer}) ->
+ ejabberd_s2s:tls_required(LServer).
-socket_type() -> xml_stream.
+tls_verify(#{server_host := LServer}) ->
+ ejabberd_s2s:tls_verify(LServer).
-%%%----------------------------------------------------------------------
-%%% Callback functions from gen_fsm
-%%%----------------------------------------------------------------------
+tls_enabled(#{server_host := LServer}) ->
+ ejabberd_s2s:tls_enabled(LServer).
-init([{SockMod, Socket}, Opts]) ->
- ?DEBUG("started: ~p", [{SockMod, Socket}]),
- Shaper = case lists:keysearch(shaper, 1, Opts) of
- {value, {_, S}} -> S;
- _ -> none
+compress_methods(#{server_host := LServer}) ->
+ case ejabberd_s2s:zlib_enabled(LServer) of
+ true -> [<<"zlib">>];
+ false -> []
+ end.
+
+unauthenticated_stream_features(#{server_host := LServer}) ->
+ ejabberd_hooks:run_fold(s2s_in_pre_auth_features, LServer, [], [LServer]).
+
+authenticated_stream_features(#{server_host := LServer}) ->
+ ejabberd_hooks:run_fold(s2s_in_post_auth_features, LServer, [], [LServer]).
+
+handle_stream_start(_StreamStart, #{lserver := LServer} = State) ->
+ case check_to(jid:make(LServer), State) of
+ false ->
+ send(State, xmpp:serr_host_unknown());
+ true ->
+ ServerHost = ejabberd_router:host_of_route(LServer),
+ State#{server_host => ServerHost}
+ end.
+
+handle_stream_end(Reason, #{server_host := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [Reason]).
+
+handle_stream_close(_Reason, #{server_host := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [normal]).
+
+handle_stream_established(State) ->
+ set_idle_timeout(State#{established => true}).
+
+handle_auth_success(RServer, Mech, _AuthModule,
+ #{socket := Socket, ip := IP,
+ auth_domains := AuthDomains,
+ server_host := ServerHost,
+ lserver := LServer} = State) ->
+ ?INFO_MSG("(~s) Accepted inbound s2s ~s authentication ~s -> ~s (~s)",
+ [ejabberd_socket:pp(Socket), Mech, RServer, LServer,
+ ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+ State1 = case ejabberd_s2s:allow_host(ServerHost, RServer) of
+ true ->
+ AuthDomains1 = sets:add_element(RServer, AuthDomains),
+ State#{auth_domains => AuthDomains1};
+ false ->
+ State
end,
- {StartTLS, TLSRequired, TLSCertverify} =
- case ejabberd_config:get_option(
- s2s_use_starttls,
- fun(false) -> false;
- (true) -> true;
- (optional) -> optional;
- (required) -> required;
- (required_trusted) -> required_trusted
- end,
- false) of
- UseTls
- when (UseTls == undefined) or
- (UseTls == false) ->
- {false, false, false};
- UseTls
- when (UseTls == true) or
- (UseTls ==
- optional) ->
- {true, false, false};
- required -> {true, true, false};
- required_trusted ->
- {true, true, true}
- end,
- TLSOpts1 = case ejabberd_config:get_option(
- s2s_certfile,
- fun iolist_to_binary/1) of
- undefined -> [];
- CertFile -> [{certfile, CertFile}]
- end,
- TLSOpts2 = case ejabberd_config:get_option(
- s2s_ciphers, fun iolist_to_binary/1) of
- undefined -> TLSOpts1;
- Ciphers -> [{ciphers, Ciphers} | TLSOpts1]
- end,
- TLSOpts3 = case ejabberd_config:get_option(
- s2s_protocol_options,
- fun (Options) ->
- [_|O] = lists:foldl(
- fun(X, Acc) -> X ++ Acc end, [],
- [["|" | binary_to_list(Opt)] || Opt <- Options, is_binary(Opt)]
- ),
- iolist_to_binary(O)
- end) of
- undefined -> TLSOpts2;
- ProtocolOpts -> [{protocol_options, ProtocolOpts} | TLSOpts2]
- end,
- TLSOpts4 = case ejabberd_config:get_option(
- s2s_dhfile, fun iolist_to_binary/1) of
- undefined -> TLSOpts3;
- DHFile -> [{dhfile, DHFile} | TLSOpts3]
- end,
- TLSOpts = case proplists:get_bool(tls_compression, Opts) of
- false -> [compression_none | TLSOpts4];
- true -> TLSOpts4
- end,
- Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
- {ok, wait_for_stream,
- #state{socket = Socket, sockmod = SockMod,
- streamid = new_id(), shaper = Shaper, tls = StartTLS,
- tls_enabled = false, tls_required = TLSRequired,
- tls_certverify = TLSCertverify, tls_options = TLSOpts,
- timer = Timer}}.
-
-%%----------------------------------------------------------------------
-%% Func: StateName/2
-%% Returns: {next_state, NextStateName, NextStateData} |
-%% {next_state, NextStateName, NextStateData, Timeout} |
-%% {stop, Reason, NewStateData}
-%%----------------------------------------------------------------------
-wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
- try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
- #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM}
- when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM ->
- send_header(StateData, {1,0}),
- send_element(StateData, xmpp:serr_invalid_namespace()),
- {stop, normal, StateData};
- #stream_start{to = #jid{lserver = Server},
- from = From, version = {1,0}}
- when StateData#state.tls and not StateData#state.authenticated ->
- send_header(StateData, {1,0}),
- Auth = if StateData#state.tls_enabled ->
- case From of
- #jid{} ->
- {Result, Message} =
- ejabberd_s2s:check_peer_certificate(
- StateData#state.sockmod,
- StateData#state.socket,
- From#jid.lserver),
- {Result, From#jid.lserver, Message};
- undefined ->
- {error, <<"(unknown)">>,
- <<"Got no valid 'from' attribute">>}
- end;
- true ->
- {no_verify, <<"(unknown)">>, <<"TLS not (yet) enabled">>}
- end,
- StartTLS = if StateData#state.tls_enabled -> [];
- not StateData#state.tls_enabled and
- not StateData#state.tls_required ->
- [#starttls{required = false}];
- not StateData#state.tls_enabled and
- StateData#state.tls_required ->
- [#starttls{required = true}]
- end,
- case Auth of
- {error, RemoteServer, CertError}
- when StateData#state.tls_certverify ->
- ?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)",
- [StateData#state.server, RemoteServer, CertError]),
- send_element(StateData,
- xmpp:serr_policy_violation(CertError, ?MYLANG)),
- {stop, normal, StateData};
- {VerifyResult, RemoteServer, Msg} ->
- {SASL, NewStateData} =
- case VerifyResult of
- ok ->
- {[#sasl_mechanisms{list = [<<"EXTERNAL">>]}],
- StateData#state{auth_domain = RemoteServer}};
- error ->
- ?DEBUG("Won't accept certificate of ~s: ~s",
- [RemoteServer, Msg]),
- {[], StateData};
- no_verify ->
- {[], StateData}
- end,
- send_element(NewStateData,
- #stream_features{
- sub_els = SASL ++ StartTLS ++
- ejabberd_hooks:run_fold(
- s2s_stream_features, Server, [],
- [Server])}),
- {next_state, wait_for_feature_request,
- NewStateData#state{server = Server}}
- end;
- #stream_start{to = #jid{lserver = Server},
- version = {1,0}} when StateData#state.authenticated ->
- send_header(StateData, {1,0}),
- send_element(StateData,
- #stream_features{
- sub_els = ejabberd_hooks:run_fold(
- s2s_stream_features, Server, [],
- [Server])}),
- {next_state, stream_established, StateData};
- #stream_start{db_xmlns = ?NS_SERVER_DIALBACK}
- when (StateData#state.tls_required and StateData#state.tls_enabled)
- or (not StateData#state.tls_required) ->
- send_header(StateData, undefined),
- {next_state, stream_established, StateData};
- #stream_start{} ->
- send_header(StateData, {1,0}),
- send_element(StateData, xmpp:serr_undefined_condition()),
- {stop, normal, StateData};
- _ ->
- send_header(StateData, {1,0}),
- send_element(StateData, xmpp:serr_invalid_xml()),
- {stop, normal, StateData}
- catch _:{xmpp_codec, Why} ->
- Txt = xmpp:format_error(Why),
- send_header(StateData, {1,0}),
- send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)),
- {stop, normal, StateData}
- end;
-wait_for_stream({xmlstreamerror, _}, StateData) ->
- send_header(StateData, {1,0}),
- send_element(StateData, xmpp:serr_not_well_formed()),
- {stop, normal, StateData};
-wait_for_stream(timeout, StateData) ->
- send_header(StateData, {1,0}),
- send_element(StateData, xmpp:serr_connection_timeout()),
- {stop, normal, StateData};
-wait_for_stream(closed, StateData) ->
- {stop, normal, StateData}.
-
-wait_for_feature_request({xmlstreamelement, El}, StateData) ->
- decode_element(El, wait_for_feature_request, StateData);
-wait_for_feature_request(#starttls{},
- #state{tls = true, tls_enabled = false} = StateData) ->
- case (StateData#state.sockmod):get_sockmod(StateData#state.socket) of
- gen_tcp ->
- ?DEBUG("starttls", []),
- Socket = StateData#state.socket,
- TLSOpts1 = case
- ejabberd_config:get_option(
- {domain_certfile, StateData#state.server},
- fun iolist_to_binary/1) of
- undefined -> StateData#state.tls_options;
- CertFile ->
- lists:keystore(certfile, 1,
- StateData#state.tls_options,
- {certfile, CertFile})
- end,
- TLSOpts2 = case ejabberd_config:get_option(
- {s2s_cafile, StateData#state.server},
- fun iolist_to_binary/1) of
- undefined -> TLSOpts1;
- CAFile ->
- lists:keystore(cafile, 1, TLSOpts1,
- {cafile, CAFile})
- end,
- TLSOpts = case ejabberd_config:get_option(
- {s2s_tls_compression, StateData#state.server},
- fun(true) -> true;
- (false) -> false
- end, false) of
- true -> lists:delete(compression_none, TLSOpts2);
- false -> [compression_none | TLSOpts2]
- end,
- TLSSocket = (StateData#state.sockmod):starttls(
- Socket, TLSOpts,
- fxml:element_to_binary(
- xmpp:encode(#starttls_proceed{}))),
- {next_state, wait_for_stream,
- StateData#state{socket = TLSSocket, streamid = new_id(),
- tls_enabled = true, tls_options = TLSOpts}};
- _ ->
- send_element(StateData, #starttls_failure{}),
- {stop, normal, StateData}
- end;
-wait_for_feature_request(#sasl_auth{mechanism = Mech},
- #state{tls_enabled = true} = StateData) ->
- case Mech of
- <<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> ->
- AuthDomain = StateData#state.auth_domain,
- AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>, AuthDomain),
- if AllowRemoteHost ->
- (StateData#state.sockmod):reset_stream(StateData#state.socket),
- send_element(StateData, #sasl_success{}),
- ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)",
- [AuthDomain, StateData#state.tls_enabled]),
- change_shaper(StateData, <<"">>, jid:make(AuthDomain)),
- {next_state, wait_for_stream,
- StateData#state{streamid = new_id(),
- authenticated = true}};
- true ->
- Txt = xmpp:mk_text(<<"Denied by ACL">>, ?MYLANG),
- send_element(StateData,
- #sasl_failure{reason = 'not-authorized',
- text = Txt}),
- {stop, normal, StateData}
- end;
- _ ->
- send_element(StateData, #sasl_failure{reason = 'invalid-mechanism'}),
- {stop, normal, StateData}
- end;
-wait_for_feature_request({xmlstreamend, _Name}, StateData) ->
- {stop, normal, StateData};
-wait_for_feature_request({xmlstreamerror, _}, StateData) ->
- send_element(StateData, xmpp:serr_not_well_formed()),
- {stop, normal, StateData};
-wait_for_feature_request(closed, StateData) ->
- {stop, normal, StateData};
-wait_for_feature_request(_Pkt, #state{tls_required = TLSRequired,
- tls_enabled = TLSEnabled} = StateData)
- when TLSRequired and not TLSEnabled ->
- Txt = <<"Use of STARTTLS required">>,
- send_element(StateData, xmpp:serr_policy_violation(Txt, ?MYLANG)),
- {stop, normal, StateData};
-wait_for_feature_request(El, StateData) ->
- stream_established({xmlstreamelement, El}, StateData).
-
-stream_established({xmlstreamelement, El}, StateData) ->
- cancel_timer(StateData#state.timer),
- Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
- decode_element(El, stream_established, StateData#state{timer = Timer});
-stream_established(#db_result{to = To, from = From, key = Key},
- StateData) ->
- ?DEBUG("GET KEY: ~p", [{To, From, Key}]),
- case {ejabberd_s2s:allow_host(To, From),
- lists:member(To, ejabberd_router:dirty_get_all_domains())} of
- {true, true} ->
- ejabberd_s2s_out:terminate_if_waiting_delay(To, From),
- ejabberd_s2s_out:start(To, From,
- {verify, self(), Key,
- StateData#state.streamid}),
- Conns = (?DICT):store({From, To},
- wait_for_verification,
- StateData#state.connections),
- change_shaper(StateData, To, jid:make(From)),
- {next_state, stream_established,
- StateData#state{connections = Conns}};
- {_, false} ->
- send_element(StateData, xmpp:serr_host_unknown()),
- {stop, normal, StateData};
- {false, _} ->
- send_element(StateData, xmpp:serr_invalid_from()),
- {stop, normal, StateData}
- end;
-stream_established(#db_verify{to = To, from = From, id = Id, key = Key},
- StateData) ->
- ?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]),
- Type = case ejabberd_s2s:make_key({To, From}, Id) of
- Key -> valid;
- _ -> invalid
- end,
- send_element(StateData,
- #db_verify{from = To, to = From, id = Id, type = Type}),
- {next_state, stream_established, StateData};
-stream_established(Pkt, StateData) when ?is_stanza(Pkt) ->
+ ejabberd_hooks:run_fold(s2s_in_auth_result, ServerHost, State1, [true, RServer]).
+
+handle_auth_failure(RServer, Mech, Reason,
+ #{socket := Socket, ip := IP,
+ server_host := ServerHost,
+ lserver := LServer} = State) ->
+ ?INFO_MSG("(~s) Failed inbound s2s ~s authentication ~s -> ~s (~s): ~s",
+ [ejabberd_socket:pp(Socket), Mech, RServer, LServer,
+ ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
+ ejabberd_hooks:run_fold(s2s_in_auth_result,
+ ServerHost, State, [false, RServer]).
+
+handle_unauthenticated_packet(Pkt, #{server_host := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_in_unauthenticated_packet,
+ LServer, State, [Pkt]).
+
+handle_authenticated_packet(Pkt, #{server_host := LServer} = State) when not ?is_stanza(Pkt) ->
+ ejabberd_hooks:run_fold(s2s_in_authenticated_packet, LServer, State, [Pkt]);
+handle_authenticated_packet(Pkt, State) ->
From = xmpp:get_from(Pkt),
To = xmpp:get_to(Pkt),
- if To /= undefined, From /= undefined ->
- LFrom = From#jid.lserver,
- LTo = To#jid.lserver,
- if StateData#state.authenticated ->
- case LFrom == StateData#state.auth_domain andalso
- lists:member(LTo, ejabberd_router:dirty_get_all_domains()) of
- true ->
- ejabberd_hooks:run(s2s_receive_packet, LTo,
- [From, To, Pkt]),
- ejabberd_router:route(From, To, Pkt);
- false ->
- send_error(StateData, Pkt, xmpp:err_not_authorized())
- end;
- true ->
- case (?DICT):find({LFrom, LTo}, StateData#state.connections) of
- {ok, established} ->
- ejabberd_hooks:run(s2s_receive_packet, LTo,
- [From, To, Pkt]),
- ejabberd_router:route(From, To, Pkt);
- _ ->
- send_error(StateData, Pkt, xmpp:err_not_authorized())
- end
- end;
- true ->
- send_error(StateData, Pkt, xmpp:err_jid_malformed())
- end,
- ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]),
- {next_state, stream_established, StateData};
-stream_established({valid, From, To}, StateData) ->
- send_element(StateData,
- #db_result{from = To, to = From, type = valid}),
- ?INFO_MSG("Accepted s2s dialback authentication for ~s (TLS=~p)",
- [From, StateData#state.tls_enabled]),
- NSD = StateData#state{connections =
- (?DICT):store({From, To}, established,
- StateData#state.connections)},
- {next_state, stream_established, NSD};
-stream_established({invalid, From, To}, StateData) ->
- send_element(StateData,
- #db_result{from = To, to = From, type = invalid}),
- NSD = StateData#state{connections =
- (?DICT):erase({From, To},
- StateData#state.connections)},
- {next_state, stream_established, NSD};
-stream_established({xmlstreamend, _Name}, StateData) ->
- {stop, normal, StateData};
-stream_established({xmlstreamerror, _}, StateData) ->
- send_element(StateData, xmpp:serr_not_well_formed()),
- {stop, normal, StateData};
-stream_established(timeout, StateData) ->
- send_element(StateData, xmpp:serr_connection_timeout()),
- {stop, normal, StateData};
-stream_established(closed, StateData) ->
- {stop, normal, StateData};
-stream_established(Pkt, StateData) ->
- ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]),
- {next_state, stream_established, StateData}.
-
-%%----------------------------------------------------------------------
-%% Func: StateName/3
-%% Returns: {next_state, NextStateName, NextStateData} |
-%% {next_state, NextStateName, NextStateData, Timeout} |
-%% {reply, Reply, NextStateName, NextStateData} |
-%% {reply, Reply, NextStateName, NextStateData, Timeout} |
-%% {stop, Reason, NewStateData} |
-%% {stop, Reason, Reply, NewStateData}
-%%----------------------------------------------------------------------
-%state_name(Event, From, StateData) ->
-% Reply = ok,
-% {reply, Reply, state_name, StateData}.
-
-handle_event(_Event, StateName, StateData) ->
- {next_state, StateName, StateData}.
-
-handle_sync_event(get_state_infos, _From, StateName,
- StateData) ->
- SockMod = StateData#state.sockmod,
- {Addr, Port} = try
- SockMod:peername(StateData#state.socket)
- of
- {ok, {A, P}} -> {A, P};
- {error, _} -> {unknown, unknown}
- catch
- _:_ -> {unknown, unknown}
- end,
- Domains = get_external_hosts(StateData),
- Infos = [{direction, in}, {statename, StateName},
- {addr, Addr}, {port, Port},
- {streamid, StateData#state.streamid},
- {tls, StateData#state.tls},
- {tls_enabled, StateData#state.tls_enabled},
- {tls_options, StateData#state.tls_options},
- {authenticated, StateData#state.authenticated},
- {shaper, StateData#state.shaper}, {sockmod, SockMod},
- {domains, Domains}],
- Reply = {state_infos, Infos},
- {reply, Reply, StateName, StateData};
-%%----------------------------------------------------------------------
-%% Func: handle_sync_event/4
-%% Returns: {next_state, NextStateName, NextStateData} |
-%% {next_state, NextStateName, NextStateData, Timeout} |
-%% {reply, Reply, NextStateName, NextStateData} |
-%% {reply, Reply, NextStateName, NextStateData, Timeout} |
-%% {stop, Reason, NewStateData} |
-%% {stop, Reason, Reply, NewStateData}
-%%----------------------------------------------------------------------
-handle_sync_event(_Event, _From, StateName,
- StateData) ->
- Reply = ok, {reply, Reply, StateName, StateData}.
-
-code_change(_OldVsn, StateName, StateData, _Extra) ->
- {ok, StateName, StateData}.
-
-handle_info({send_text, Text}, StateName, StateData) ->
- send_text(StateData, Text),
- {next_state, StateName, StateData};
-handle_info({timeout, Timer, _}, StateName,
- #state{timer = Timer} = StateData) ->
- if StateName == wait_for_stream ->
- send_header(StateData, undefined);
- true ->
- ok
- end,
- send_element(StateData, xmpp:serr_connection_timeout()),
- {stop, normal, StateData};
-handle_info(_, StateName, StateData) ->
- {next_state, StateName, StateData}.
-
-terminate(Reason, _StateName, StateData) ->
- ?DEBUG("terminated: ~p", [Reason]),
- case Reason of
- {process_limit, _} ->
- [ejabberd_s2s:external_host_overloaded(Host)
- || Host <- get_external_hosts(StateData)];
- _ -> ok
- end,
- catch send_trailer(StateData),
- (StateData#state.sockmod):close(StateData#state.socket),
- ok.
-
-get_external_hosts(StateData) ->
- case StateData#state.authenticated of
- true -> [StateData#state.auth_domain];
- false ->
- Connections = StateData#state.connections,
- [D
- || {{D, _}, established} <- dict:to_list(Connections)]
+ case check_from_to(From, To, State) of
+ ok ->
+ LServer = ejabberd_router:host_of_route(To#jid.lserver),
+ State1 = ejabberd_hooks:run_fold(s2s_in_authenticated_packet,
+ LServer, State, [Pkt]),
+ Pkt1 = ejabberd_hooks:run_fold(s2s_receive_packet, LServer,
+ Pkt, [State1]),
+ ejabberd_router:route(From, To, Pkt1),
+ State1;
+ {error, Err} ->
+ send(State, Err)
end.
-print_state(State) -> State.
+handle_cdata(Data, #{server_host := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_in_handle_cdata, LServer, State, [Data]).
+
+handle_recv(El, Pkt, #{server_host := LServer} = State) ->
+ State1 = set_idle_timeout(State),
+ ejabberd_hooks:run_fold(s2s_in_handle_recv, LServer, State1, [El, Pkt]).
+
+handle_send(Pkt, Result, #{server_host := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_in_handle_send, LServer,
+ State, [Pkt, Result]).
+
+init([State, Opts]) ->
+ Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none),
+ TLSCompression = proplists:get_bool(tls_compression, Opts),
+ State1 = State#{tls_compression => TLSCompression,
+ auth_domains => sets:new(),
+ xmlns => ?NS_SERVER,
+ lang => ?MYLANG,
+ server => ?MYNAME,
+ lserver => ?MYNAME,
+ server_host => ?MYNAME,
+ established => false,
+ shaper => Shaper},
+ ejabberd_hooks:run_fold(s2s_in_init, {ok, State1}, [Opts]).
+
+handle_call(Request, From, #{server_host := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_in_handle_call, LServer, State, [Request, From]).
+
+handle_cast({update_state, Fun}, State) ->
+ case Fun of
+ {M, F, A} -> erlang:apply(M, F, [State|A]);
+ _ when is_function(Fun) -> Fun(State)
+ end;
+handle_cast(Msg, #{server_host := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_in_handle_cast, LServer, State, [Msg]).
+
+handle_info(Info, #{server_host := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_in_handle_info, LServer, State, [Info]).
+
+terminate(_Reason, _State) ->
+ ok.
+
+code_change(_OldVsn, State, _Extra) ->
+ {ok, State}.
-%%%----------------------------------------------------------------------
+%%%===================================================================
%%% Internal functions
-%%%----------------------------------------------------------------------
-
--spec send_text(state(), iodata()) -> ok.
-send_text(StateData, Text) ->
- (StateData#state.sockmod):send(StateData#state.socket,
- Text).
-
--spec send_element(state(), xmpp_element()) -> ok.
-send_element(StateData, El) ->
- El1 = xmpp:encode(El, ?NS_SERVER),
- send_text(StateData, fxml:element_to_binary(El1)).
-
--spec send_error(state(), xmlel() | stanza(), stanza_error()) -> ok.
-send_error(StateData, Stanza, Error) ->
- Type = xmpp:get_type(Stanza),
- if Type == error; Type == result;
- Type == <<"error">>; Type == <<"result">> ->
- ok;
- true ->
- send_element(StateData, xmpp:make_error(Stanza, Error))
+%%%===================================================================
+-spec check_from_to(jid(), jid(), state()) -> ok | {error, stream_error()}.
+check_from_to(From, To, State) ->
+ case check_from(From, State) of
+ true ->
+ case check_to(To, State) of
+ true ->
+ ok;
+ false ->
+ {error, xmpp:serr_improper_addressing()}
+ end;
+ false ->
+ {error, xmpp:serr_invalid_from()}
end.
--spec send_trailer(state()) -> ok.
-send_trailer(StateData) ->
- send_text(StateData, <<"</stream:stream>">>).
-
--spec send_header(state(), undefined | {integer(), integer()}) -> ok.
-send_header(StateData, Version) ->
- Header = xmpp:encode(
- #stream_start{xmlns = ?NS_SERVER,
- stream_xmlns = ?NS_STREAM,
- db_xmlns = ?NS_SERVER_DIALBACK,
- id = StateData#state.streamid,
- version = Version}),
- send_text(StateData, fxml:element_to_header(Header)).
-
--spec change_shaper(state(), binary(), jid()) -> ok.
-change_shaper(StateData, Host, JID) ->
- Shaper = acl:match_rule(Host, StateData#state.shaper,
- JID),
- (StateData#state.sockmod):change_shaper(StateData#state.socket,
- Shaper).
-
--spec new_id() -> binary().
-new_id() -> randoms:get_string().
-
--spec cancel_timer(reference()) -> ok.
-cancel_timer(Timer) ->
- erlang:cancel_timer(Timer),
- receive {timeout, Timer, _} -> ok after 0 -> ok end.
-
-fsm_limit_opts(Opts) ->
- case lists:keysearch(max_fsm_queue, 1, Opts) of
- {value, {_, N}} when is_integer(N) -> [{max_queue, N}];
- _ ->
- case ejabberd_config:get_option(
- max_fsm_queue,
- fun(I) when is_integer(I), I > 0 -> I end) of
- undefined -> [];
- N -> [{max_queue, N}]
- end
- end.
+-spec check_from(jid(), state()) -> boolean().
+check_from(#jid{lserver = S1}, #{auth_domains := AuthDomains}) ->
+ sets:is_element(S1, AuthDomains).
+
+-spec check_to(jid(), state()) -> boolean().
+check_to(#jid{lserver = LServer}, _State) ->
+ ejabberd_router:is_my_route(LServer).
+
+-spec set_idle_timeout(state()) -> state().
+set_idle_timeout(#{server_host := LServer,
+ established := true} = State) ->
+ Timeout = ejabberd_s2s:get_idle_timeout(LServer),
+ xmpp_stream_in:set_timeout(State, Timeout);
+set_idle_timeout(State) ->
+ State.
--spec decode_element(xmlel() | xmpp_element(), state_name(), state()) -> fsm_transition().
-decode_element(#xmlel{} = El, StateName, StateData) ->
- Opts = if StateName == stream_established ->
- [ignore_els];
- true ->
- []
- end,
- try xmpp:decode(El, ?NS_SERVER, Opts) of
- Pkt -> ?MODULE:StateName(Pkt, StateData)
- catch error:{xmpp_codec, Why} ->
- case xmpp:is_stanza(El) of
- true ->
- Lang = xmpp:get_lang(El),
- Txt = xmpp:format_error(Why),
- send_error(StateData, El, xmpp:err_bad_request(Txt, Lang));
- false ->
- ok
- end,
- {next_state, StateName, StateData}
- end;
-decode_element(Pkt, StateName, StateData) ->
- ?MODULE:StateName(Pkt, StateData).
-
-opt_type(domain_certfile) -> fun iolist_to_binary/1;
-opt_type(max_fsm_queue) ->
- fun (I) when is_integer(I), I > 0 -> I end;
-opt_type(s2s_certfile) -> fun iolist_to_binary/1;
-opt_type(s2s_cafile) -> fun iolist_to_binary/1;
-opt_type(s2s_ciphers) -> fun iolist_to_binary/1;
-opt_type(s2s_dhfile) -> fun iolist_to_binary/1;
-opt_type(s2s_protocol_options) ->
- fun (Options) ->
- [_ | O] = lists:foldl(fun (X, Acc) -> X ++ Acc end, [],
- [["|" | binary_to_list(Opt)]
- || Opt <- Options, is_binary(Opt)]),
- iolist_to_binary(O)
- end;
-opt_type(s2s_tls_compression) ->
- fun (true) -> true;
- (false) -> false
- end;
-opt_type(s2s_use_starttls) ->
- fun (false) -> false;
- (true) -> true;
- (optional) -> optional;
- (required) -> required;
- (required_trusted) -> required_trusted
- end;
opt_type(_) ->
- [domain_certfile, max_fsm_queue, s2s_certfile, s2s_cafile,
- s2s_ciphers, s2s_dhfile, s2s_protocol_options,
- s2s_tls_compression, s2s_use_starttls].
+ [].
-%%%----------------------------------------------------------------------
-%%% File : ejabberd_s2s_out.erl
-%%% Author : Alexey Shchepin <alexey@process-one.net>
-%%% Purpose : Manage outgoing server-to-server connections
-%%% Created : 6 Dec 2002 by Alexey Shchepin <alexey@process-one.net>
+%%%-------------------------------------------------------------------
+%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%% @copyright (C) 2016, Evgeny Khramtsov
+%%% @doc
%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
+%%% @end
+%%% Created : 16 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%-------------------------------------------------------------------
-module(ejabberd_s2s_out).
-
+-behaviour(xmpp_stream_out).
-behaviour(ejabberd_config).
--author('alexey@process-one.net').
-
--behaviour(p1_fsm).
-
-%% External exports
--export([start/3,
- start_link/3,
- start_connection/1,
- terminate_if_waiting_delay/2,
- stop_connection/1,
- transform_options/1]).
-
--export([init/1, open_socket/2, wait_for_stream/2,
- wait_for_validation/2, wait_for_features/2,
- wait_for_auth_result/2, wait_for_starttls_proceed/2,
- relay_to_bridge/2, reopen_socket/2, wait_before_retry/2,
- stream_established/2, handle_event/3,
- handle_sync_event/4, handle_info/3, terminate/3,
- print_state/1, code_change/4, test_get_addr_port/1,
- get_addr_port/1, opt_type/1]).
+%% ejabberd_config callbacks
+-export([opt_type/1, transform_options/1]).
+%% xmpp_stream_out callbacks
+-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
+ handle_auth_success/2, handle_auth_failure/3, handle_packet/2,
+ handle_stream_end/2, handle_stream_close/2,
+ handle_recv/3, handle_send/4, handle_cdata/2,
+ handle_stream_established/1, handle_timeout/1]).
+-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
+ terminate/2, code_change/3]).
+%% Hooks
+-export([process_auth_result/2, process_closed/2, handle_unexpected_info/2,
+ handle_unexpected_cast/2]).
+%% API
+-export([start/3, start_link/3, connect/1, close/1, stop/1, send/2,
+ route/2, establish/1, update_state/2, add_hooks/0]).
-include("ejabberd.hrl").
--include("logger.hrl").
-include("xmpp.hrl").
+-include("logger.hrl").
--record(state,
- {socket :: ejabberd_socket:socket_state(),
- streamid = <<"">> :: binary(),
- remote_streamid = <<"">> :: binary(),
- use_v10 = true :: boolean(),
- tls = false :: boolean(),
- tls_required = false :: boolean(),
- tls_certverify = false :: boolean(),
- tls_enabled = false :: boolean(),
- tls_options = [connect] :: list(),
- authenticated = false :: boolean(),
- db_enabled = true :: boolean(),
- try_auth = true :: boolean(),
- myname = <<"">> :: binary(),
- server = <<"">> :: binary(),
- queue = queue:new() :: ?TQUEUE,
- delay_to_retry = undefined_delay :: undefined_delay | non_neg_integer(),
- new = false :: boolean(),
- verify = false :: false | {pid(), binary(), binary()},
- bridge :: {atom(), atom()},
- timer = make_ref() :: reference()}).
-
--type state_name() :: open_socket | wait_for_stream |
- wait_for_validation | wait_for_features |
- wait_for_auth_result | wait_for_starttls_proceed |
- relay_to_bridge | reopen_socket | wait_before_retry |
- stream_established.
--type state() :: #state{}.
--type fsm_stop() :: {stop, normal, state()}.
--type fsm_next() :: {next_state, state_name(), state(), non_neg_integer()} |
- {next_state, state_name(), state()}.
--type fsm_transition() :: fsm_stop() | fsm_next().
-
-%%-define(DBGFSM, true).
-
--ifdef(DBGFSM).
-
--define(FSMOPTS, [{debug, [trace]}]).
-
--else.
-
--define(FSMOPTS, []).
-
--endif.
-
--define(FSMTIMEOUT, 30000).
-
-%% We do not block on send anymore.
--define(TCP_SEND_TIMEOUT, 15000).
-
-%% Maximum delay to wait before retrying to connect after a failed attempt.
-%% Specified in miliseconds. Default value is 5 minutes.
--define(MAX_RETRY_DELAY, 300000).
-
--define(SOCKET_DEFAULT_RESULT, {error, badarg}).
+-type state() :: map().
+-export_type([state/0]).
-%%%----------------------------------------------------------------------
+%%%===================================================================
%%% API
-%%%----------------------------------------------------------------------
-start(From, Host, Type) ->
- supervisor:start_child(ejabberd_s2s_out_sup,
- [From, Host, Type]).
-
-start_link(From, Host, Type) ->
- p1_fsm:start_link(ejabberd_s2s_out, [From, Host, Type],
- fsm_limit_opts() ++ (?FSMOPTS)).
-
-start_connection(Pid) -> p1_fsm:send_event(Pid, init).
-
-stop_connection(Pid) -> p1_fsm:send_event(Pid, closed).
-
-%%%----------------------------------------------------------------------
-%%% Callback functions from p1_fsm
-%%%----------------------------------------------------------------------
-
-init([From, Server, Type]) ->
- process_flag(trap_exit, true),
- ?DEBUG("started: ~p", [{From, Server, Type}]),
- {TLS, TLSRequired, TLSCertverify} =
- case ejabberd_config:get_option(
- s2s_use_starttls,
- fun(true) -> true;
- (false) -> false;
- (optional) -> optional;
- (required) -> required;
- (required_trusted) -> required_trusted
- end)
- of
- UseTls
- when (UseTls == undefined) or (UseTls == false) ->
- {false, false, false};
- UseTls
- when (UseTls == true) or (UseTls == optional) ->
- {true, false, false};
- required ->
- {true, true, false};
- required_trusted ->
- {true, true, true}
- end,
- UseV10 = TLS,
- TLSOpts1 = case
- ejabberd_config:get_option(
- s2s_certfile, fun iolist_to_binary/1)
- of
- undefined -> [connect];
- CertFile -> [{certfile, CertFile}, connect]
- end,
- TLSOpts2 = case ejabberd_config:get_option(
- s2s_ciphers, fun iolist_to_binary/1) of
- undefined -> TLSOpts1;
- Ciphers -> [{ciphers, Ciphers} | TLSOpts1]
- end,
- TLSOpts3 = case ejabberd_config:get_option(
- s2s_protocol_options,
- fun (Options) ->
- [_|O] = lists:foldl(
- fun(X, Acc) -> X ++ Acc end, [],
- [["|" | binary_to_list(Opt)] || Opt <- Options, is_binary(Opt)]
- ),
- iolist_to_binary(O)
- end) of
- undefined -> TLSOpts2;
- ProtocolOpts -> [{protocol_options, ProtocolOpts} | TLSOpts2]
- end,
- TLSOpts4 = case ejabberd_config:get_option(
- s2s_dhfile, fun iolist_to_binary/1) of
- undefined -> TLSOpts3;
- DHFile -> [{dhfile, DHFile} | TLSOpts3]
- end,
- TLSOpts = case ejabberd_config:get_option(
- {s2s_tls_compression, From},
- fun(true) -> true;
- (false) -> false
- end, false) of
- false -> [compression_none | TLSOpts4];
- true -> TLSOpts4
- end,
- {New, Verify} = case Type of
- new -> {true, false};
- {verify, Pid, Key, SID} ->
- start_connection(self()), {false, {Pid, Key, SID}}
- end,
- Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
- {ok, open_socket,
- #state{use_v10 = UseV10, tls = TLS,
- tls_required = TLSRequired, tls_certverify = TLSCertverify,
- tls_options = TLSOpts, queue = queue:new(), myname = From,
- server = Server, new = New, verify = Verify, timer = Timer}}.
-
-open_socket(init, StateData) ->
- log_s2s_out(StateData#state.new, StateData#state.myname,
- StateData#state.server, StateData#state.tls),
- ?DEBUG("open_socket: ~p",
- [{StateData#state.myname, StateData#state.server,
- StateData#state.new, StateData#state.verify}]),
- AddrList = case
- ejabberd_idna:domain_utf8_to_ascii(StateData#state.server)
- of
- false -> [];
- ASCIIAddr -> get_addr_port(ASCIIAddr)
+%%%===================================================================
+start(From, To, Opts) ->
+ xmpp_stream_out:start(?MODULE, [ejabberd_socket, From, To, Opts],
+ ejabberd_config:fsm_limit_opts([])).
+
+start_link(From, To, Opts) ->
+ xmpp_stream_out:start_link(?MODULE, [ejabberd_socket, From, To, Opts],
+ ejabberd_config:fsm_limit_opts([])).
+
+connect(Ref) ->
+ xmpp_stream_out:connect(Ref).
+
+close(Ref) ->
+ xmpp_stream_out:close(Ref).
+
+stop(Ref) ->
+ xmpp_stream_out:stop(Ref).
+
+-spec send(pid(), xmpp_element()) -> ok;
+ (state(), xmpp_element()) -> state().
+send(Stream, Pkt) ->
+ xmpp_stream_out:send(Stream, Pkt).
+
+-spec route(pid(), xmpp_element()) -> ok.
+route(Ref, Pkt) ->
+ Ref ! {route, Pkt}.
+
+-spec establish(state()) -> state().
+establish(State) ->
+ xmpp_stream_out:establish(State).
+
+-spec update_state(pid(), fun((state()) -> state()) |
+ {module(), atom(), list()}) -> ok.
+update_state(Ref, Callback) ->
+ xmpp_stream_out:cast(Ref, {update_state, Callback}).
+
+-spec add_hooks() -> ok.
+add_hooks() ->
+ lists:foreach(
+ fun(Host) ->
+ ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE,
+ process_auth_result, 100),
+ ejabberd_hooks:add(s2s_out_closed, Host, ?MODULE,
+ process_closed, 100),
+ ejabberd_hooks:add(s2s_out_handle_info, Host, ?MODULE,
+ handle_unexpected_info, 100),
+ ejabberd_hooks:add(s2s_out_handle_cast, Host, ?MODULE,
+ handle_unexpected_cast, 100)
+ end, ?MYHOSTS).
+
+%%%===================================================================
+%%% Hooks
+%%%===================================================================
+process_auth_result(#{server := LServer, remote_server := RServer} = State,
+ false) ->
+ Delay = get_delay(),
+ ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: authentication failed;"
+ " bouncing for ~p seconds",
+ [LServer, RServer, Delay]),
+ State1 = close(State),
+ State2 = bounce_queue(State1),
+ xmpp_stream_out:set_timeout(State2, timer:seconds(Delay));
+process_auth_result(State, true) ->
+ State.
+
+process_closed(#{server := LServer, remote_server := RServer} = State,
+ _Reason) ->
+ Delay = get_delay(),
+ ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s; "
+ "bouncing for ~p seconds",
+ [LServer, RServer,
+ try maps:get(stop_reason, State) of
+ {error, Why} -> xmpp_stream_out:format_error(Why)
+ catch _:undef -> <<"unexplained reason">>
end,
- case lists:foldl(fun ({Addr, Port}, Acc) ->
- case Acc of
- {ok, Socket} -> {ok, Socket};
- _ -> open_socket1(Addr, Port)
- end
- end,
- ?SOCKET_DEFAULT_RESULT, AddrList)
- of
- {ok, Socket} ->
- Version = if StateData#state.use_v10 -> {1,0};
- true -> undefined
- end,
- NewStateData = StateData#state{socket = Socket,
- tls_enabled = false,
- streamid = new_id()},
- send_header(NewStateData, Version),
- {next_state, wait_for_stream, NewStateData,
- ?FSMTIMEOUT};
- {error, Reason} ->
- ?INFO_MSG("s2s connection: ~s -> ~s (remote server "
- "not found: ~p)",
- [StateData#state.myname, StateData#state.server, Reason]),
- case ejabberd_hooks:run_fold(find_s2s_bridge, undefined,
- [StateData#state.myname,
- StateData#state.server])
- of
- {Mod, Fun, Type} ->
- ?INFO_MSG("found a bridge to ~s for: ~s -> ~s",
- [Type, StateData#state.myname,
- StateData#state.server]),
- NewStateData = StateData#state{bridge = {Mod, Fun}},
- {next_state, relay_to_bridge, NewStateData};
- _ -> wait_before_reconnect(StateData)
- end
- end;
-open_socket(Event, StateData) ->
- handle_unexpected_event(Event, open_socket, StateData).
-
-open_socket1({_, _, _, _} = Addr, Port) ->
- open_socket2(inet, Addr, Port);
-%% IPv6
-open_socket1({_, _, _, _, _, _, _, _} = Addr, Port) ->
- open_socket2(inet6, Addr, Port);
-%% Hostname
-open_socket1(Host, Port) ->
- lists:foldl(fun (_Family, {ok, _Socket} = R) -> R;
- (Family, _) ->
- Addrs = get_addrs(Host, Family),
- lists:foldl(fun (_Addr, {ok, _Socket} = R) -> R;
- (Addr, _) -> open_socket1(Addr, Port)
- end,
- ?SOCKET_DEFAULT_RESULT, Addrs)
- end,
- ?SOCKET_DEFAULT_RESULT, outgoing_s2s_families()).
-
-open_socket2(Type, Addr, Port) ->
- ?DEBUG("s2s_out: connecting to ~p:~p~n", [Addr, Port]),
- Timeout = outgoing_s2s_timeout(),
- case catch ejabberd_socket:connect(Addr, Port,
- [binary, {packet, 0},
- {send_timeout, ?TCP_SEND_TIMEOUT},
- {send_timeout_close, true},
- {active, false}, Type],
- Timeout)
- of
- {ok, _Socket} = R -> R;
- {error, Reason} = R ->
- ?DEBUG("s2s_out: connect return ~p~n", [Reason]), R;
- {'EXIT', Reason} ->
- ?DEBUG("s2s_out: connect crashed ~p~n", [Reason]),
- {error, Reason}
+ Delay]),
+ State1 = bounce_queue(State),
+ xmpp_stream_out:set_timeout(State1, timer:seconds(Delay)).
+
+handle_unexpected_info(State, Info) ->
+ ?WARNING_MSG("got unexpected info: ~p", [Info]),
+ State.
+
+handle_unexpected_cast(State, Msg) ->
+ ?WARNING_MSG("got unexpected cast: ~p", [Msg]),
+ State.
+
+%%%===================================================================
+%%% gen_server callbacks
+%%%===================================================================
+tls_options(#{server := LServer}) ->
+ ejabberd_s2s:tls_options(LServer, []).
+
+tls_required(#{server := LServer}) ->
+ ejabberd_s2s:tls_required(LServer).
+
+tls_verify(#{server := LServer}) ->
+ ejabberd_s2s:tls_verify(LServer).
+
+tls_enabled(#{server := LServer}) ->
+ ejabberd_s2s:tls_enabled(LServer).
+
+handle_auth_success(Mech, #{socket := Socket, ip := IP,
+ remote_server := RServer,
+ server := LServer} = State) ->
+ ?INFO_MSG("(~s) Accepted outbound s2s ~s authentication ~s -> ~s (~s)",
+ [ejabberd_socket:pp(Socket), Mech, LServer, RServer,
+ ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+ ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State, [true]).
+
+handle_auth_failure(Mech, Reason,
+ #{socket := Socket, ip := IP,
+ remote_server := RServer,
+ server := LServer} = State) ->
+ ?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s",
+ [ejabberd_socket:pp(Socket), Mech, LServer, RServer,
+ ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]),
+ State1 = State#{on_route => bounce,
+ stop_reason => {error, {auth, Reason}}},
+ ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State1, [false]).
+
+handle_packet(Pkt, #{server := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]).
+
+handle_stream_end(Reason, #{server := LServer} = State) ->
+ State1 = State#{on_route => bounce, stop_reason => Reason},
+ ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [normal]).
+
+handle_stream_close(Reason, #{server := LServer} = State) ->
+ State1 = State#{on_route => bounce, stop_reason => Reason},
+ ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [Reason]).
+
+handle_stream_established(State) ->
+ State1 = State#{on_route => send},
+ State2 = resend_queue(State1),
+ set_idle_timeout(State2).
+
+handle_cdata(Data, #{server := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_out_handle_cdata, LServer, State, [Data]).
+
+handle_recv(El, Pkt, #{server := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_out_handle_recv, LServer, State, [El, Pkt]).
+
+handle_send(Pkt, El, Data, #{server := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_out_handle_send, LServer,
+ State, [Pkt, El, Data]).
+
+handle_timeout(#{server := LServer, remote_server := RServer,
+ on_route := Action} = State) ->
+ case Action of
+ bounce -> stop(State);
+ queue -> send(State, xmpp:serr_connection_timeout());
+ send ->
+ ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: inactive",
+ [LServer, RServer]),
+ stop(State)
end.
-%%----------------------------------------------------------------------
-
-wait_for_stream({xmlstreamstart, Name, Attrs}, StateData0) ->
- {CertCheckRes, CertCheckMsg, StateData} =
- if StateData0#state.tls_certverify, StateData0#state.tls_enabled ->
- {Res, Msg} =
- ejabberd_s2s:check_peer_certificate(ejabberd_socket,
- StateData0#state.socket,
- StateData0#state.server),
- ?DEBUG("Certificate verification result for ~s: ~s",
- [StateData0#state.server, Msg]),
- {Res, Msg, StateData0#state{tls_certverify = false}};
- true ->
- {no_verify, <<"Not verified">>, StateData0}
- end,
- try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
- _ when CertCheckRes == error ->
- send_element(StateData,
- xmpp:serr_policy_violation(CertCheckMsg, ?MYLANG)),
- ?INFO_MSG("Closing s2s connection: ~s -> ~s (~s)",
- [StateData#state.myname, StateData#state.server,
- CertCheckMsg]),
- {stop, normal, StateData};
- #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM}
- when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM ->
- send_element(StateData, xmpp:serr_invalid_namespace()),
- {stop, normal, StateData};
- #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID,
- version = V} when V /= {1,0} ->
- send_db_request(StateData#state{remote_streamid = ID});
- #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID}
- when StateData#state.use_v10 ->
- {next_state, wait_for_features,
- StateData#state{remote_streamid = ID}, ?FSMTIMEOUT};
- #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID}
- when not StateData#state.use_v10 ->
- %% Handle Tigase's workaround for an old ejabberd bug:
- send_db_request(StateData#state{remote_streamid = ID});
- #stream_start{id = ID} when StateData#state.use_v10 ->
- {next_state, wait_for_features,
- StateData#state{db_enabled = false, remote_streamid = ID},
- ?FSMTIMEOUT};
- #stream_start{} ->
- send_element(StateData, xmpp:serr_invalid_namespace()),
- {stop, normal, StateData};
- _ ->
- send_element(StateData, xmpp:serr_invalid_xml()),
- {stop, normal, StateData}
- catch _:{xmpp_codec, Why} ->
- Txt = xmpp:format_error(Why),
- send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)),
- {stop, normal, StateData}
- end;
-wait_for_stream(Event, StateData) ->
- handle_unexpected_event(Event, wait_for_stream, StateData).
-
-wait_for_validation({xmlstreamelement, El}, StateData) ->
- decode_element(El, wait_for_validation, StateData);
-wait_for_validation(#db_result{to = To, from = From, type = Type}, StateData) ->
- ?DEBUG("recv result: ~p", [{From, To, Type}]),
- case {Type, StateData#state.tls_enabled, StateData#state.tls_required} of
- {valid, Enabled, Required} when (Enabled == true) or (Required == false) ->
- send_queue(StateData, StateData#state.queue),
- ?INFO_MSG("Connection established: ~s -> ~s with "
- "TLS=~p",
- [StateData#state.myname, StateData#state.server,
- StateData#state.tls_enabled]),
- ejabberd_hooks:run(s2s_connect_hook,
- [StateData#state.myname,
- StateData#state.server]),
- {next_state, stream_established, StateData#state{queue = queue:new()}};
- {valid, Enabled, Required} when (Enabled == false) and (Required == true) ->
- ?INFO_MSG("Closing s2s connection: ~s -> ~s (TLS "
- "is required but unavailable)",
- [StateData#state.myname, StateData#state.server]),
- {stop, normal, StateData};
- _ ->
- ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid "
- "dialback key result)",
- [StateData#state.myname, StateData#state.server]),
- {stop, normal, StateData}
- end;
-wait_for_validation(#db_verify{to = To, from = From, id = Id, type = Type},
- StateData) ->
- ?DEBUG("recv verify: ~p", [{From, To, Id, Type}]),
- case StateData#state.verify of
- false ->
- NextState = wait_for_validation,
- {next_state, NextState, StateData, get_timeout_interval(NextState)};
- {Pid, _Key, _SID} ->
- case Type of
- valid ->
- p1_fsm:send_event(Pid,
- {valid, StateData#state.server,
- StateData#state.myname});
- _ ->
- p1_fsm:send_event(Pid,
- {invalid, StateData#state.server,
- StateData#state.myname})
- end,
- if StateData#state.verify == false ->
- {stop, normal, StateData};
- true ->
- NextState = wait_for_validation,
- {next_state, NextState, StateData, get_timeout_interval(NextState)}
- end
- end;
-wait_for_validation(timeout,
- #state{verify = {VPid, VKey, SID}} = StateData)
- when is_pid(VPid) and is_binary(VKey) and is_binary(SID) ->
- ?DEBUG("wait_for_validation: ~s -> ~s (timeout in verify connection)",
- [StateData#state.myname, StateData#state.server]),
- {stop, normal, StateData};
-wait_for_validation(Event, StateData) ->
- handle_unexpected_event(Event, wait_for_validation, StateData).
-
-wait_for_features({xmlstreamelement, El}, StateData) ->
- decode_element(El, wait_for_features, StateData);
-wait_for_features(#stream_features{sub_els = Els}, StateData) ->
- {SASLEXT, StartTLS, StartTLSRequired} =
- lists:foldl(
- fun(#sasl_mechanisms{list = Mechs}, {_, STLS, STLSReq}) ->
- {lists:member(<<"EXTERNAL">>, Mechs), STLS, STLSReq};
- (#starttls{required = Required}, {SEXT, _, _}) ->
- {SEXT, true, Required};
- (_, Acc) ->
- Acc
- end, {false, false, false}, Els),
- if not SASLEXT and not StartTLS and StateData#state.authenticated ->
- send_queue(StateData, StateData#state.queue),
- ?INFO_MSG("Connection established: ~s -> ~s with "
- "SASL EXTERNAL and TLS=~p",
- [StateData#state.myname, StateData#state.server,
- StateData#state.tls_enabled]),
- ejabberd_hooks:run(s2s_connect_hook,
- [StateData#state.myname,
- StateData#state.server]),
- {next_state, stream_established,
- StateData#state{queue = queue:new()}};
- SASLEXT and StateData#state.try_auth and
- (StateData#state.new /= false) and
- (StateData#state.tls_enabled or
- not StateData#state.tls_required) ->
- send_element(StateData,
- #sasl_auth{mechanism = <<"EXTERNAL">>,
- text = StateData#state.myname}),
- {next_state, wait_for_auth_result,
- StateData#state{try_auth = false}, ?FSMTIMEOUT};
- StartTLS and StateData#state.tls and
- not StateData#state.tls_enabled ->
- send_element(StateData, #starttls{}),
- {next_state, wait_for_starttls_proceed, StateData, ?FSMTIMEOUT};
- StartTLSRequired and not StateData#state.tls ->
- ?DEBUG("restarted: ~p",
- [{StateData#state.myname, StateData#state.server}]),
- ejabberd_socket:close(StateData#state.socket),
- {next_state, reopen_socket,
- StateData#state{socket = undefined, use_v10 = false},
- ?FSMTIMEOUT};
- StateData#state.db_enabled ->
- send_db_request(StateData);
- true ->
- ?DEBUG("restarted: ~p",
- [{StateData#state.myname, StateData#state.server}]),
- ejabberd_socket:close(StateData#state.socket),
- {next_state, reopen_socket,
- StateData#state{socket = undefined, use_v10 = false}, ?FSMTIMEOUT}
+init([#{server := LServer, remote_server := RServer} = State, Opts]) ->
+ State1 = State#{on_route => queue,
+ queue => queue:new(),
+ xmlns => ?NS_SERVER,
+ lang => ?MYLANG,
+ shaper => none},
+ ?INFO_MSG("Outbound s2s connection started: ~s -> ~s",
+ [LServer, RServer]),
+ ejabberd_hooks:run_fold(s2s_out_init, LServer, {ok, State1}, [Opts]).
+
+handle_call(Request, From, #{server := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_out_handle_call, LServer, State, [Request, From]).
+
+handle_cast({update_state, Fun}, State) ->
+ case Fun of
+ {M, F, A} -> erlang:apply(M, F, [State|A]);
+ _ when is_function(Fun) -> Fun(State)
end;
-wait_for_features(Event, StateData) ->
- handle_unexpected_event(Event, wait_for_features, StateData).
-
-wait_for_auth_result({xmlstreamelement, El}, StateData) ->
- decode_element(El, wait_for_auth_result, StateData);
-wait_for_auth_result(#sasl_success{}, StateData) ->
- ?DEBUG("auth: ~p", [{StateData#state.myname, StateData#state.server}]),
- ejabberd_socket:reset_stream(StateData#state.socket),
- send_header(StateData, {1,0}),
- {next_state, wait_for_stream,
- StateData#state{streamid = new_id(), authenticated = true},
- ?FSMTIMEOUT};
-wait_for_auth_result(#sasl_failure{}, StateData) ->
- ?DEBUG("restarted: ~p", [{StateData#state.myname, StateData#state.server}]),
- ejabberd_socket:close(StateData#state.socket),
- {next_state, reopen_socket,
- StateData#state{socket = undefined}, ?FSMTIMEOUT};
-wait_for_auth_result(Event, StateData) ->
- handle_unexpected_event(Event, wait_for_auth_result, StateData).
-
-wait_for_starttls_proceed({xmlstreamelement, El}, StateData) ->
- decode_element(El, wait_for_starttls_proceed, StateData);
-wait_for_starttls_proceed(#starttls_proceed{}, StateData) ->
- ?DEBUG("starttls: ~p", [{StateData#state.myname, StateData#state.server}]),
- Socket = StateData#state.socket,
- TLSOpts = case ejabberd_config:get_option(
- {domain_certfile, StateData#state.myname},
- fun iolist_to_binary/1) of
- undefined -> StateData#state.tls_options;
- CertFile ->
- [{certfile, CertFile}
- | lists:keydelete(certfile, 1,
- StateData#state.tls_options)]
- end,
- TLSSocket = ejabberd_socket:starttls(Socket, TLSOpts),
- NewStateData = StateData#state{socket = TLSSocket,
- streamid = new_id(),
- tls_enabled = true,
- tls_options = TLSOpts},
- send_header(NewStateData, {1,0}),
- {next_state, wait_for_stream, NewStateData, ?FSMTIMEOUT};
-wait_for_starttls_proceed(Event, StateData) ->
- handle_unexpected_event(Event, wait_for_starttls_proceed, StateData).
-
-reopen_socket({xmlstreamelement, _El}, StateData) ->
- {next_state, reopen_socket, StateData, ?FSMTIMEOUT};
-reopen_socket({xmlstreamend, _Name}, StateData) ->
- {next_state, reopen_socket, StateData, ?FSMTIMEOUT};
-reopen_socket({xmlstreamerror, _}, StateData) ->
- {next_state, reopen_socket, StateData, ?FSMTIMEOUT};
-reopen_socket(timeout, StateData) ->
- ?INFO_MSG("reopen socket: timeout", []),
- {stop, normal, StateData};
-reopen_socket(closed, StateData) ->
- p1_fsm:send_event(self(), init),
- {next_state, open_socket, StateData, ?FSMTIMEOUT}.
-
-%% This state is use to avoid reconnecting to often to bad sockets
-wait_before_retry(_Event, StateData) ->
- {next_state, wait_before_retry, StateData, ?FSMTIMEOUT}.
-
-relay_to_bridge(stop, StateData) ->
- wait_before_reconnect(StateData);
-relay_to_bridge(closed, StateData) ->
- ?INFO_MSG("relay to bridge: ~s -> ~s (closed)",
- [StateData#state.myname, StateData#state.server]),
- {stop, normal, StateData};
-relay_to_bridge(_Event, StateData) ->
- {next_state, relay_to_bridge, StateData}.
-
-stream_established({xmlstreamelement, El}, StateData) ->
- decode_element(El, stream_established, StateData);
-stream_established(#db_verify{to = VTo, from = VFrom, id = VId, type = VType},
- StateData) ->
- ?DEBUG("recv verify: ~p", [{VFrom, VTo, VId, VType}]),
- case StateData#state.verify of
- {VPid, _VKey, _SID} ->
- case VType of
- valid ->
- p1_fsm:send_event(VPid,
- {valid, StateData#state.server,
- StateData#state.myname});
- _ ->
- p1_fsm:send_event(VPid,
- {invalid, StateData#state.server,
- StateData#state.myname})
- end;
- _ -> ok
- end,
- {next_state, stream_established, StateData};
-stream_established(Event, StateData) ->
- handle_unexpected_event(Event, stream_established, StateData).
-
--spec handle_unexpected_event(term(), state_name(), state()) -> fsm_transition().
-handle_unexpected_event(Event, StateName, StateData) ->
- case Event of
- {xmlstreamerror, _} ->
- send_element(StateData, xmpp:serr_not_well_formed()),
- ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
- "got invalid XML from peer",
- [StateData#state.myname, StateData#state.server,
- StateName]),
- {stop, normal, StateData};
- {xmlstreamend, _} ->
- ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
- "XML stream closed by peer",
- [StateData#state.myname, StateData#state.server,
- StateName]),
- {stop, normal, StateData};
- timeout ->
- send_element(StateData, xmpp:serr_connection_timeout()),
- ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
- "timed out during establishing an XML stream",
- [StateData#state.myname, StateData#state.server,
- StateName]),
- {stop, normal, StateData};
- closed ->
- ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
- "connection socket closed",
- [StateData#state.myname, StateData#state.server,
- StateName]),
- {stop, normal, StateData};
- Pkt when StateName == wait_for_stream;
- StateName == wait_for_features;
- StateName == wait_for_auth_result;
- StateName == wait_for_starttls_proceed ->
- send_element(StateData, xmpp:serr_bad_format()),
- ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
- "got unexpected event ~p",
- [StateData#state.myname, StateData#state.server,
- StateName, Pkt]),
- {stop, normal, StateData};
- _ ->
- {next_state, StateName, StateData, get_timeout_interval(StateName)}
- end.
-
-%%----------------------------------------------------------------------
-%% Func: StateName/3
-%% Returns: {next_state, NextStateName, NextStateData} |
-%% {next_state, NextStateName, NextStateData, Timeout} |
-%% {reply, Reply, NextStateName, NextStateData} |
-%% {reply, Reply, NextStateName, NextStateData, Timeout} |
-%% {stop, Reason, NewStateData} |
-%% {stop, Reason, Reply, NewStateData}
-%%----------------------------------------------------------------------
-%%state_name(Event, From, StateData) ->
-%% Reply = ok,
-%% {reply, Reply, state_name, StateData}.
-
-handle_event(_Event, StateName, StateData) ->
- {next_state, StateName, StateData,
- get_timeout_interval(StateName)}.
-
-handle_sync_event(get_state_infos, _From, StateName,
- StateData) ->
- {Addr, Port} = try
- ejabberd_socket:peername(StateData#state.socket)
- of
- {ok, {A, P}} -> {A, P};
- {error, _} -> {unknown, unknown}
- catch
- _:_ -> {unknown, unknown}
- end,
- Infos = [{direction, out}, {statename, StateName},
- {addr, Addr}, {port, Port},
- {streamid, StateData#state.streamid},
- {use_v10, StateData#state.use_v10},
- {tls, StateData#state.tls},
- {tls_required, StateData#state.tls_required},
- {tls_enabled, StateData#state.tls_enabled},
- {tls_options, StateData#state.tls_options},
- {authenticated, StateData#state.authenticated},
- {db_enabled, StateData#state.db_enabled},
- {try_auth, StateData#state.try_auth},
- {myname, StateData#state.myname},
- {server, StateData#state.server},
- {delay_to_retry, StateData#state.delay_to_retry},
- {verify, StateData#state.verify}],
- Reply = {state_infos, Infos},
- {reply, Reply, StateName, StateData};
-%%----------------------------------------------------------------------
-%% Func: handle_sync_event/4
-%% Returns: {next_state, NextStateName, NextStateData} |
-%% {next_state, NextStateName, NextStateData, Timeout} |
-%% {reply, Reply, NextStateName, NextStateData} |
-%% {reply, Reply, NextStateName, NextStateData, Timeout} |
-%% {stop, Reason, NewStateData} |
-%% {stop, Reason, Reply, NewStateData}
-%%----------------------------------------------------------------------
-handle_sync_event(_Event, _From, StateName,
- StateData) ->
- Reply = ok,
- {reply, Reply, StateName, StateData,
- get_timeout_interval(StateName)}.
-
-code_change(_OldVsn, StateName, StateData, _Extra) ->
- {ok, StateName, StateData}.
-
-handle_info({send_text, Text}, StateName, StateData) ->
- send_text(StateData, Text),
- cancel_timer(StateData#state.timer),
- Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
- {next_state, StateName, StateData#state{timer = Timer},
- get_timeout_interval(StateName)};
-handle_info({send_element, El}, StateName, StateData) ->
- case StateName of
- stream_established ->
- cancel_timer(StateData#state.timer),
- Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
- send_element(StateData, El),
- {next_state, StateName, StateData#state{timer = Timer}};
- %% In this state we bounce all message: We are waiting before
- %% trying to reconnect
- wait_before_retry ->
- bounce_element(El, xmpp:err_remote_server_not_found()),
- {next_state, StateName, StateData};
- relay_to_bridge ->
- {Mod, Fun} = StateData#state.bridge,
- ?DEBUG("relaying stanza via ~p:~p/1", [Mod, Fun]),
- case catch Mod:Fun(El) of
- {'EXIT', Reason} ->
- ?ERROR_MSG("Error while relaying to bridge: ~p",
- [Reason]),
- bounce_element(El, xmpp:err_internal_server_error()),
- wait_before_reconnect(StateData);
- _ -> {next_state, StateName, StateData}
- end;
- _ ->
- Q = queue:in(El, StateData#state.queue),
- {next_state, StateName, StateData#state{queue = Q},
- get_timeout_interval(StateName)}
+handle_cast(Msg, #{server := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_out_handle_cast, LServer, State, [Msg]).
+
+handle_info({route, Pkt}, #{queue := Q, on_route := Action} = State) ->
+ case Action of
+ queue -> State#{queue => queue:in(Pkt, Q)};
+ bounce -> bounce_packet(Pkt, State);
+ send -> set_idle_timeout(send(State, Pkt))
end;
-handle_info({timeout, Timer, _}, wait_before_retry,
- #state{timer = Timer} = StateData) ->
- ?INFO_MSG("Reconnect delay expired: Will now retry "
- "to connect to ~s when needed.",
- [StateData#state.server]),
- {stop, normal, StateData};
-handle_info({timeout, Timer, _}, _StateName,
- #state{timer = Timer} = StateData) ->
- ?INFO_MSG("Closing connection with ~s: timeout",
- [StateData#state.server]),
- {stop, normal, StateData};
-handle_info(terminate_if_waiting_before_retry,
- wait_before_retry, StateData) ->
- {stop, normal, StateData};
-handle_info(terminate_if_waiting_before_retry,
- StateName, StateData) ->
- {next_state, StateName, StateData,
- get_timeout_interval(StateName)};
-handle_info(_, StateName, StateData) ->
- {next_state, StateName, StateData,
- get_timeout_interval(StateName)}.
-
-terminate(Reason, StateName, StateData) ->
- ?DEBUG("terminated: ~p", [{Reason, StateName}]),
- case StateData#state.new of
- false -> ok;
- true ->
- ejabberd_s2s:remove_connection({StateData#state.myname,
- StateData#state.server},
- self())
- end,
- bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()),
- bounce_messages(xmpp:err_remote_server_not_found()),
- case StateData#state.socket of
- undefined -> ok;
- _Socket ->
- catch send_trailer(StateData),
- ejabberd_socket:close(StateData#state.socket)
- end,
- ok.
-
-print_state(State) -> State.
-
-%%%----------------------------------------------------------------------
+handle_info(Info, #{server := LServer} = State) ->
+ ejabberd_hooks:run_fold(s2s_out_handle_info, LServer, State, [Info]).
+
+terminate(Reason, #{server := LServer,
+ remote_server := RServer} = State) ->
+ ejabberd_s2s:remove_connection({LServer, RServer}, self()),
+ State1 = case Reason of
+ normal -> State;
+ _ -> State#{stop_reason => {error, internal_failure}}
+ end,
+ bounce_queue(State1),
+ bounce_message_queue(State1).
+
+code_change(_OldVsn, State, _Extra) ->
+ {ok, State}.
+
+%%%===================================================================
%%% Internal functions
-%%%----------------------------------------------------------------------
-
--spec send_text(state(), iodata()) -> ok.
-send_text(StateData, Text) ->
- ?DEBUG("Send Text on stream = ~s", [Text]),
- ejabberd_socket:send(StateData#state.socket, Text).
-
--spec send_element(state(), xmpp_element()) -> ok.
-send_element(StateData, El) ->
- El1 = xmpp:encode(El, ?NS_SERVER),
- send_text(StateData, fxml:element_to_binary(El1)).
-
--spec send_header(state(), undefined | {integer(), integer()}) -> ok.
-send_header(StateData, Version) ->
- Header = xmpp:encode(
- #stream_start{xmlns = ?NS_SERVER,
- stream_xmlns = ?NS_STREAM,
- db_xmlns = ?NS_SERVER_DIALBACK,
- from = jid:make(StateData#state.myname),
- to = jid:make(StateData#state.server),
- version = Version}),
- send_text(StateData, fxml:element_to_header(Header)).
-
--spec send_trailer(state()) -> ok.
-send_trailer(StateData) ->
- send_text(StateData, <<"</stream:stream>">>).
-
--spec send_queue(state(), queue:queue()) -> ok.
-send_queue(StateData, Q) ->
- case queue:out(Q) of
- {{value, El}, Q1} ->
- send_element(StateData, El), send_queue(StateData, Q1);
- {empty, _Q1} -> ok
- end.
-
-%% Bounce a single message (xmlelement)
--spec bounce_element(stanza(), stanza_error()) -> ok.
-bounce_element(El, Error) ->
- From = xmpp:get_from(El),
- To = xmpp:get_to(El),
- ejabberd_router:route_error(To, From, El, Error).
-
--spec bounce_queue(queue:queue(), stanza_error()) -> ok.
-bounce_queue(Q, Error) ->
- case queue:out(Q) of
- {{value, El}, Q1} ->
- bounce_element(El, Error), bounce_queue(Q1, Error);
- {empty, _} -> ok
- end.
-
--spec new_id() -> binary().
-new_id() -> randoms:get_string().
-
--spec cancel_timer(reference()) -> ok.
-cancel_timer(Timer) ->
- erlang:cancel_timer(Timer),
- receive {timeout, Timer, _} -> ok after 0 -> ok end.
-
--spec bounce_messages(stanza_error()) -> ok.
-bounce_messages(Error) ->
+%%%===================================================================
+-spec resend_queue(state()) -> state().
+resend_queue(#{queue := Q} = State) ->
+ State1 = State#{queue => queue:new()},
+ jlib:queue_foldl(
+ fun(Pkt, AccState) ->
+ send(AccState, Pkt)
+ end, State1, Q).
+
+-spec bounce_queue(state()) -> state().
+bounce_queue(#{queue := Q} = State) ->
+ State1 = State#{queue => queue:new()},
+ jlib:queue_foldl(
+ fun(Pkt, AccState) ->
+ bounce_packet(Pkt, AccState)
+ end, State1, Q).
+
+-spec bounce_message_queue(state()) -> state().
+bounce_message_queue(State) ->
receive
- {send_element, El} ->
- bounce_element(El, Error), bounce_messages(Error)
- after 0 -> ok
- end.
-
--spec send_db_request(state()) -> fsm_transition().
-send_db_request(StateData) ->
- Server = StateData#state.server,
- New = case StateData#state.new of
- false ->
- ejabberd_s2s:try_register({StateData#state.myname, Server});
- true ->
- true
- end,
- NewStateData = StateData#state{new = New},
- try case New of
- false -> ok;
- true ->
- Key1 = ejabberd_s2s:make_key(
- {StateData#state.myname, Server},
- StateData#state.remote_streamid),
- send_element(StateData,
- #db_result{from = StateData#state.myname,
- to = Server,
- key = Key1})
- end,
- case StateData#state.verify of
- false -> ok;
- {_Pid, Key2, SID} ->
- send_element(StateData,
- #db_verify{from = StateData#state.myname,
- to = StateData#state.server,
- id = SID,
- key = Key2})
- end,
- {next_state, wait_for_validation, NewStateData,
- (?FSMTIMEOUT) * 6}
- catch
- _:_ -> {stop, normal, NewStateData}
+ {route, Pkt} ->
+ State1 = bounce_packet(Pkt, State),
+ bounce_message_queue(State1)
+ after 0 ->
+ State
end.
-%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
-%% SRV support
-
--include_lib("kernel/include/inet.hrl").
-
--spec get_addr_port(binary()) -> [{binary(), inet:port_number()}].
-get_addr_port(Server) ->
- Res = srv_lookup(Server),
- case Res of
- {error, Reason} ->
- ?DEBUG("srv lookup of '~s' failed: ~p~n",
- [Server, Reason]),
- [{Server, outgoing_s2s_port()}];
- {ok, HEnt} ->
- ?DEBUG("srv lookup of '~s': ~p~n",
- [Server, HEnt#hostent.h_addr_list]),
- AddrList = HEnt#hostent.h_addr_list,
- case catch lists:map(fun ({Priority, Weight, Port,
- Host}) ->
- N = case Weight of
- 0 -> 0;
- _ ->
- (Weight + 1) * randoms:uniform()
- end,
- {Priority * 65536 - N, Host, Port}
- end,
- AddrList)
- of
- SortedList = [_ | _] ->
- List = lists:map(fun ({_, Host, Port}) ->
- {list_to_binary(Host), Port}
- end,
- lists:keysort(1, SortedList)),
- ?DEBUG("srv lookup of '~s': ~p~n", [Server, List]),
- List;
- _ -> [{Server, outgoing_s2s_port()}]
- end
- end.
-
-srv_lookup(Server) ->
- TimeoutMs = timer:seconds(
- ejabberd_config:get_option(
- s2s_dns_timeout,
- fun(I) when is_integer(I), I>=0 -> I end,
- 10)),
- Retries = ejabberd_config:get_option(
- s2s_dns_retries,
- fun(I) when is_integer(I), I>=0 -> I end,
- 2),
- srv_lookup(binary_to_list(Server), TimeoutMs, Retries).
-
-%% XXX - this behaviour is suboptimal in the case that the domain
-%% has a "_xmpp-server._tcp." but not a "_jabber._tcp." record and
-%% we don't get a DNS reply for the "_xmpp-server._tcp." lookup. In this
-%% case we'll give up when we get the "_jabber._tcp." nxdomain reply.
-srv_lookup(_Server, _Timeout, Retries)
- when Retries < 1 ->
- {error, timeout};
-srv_lookup(Server, Timeout, Retries) ->
- case inet_res:getbyname("_xmpp-server._tcp." ++ Server,
- srv, Timeout)
- of
- {error, _Reason} ->
- case inet_res:getbyname("_jabber._tcp." ++ Server, srv,
- Timeout)
- of
- {error, timeout} ->
- ?ERROR_MSG("The DNS servers~n ~p~ntimed out on "
- "request for ~p IN SRV. You should check "
- "your DNS configuration.",
- [inet_db:res_option(nameserver), Server]),
- srv_lookup(Server, Timeout, Retries - 1);
- R -> R
- end;
- {ok, _HEnt} = R -> R
- end.
-
-test_get_addr_port(Server) ->
- lists:foldl(fun (_, Acc) ->
- [HostPort | _] = get_addr_port(Server),
- case lists:keysearch(HostPort, 1, Acc) of
- false -> [{HostPort, 1} | Acc];
- {value, {_, Num}} ->
- lists:keyreplace(HostPort, 1, Acc,
- {HostPort, Num + 1})
- end
- end,
- [], lists:seq(1, 100000)).
-
-get_addrs(Host, Family) ->
- Type = case Family of
- inet4 -> inet;
- ipv4 -> inet;
- inet6 -> inet6;
- ipv6 -> inet6
- end,
- case inet:gethostbyname(binary_to_list(Host), Type) of
- {ok, #hostent{h_addr_list = Addrs}} ->
- ?DEBUG("~s of ~s resolved to: ~p~n",
- [Type, Host, Addrs]),
- Addrs;
- {error, Reason} ->
- ?DEBUG("~s lookup of '~s' failed: ~p~n",
- [Type, Host, Reason]),
- []
+-spec bounce_packet(xmpp_element(), state()) -> state().
+bounce_packet(Pkt, State) when ?is_stanza(Pkt) ->
+ From = xmpp:get_from(Pkt),
+ To = xmpp:get_to(Pkt),
+ Lang = xmpp:get_lang(Pkt),
+ Err = mk_bounce_error(Lang, State),
+ ejabberd_router:route_error(To, From, Pkt, Err),
+ State;
+bounce_packet(_, State) ->
+ State.
+
+-spec mk_bounce_error(binary(), state()) -> stanza_error().
+mk_bounce_error(Lang, State) ->
+ try maps:get(stop_reason, State) of
+ {error, internal_failure} ->
+ xmpp:err_internal_server_error();
+ {error, Why} ->
+ Reason = xmpp_stream_out:format_error(Why),
+ case Why of
+ {dns, _} ->
+ xmpp:err_remote_server_timeout(Reason, Lang);
+ _ ->
+ xmpp:err_remote_server_not_found(Reason, Lang)
+ end
+ catch _:{badkey, _} ->
+ xmpp:err_remote_server_not_found()
end.
--spec outgoing_s2s_port() -> pos_integer().
-outgoing_s2s_port() ->
- ejabberd_config:get_option(
- outgoing_s2s_port,
- fun(I) when is_integer(I), I > 0, I =< 65536 -> I end,
- 5269).
-
--spec outgoing_s2s_families() -> [ipv4 | ipv6].
-outgoing_s2s_families() ->
- ejabberd_config:get_option(
- outgoing_s2s_families,
- fun(Families) ->
- true = lists:all(
- fun(ipv4) -> true;
- (ipv6) -> true
- end, Families),
- Families
- end, [ipv4, ipv6]).
-
--spec outgoing_s2s_timeout() -> pos_integer().
-outgoing_s2s_timeout() ->
- ejabberd_config:get_option(
- outgoing_s2s_timeout,
- fun(TimeOut) when is_integer(TimeOut), TimeOut > 0 ->
- TimeOut;
- (infinity) ->
- infinity
- end, 10000).
+-spec get_delay() -> non_neg_integer().
+get_delay() ->
+ MaxDelay = ejabberd_config:get_option(
+ s2s_max_retry_delay,
+ fun(I) when is_integer(I), I > 0 -> I end,
+ 300),
+ crypto:rand_uniform(0, MaxDelay).
+
+-spec set_idle_timeout(state()) -> state().
+set_idle_timeout(#{on_route := send, server := LServer} = State) ->
+ Timeout = ejabberd_s2s:get_idle_timeout(LServer),
+ xmpp_stream_out:set_timeout(State, Timeout);
+set_idle_timeout(State) ->
+ State.
transform_options(Opts) ->
lists:foldl(fun transform_options/2, [], Opts).
transform_options(Opt, Opts) ->
[Opt|Opts].
-%% Human readable S2S logging: Log only new outgoing connections as INFO
-%% Do not log dialback
-log_s2s_out(false, _, _, _) -> ok;
-%% Log new outgoing connections:
-log_s2s_out(_, Myname, Server, Tls) ->
- ?INFO_MSG("Trying to open s2s connection: ~s -> "
- "~s with TLS=~p",
- [Myname, Server, Tls]).
-
-%% Calculate timeout depending on which state we are in:
-%% Can return integer > 0 | infinity
--spec get_timeout_interval(state_name()) -> pos_integer() | infinity.
-get_timeout_interval(StateName) ->
- case StateName of
- %% Validation implies dialback: Networking can take longer:
- wait_for_validation -> (?FSMTIMEOUT) * 6;
- %% When stream is established, we only rely on S2S Timeout timer:
- stream_established -> infinity;
- relay_to_bridge -> infinity;
- open_socket -> infinity;
- _ -> ?FSMTIMEOUT
- end.
-
-%% This function is intended to be called at the end of a state
-%% function that want to wait for a reconnect delay before stopping.
--spec wait_before_reconnect(state()) -> fsm_next().
-wait_before_reconnect(StateData) ->
- bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()),
- bounce_messages(xmpp:err_remote_server_not_found()),
- cancel_timer(StateData#state.timer),
- Delay = case StateData#state.delay_to_retry of
- undefined_delay ->
- {_, _, MicroSecs} = p1_time_compat:timestamp(), MicroSecs rem 14000 + 1000;
- D1 -> lists:min([D1 * 2, get_max_retry_delay()])
- end,
- Timer = erlang:start_timer(Delay, self(), []),
- {next_state, wait_before_retry,
- StateData#state{timer = Timer, delay_to_retry = Delay,
- queue = queue:new()}}.
-
--spec get_max_retry_delay() -> pos_integer().
-get_max_retry_delay() ->
- case ejabberd_config:get_option(
- s2s_max_retry_delay,
- fun(I) when is_integer(I), I > 0 -> I end) of
- undefined -> ?MAX_RETRY_DELAY;
- Seconds -> Seconds * 1000
- end.
-
-%% Terminate s2s_out connections that are in state wait_before_retry
--spec terminate_if_waiting_delay(binary(), binary()) -> ok.
-terminate_if_waiting_delay(From, To) ->
- FromTo = {From, To},
- Pids = ejabberd_s2s:get_connections_pids(FromTo),
- lists:foreach(fun (Pid) ->
- Pid ! terminate_if_waiting_before_retry
- end,
- Pids).
-
--spec fsm_limit_opts() -> [{max_queue, pos_integer()}].
-fsm_limit_opts() ->
- case ejabberd_config:get_option(
- max_fsm_queue,
- fun(I) when is_integer(I), I > 0 -> I end) of
- undefined -> [];
- N -> [{max_queue, N}]
- end.
-
--spec decode_element(xmlel(), state_name(), state()) -> fsm_next().
-decode_element(#xmlel{} = El, StateName, StateData) ->
- Opts = if StateName == stream_established ->
- [ignore_els];
- true ->
- []
- end,
- try xmpp:decode(El, ?NS_SERVER, Opts) of
- Pkt -> ?MODULE:StateName(Pkt, StateData)
- catch error:{xmpp_codec, Why} ->
- Type = xmpp:get_type(El),
- case xmpp:is_stanza(El) of
- true when Type /= <<"result">>, Type /= <<"error">> ->
- Lang = xmpp:get_lang(El),
- Txt = xmpp:format_error(Why),
- Err = xmpp:make_error(El, xmpp:err_bad_request(Txt, Lang)),
- send_element(StateData, Err);
- false ->
- ok
- end,
- {next_state, StateName, StateData, get_timeout_interval(StateName)}
- end.
-
-opt_type(domain_certfile) -> fun iolist_to_binary/1;
-opt_type(max_fsm_queue) ->
- fun (I) when is_integer(I), I > 0 -> I end;
opt_type(outgoing_s2s_families) ->
fun (Families) ->
true = lists:all(fun (ipv4) -> true;
TimeOut;
(infinity) -> infinity
end;
-opt_type(s2s_certfile) -> fun iolist_to_binary/1;
-opt_type(s2s_ciphers) -> fun iolist_to_binary/1;
-opt_type(s2s_dhfile) -> fun iolist_to_binary/1;
opt_type(s2s_dns_retries) ->
fun (I) when is_integer(I), I >= 0 -> I end;
opt_type(s2s_dns_timeout) ->
fun (I) when is_integer(I), I >= 0 -> I end;
opt_type(s2s_max_retry_delay) ->
fun (I) when is_integer(I), I > 0 -> I end;
-opt_type(s2s_protocol_options) ->
- fun (Options) ->
- [_ | O] = lists:foldl(fun (X, Acc) -> X ++ Acc end, [],
- [["|" | binary_to_list(Opt)]
- || Opt <- Options, is_binary(Opt)]),
- iolist_to_binary(O)
- end;
-opt_type(s2s_tls_compression) ->
- fun (true) -> true;
- (false) -> false
- end;
-opt_type(s2s_use_starttls) ->
- fun (true) -> true;
- (false) -> false;
- (optional) -> optional;
- (required) -> required;
- (required_trusted) -> required_trusted
- end;
opt_type(_) ->
- [domain_certfile, max_fsm_queue, outgoing_s2s_families,
- outgoing_s2s_port, outgoing_s2s_timeout, s2s_certfile,
- s2s_ciphers, s2s_dhfile, s2s_dns_retries, s2s_dns_timeout,
- s2s_max_retry_delay, s2s_protocol_options,
- s2s_tls_compression, s2s_use_starttls].
+ [outgoing_s2s_families, outgoing_s2s_port, outgoing_s2s_timeout,
+ s2s_dns_retries, s2s_dns_timeout, s2s_max_retry_delay].
-module(ejabberd_service).
-behaviour(xmpp_stream_in).
-behaviour(ejabberd_config).
+-behaviour(ejabberd_socket).
-protocol({xep, 114, '1.6'}).
%% ejabberd_socket callbacks
--export([start/2, socket_type/0]).
+-export([start/2, start_link/2, socket_type/0]).
%% ejabberd_config callbacks
-export([opt_type/1, transform_listen_option/2]).
%% xmpp_stream_in callbacks
--export([init/1, handle_call/3, handle_cast/2, handle_info/2,
- terminate/2, code_change/3]).
--export([handshake/2, handle_stream_start/1, handle_authenticated_packet/2]).
+-export([init/1, handle_info/2, terminate/2, code_change/3]).
+-export([handle_stream_start/2, handle_auth_success/4, handle_auth_failure/4,
+ handle_authenticated_packet/2, get_password_fun/1]).
%% API
-export([send/2]).
-include("xmpp.hrl").
-include("logger.hrl").
-%%-define(DBGFSM, true).
--ifdef(DBGFSM).
--define(FSMOPTS, [{debug, [trace]}]).
--else.
--define(FSMOPTS, []).
--endif.
-
-type state() :: map().
--type next_state() :: {noreply, state()} | {stop, term(), state()}.
--export_type([state/0, next_state/0]).
+-export_type([state/0]).
%%%===================================================================
%%% API
%%%===================================================================
start(SockData, Opts) ->
xmpp_stream_in:start(?MODULE, [SockData, Opts],
- fsm_limit_opts(Opts) ++ ?FSMOPTS).
+ ejabberd_config:fsm_limit_opts(Opts)).
+
+start_link(SockData, Opts) ->
+ xmpp_stream_in:start_link(?MODULE, [SockData, Opts],
+ ejabberd_config:fsm_limit_opts(Opts)).
socket_type() ->
xml_stream.
--spec send(state(), xmpp_element()) -> next_state().
-send(State, Pkt) ->
- xmpp_stream_in:send(State, Pkt).
+-spec send(pid(), xmpp_element()) -> ok;
+ (state(), xmpp_element()) -> state().
+send(Stream, Pkt) ->
+ xmpp_stream_in:send(Stream, Pkt).
%%%===================================================================
%%% xmpp_stream_in callbacks
%%%===================================================================
-init([#{socket := Socket} = State, Opts]) ->
- ?INFO_MSG("(~w) External service connected", [Socket]),
+init([State, Opts]) ->
Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all),
Shaper = gen_mod:get_opt(shaper_rule, Opts, fun acl:shaper_rules_validator/1, none),
HostOpts = case lists:keyfind(hosts, 1, Opts) of
server => ?MYNAME,
host_opts => HostOpts,
check_from => CheckFrom},
- ejabberd_hooks:run_fold(component_init, {ok, State1}, []).
+ ejabberd_hooks:run_fold(component_init, {ok, State1}, [Opts]).
-handle_stream_start(#{remote_server := RemoteServer,
+handle_stream_start(_StreamStart,
+ #{remote_server := RemoteServer,
+ lang := Lang,
host_opts := HostOpts} = State) ->
- NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of
- true ->
- HostOpts;
- false ->
- case dict:find(global, HostOpts) of
- {ok, GlobalPass} ->
- dict:from_list([{RemoteServer, GlobalPass}]);
- error ->
- HostOpts
- end
- end,
- {noreply, State#{host_opts => NewHostOpts}}.
-
-handshake(Digest, #{remote_server := RemoteServer,
- stream_id := StreamID,
- host_opts := HostOpts} = State) ->
- case dict:find(RemoteServer, HostOpts) of
- {ok, Password} ->
- case p1_sha:sha(<<StreamID/binary, Password/binary>>) of
- Digest ->
- lists:foreach(
- fun (H) ->
- ejabberd_router:register_route(H, ?MYNAME),
- ?INFO_MSG("Route registered for service ~p~n", [H]),
- ejabberd_hooks:run(component_connected, [H])
- end, dict:fetch_keys(HostOpts)),
- {ok, State};
- _ ->
- ?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]),
- {error, xmpp:serr_not_authorized(), State}
- end;
- _ ->
- ?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]),
- {error, xmpp:serr_not_authorized(), State}
+ case lists:member(RemoteServer, ?MYHOSTS) of
+ true ->
+ Txt = <<"Unable to register route on existing local domain">>,
+ xmpp_stream_in:send(State, xmpp:serr_conflict(Txt, Lang));
+ false ->
+ NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of
+ true ->
+ HostOpts;
+ false ->
+ case dict:find(global, HostOpts) of
+ {ok, GlobalPass} ->
+ dict:from_list([{RemoteServer, GlobalPass}]);
+ error ->
+ HostOpts
+ end
+ end,
+ State#{host_opts => NewHostOpts}
+ end.
+
+get_password_fun(#{remote_server := RemoteServer,
+ socket := Socket,
+ ip := IP,
+ host_opts := HostOpts}) ->
+ fun(_) ->
+ case dict:find(RemoteServer, HostOpts) of
+ {ok, Password} ->
+ {Password, undefined};
+ error ->
+ ?ERROR_MSG("(~s) Domain ~s is unconfigured for "
+ "external component from ~s",
+ [ejabberd_socket:pp(Socket), RemoteServer,
+ ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+ {false, undefined}
+ end
end.
+handle_auth_success(_, Mech, _,
+ #{remote_server := RemoteServer, host_opts := HostOpts,
+ socket := Socket, ip := IP} = State) ->
+ ?INFO_MSG("(~s) Accepted external component ~s authentication "
+ "for ~s from ~s",
+ [ejabberd_socket:pp(Socket), Mech, RemoteServer,
+ ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+ lists:foreach(
+ fun (H) ->
+ ejabberd_router:register_route(H, ?MYNAME),
+ ejabberd_hooks:run(component_connected, [H])
+ end, dict:fetch_keys(HostOpts)),
+ State.
+
+handle_auth_failure(_, Mech, Reason,
+ #{remote_server := RemoteServer,
+ socket := Socket, ip := IP} = State) ->
+ ?ERROR_MSG("(~s) Failed external component ~s authentication "
+ "for ~s from ~s: ~s",
+ [ejabberd_socket:pp(Socket), Mech, RemoteServer,
+ ejabberd_config:may_hide_data(jlib:ip_to_list(IP)),
+ Reason]),
+ State.
+
handle_authenticated_packet(Pkt, #{lang := Lang} = State) ->
From = xmpp:get_from(Pkt),
case check_from(From, State) of
true ->
To = xmpp:get_to(Pkt),
ejabberd_router:route(From, To, Pkt),
- {noreply, State};
+ State;
false ->
Txt = <<"Improper domain part of 'from' attribute">>,
Err = xmpp:serr_invalid_from(Txt, Lang),
xmpp_stream_in:send(State, Err)
end.
-handle_call(_Request, _From, State) ->
- Reply = ok,
- {reply, Reply, State}.
-
-handle_cast(_Msg, State) ->
- {noreply, State}.
-
handle_info({route, From, To, Packet}, #{access := Access} = State) ->
case acl:match_rule(global, Access, From) of
allow ->
Lang = xmpp:get_lang(Packet),
Err = xmpp:err_not_allowed(<<"Denied by ACL">>, Lang),
ejabberd_router:route_error(To, From, Packet, Err),
- {noreply, State}
+ State
end;
handle_info(Info, State) ->
?ERROR_MSG("Unexpected info: ~p", [Info]),
- {noreply, State}.
+ State.
terminate(Reason, #{stream_state := StreamState, host_opts := HostOpts}) ->
- ?INFO_MSG("External service disconnected: ~p", [Reason]),
case StreamState of
- session_established ->
+ established ->
lists:foreach(
fun(H) ->
ejabberd_router:unregister_route(H),
transform_listen_option(Opt, Opts) ->
[Opt|Opts].
-fsm_limit_opts(Opts) ->
- case lists:keysearch(max_fsm_queue, 1, Opts) of
- {value, {_, N}} when is_integer(N) ->
- [{max_queue, N}];
- _ ->
- case ejabberd_config:get_option(
- max_fsm_queue,
- fun(I) when is_integer(I), I > 0 -> I end) of
- undefined -> [];
- N -> [{max_queue, N}]
- end
- end.
-
-opt_type(max_fsm_queue) ->
- fun (I) when is_integer(I), I > 0 -> I end;
-opt_type(_) -> [max_fsm_queue].
+opt_type(_) -> [].
ejabberd_sm ! {unregister_iq_handler, Host, XMLNS}.
%% Why the hell do we have so many similar kicks?
-c2s_handle_info({noreply, #{lang := Lang} = State}, replaced) ->
+c2s_handle_info(#{lang := Lang} = State, replaced) ->
State1 = State#{replaced => true},
Err = xmpp:serr_conflict(<<"Replaced by new connection">>, Lang),
- ejabberd_c2s:send(State1, Err);
-c2s_handle_info({noreply, #{lang := Lang} = State}, kick) ->
+ {stop, ejabberd_c2s:send(State1, Err)};
+c2s_handle_info(#{lang := Lang} = State, kick) ->
Err = xmpp:serr_policy_violation(<<"has been kicked">>, Lang),
- c2s_handle_info({noreply, State}, {kick, kicked_by_admin, Err});
-c2s_handle_info({noreply, State}, {kick, _Reason, Err}) ->
- ejabberd_c2s:send(State, Err);
-c2s_handle_info({noreply, #{lang := Lang} = State}, {exit, Reason}) ->
+ c2s_handle_info(State, {kick, kicked_by_admin, Err});
+c2s_handle_info(State, {kick, _Reason, Err}) ->
+ {stop, ejabberd_c2s:send(State, Err)};
+c2s_handle_info(#{lang := Lang} = State, {exit, Reason}) ->
Err = xmpp:serr_conflict(Reason, Lang),
- ejabberd_c2s:send(State, Err);
-c2s_handle_info(Acc, _) ->
- Acc.
+ {stop, ejabberd_c2s:send(State, Err)};
+c2s_handle_info(State, _) ->
+ State.
%%====================================================================
%% gen_server callbacks
get_peer_certificate/1,
get_verify_result/1,
close/1,
+ pp/1,
sockname/1, peername/1]).
-include("ejabberd.hrl").
-export_type([socket/0, socket_state/0, sockmod/0]).
+-callback start({module(), socket_state()},
+ [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore.
+-callback start_link({module(), socket_state()},
+ [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore.
+-callback socket_type() -> xml_stream | independent | raw.
%%====================================================================
%% API
{error, _Reason} -> SockMod:close(Socket)
end,
ReceiverMod:become_controller(Receiver, Pid);
- {error, _Reason} ->
+ _ ->
SockMod:close(Socket),
case ReceiverMod of
ejabberd_receiver -> ReceiverMod:close(Receiver);
-spec send(socket_state(), iodata()) -> ok.
send(SocketData, Data) ->
+ ?DEBUG("Send XML on stream = ~p", [Data]),
case catch (SocketData#socket_state.sockmod):send(
SocketData#socket_state.socket, Data) of
ok -> ok;
fast_tls -> tls;
ezlib ->
case ezlib:get_sockmod(Socket) of
- tcp -> tcp_zlib;
- tls -> tls_zlib
+ gen_tcp -> tcp_zlib;
+ fast_tls -> tls_zlib
end;
ejabberd_bosh -> http_bind;
ejabberd_http_bind -> http_bind;
gen_tcp -> inet:peername(Socket);
_ -> SockMod:peername(Socket)
end.
+
+pp(#socket_state{receiver = Receiver} = State) ->
+ Transport = get_transport(State),
+ io_lib:format("~s|~w", [Transport, Receiver]).
-export([tolower/1, term_to_base64/1, base64_to_term/1,
decode_base64/1, encode_base64/1, ip_to_list/1,
atom_to_binary/1, binary_to_atom/1, tuple_to_binary/1,
- l2i/1, i2l/1, i2l/2, queue_drop_while/2,
- expr_to_term/1, term_to_expr/1]).
+ l2i/1, i2l/1, i2l/2, expr_to_term/1, term_to_expr/1,
+ queue_drop_while/2, queue_foldl/3, queue_foldr/3, queue_foreach/2]).
%% The following functions are used by gen_iq_handler.erl for providing backward
%% compatibility and must not be used in other parts of the code
empty ->
Q
end.
+
+-spec queue_foldl(fun((term(), T) -> T), T, ?TQUEUE) -> T.
+queue_foldl(F, Acc, Q) ->
+ case queue:out(Q) of
+ {{value, Item}, Q1} ->
+ Acc1 = F(Item, Acc),
+ queue_foldl(F, Acc1, Q1);
+ {empty, _} ->
+ Acc
+ end.
+
+-spec queue_foldr(fun((term(), T) -> T), T, ?TQUEUE) -> T.
+queue_foldr(F, Acc, Q) ->
+ case queue:out_r(Q) of
+ {{value, Item}, Q1} ->
+ Acc1 = F(Item, Acc),
+ queue_foldr(F, Acc1, Q1);
+ {empty, _} ->
+ Acc
+ end.
+
+-spec queue_foreach(fun((_) -> _), ?TQUEUE) -> ok.
+queue_foreach(F, Q) ->
+ case queue:out(Q) of
+ {{value, Item}, Q1} ->
+ F(Item),
+ queue_foreach(F, Q1);
+ {empty, _} ->
+ ok
+ end.
process_iq_set, 40),
ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
c2s_handle_info, 40),
- ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
- c2s_handle_info, 40),
mod_disco:register_feature(Host, ?NS_BLOCKING),
gen_iq_handler:add_iq_handler(ejabberd_sm, Host,
?NS_BLOCKING, ?MODULE, process_iq, IQDisc).
process_iq_get, 40),
ejabberd_hooks:delete(privacy_iq_set, Host, ?MODULE,
process_iq_set, 40),
+ ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
+ c2s_handle_info, 40),
mod_disco:unregister_feature(Host, ?NS_BLOCKING),
gen_iq_handler:remove_iq_handler(ejabberd_sm, Host,
?NS_BLOCKING).
{result, #block_list{items = Items}}
end.
--spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state().
-c2s_handle_info({noreply, #{user := U, server := S, resource := R} = State},
+-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
+c2s_handle_info(#{user := U, server := S, resource := R} = State,
{blocking, Action}) ->
SubEl = case Action of
{block, JIDs} ->
%% No need to replace active privacy list here,
%% blocking pushes are always accompanied by
%% Privacy List pushes
- ejabberd_c2s:send(State, PushIQ).
+ {stop, ejabberd_c2s:send(State, PushIQ)};
+c2s_handle_info(State, _) ->
+ State.
-spec db_mod(binary()) -> module().
db_mod(LServer) ->
mod_opt_type(_) ->
[].
-c2s_unauthenticated_packet({noreply, State}, #iq{type = T, sub_els = [_]} = IQ)
+c2s_unauthenticated_packet(State, #iq{type = T, sub_els = [_]} = IQ)
when T == get; T == set ->
case xmpp:get_subtag(IQ, #legacy_auth{}) of
#legacy_auth{} = Auth ->
{stop, authenticate(State, xmpp:set_els(IQ, [Auth]))};
false ->
- {noreply, State}
+ State
end;
-c2s_unauthenticated_packet(Acc, _) ->
- Acc.
+c2s_unauthenticated_packet(State, _) ->
+ State.
c2s_stream_features(Acc, LServer) ->
case gen_mod:is_loaded(LServer, ?MODULE) of
case ejabberd_auth:check_password_with_authmodule(
U, U, JID#jid.lserver, P, D, DGen) of
{true, AuthModule} ->
- case ejabberd_c2s:handle_auth_success(
- U, <<"legacy">>, AuthModule, State) of
- {noreply, State1} ->
- State2 = State1#{user := U},
- open_session(State2, IQ, R);
- Err ->
- Err
- end;
+ State1 = ejabberd_c2s:handle_auth_success(
+ U, <<"legacy">>, AuthModule, State),
+ State2 = State1#{user := U},
+ open_session(State2, IQ, R);
_ ->
Err = xmpp:make_error(IQ, xmpp:err_not_authorized()),
process_auth_failure(State, U, Err, 'not-authorized')
case ejabberd_c2s:bind(R, State) of
{ok, State1} ->
Res = xmpp:make_iq_result(IQ),
- case ejabberd_c2s:send(State1, Res) of
- {noreply, State2} ->
- {noreply, State2#{stream_authenticated := true,
- stream_state := session_established}};
- Err ->
- Err
- end;
+ State2 = ejabberd_c2s:send(State1, Res),
+ ejabberd_c2s:establish(State2);
{error, Err, State1} ->
Res = xmpp:make_error(IQ, Err),
ejabberd_c2s:send(State1, Res)
end.
process_auth_failure(State, User, StanzaErr, Reason) ->
- case ejabberd_c2s:send(State, StanzaErr) of
- {noreply, State1} ->
- ejabberd_c2s:handle_auth_failure(
- User, <<"legacy">>, Reason, State1);
- Err ->
- Err
- end.
+ State1 = ejabberd_c2s:send(State, StanzaErr),
+ ejabberd_c2s:handle_auth_failure(User, <<"legacy">>, Reason, State1).
get_info(Acc, _From, _To, _Node, _Lang) ->
Acc.
--spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state().
-c2s_handle_info({noreply, State}, {resend_offline, Flag}) ->
- {noreply, State#{resend_offline => Flag}};
-c2s_handle_info(Acc, _) ->
- Acc.
+-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
+c2s_handle_info(State, {resend_offline, Flag}) ->
+ {stop, State#{resend_offline => Flag}};
+c2s_handle_info(State, _) ->
+ State.
-spec handle_offline_query(iq()) -> iq().
handle_offline_query(#iq{from = #jid{luser = U1, lserver = S1},
Mod = gen_mod:db_mod(LServer, ?MODULE),
Mod:remove_user(LUser, LServer).
-c2s_handle_info({noreply, #{privacy_list := Old,
- user := U, server := S, resource := R} = State},
+c2s_handle_info(#{privacy_list := Old,
+ user := U, server := S, resource := R} = State,
{privacy_list, New, Name}) ->
List = if Old#userlist.name == New#userlist.name -> New;
true -> Old
sub_els = [#privacy_query{
lists = [#privacy_list{name = Name}]}]},
State1 = State#{privacy_list => List},
- ejabberd_c2s:send(State1, PushIQ);
-c2s_handle_info(Acc, _) ->
- Acc.
+ {stop, ejabberd_c2s:send(State1, PushIQ)};
+c2s_handle_info(State, _) ->
+ State.
-spec updated_list(userlist(), userlist(), userlist()) -> userlist().
updated_list(_, #userlist{name = OldName} = Old,
broadcast_stanza(Host, _Publisher, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM) ->
broadcast_stanza(Host, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM).
--spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state().
-c2s_handle_info({noreply, #{server := Server} = C2SState},
+-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state().
+c2s_handle_info(#{server := Server} = C2SState,
{pep_message, Feature, From, Packet}) ->
LServer = jid:nameprep(Server),
lists:foreach(
ok
end
end, mod_caps:list_features(C2SState)),
- {noreply, C2SState};
-c2s_handle_info({noreply, #{server := Server} = C2SState},
+ {stop, C2SState};
+c2s_handle_info(#{server := Server} = C2SState,
{send_filtered, {pep_message, Feature}, From, To, Packet}) ->
LServer = jid:nameprep(Server),
case mod_caps:get_user_caps(To, C2SState) of
error ->
ok
end,
- {noreply, C2SState};
-c2s_handle_info(Acc, _) ->
- Acc.
+ {stop, C2SState};
+c2s_handle_info(C2SState, _) ->
+ C2SState.
subscribed_nodes_by_jid(NotifyType, SubsByDepth) ->
NodesToDeliver = fun (Depth, Node, Subs, Acc) ->
Acc
end.
-c2s_unauthenticated_packet({noreply, #{ip := IP, server := Server} = State},
+c2s_unauthenticated_packet(#{ip := IP, server := Server} = State,
#iq{type = T, sub_els = [_]} = IQ)
when T == set; T == get ->
case xmpp:get_subtag(IQ, #register{}) of
ResIQ1 = xmpp:set_from_to(ResIQ, jid:make(Server), undefined),
{stop, ejabberd_c2s:send(State, ResIQ1)};
false ->
- {noreply, State}
+ State
end;
-c2s_unauthenticated_packet(Acc, _) ->
- Acc.
+c2s_unauthenticated_packet(State, _) ->
+ State.
process_iq(#iq{from = From} = IQ) ->
process_iq(IQ, jid:tolower(From)).
end,
ejabberd_sm:get_user_resources(User, Server)).
-c2s_handle_info({noreply, State}, {item, JID, Sub}) ->
- {noreply, roster_change(State, JID, Sub)};
-c2s_handle_info(Acc, _) ->
- Acc.
+c2s_handle_info(State, {item, JID, Sub}) ->
+ {stop, roster_change(State, JID, Sub)};
+c2s_handle_info(State, _) ->
+ State.
-spec roster_change(ejabberd_c2s:state(), jid(), subscription()) -> ejabberd_c2s:state().
roster_change(#{user := U, server := S, resource := R} = State,
--- /dev/null
+%%%-------------------------------------------------------------------
+%%% Created : 16 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
+%%%-------------------------------------------------------------------
+-module(mod_s2s_dialback).
+-behaviour(gen_mod).
+
+-protocol({xep, 220, '1.1.1'}).
+-protocol({xep, 185, '1.0'}).
+
+%% gen_mod API
+-export([start/2, stop/1, depends/2, mod_opt_type/1]).
+%% Hooks
+-export([s2s_out_auth_result/2, s2s_in_packet/2, s2s_out_packet/2,
+ s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]).
+
+-include("ejabberd.hrl").
+-include("xmpp.hrl").
+-include("logger.hrl").
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+start(Host, _Opts) ->
+ case ejabberd_s2s:tls_verify(Host) of
+ true ->
+ ?ERROR_MSG("disabling ~s for host ~s because option "
+ "'s2s_use_starttls' is set to 'required_trusted'",
+ [?MODULE, Host]);
+ false ->
+ ejabberd_hooks:add(s2s_out_init, Host, ?MODULE, s2s_out_init, 50),
+ ejabberd_hooks:add(s2s_out_closed, Host, ?MODULE, s2s_out_closed, 50),
+ ejabberd_hooks:add(s2s_in_pre_auth_features, Host, ?MODULE,
+ s2s_in_features, 50),
+ ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE,
+ s2s_in_features, 50),
+ ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE,
+ s2s_in_packet, 50),
+ ejabberd_hooks:add(s2s_in_authenticated_packet, Host, ?MODULE,
+ s2s_in_packet, 50),
+ ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE,
+ s2s_out_packet, 50),
+ ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE,
+ s2s_out_auth_result, 50)
+ end.
+
+stop(Host) ->
+ ejabberd_hooks:delete(s2s_out_init, Host, ?MODULE, s2s_out_init, 50),
+ ejabberd_hooks:delete(s2s_out_closed, Host, ?MODULE, s2s_out_closed, 50),
+ ejabberd_hooks:delete(s2s_in_pre_auth_features, Host, ?MODULE,
+ s2s_in_features, 50),
+ ejabberd_hooks:delete(s2s_in_post_auth_features, Host, ?MODULE,
+ s2s_in_features, 50),
+ ejabberd_hooks:delete(s2s_in_unauthenticated_packet, Host, ?MODULE,
+ s2s_in_packet, 50),
+ ejabberd_hooks:delete(s2s_in_authenticated_packet, Host, ?MODULE,
+ s2s_in_packet, 50),
+ ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE,
+ s2s_out_packet, 50),
+ ejabberd_hooks:delete(s2s_out_auth_result, Host, ?MODULE,
+ s2s_out_auth_result, 50).
+
+depends(_Host, _Opts) ->
+ [].
+
+mod_opt_type(_) ->
+ [].
+
+s2s_in_features(Acc, _) ->
+ [#db_feature{errors = true}|Acc].
+
+s2s_out_init({ok, State}, Opts) ->
+ case proplists:get_value(db_verify, Opts) of
+ {StreamID, Key, Pid} ->
+ %% This is an outbound s2s connection created at step 1.
+ %% The purpose of this connection is to verify dialback key ONLY.
+ %% The connection is not registered in s2s table and thus is not
+ %% seen by anyone.
+ %% The connection will be closed immediately after receiving the
+ %% verification response (at step 3)
+ {ok, State#{db_verify => {StreamID, Key, Pid}}};
+ undefined ->
+ {ok, State#{db_enabled => true}}
+ end;
+s2s_out_init(Acc, _Opts) ->
+ Acc.
+
+s2s_out_closed(#{server := LServer,
+ remote_server := RServer,
+ db_verify := {StreamID, _Key, _Pid}} = State, _Reason) ->
+ %% Outbound s2s verificating connection (created at step 1) is
+ %% closed suddenly without receiving the response.
+ %% Building a response on our own
+ Response = #db_verify{from = RServer, to = LServer,
+ id = StreamID, type = error,
+ sub_els = [mk_error(internal_server_error)]},
+ s2s_out_packet(State, Response);
+s2s_out_closed(State, _Reason) ->
+ State.
+
+s2s_out_auth_result(#{server := LServer,
+ remote_server := RServer,
+ db_verify := {StreamID, Key, _Pid}} = State,
+ _) ->
+ %% The temporary outbound s2s connect (intended for verification)
+ %% has passed authentication state (either successfully or not, no matter)
+ %% and at this point we can send verification request as described
+ %% in section 2.1.2, step 2
+ Request = #db_verify{from = LServer, to = RServer,
+ key = Key, id = StreamID},
+ {stop, ejabberd_s2s_out:send(State, Request)};
+s2s_out_auth_result(#{db_enabled := true,
+ socket := Socket, ip := IP,
+ server := LServer,
+ remote_server := RServer,
+ stream_remote_id := StreamID} = State, false) ->
+ %% SASL authentication has failed, retrying with dialback
+ %% Sending dialback request, section 2.1.1, step 1
+ ?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)",
+ [ejabberd_socket:pp(Socket), LServer, RServer,
+ ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]),
+ Key = make_key(LServer, RServer, StreamID),
+ State1 = maps:remove(stop_reason, State#{on_route => queue}),
+ State2 = ejabberd_s2s_out:send(State1, #db_result{from = LServer,
+ to = RServer,
+ key = Key}),
+ {stop, State2};
+s2s_out_auth_result(State, _) ->
+ State.
+
+s2s_in_packet(#{stream_id := StreamID} = State,
+ #db_result{from = From, to = To, key = Key, type = undefined}) ->
+ %% Received dialback request, section 2.2.1, step 1
+ try
+ ok = check_from_to(From, To),
+ %% We're creating a temporary outbound s2s connection to
+ %% send verification request and to receive verification response
+ {ok, Pid} = ejabberd_s2s_out:start(
+ To, From, [{db_verify, {StreamID, Key, self()}}]),
+ ejabberd_s2s_out:connect(Pid),
+ State
+ catch _:{badmatch, {error, Reason}} ->
+ send_db_result(State,
+ #db_verify{from = From, to = To, type = error,
+ sub_els = [mk_error(Reason)]})
+ end;
+s2s_in_packet(State, #db_verify{to = To, from = From, key = Key,
+ id = StreamID, type = undefined}) ->
+ %% Received verification request, section 2.2.2, step 2
+ Type = case make_key(To, From, StreamID) of
+ Key -> valid;
+ _ -> invalid
+ end,
+ Response = #db_verify{from = To, to = From, id = StreamID, type = Type},
+ ejabberd_s2s_in:send(State, Response);
+s2s_in_packet(State, Pkt) when is_record(Pkt, db_result);
+ is_record(Pkt, db_verify) ->
+ ?WARNING_MSG("Got stray dialback packet:~n~s", [xmpp:pp(Pkt)]),
+ State;
+s2s_in_packet(State, _) ->
+ State.
+
+s2s_out_packet(#{server := LServer,
+ remote_server := RServer,
+ db_verify := {StreamID, _Key, Pid}} = State,
+ #db_verify{from = RServer, to = LServer,
+ id = StreamID, type = Type} = Response)
+ when Type /= undefined ->
+ %% Received verification response, section 2.1.2, step 3
+ %% This is a response for the request sent at step 2
+ ejabberd_s2s_in:update_state(
+ Pid, fun(S) -> send_db_result(S, Response) end),
+ %% At this point the connection is no longer needed and we can terminate it
+ ejabberd_s2s_out:stop(State);
+s2s_out_packet(#{server := LServer, remote_server := RServer} = State,
+ #db_result{to = LServer, from = RServer,
+ type = Type} = Result) when Type /= undefined ->
+ %% Received dialback response, section 2.1.1, step 4
+ %% This is a response to the request sent at step 1
+ State1 = maps:remove(db_enabled, State),
+ case Type of
+ valid ->
+ State2 = ejabberd_s2s_out:handle_auth_success(<<"dialback">>, State1),
+ ejabberd_s2s_out:establish(State2);
+ _ ->
+ Reason = format_error(Result),
+ ejabberd_s2s_out:handle_auth_failure(<<"dialback">>, Reason, State1)
+ end;
+s2s_out_packet(State, Pkt) when is_record(Pkt, db_result);
+ is_record(Pkt, db_verify) ->
+ ?WARNING_MSG("Got stray dialback packet:~n~s", [xmpp:pp(Pkt)]),
+ State;
+s2s_out_packet(State, _) ->
+ State.
+
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
+-spec make_key(binary(), binary(), binary()) -> binary().
+make_key(From, To, StreamID) ->
+ Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end),
+ p1_sha:to_hexlist(
+ crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)),
+ [To, " ", From, " ", StreamID])).
+
+-spec send_db_result(ejabberd_s2s_in:state(), db_verify()) -> ejabberd_s2s_in:state().
+send_db_result(State, #db_verify{from = From, to = To,
+ type = Type, sub_els = Els}) ->
+ %% Sending dialback response, section 2.2.1, step 4
+ %% This is a response to the request received at step 1
+ Response = #db_result{from = To, to = From, type = Type, sub_els = Els},
+ State1 = ejabberd_s2s_in:send(State, Response),
+ case Type of
+ valid ->
+ State2 = ejabberd_s2s_in:handle_auth_success(
+ From, <<"dialback">>, undefined, State1),
+ ejabberd_s2s_in:establish(State2);
+ _ ->
+ Reason = format_error(Response),
+ ejabberd_s2s_in:handle_auth_failure(
+ From, <<"dialback">>, Reason, State1)
+ end.
+
+-spec check_from_to(binary(), binary()) -> ok | {error, forbidden | host_unknown}.
+check_from_to(From, To) ->
+ case ejabberd_router:is_my_route(To) of
+ false -> {error, host_unknown};
+ true ->
+ LServer = ejabberd_router:host_of_route(To),
+ case ejabberd_s2s:allow_host(LServer, From) of
+ true -> ok;
+ false -> {error, forbidden}
+ end
+ end.
+
+-spec mk_error(term()) -> stanza_error().
+mk_error(forbidden) ->
+ xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG);
+mk_error(host_unknown) ->
+ xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG);
+mk_error(_) ->
+ xmpp:err_internal_server_error().
+
+-spec format_error(db_result()) -> binary().
+format_error(#db_result{type = invalid}) ->
+ <<"invalid dialback key">>;
+format_error(#db_result{type = error, sub_els = Els}) ->
+ %% TODO: improve xmpp.erl
+ case xmpp:get_error(#message{sub_els = Els}) of
+ #stanza_error{reason = Reason} ->
+ erlang:atom_to_binary(Reason, latin1);
+ undefined ->
+ <<"unrecognized error">>
+ end;
+format_error(_) ->
+ <<"unexpected dialback result">>.
--- /dev/null
+%%%-------------------------------------------------------------------
+%%% Author : Holger Weiss <holger@zedat.fu-berlin.de>
+%%% Created : 25 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2016 ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
+%%%-------------------------------------------------------------------
+-module(mod_sm).
+-behaviour(gen_mod).
+-author('holger@zedat.fu-berlin.de').
+-protocol({xep, 198, '1.5.2'}).
+
+%% gen_mod API
+-export([start/2, stop/1, depends/2, mod_opt_type/1]).
+%% hooks
+-export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2,
+ c2s_authenticated_packet/2, c2s_unauthenticated_packet/2,
+ c2s_unbinded_packet/2, c2s_closed/2,
+ c2s_handle_send/3, c2s_filter_send/2, c2s_handle_info/2]).
+
+-include("xmpp.hrl").
+-include("logger.hrl").
+
+-define(is_sm_packet(Pkt),
+ is_record(Pkt, sm_enable) or
+ is_record(Pkt, sm_resume) or
+ is_record(Pkt, sm_a) or
+ is_record(Pkt, sm_r)).
+
+-type state() :: ejabberd_c2s:state().
+-type lqueue() :: {non_neg_integer(), queue:queue()}.
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+start(Host, _Opts) ->
+ ejabberd_hooks:add(c2s_init, ?MODULE, c2s_stream_init, 50),
+ ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE,
+ c2s_stream_started, 50),
+ ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE,
+ c2s_stream_features, 50),
+ ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE,
+ c2s_unauthenticated_packet, 50),
+ ejabberd_hooks:add(c2s_unbinded_packet, Host, ?MODULE,
+ c2s_unbinded_packet, 50),
+ ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE,
+ c2s_authenticated_packet, 50),
+ ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE,
+ c2s_handle_send, 50),
+ ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE,
+ c2s_filter_send, 50),
+ ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE,
+ c2s_handle_info, 50),
+ ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50).
+
+stop(Host) ->
+ %% TODO: do something with global 'c2s_init' hook
+ ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE,
+ c2s_stream_started, 50),
+ ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE,
+ c2s_stream_features, 50),
+ ejabberd_hooks:delete(c2s_unauthenticated_packet, Host, ?MODULE,
+ c2s_unauthenticated_packet, 50),
+ ejabberd_hooks:delete(c2s_unbinded_packet, Host, ?MODULE,
+ c2s_unbinded_packet, 50),
+ ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE,
+ c2s_authenticated_packet, 50),
+ ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE,
+ c2s_handle_send, 50),
+ ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE,
+ c2s_filter_send, 50),
+ ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE,
+ c2s_handle_info, 50),
+ ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50).
+
+depends(_Host, _Opts) ->
+ [].
+
+c2s_stream_init({ok, State}, Opts) ->
+ MgmtOpts = lists:filter(
+ fun({stream_management, _}) -> true;
+ ({max_ack_queue, _}) -> true;
+ ({resume_timeout, _}) -> true;
+ ({max_resume_timeout, _}) -> true;
+ ({ack_timeout, _}) -> true;
+ ({resend_on_timeout, _}) -> true;
+ (_) -> false
+ end, Opts),
+ {ok, State#{mgmt_options => MgmtOpts}};
+c2s_stream_init(Acc, _Opts) ->
+ Acc.
+
+c2s_stream_started(#{lserver := LServer, mgmt_options := Opts} = State,
+ _StreamStart) ->
+ State1 = maps:remove(mgmt_options, State),
+ ResumeTimeout = get_resume_timeout(LServer, Opts),
+ MaxResumeTimeout = get_max_resume_timeout(LServer, Opts, ResumeTimeout),
+ State1#{mgmt_state => inactive,
+ mgmt_max_queue => get_max_ack_queue(LServer, Opts),
+ mgmt_timeout => ResumeTimeout,
+ mgmt_max_timeout => MaxResumeTimeout,
+ mgmt_ack_timeout => get_ack_timeout(LServer, Opts),
+ mgmt_resend => get_resend_on_timeout(LServer, Opts)};
+c2s_stream_started(State, _StreamStart) ->
+ State.
+
+c2s_stream_features(Acc, Host) ->
+ case gen_mod:is_loaded(Host, ?MODULE) of
+ true ->
+ [#feature_sm{xmlns = ?NS_STREAM_MGMT_2},
+ #feature_sm{xmlns = ?NS_STREAM_MGMT_3}|Acc];
+ false ->
+ Acc
+ end.
+
+c2s_unauthenticated_packet(State, Pkt) when ?is_sm_packet(Pkt) ->
+ %% XEP-0198 says: "For client-to-server connections, the client MUST NOT
+ %% attempt to enable stream management until after it has completed Resource
+ %% Binding unless it is resuming a previous session". However, it also
+ %% says: "Stream management errors SHOULD be considered recoverable", so we
+ %% won't bail out.
+ Err = #sm_failed{reason = 'unexpected-request', xmlns = ?NS_STREAM_MGMT_3},
+ {stop, send(State, Err)};
+c2s_unauthenticated_packet(State, _Pkt) ->
+ State.
+
+c2s_unbinded_packet(State, #sm_resume{} = Pkt) ->
+ case handle_resume(State, Pkt) of
+ {ok, ResumedState} ->
+ {stop, ResumedState};
+ error ->
+ {stop, State}
+ end;
+c2s_unbinded_packet(State, Pkt) when ?is_sm_packet(Pkt) ->
+ c2s_unauthenticated_packet(State, Pkt);
+c2s_unbinded_packet(State, _Pkt) ->
+ State.
+
+c2s_authenticated_packet(#{mgmt_state := MgmtState} = State, Pkt)
+ when ?is_sm_packet(Pkt) ->
+ if MgmtState == pending; MgmtState == active ->
+ {stop, perform_stream_mgmt(Pkt, State)};
+ true ->
+ {stop, negotiate_stream_mgmt(Pkt, State)}
+ end;
+c2s_authenticated_packet(State, Pkt) ->
+ update_num_stanzas_in(State, Pkt).
+
+c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
+ when MgmtState == pending; MgmtState == active ->
+ State1 = mgmt_queue_add(State, Pkt),
+ case Result of
+ ok when ?is_stanza(Pkt) ->
+ send_ack(State1);
+ ok ->
+ State1;
+ {error, _} ->
+ transition_to_pending(State1)
+ end;
+c2s_handle_send(State, _Pkt, _Result) ->
+ State.
+
+c2s_filter_send(Pkt, _State) ->
+ Pkt.
+
+c2s_handle_info(#{mgmt_ack_timer := T, jid := JID} = State,
+ {timeout, T, ack_timeout}) ->
+ ?DEBUG("Timeout waiting for stream management acknowledgement of ~s",
+ [jid:to_string(JID)]),
+ State1 = ejabberd_c2s:close(State, _SendTrailer = false),
+ c2s_closed(State1, ack_timeout);
+c2s_handle_info(State, _) ->
+ State.
+
+c2s_closed(#{mgmt_state := active} = State, Reason) when Reason /= normal ->
+ {stop, transition_to_pending(State)};
+c2s_closed(State, _) ->
+ State.
+
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
+-spec negotiate_stream_mgmt(xmpp_element(), state()) -> state().
+negotiate_stream_mgmt(Pkt, State) ->
+ Xmlns = xmpp:get_ns(Pkt),
+ case Pkt of
+ #sm_enable{} ->
+ handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt);
+ _ ->
+ Res = if is_record(Pkt, sm_a);
+ is_record(Pkt, sm_r);
+ is_record(Pkt, sm_resume) ->
+ #sm_failed{reason = 'unexpected-request',
+ xmlns = Xmlns};
+ true ->
+ #sm_failed{reason = 'bad-request',
+ xmlns = Xmlns}
+ end,
+ send(State, Res)
+ end.
+
+-spec perform_stream_mgmt(xmpp_element(), state()) -> state().
+perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) ->
+ case xmpp:get_ns(Pkt) of
+ Xmlns ->
+ case Pkt of
+ #sm_r{} ->
+ handle_r(State);
+ #sm_a{} ->
+ handle_a(State, Pkt);
+ _ ->
+ Res = if is_record(Pkt, sm_enable);
+ is_record(Pkt, sm_resume) ->
+ #sm_failed{reason = 'unexpected-request',
+ xmlns = Xmlns};
+ true ->
+ #sm_failed{reason = 'bad-request',
+ xmlns = Xmlns}
+ end,
+ send(State, Res)
+ end;
+ _ ->
+ send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns})
+ end.
+
+-spec handle_enable(state(), sm_enable()) -> state().
+handle_enable(#{mgmt_timeout := DefaultTimeout,
+ mgmt_max_timeout := MaxTimeout,
+ xmlns := Xmlns, jid := JID} = State,
+ #sm_enable{resume = Resume, max = Max}) ->
+ Timeout = if Resume == false ->
+ 0;
+ Max /= undefined, Max > 0, Max =< MaxTimeout ->
+ Max;
+ true ->
+ DefaultTimeout
+ end,
+ Res = if Timeout > 0 ->
+ ?INFO_MSG("Stream management with resumption enabled for ~s",
+ [jid:to_string(JID)]),
+ #sm_enabled{xmlns = Xmlns,
+ id = make_resume_id(State),
+ resume = true,
+ max = Timeout};
+ true ->
+ ?INFO_MSG("Stream management without resumption enabled for ~s",
+ [jid:to_string(JID)]),
+ #sm_enabled{xmlns = Xmlns}
+ end,
+ State1 = State#{mgmt_state => active,
+ mgmt_queue => queue_new(),
+ mgmt_timeout => Timeout * 1000},
+ send(State1, Res).
+
+-spec handle_r(state()) -> state().
+handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) ->
+ Res = #sm_a{xmlns = Xmlns, h = H},
+ send(State, Res).
+
+-spec handle_a(state(), sm_a()) -> state().
+handle_a(State, #sm_a{h = H}) ->
+ State1 = check_h_attribute(State, H),
+ resend_ack(State1).
+
+-spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}.
+handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State,
+ #sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) ->
+ R = case inherit_session_state(State, PrevID) of
+ {ok, InheritedState} ->
+ {ok, InheritedState, H};
+ {error, Err, InH} ->
+ {error, #sm_failed{reason = 'item-not-found',
+ h = InH, xmlns = Xmlns}, Err};
+ {error, Err} ->
+ {error, #sm_failed{reason = 'item-not-found',
+ xmlns = Xmlns}, Err}
+ end,
+ case R of
+ {ok, ResumedState, NumHandled} ->
+ State1 = check_h_attribute(ResumedState, NumHandled),
+ #{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1,
+ AttrId = make_resume_id(State1),
+ State2 = send(State1, #sm_resumed{xmlns = AttrXmlns,
+ h = AttrH,
+ previd = AttrId}),
+ State3 = resend_unacked_stanzas(State2),
+ State4 = send(State3, #sm_r{xmlns = AttrXmlns}),
+ %% TODO: move this to mod_client_state
+ %% csi_flush_queue(State4),
+ State5 = ejabberd_hooks:run_fold(c2s_session_resumed, LServer, State4, []),
+ ?INFO_MSG("(~s) Resumed session for ~s",
+ [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
+ {ok, State5};
+ {error, El, Msg} ->
+ ?INFO_MSG("Cannot resume session for ~s: ~s", [jid:to_string(JID), Msg]),
+ {error, send(State, El)}
+ end.
+
+-spec transition_to_pending(state()) -> state().
+transition_to_pending(#{mgmt_state := active} = State) ->
+ %% TODO
+ State;
+transition_to_pending(State) ->
+ State.
+
+-spec check_h_attribute(state(), non_neg_integer()) -> state().
+check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, jid := JID} = State, H)
+ when H > NumStanzasOut ->
+ ?DEBUG("~s acknowledged ~B stanzas, but only ~B were sent",
+ [jid:to_string(JID), H, NumStanzasOut]),
+ mgmt_queue_drop(State#{mgmt_stanzas_out => H}, NumStanzasOut);
+check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, jid := JID} = State, H) ->
+ ?DEBUG("~s acknowledged ~B of ~B stanzas",
+ [jid:to_string(JID), H, NumStanzasOut]),
+ mgmt_queue_drop(State, H).
+
+-spec update_num_stanzas_in(state(), xmpp_element()) -> state().
+update_num_stanzas_in(#{mgmt_state := MgmtState,
+ mgmt_stanzas_in := NumStanzasIn} = State, El)
+ when MgmtState == active; MgmtState == pending ->
+ NewNum = case {xmpp:is_stanza(El), NumStanzasIn} of
+ {true, 4294967295} ->
+ 0;
+ {true, Num} ->
+ Num + 1;
+ {false, Num} ->
+ Num
+ end,
+ State#{mgmt_stanzas_in => NewNum};
+update_num_stanzas_in(State, _El) ->
+ State.
+
+send_ack(#{mgmt_ack_timer := _} = State) ->
+ State;
+send_ack(#{mgmt_xmlns := Xmlns,
+ mgmt_stanzas_out := NumStanzasOut,
+ mgmt_ack_timeout := AckTimeout} = State) ->
+ State1 = send(State, #sm_r{xmlns = Xmlns}),
+ TRef = erlang:start_timer(AckTimeout, self(), ack_timeout),
+ State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}.
+
+resend_ack(#{mgmt_ack_timer := _,
+ mgmt_queue := Queue,
+ mgmt_stanzas_out := NumStanzasOut,
+ mgmt_stanzas_req := NumStanzasReq} = State) ->
+ State1 = cancel_ack_timer(State),
+ case NumStanzasReq < NumStanzasOut andalso not queue_is_empty(Queue) of
+ true -> send_ack(State1);
+ false -> State1
+ end;
+resend_ack(State) ->
+ State.
+
+-spec mgmt_queue_add(state(), xmpp_element()) -> state().
+mgmt_queue_add(#{mgmt_stanzas_out := NumStanzasOut,
+ mgmt_queue := Queue} = State, Stanza) when ?is_stanza(Stanza) ->
+ NewNum = case NumStanzasOut of
+ 4294967295 -> 0;
+ Num -> Num + 1
+ end,
+ Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Stanza}, Queue),
+ State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum},
+ check_queue_length(State1);
+mgmt_queue_add(State, _Nonza) ->
+ State.
+
+-spec mgmt_queue_drop(state(), non_neg_integer()) -> state().
+mgmt_queue_drop(#{mgmt_queue := Queue} = State, NumHandled) ->
+ NewQueue = queue_dropwhile(
+ fun({N, _T, _E}) -> N =< NumHandled end, Queue),
+ State#{mgmt_queue => NewQueue}.
+
+-spec check_queue_length(state()) -> state().
+check_queue_length(#{mgmt_max_queue := Limit} = State)
+ when Limit == infinity; Limit == exceeded ->
+ State;
+check_queue_length(#{mgmt_queue := Queue, mgmt_max_queue := Limit} = State) ->
+ case queue_len(Queue) > Limit of
+ true ->
+ State#{mgmt_max_queue => exceeded};
+ false ->
+ State
+ end.
+
+-spec resend_unacked_stanzas(state()) -> state().
+resend_unacked_stanzas(#{mgmt_state := MgmtState,
+ mgmt_queue := {QueueLen, _} = Queue,
+ jid := JID} = State)
+ when (MgmtState == active orelse
+ MgmtState == pending orelse
+ MgmtState == timeout) andalso QueueLen > 0 ->
+ ?DEBUG("Resending ~B unacknowledged stanza(s) to ~s",
+ [QueueLen, jid:to_string(JID)]),
+ queue_foldl(
+ fun({_, Time, Pkt}, AccState) ->
+ NewPkt = add_resent_delay_info(AccState, Pkt, Time),
+ send(AccState, NewPkt)
+ end, State, Queue);
+resend_unacked_stanzas(State) ->
+ State.
+
+-spec route_unacked_stanzas(state()) -> ok.
+route_unacked_stanzas(#{mgmt_state := MgmtState,
+ mgmt_resend := MgmtResend,
+ lang := Lang, user := User,
+ jid := JID, lserver := LServer,
+ mgmt_queue := {QueueLen, _} = Queue,
+ resource := Resource} = State)
+ when (MgmtState == active orelse
+ MgmtState == pending orelse
+ MgmtState == timeout) andalso QueueLen > 0 ->
+ ResendOnTimeout = case MgmtResend of
+ Resend when is_boolean(Resend) ->
+ Resend;
+ if_offline ->
+ case ejabberd_sm:get_user_resources(User, Resource) of
+ [Resource] ->
+ %% Same resource opened new session
+ true;
+ [] -> true;
+ _ -> false
+ end
+ end,
+ ?DEBUG("Re-routing ~B unacknowledged stanza(s) to ~s",
+ [QueueLen, jid:to_string(JID)]),
+ queue_foreach(
+ fun({_, _Time, #presence{from = From}}) ->
+ ?DEBUG("Dropping presence stanza from ~s", [jid:to_string(From)]);
+ ({_, _Time, #iq{} = El}) ->
+ Txt = <<"User session terminated">>,
+ route_error(El, xmpp:err_service_unavailable(Txt, Lang));
+ ({_, _Time, #message{from = From, meta = #{carbon_copy := true}}}) ->
+ %% XEP-0280 says: "When a receiving server attempts to deliver a
+ %% forked message, and that message bounces with an error for
+ %% any reason, the receiving server MUST NOT forward that error
+ %% back to the original sender." Resending such a stanza could
+ %% easily lead to unexpected results as well.
+ ?DEBUG("Dropping forwarded message stanza from ~s",
+ [jid:to_string(From)]);
+ ({_, Time, El}) ->
+ case ejabberd_hooks:run_fold(message_is_archived,
+ LServer, false,
+ [State, El]) of
+ true ->
+ ?DEBUG("Dropping archived message stanza from ~s",
+ [jid:to_string(xmpp:get_from(El))]);
+ false when ResendOnTimeout ->
+ NewEl = add_resent_delay_info(State, El, Time),
+ route(NewEl);
+ false ->
+ Txt = <<"User session terminated">>,
+ route_error(El, xmpp:err_service_unavailable(Txt, Lang))
+ end
+ end, Queue);
+route_unacked_stanzas(_State) ->
+ ok.
+
+-spec inherit_session_state(state(), binary()) -> {ok, state()} |
+ {error, binary()} |
+ {error, binary(), non_neg_integer()}.
+inherit_session_state(#{user := U, server := S} = State, ResumeID) ->
+ case jlib:base64_to_term(ResumeID) of
+ {term, {R, Time}} ->
+ case ejabberd_sm:get_session_pid(U, S, R) of
+ none ->
+ case ejabberd_sm:get_offline_info(Time, U, S, R) of
+ none ->
+ {error, <<"Previous session PID not found">>};
+ Info ->
+ case proplists:get_value(num_stanzas_in, Info) of
+ undefined ->
+ {error, <<"Previous session timed out">>};
+ H ->
+ {error, <<"Previous session timed out">>, H}
+ end
+ end;
+ OldPID ->
+ OldSID = {Time, OldPID},
+ try resume_session(OldSID, State) of
+ {resume, OldState} ->
+ State1 = ejabberd_c2s:copy_state(State, OldState),
+ State2 = ejabberd_c2s:open_session(State1),
+ {ok, State2};
+ {error, Msg} ->
+ {error, Msg}
+ catch exit:{noproc, _} ->
+ {error, <<"Previous session PID is dead">>};
+ exit:{timeout, _} ->
+ {error, <<"Session state copying timed out">>}
+ end
+ end;
+ _ ->
+ {error, <<"Invalid 'previd' value">>}
+ end.
+
+-spec resume_session({integer(), pid()}, state()) -> {resume, state()} |
+ {error, binary()}.
+resume_session({Time, Pid}, _State) ->
+ ejabberd_c2s:call(Pid, {resume_session, Time}, timer:seconds(15)).
+
+-spec make_resume_id(state()) -> binary().
+make_resume_id(#{sid := {Time, _}, resource := Resource}) ->
+ jlib:term_to_base64({Resource, Time}).
+
+-spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza().
+add_resent_delay_info(_State, #iq{} = El, _Time) ->
+ El;
+add_resent_delay_info(#{lserver := LServer}, El, Time) ->
+ xmpp_util:add_delay_info(El, jid:make(LServer), Time, <<"Resent">>).
+
+-spec route(stanza()) -> ok.
+route(Pkt) ->
+ From = xmpp:get_from(Pkt),
+ To = xmpp:get_to(Pkt),
+ ejabberd_router:route(From, To, Pkt).
+
+-spec route_error(stanza(), stanza_error()) -> ok.
+route_error(Pkt, Err) ->
+ From = xmpp:get_from(Pkt),
+ To = xmpp:get_to(Pkt),
+ ejabberd_router:route_error(To, From, Pkt, Err).
+
+-spec send(state(), xmpp_element()) -> state().
+send(#{mod := Mod} = State, Pkt) ->
+ Mod:send(State, Pkt).
+
+-spec queue_new() -> lqueue().
+queue_new() ->
+ {0, queue:new()}.
+
+-spec queue_in(term(), lqueue()) -> lqueue().
+queue_in(Elem, {N, Q}) ->
+ {N+1, queue:in(Elem, Q)}.
+
+-spec queue_len(lqueue()) -> non_neg_integer().
+queue_len({N, _}) ->
+ N.
+
+-spec queue_foldl(fun((term(), T) -> T), T, lqueue()) -> T.
+queue_foldl(F, Acc, {_N, Q}) ->
+ jlib:queue_foldl(F, Acc, Q).
+
+-spec queue_foreach(fun((_) -> _), lqueue()) -> ok.
+queue_foreach(F, {_N, Q}) ->
+ jlib:queue_foreach(F, Q).
+
+-spec queue_dropwhile(fun((term()) -> boolean()), lqueue()) -> lqueue().
+queue_dropwhile(F, {N, Q}) ->
+ case queue:peek(Q) of
+ {value, Item} ->
+ case F(Item) of
+ true ->
+ queue_dropwhile(F, {N-1, queue:drop(Q)});
+ false ->
+ {N, Q}
+ end;
+ empty ->
+ {N, Q}
+ end.
+
+-spec queue_is_empty(lqueue()) -> boolean().
+queue_is_empty({N, _Q}) ->
+ N == 0.
+
+-spec cancel_ack_timer(state()) -> state().
+cancel_ack_timer(#{mgmt_ack_timer := TRef} = State) ->
+ case erlang:cancel_timer(TRef) of
+ false ->
+ receive {timeout, TRef, _} -> ok
+ after 0 -> ok
+ end;
+ _ ->
+ ok
+ end,
+ maps:remove(mgmt_ack_timer, State);
+cancel_ack_timer(State) ->
+ State.
+
+%%%===================================================================
+%%% Configuration processing
+%%%===================================================================
+get_max_ack_queue(Host, Opts) ->
+ VFun = mod_opt_type(max_ack_queue),
+ case gen_mod:get_module_opt(Host, ?MODULE, max_ack_queue, VFun) of
+ undefined -> gen_mod:get_opt(max_ack_queue, Opts, VFun, 1000);
+ Limit -> Limit
+ end.
+
+get_resume_timeout(Host, Opts) ->
+ VFun = mod_opt_type(resume_timeout),
+ case gen_mod:get_module_opt(Host, ?MODULE, resume_timeout, VFun) of
+ undefined -> gen_mod:get_opt(resume_timeout, Opts, VFun, 300);
+ Timeout -> Timeout
+ end.
+
+get_max_resume_timeout(Host, Opts, ResumeTimeout) ->
+ VFun = mod_opt_type(max_resume_timeout),
+ case gen_mod:get_module_opt(Host, ?MODULE, max_resume_timeout, VFun) of
+ undefined ->
+ case gen_mod:get_opt(max_resume_timeout, Opts, VFun) of
+ undefined -> ResumeTimeout;
+ Max when Max >= ResumeTimeout -> Max;
+ _ -> ResumeTimeout
+ end;
+ Max when Max >= ResumeTimeout -> Max;
+ _ -> ResumeTimeout
+ end.
+
+get_ack_timeout(Host, Opts) ->
+ VFun = mod_opt_type(ack_timeout),
+ T = case gen_mod:get_module_opt(Host, ?MODULE, ack_timeout, VFun) of
+ undefined -> gen_mod:get_opt(ack_timeout, Opts, VFun, 60);
+ AckTimeout -> AckTimeout
+ end,
+ case T of
+ infinity -> infinity;
+ _ -> timer:seconds(T)
+ end.
+
+get_resend_on_timeout(Host, Opts) ->
+ VFun = mod_opt_type(resend_on_timeout),
+ case gen_mod:get_module_opt(Host, ?MODULE, resend_on_timeout, VFun) of
+ undefined -> gen_mod:get_opt(resend_on_timeout, Opts, VFun, false);
+ Resend -> Resend
+ end.
+
+mod_opt_type(max_ack_queue) ->
+ fun(I) when is_integer(I), I > 0 -> I;
+ (infinity) -> infinity
+ end;
+mod_opt_type(resume_timeout) ->
+ fun(I) when is_integer(I), I >= 0 -> I end;
+mod_opt_type(max_resume_timeout) ->
+ fun(I) when is_integer(I), I >= 0 -> I end;
+mod_opt_type(ack_timeout) ->
+ fun(I) when is_integer(I), I > 0 -> I;
+ (infinity) -> infinity
+ end;
+mod_opt_type(resend_on_timeout) ->
+ fun(B) when is_boolean(B) -> B;
+ (if_offline) -> if_offline
+ end;
+mod_opt_type(_) ->
+ [max_ack_queue, resume_timeout, max_resume_timeout, ack_timeout,
+ resend_on_timeout].
-protocol({rfc, 6120}).
%% API
--export([start/3, call/3, cast/2, reply/2, send/2, send_error/3,
- get_transport/1, change_shaper/2]).
+-export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1,
+ send/2, close/1, close/2, send_error/3, establish/1,
+ get_transport/1, change_shaper/2, set_timeout/2, format_error/1]).
%% gen_server callbacks
-export([init/1, handle_cast/2, handle_call/3, handle_info/2,
terminate/2, code_change/3]).
+%%-define(DBGFSM, true).
+-ifdef(DBGFSM).
+-define(FSMOPTS, [{debug, [trace]}]).
+-else.
+-define(FSMOPTS, []).
+-endif.
+
-include("xmpp.hrl").
-type state() :: map().
--type next_state() :: {noreply, state()} | {stop, term(), state()}.
+-type stop_reason() :: {stream, reset | stream_error()} |
+ {tls, term()} |
+ {socket, inet:posix() | closed | timeout}.
-callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
--callback handle_stream_start(state()) -> next_state().
--callback handle_stream_end(state()) -> next_state().
--callback handle_stream_close(state()) -> next_state().
--callback handle_cdata(binary(), state()) -> next_state().
--callback handle_unauthenticated_packet(xmpp_element(), state()) -> next_state().
--callback handle_authenticated_packet(xmpp_element(), state()) -> next_state().
--callback handle_unbinded_packet(xmpp_element(), state()) -> next_state().
--callback handle_auth_success(binary(), binary(), module(), state()) -> next_state().
--callback handle_auth_failure(binary(), binary(), atom(), state()) -> next_state().
--callback handle_send(ok | {error, atom()},
- xmpp_element(), fxml:xmlel(), binary(), state()) -> next_state().
--callback init_sasl(state()) -> cyrsasl:sasl_state().
+-callback handle_cast(term(), state()) -> state().
+-callback handle_call(term(), term(), state()) -> state().
+-callback handle_info(term(), state()) -> state().
+-callback terminate(term(), state()) -> any().
+-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
+-callback handle_stream_start(state()) -> state().
+-callback handle_stream_end(stop_reason(), state()) -> state().
+-callback handle_stream_close(stop_reason(), state()) -> state().
+-callback handle_cdata(binary(), state()) -> state().
+-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
+-callback handle_authenticated_packet(xmpp_element(), state()) -> state().
+-callback handle_unbinded_packet(xmpp_element(), state()) -> state().
+-callback handle_auth_success(binary(), binary(), module(), state()) -> state().
+-callback handle_auth_failure(binary(), binary(), atom(), state()) -> state().
+-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
+-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
+-callback get_password_fun(state()) -> fun().
+-callback check_password_fun(state()) -> fun().
+-callback check_password_digest_fun(state()) -> fun().
-callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}.
--callback handshake(binary(), state()) -> {ok, state()} | {error, stream_error(), state()}.
-callback compress_methods(state()) -> [binary()].
-callback tls_options(state()) -> [proplists:property()].
-callback tls_required(state()) -> boolean().
--callback sasl_mechanisms(state()) -> [binary()].
+-callback tls_verify(state()) -> boolean().
-callback unauthenticated_stream_features(state()) -> [xmpp_element()].
-callback authenticated_stream_features(state()) -> [xmpp_element()].
%% All callbacks are optional
-optional_callbacks([init/1,
+ handle_cast/2,
+ handle_call/3,
+ handle_info/2,
+ terminate/2,
+ code_change/3,
handle_stream_start/1,
- handle_stream_end/1,
- handle_stream_close/1,
+ handle_stream_end/2,
+ handle_stream_close/2,
handle_cdata/2,
handle_authenticated_packet/2,
handle_unauthenticated_packet/2,
handle_unbinded_packet/2,
handle_auth_success/4,
handle_auth_failure/4,
- handle_send/5,
- init_sasl/1,
+ handle_send/3,
+ handle_recv/3,
+ get_password_fun/1,
+ check_password_fun/1,
+ check_password_digest_fun/1,
bind/2,
- handshake/2,
compress_methods/1,
tls_options/1,
tls_required/1,
- sasl_mechanisms/1,
+ tls_verify/1,
unauthenticated_stream_features/1,
authenticated_stream_features/1]).
%%% API
%%%===================================================================
start(Mod, Args, Opts) ->
- gen_server:start(?MODULE, [Mod|Args], Opts).
+ gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+
+start_link(Mod, Args, Opts) ->
+ gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
call(Ref, Msg, Timeout) ->
gen_server:call(Ref, Msg, Timeout).
reply(Ref, Reply) ->
gen_server:reply(Ref, Reply).
--spec send(state(), xmpp_element()) -> next_state().
-send(State, Pkt) ->
- send_element(State, Pkt).
+-spec stop(pid()) -> ok;
+ (state()) -> no_return().
+stop(Pid) when is_pid(Pid) ->
+ cast(Pid, stop);
+stop(#{owner := Owner} = State) when Owner == self() ->
+ terminate(normal, State),
+ exit(normal);
+stop(_) ->
+ erlang:error(badarg).
-get_transport(#{sockmod := SockMod, socket := Socket}) ->
- SockMod:get_transport(Socket).
+-spec send(pid(), xmpp_element()) -> ok;
+ (state(), xmpp_element()) -> state().
+send(Pid, Pkt) when is_pid(Pid) ->
+ cast(Pid, {send, Pkt});
+send(#{owner := Owner} = State, Pkt) when Owner == self() ->
+ send_element(State, Pkt);
+send(_, _) ->
+ erlang:error(badarg).
+
+-spec close(pid()) -> ok;
+ (state()) -> state().
+close(Ref) ->
+ close(Ref, true).
+
+-spec close(pid(), boolean()) -> ok;
+ (state(), boolean()) -> state().
+close(Pid, SendTrailer) when is_pid(Pid) ->
+ cast(Pid, {close, SendTrailer});
+close(#{owner := Owner} = State, SendTrailer) when Owner == self() ->
+ if SendTrailer -> send_trailer(State);
+ true -> close_socket(State)
+ end;
+close(_, _) ->
+ erlang:error(badarg).
+
+-spec establish(state()) -> state().
+establish(State) ->
+ process_stream_established(State).
+
+-spec set_timeout(state(), non_neg_integer() | infinity) -> state().
+set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
+ case Timeout of
+ infinity -> State#{stream_timeout => infinity};
+ _ ->
+ Time = p1_time_compat:monotonic_time(milli_seconds),
+ State#{stream_timeout => {Timeout, Time}}
+ end;
+set_timeout(_, _) ->
+ erlang:error(badarg).
+
+get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner})
+ when Owner == self() ->
+ SockMod:get_transport(Socket);
+get_transport(_) ->
+ erlang:error(badarg).
-spec change_shaper(state(), shaper:shaper()) -> ok.
-change_shaper(#{sockmod := SockMod, socket := Socket}, Shaper) ->
- SockMod:change_shaper(Socket, Shaper).
+change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper)
+ when Owner == self() ->
+ SockMod:change_shaper(Socket, Shaper);
+change_shaper(_, _) ->
+ erlang:error(badarg).
+
+-spec format_error(stop_reason()) -> binary().
+format_error({socket, Reason}) ->
+ format("Connection failed: ~s", [format_inet_error(Reason)]);
+format_error({stream, reset}) ->
+ <<"Stream reset by peer">>;
+format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
+ format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
+format_error({tls, Reason}) ->
+ format("TLS failed: ~w", [Reason]);
+format_error(Err) ->
+ format("Unrecognized error: ~w", [Err]).
%%%===================================================================
%%% gen_server callbacks
{_, XS} -> XS;
false -> false
end,
- TLSEnabled = proplists:get_bool(tls, Opts),
+ Encrypted = proplists:get_bool(tls, Opts),
SocketMonitor = SockMod:monitor(Socket),
case peername(SockMod, Socket) of
{ok, IP} ->
- State = #{mod => Module,
+ Time = p1_time_compat:monotonic_time(milli_seconds),
+ State = #{owner => self(),
+ mod => Module,
socket => Socket,
sockmod => SockMod,
socket_monitor => SocketMonitor,
+ stream_timeout => {timer:seconds(30), Time},
+ stream_direction => in,
stream_id => new_id(),
stream_state => wait_for_stream,
+ stream_header_sent => false,
stream_restarted => false,
stream_compressed => false,
- stream_tlsed => TLSEnabled,
+ stream_encrypted => Encrypted,
stream_version => {1,0},
stream_authenticated => false,
xml_socket => XMLSocket,
resource => <<"">>,
lserver => <<"">>,
ip => IP},
- try Module:init([State, Opts])
- catch _:undef -> {ok, State}
+ case try Module:init([State, Opts])
+ catch _:undef -> {ok, State}
+ end of
+ {ok, State1} ->
+ {_, State2, Timeout} = noreply(State1),
+ {ok, State2, Timeout};
+ Err ->
+ Err
end;
{error, Reason} ->
{stop, Reason}
end.
+handle_cast({send, Pkt}, State) ->
+ noreply(send_element(State, Pkt));
+handle_cast(stop, State) ->
+ {stop, normal, State};
handle_cast(Cast, #{mod := Mod} = State) ->
- try Mod:handle_cast(Cast, State)
- catch _:undef -> {noreply, State}
- end.
+ noreply(try Mod:handle_cast(Cast, State)
+ catch _:undef -> State
+ end).
handle_call(Call, From, #{mod := Mod} = State) ->
- try Mod:handle_call(Call, From, State)
- catch _:undef -> {reply, ok, State}
- end.
+ noreply(try Mod:handle_call(Call, From, State)
+ catch _:undef -> State
+ end).
handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
- #{stream_state := wait_for_stream, xmlns := XMLNS} = State) ->
- try xmpp:decode(#xmlel{name = Name, attrs = Attrs}, XMLNS, []) of
+ #{stream_state := wait_for_stream,
+ xmlns := XMLNS, lang := MyLang} = State) ->
+ El = #xmlel{name = Name, attrs = Attrs},
+ try xmpp:decode(El, XMLNS, []) of
#stream_start{} = Pkt ->
- case send_header(State, Pkt) of
- {noreply, State1} ->
- process_stream(Pkt, State1);
- Err ->
- Err
+ State1 = send_header(State, Pkt),
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> noreply(process_stream(Pkt, State1))
end;
_ ->
- case send_header(State) of
- {noreply, State1} ->
- send_element(State1, xmpp:serr_invalid_xml());
- Err ->
- Err
+ State1 = send_header(State),
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> noreply(send_element(State1, xmpp:serr_invalid_xml()))
end
catch _:{xmpp_codec, Why} ->
- case send_header(State) of
- {noreply, State1} -> process_invalid_xml(Why, State1);
- Err -> Err
+ State1 = send_header(State),
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ Txt = xmpp:io_format_error(Why),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ Err = xmpp:serr_invalid_xml(Txt, Lang),
+ noreply(send_element(State1, Err))
end
end;
-handle_info({'$gen_event', {xmlstreamend, _}}, #{mod := Mod} = State) ->
- try Mod:handle_stream_end(State)
- catch _:undef -> {stop, normal, State}
- end;
handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
- case send_header(State) of
- {noreply, State1} ->
+ State1 = send_header(State),
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
Err = case Reason of
<<"XML stanza is too big">> ->
xmpp:serr_policy_violation(Reason, Lang);
_ ->
xmpp:serr_not_well_formed()
end,
- send_element(State1, Err);
- Err ->
- Err
+ noreply(send_element(State1, Err))
end;
handle_info({'$gen_event', {xmlstreamelement, El}},
- #{xmlns := NS} = State) ->
+ #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
try xmpp:decode(El, NS, [ignore_els]) of
Pkt ->
- process_element(Pkt, State)
+ State1 = try Mod:handle_recv(El, Pkt, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> noreply(process_element(Pkt, State1))
+ end
catch _:{xmpp_codec, Why} ->
- process_invalid_xml(Why, State)
+ State1 = try Mod:handle_recv(El, {error, Why}, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ Txt = xmpp:io_format_error(Why),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
+ end
end;
handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
#{mod := Mod} = State) ->
- try Mod:handle_cdata(Data, State)
- catch _:undef -> {noreply, State}
- end;
-handle_info(closed, #{mod := Mod} = State) ->
- try Mod:handle_stream_close(State)
- catch _:undef -> {stop, normal, State}
- end;
+ noreply(try Mod:handle_cdata(Data, State)
+ catch _:undef -> State
+ end);
+handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
+ noreply(process_stream_end({error, {stream, reset}}, State));
+handle_info({'$gen_event', closed}, State) ->
+ noreply(process_stream_close({error, {socket, closed}}, State));
+handle_info(timeout, #{mod := Mod} = State) ->
+ Disconnected = is_disconnected(State),
+ noreply(try Mod:handle_timeout(State)
+ catch _:undef when not Disconnected ->
+ send_element(State, xmpp:serr_connection_timeout());
+ _:undef ->
+ stop(State)
+ end);
handle_info({'DOWN', MRef, _Type, _Object, _Info},
- #{socket_monitor := MRef, mod := Mod} = State) ->
- try Mod:handle_stream_close(State)
- catch _:undef -> {stop, normal, State}
- end;
+ #{socket_monitor := MRef} = State) ->
+ noreply(process_stream_close({error, {socket, closed}}, State));
handle_info(Info, #{mod := Mod} = State) ->
- try Mod:handle_info(Info, State)
- catch _:undef -> {noreply, State}
- end.
+ noreply(try Mod:handle_info(Info, State)
+ catch _:undef -> State
+ end).
-terminate(Reason, #{mod := Mod, socket := Socket,
- sockmod := SockMod} = State) ->
- try Mod:terminate(Reason, State)
- catch _:undef -> ok
- end,
- send_text(State, <<"</stream:stream>">>),
- SockMod:close(Socket).
+terminate(Reason, #{mod := Mod} = State) ->
+ case get(already_terminated) of
+ true ->
+ State;
+ _ ->
+ put(already_terminated, true),
+ try Mod:terminate(Reason, State)
+ catch _:undef -> ok
+ end,
+ send_trailer(State)
+ end.
code_change(OldVsn, #{mod := Mod} = State, Extra) ->
Mod:code_change(OldVsn, State, Extra).
%%%===================================================================
%%% Internal functions
%%%===================================================================
+-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
+noreply(#{stream_timeout := infinity} = State) ->
+ {noreply, State, infinity};
+noreply(#{stream_timeout := {MSecs, StartTime}} = State) ->
+ CurrentTime = p1_time_compat:monotonic_time(milli_seconds),
+ Timeout = max(0, MSecs - CurrentTime + StartTime),
+ {noreply, State, Timeout}.
+
-spec new_id() -> binary().
new_id() ->
randoms:get_string().
+-spec is_disconnected(state()) -> boolean().
+is_disconnected(#{stream_state := StreamState}) ->
+ StreamState == disconnected.
+
+-spec peername(term(), term()) -> {ok, {inet:ip_address(), inet:port_number()}}|
+ {error, inet:posix()}.
peername(SockMod, Socket) ->
case SockMod of
gen_tcp -> inet:peername(Socket);
_ -> SockMod:peername(Socket)
end.
-process_invalid_xml(Reason, #{lang := Lang} = State) ->
- Txt = xmpp:io_format_error(Reason),
- send_element(State, xmpp:serr_invalid_xml(Txt, Lang)).
+-spec process_stream_close(stop_reason(), state()) -> state().
+process_stream_close(_, #{stream_state := disconnected} = State) ->
+ State;
+process_stream_close(Reason, #{mod := Mod} = State) ->
+ State1 = send_trailer(State),
+ try Mod:handle_stream_close(Reason, State1)
+ catch _:undef -> stop(State1)
+ end.
+
+-spec process_stream_end(stop_reason(), state()) -> state().
+process_stream_end(_, #{stream_state := disconnected} = State) ->
+ State;
+process_stream_end(Reason, #{mod := Mod} = State) ->
+ State1 = send_trailer(State),
+ try Mod:handle_stream_end(Reason, State1)
+ catch _:undef -> stop(State1)
+ end.
+-spec process_stream(stream_start(), state()) -> state().
process_stream(#stream_start{xmlns = XML_NS,
stream_xmlns = STREAM_NS},
#{xmlns := NS} = State)
send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
process_stream(#stream_start{from = undefined, version = {1,0}},
#{lang := Lang, xmlns := ?NS_SERVER,
- stream_tlsed := true} = State) ->
+ stream_encrypted := true} = State) ->
Txt = <<"Missing 'from' attribute">>,
send_element(State, xmpp:serr_invalid_from(Txt, Lang));
process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
#{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
Txt = <<"Improper 'to' attribute">>,
send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
-process_stream(#stream_start{to = #jid{lserver = RemoteServer}},
+process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
#{xmlns := ?NS_COMPONENT, mod := Mod} = State) ->
- State1 = State#{remote_server => RemoteServer},
- case try Mod:handle_stream_start(State1)
- catch _:undef -> {noreply, State1}
- end of
- {noreply, State2} ->
- {noreply, State2#{stream_state => wait_for_handshake}};
- Err ->
- Err
+ State1 = State#{remote_server => RemoteServer,
+ stream_state => wait_for_handshake},
+ try Mod:handle_stream_start(StreamStart, State1)
+ catch _:undef -> State1
end;
process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
- from = From},
+ from = From} = StreamStart,
#{stream_authenticated := Authenticated,
stream_restarted := StreamWasRestarted,
mod := Mod, xmlns := NS, resource := Resource,
- stream_tlsed := TLSEnabled} = State) ->
- case if not StreamWasRestarted ->
- State1 = State#{server => Server, lserver => LServer},
- try Mod:handle_stream_start(State1)
- catch _:undef -> {noreply, State1}
- end;
- true ->
- {noreply, State}
- end of
- {noreply, State2} ->
- State3 = if NS == ?NS_SERVER andalso TLSEnabled ->
- State2#{remote_server => From#jid.lserver};
- true ->
- State2
- end,
- case send_features(State3) of
- {noreply, State4} ->
+ stream_encrypted := Encrypted} = State) ->
+ State1 = if not StreamWasRestarted ->
+ State#{server => Server, lserver => LServer};
+ true ->
+ State
+ end,
+ State2 = if NS == ?NS_SERVER andalso Encrypted ->
+ State1#{remote_server => From#jid.lserver};
+ true ->
+ State1
+ end,
+ State3 = try Mod:handle_stream_start(StreamStart, State2)
+ catch _:undef -> State2
+ end,
+ case is_disconnected(State3) of
+ true -> State3;
+ false ->
+ State4 = send_features(State3),
+ case is_disconnected(State4) of
+ true -> State4;
+ false ->
TLSRequired = is_starttls_required(State4),
- NewStreamState =
- if not Authenticated and
- (not TLSEnabled and TLSRequired) ->
- wait_for_starttls;
- not Authenticated ->
- wait_for_sasl_request;
- (NS == ?NS_CLIENT) and (Resource == <<"">>) ->
- wait_for_bind;
- true ->
- session_established
- end,
- {noreply, State4#{stream_state => NewStreamState}};
- Err ->
- Err
- end;
- Err ->
- Err
+ if not Authenticated and (TLSRequired and not Encrypted) ->
+ State4#{stream_state => wait_for_starttls};
+ not Authenticated ->
+ State4#{stream_state => wait_for_sasl_request};
+ (NS == ?NS_CLIENT) and (Resource == <<"">>) ->
+ State4#{stream_state => wait_for_bind};
+ true ->
+ process_stream_established(State4)
+ end
+ end
end.
+-spec process_element(xmpp_element(), state()) -> state().
process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
case Pkt of
#starttls{} when StateName == wait_for_starttls;
StateName == wait_for_sasl_request ->
process_starttls(State);
#starttls{} ->
- send_element(State, #starttls_failure{});
+ process_starttls_failure(unexpected_starttls_request, State);
#sasl_auth{} when StateName == wait_for_starttls ->
send_element(State, #sasl_failure{reason = 'encryption-required'});
#sasl_auth{} when StateName == wait_for_sasl_request ->
#sasl_abort{} ->
send_element(State, #sasl_failure{reason = 'aborted'});
#sasl_success{} ->
- {noreply, State};
+ State;
#compress{} when StateName == wait_for_sasl_response ->
send_element(State, #compress_failure{reason = 'setup-failed'});
#compress{} ->
#handshake{} when StateName == wait_for_handshake ->
process_handshake(Pkt, State);
#handshake{} ->
- {noreply, State};
+ State;
+ #stream_error{} ->
+ process_stream_end({error, {stream, Pkt}}, State);
_ when StateName == wait_for_sasl_request;
StateName == wait_for_handshake;
StateName == wait_for_sasl_response ->
send_error(State, Pkt, Err);
_ when StateName == wait_for_bind ->
process_bind(Pkt, State);
- _ when StateName == session_established ->
+ _ when StateName == established ->
process_authenticated_packet(Pkt, State)
end.
+-spec process_unauthenticated_packet(xmpp_element(), state()) -> state().
process_unauthenticated_packet(Pkt, #{mod := Mod} = State) ->
NewPkt = set_lang(Pkt, State),
try Mod:handle_unauthenticated_packet(NewPkt, State)
send_error(State, Pkt, Err)
end.
+-spec process_authenticated_packet(xmpp_element(), state()) -> state().
process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) ->
Pkt1 = set_lang(Pkt, State),
case set_from_to(Pkt1, State) of
send_element(State, Err)
end.
+-spec process_bind(xmpp_element(), state()) -> state().
process_bind(#iq{type = set, sub_els = [_]} = Pkt,
#{xmlns := ?NS_CLIENT, mod := Mod, lang := Lang} = State) ->
case xmpp:get_subtag(Pkt, #bind{}) of
server := S,
resource := NewR} = State1} when NewR /= <<"">> ->
Reply = #bind{jid = jid:make(U, S, NewR)},
- State2 = State1#{stream_state => session_established},
- send_element(State2, xmpp:make_iq_result(Pkt, Reply));
+ State2 = send_element(State1, xmpp:make_iq_result(Pkt, Reply)),
+ process_stream_established(State2);
{error, #stanza_error{}, State1} = Err ->
send_error(State1, Pkt, Err)
end
send_error(State, Pkt, Err)
end.
-process_handshake(#handshake{data = Data}, #{mod := Mod} = State) ->
- case Mod:handshake(Data, State) of
- {ok, State1} ->
- State2 = State1#{stream_state => session_established,
- stream_authenticated => true},
- send_element(State2, #handshake{});
- {error, #stream_error{} = Err, State1} ->
- send_element(State1, Err)
+-spec process_handshake(handshake(), state()) -> state().
+process_handshake(#handshake{data = Digest},
+ #{mod := Mod, stream_id := StreamID,
+ remote_server := RemoteServer} = State) ->
+ GetPW = try Mod:get_password_fun(State)
+ catch _:undef -> fun(_) -> {false, undefined} end
+ end,
+ AuthRes = case GetPW(<<"">>) of
+ {false, _} ->
+ false;
+ {Password, _} ->
+ p1_sha:sha(<<StreamID/binary, Password/binary>>) == Digest
+ end,
+ case AuthRes of
+ true ->
+ State1 = try Mod:handle_auth_success(
+ RemoteServer, <<"handshake">>, undefined, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ State2 = send_element(State1, #handshake{}),
+ process_stream_established(State2)
+ end;
+ false ->
+ State1 = try Mod:handle_auth_failure(
+ RemoteServer, <<"handshake">>, 'not-authorized', State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> send_element(State1, xmpp:serr_not_authorized())
+ end
+ end.
+
+-spec process_stream_established(state()) -> state().
+process_stream_established(#{stream_state := StateName} = State)
+ when StateName == disconnected; StateName == established ->
+ State;
+process_stream_established(#{mod := Mod} = State) ->
+ State1 = State#{stream_authenticated := true,
+ stream_state => established,
+ stream_timeout => infinity},
+ try Mod:handle_stream_established(State1)
+ catch _:undef -> State1
end.
+-spec process_compress(compress(), state()) -> state().
process_compress(#compress{}, #{stream_compressed := true} = State) ->
send_element(State, #compress_failure{reason = 'setup-failed'});
process_compress(#compress{methods = HisMethods},
true ->
BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})),
ZlibSocket = SockMod:compress(Socket, BCompressed),
- State1 = State#{socket => ZlibSocket,
- stream_id => new_id(),
- stream_restarted => true,
- stream_state => wait_for_stream,
- stream_compressed => true},
- {noreply, State1};
+ State#{socket => ZlibSocket,
+ stream_id => new_id(),
+ stream_header_sent => false,
+ stream_restarted => true,
+ stream_state => wait_for_stream,
+ stream_compressed => true};
false ->
send_element(State, #compress_failure{reason = 'unsupported-method'})
end.
+-spec process_starttls(state()) -> state().
process_starttls(#{socket := Socket,
sockmod := SockMod, mod := Mod} = State) ->
TLSOpts = try Mod:tls_options(State)
end,
case SockMod:starttls(Socket, TLSOpts) of
{ok, TLSSocket} ->
- case send_element(State, #starttls_proceed{}) of
- {noreply, State1} ->
- {noreply, State1#{socket => TLSSocket,
- stream_id => new_id(),
- stream_restarted => true,
- stream_state => wait_for_stream,
- stream_tlsed => true}};
- Err ->
- Err
+ State1 = send_element(State, #starttls_proceed{}),
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ State1#{socket => TLSSocket,
+ stream_id => new_id(),
+ stream_header_sent => false,
+ stream_restarted => true,
+ stream_state => wait_for_stream,
+ stream_encrypted => true}
end;
- {error, _Reason} ->
- send_element(State, #starttls_failure{})
+ {error, Reason} ->
+ process_starttls_failure(Reason, State)
end.
-process_sasl_request(#sasl_auth{mechanism = <<"EXTERNAL">>},
- #{stream_tlsed := false} = State) ->
- process_sasl_failure('encryption-required', <<"">>, State);
-process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
- #{mod := Mod} = State) ->
- try Mod:init_sasl(State) of
- SASLState ->
- SASLResult = cyrsasl:server_start(SASLState, Mech, ClientIn),
- process_sasl_result(SASLResult, State)
- catch _:undef ->
- process_sasl_failure('temporary-auth-failure', <<"">>, State)
+-spec process_starttls_failure(term(), state()) -> state().
+process_starttls_failure(Why, State) ->
+ State1 = send_element(State, #starttls_failure{}),
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> process_stream_end({error, {tls, Why}}, State1)
end.
+-spec process_sasl_request(sasl_auth(), state()) -> state().
+process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
+ #{mod := Mod, lserver := LServer} = State) ->
+ GetPW = try Mod:get_password_fun(State)
+ catch _:undef -> fun(_) -> false end
+ end,
+ CheckPW = try Mod:check_password_fun(State)
+ catch _:undef -> fun(_, _, _) -> false end
+ end,
+ CheckPWDigest = try Mod:check_password_digest_fun(State)
+ catch _:undef -> fun(_, _, _, _, _) -> false end
+ end,
+ SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
+ GetPW, CheckPW, CheckPWDigest),
+ State1 = State#{sasl_state => SASLState, sasl_mech => Mech},
+ Mechs = get_sasl_mechanisms(State1),
+ SASLResult = case lists:member(Mech, Mechs) of
+ true when Mech == <<"EXTERNAL">> ->
+ case xmpp_stream_pkix:authenticate(State1, ClientIn) of
+ {ok, Peer} ->
+ {ok, [{auth_module, pkix},
+ {username, Peer}]};
+ {error, _Reason, Peer} ->
+ %% TODO: return meaningful error
+ {error, 'not-authorized', Peer}
+ end;
+ true ->
+ cyrsasl:server_start(SASLState, Mech, ClientIn);
+ false ->
+ {error, 'invalid-mechanism'}
+ end,
+ process_sasl_result(SASLResult, State1).
+
+-spec process_sasl_response(sasl_response(), state()) -> state().
process_sasl_response(#sasl_response{text = ClientIn},
#{sasl_state := SASLState} = State) ->
SASLResult = cyrsasl:server_step(SASLState, ClientIn),
process_sasl_result(SASLResult, State).
+-spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state().
process_sasl_result({ok, Props}, State) ->
process_sasl_success(Props, <<"">>, State);
process_sasl_result({ok, Props, ServerOut}, State) ->
process_sasl_result({error, Reason}, State) ->
process_sasl_failure(Reason, <<"">>, State).
+-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
process_sasl_success(Props, ServerOut,
#{socket := Socket, sockmod := SockMod,
- mod := Mod, sasl_state := SASLState} = State) ->
- Mech = cyrsasl:get_mech(SASLState),
+ mod := Mod, sasl_mech := Mech} = State) ->
User = identity(Props),
AuthModule = proplists:get_value(auth_module, Props),
- case try Mod:handle_auth_success(User, Mech, AuthModule, State)
- catch _:undef -> {noreply, State}
- end of
- {noreply, State1} ->
+ State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
SockMod:reset_stream(Socket),
- case send_element(State1, #sasl_success{text = ServerOut}) of
- {noreply, State2} ->
- State3 = maps:remove(sasl_state, State2),
- {noreply, State3#{stream_id => new_id(),
- stream_authenticated => true,
- stream_restarted => true,
- stream_state => wait_for_stream,
- user => User}};
- Err ->
- Err
- end;
- Err ->
- Err
+ State2 = send_element(State1, #sasl_success{text = ServerOut}),
+ case is_disconnected(State2) of
+ true -> State2;
+ false ->
+ State3 = maps:remove(sasl_state,
+ maps:remove(sasl_mech, State2)),
+ State3#{stream_id => new_id(),
+ stream_authenticated => true,
+ stream_header_sent => false,
+ stream_restarted => true,
+ stream_state => wait_for_stream,
+ user => User}
+ end
end.
+-spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state().
process_sasl_continue(ServerOut, NewSASLState, State) ->
- send_element(State, #sasl_challenge{text = ServerOut}),
- {noreply, State#{sasl_state => NewSASLState,
- stream_state => wait_for_sasl_response}}.
+ State1 = State#{sasl_state => NewSASLState,
+ stream_state => wait_for_sasl_response},
+ send_element(State1, #sasl_challenge{text = ServerOut}).
+-spec process_sasl_failure(atom(), binary(), state()) -> state().
process_sasl_failure(Reason, User,
- #{mod := Mod, sasl_state := SASLState} = State) ->
- Mech = cyrsasl:get_mech(SASLState),
- case try Mod:handle_auth_failure(User, Mech, Reason, State)
- catch _:undef -> {noreply, State}
- end of
- {noreply, State1} ->
- State2 = maps:remove(sasl_state, State1),
- State3 = State2#{stream_state => wait_for_sasl_request},
- send_element(State3, #sasl_failure{reason = Reason});
- Err ->
- Err
- end.
+ #{mod := Mod, sasl_mech := Mech} = State) ->
+ State1 = try Mod:handle_auth_failure(User, Mech, Reason, State)
+ catch _:undef -> State
+ end,
+ State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)),
+ State3 = State2#{stream_state => wait_for_sasl_request},
+ send_element(State3, #sasl_failure{reason = Reason}).
+-spec process_sasl_abort(state()) -> state().
process_sasl_abort(State) ->
process_sasl_failure('aborted', <<"">>, State).
+-spec send_features(state()) -> state().
send_features(#{stream_version := {1,0},
- stream_tlsed := TLSEnabled} = State) ->
+ stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
- Features = if TLSRequired and not TLSEnabled ->
+ Features = if TLSRequired and not Encrypted ->
get_tls_feature(State);
true ->
get_sasl_feature(State) ++ get_compress_feature(State)
end,
send_element(State, #stream_features{sub_els = Features});
send_features(State) ->
- %% clients from stone age
- {noreply, State}.
+ %% clients and servers from stone age
+ State.
+-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()].
+get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod,
+ xmlns := NS, lserver := LServer} = State) ->
+ Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer);
+ true -> []
+ end,
+ TLSVerify = try Mod:tls_verify(State)
+ catch _:undef -> false
+ end,
+ if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
+ [<<"EXTERNAL">>|Mechs];
+ true ->
+ Mechs
+ end.
+
+-spec get_sasl_feature(state()) -> [sasl_mechanisms()].
get_sasl_feature(#{stream_authenticated := false,
- mod := Mod,
- stream_tlsed := TLSEnabled} = State) ->
+ stream_encrypted := Encrypted} = State) ->
TLSRequired = is_starttls_required(State),
- if TLSEnabled or not TLSRequired ->
- try Mod:sasl_mechanisms(State) of
- [] -> [];
- List -> [#sasl_mechanisms{list = List}]
- catch _:undef ->
- []
- end;
+ if Encrypted or not TLSRequired ->
+ Mechs = get_sasl_mechanisms(State),
+ [#sasl_mechanisms{list = Mechs}];
true ->
[]
end;
get_sasl_feature(_) ->
[].
+-spec get_compress_feature(state()) -> [compression()].
get_compress_feature(#{stream_compressed := false, mod := Mod} = State) ->
try Mod:compress_methods(State) of
[] -> [];
get_compress_feature(_) ->
[].
+-spec get_tls_feature(state()) -> [starttls()].
get_tls_feature(#{stream_authenticated := false,
- stream_tlsed := false} = State) ->
+ stream_encrypted := false} = State) ->
TLSRequired = is_starttls_required(State),
[#starttls{required = TLSRequired}];
get_tls_feature(_) ->
[].
-get_bind_feature(#{stream_authenticated := true, resource := <<"">>}) ->
+-spec get_bind_feature(state()) -> [bind()].
+get_bind_feature(#{xmlns := ?NS_CLIENT,
+ stream_authenticated := true,
+ resource := <<"">>}) ->
[#bind{}];
get_bind_feature(_) ->
[].
-get_session_feature(#{stream_authenticated := true, resource := <<"">>}) ->
+-spec get_session_feature(state()) -> [xmpp_session()].
+get_session_feature(#{xmlns := ?NS_CLIENT,
+ stream_authenticated := true,
+ resource := <<"">>}) ->
[#xmpp_session{optional = true}];
get_session_feature(_) ->
[].
+-spec get_other_features(state()) -> [xmpp_element()].
get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) ->
try
if Auth -> Mod:authenticated_stream_features(State);
[]
end.
+-spec is_starttls_required(state()) -> boolean().
is_starttls_required(#{mod := Mod} = State) ->
try Mod:tls_required(State)
catch _:undef -> false
end.
+-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} |
+ {error, stream_error()}.
set_from_to(Pkt, _State) when not ?is_stanza(Pkt) ->
{ok, Pkt};
set_from_to(Pkt, #{user := U, server := S, resource := R,
- xmlns := ?NS_CLIENT}) ->
+ lang := Lang, xmlns := ?NS_CLIENT}) ->
JID = jid:make(U, S, R),
From = case xmpp:get_from(Pkt) of
undefined -> JID;
end,
{ok, xmpp:set_from_to(Pkt, JID, To)};
true ->
- {error, xmpp:serr_invalid_from()}
+ Txt = <<"Improper 'from' attribute">>,
+ {error, xmpp:serr_invalid_from(Txt, Lang)}
end;
set_from_to(Pkt, #{lang := Lang}) ->
From = xmpp:get_from(Pkt),
{ok, Pkt}
end.
+-spec send_header(state()) -> state().
send_header(State) ->
send_header(State, #stream_start{}).
-send_header(#{stream_state := wait_for_stream,
- stream_id := StreamID,
+-spec send_header(state(), stream_start()) -> state().
+send_header(#{stream_id := StreamID,
stream_version := MyVersion,
+ stream_header_sent := false,
lang := MyLang,
xmlns := NS,
server := DefaultServer} = State,
#stream_start{to = To, lang = HisLang, version = HisVersion}) ->
- Lang = choose_lang(MyLang, HisLang),
+ Lang = select_lang(MyLang, HisLang),
+ NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
+ true -> <<"">>
+ end,
From = case To of
#jid{} -> To;
undefined -> jid:make(DefaultServer)
lang = Lang,
xmlns = NS,
stream_xmlns = ?NS_STREAM,
+ db_xmlns = NS_DB,
id = StreamID,
from = From}),
- State1 = State#{lang => Lang},
+ State1 = State#{lang => Lang, stream_header_sent => true},
case send_text(State1, fxml:element_to_header(Header)) of
- ok -> {noreply, State1};
- {error, _} -> {stop, normal, State1}
+ ok -> State1;
+ {error, Why} -> process_stream_close({error, {socket, Why}}, State1)
end;
send_header(State, _) ->
- {noreply, State}.
+ State.
+-spec send_element(state(), xmpp_element()) -> state().
send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
El = xmpp:encode(Pkt, NS),
Data = fxml:element_to_binary(El),
- case send_text(State, Data) of
- ok when is_record(Pkt, stream_error) ->
- {stop, normal, State};
- ok when is_record(Pkt, starttls_failure) ->
- {stop, normal, State};
- Res ->
- try Mod:handle_send(Res, Pkt, El, Data, State)
- catch _:undef when Res == ok ->
- {noreply, State};
- _:undef ->
- {stop, normal, State}
- end
+ Result = send_text(State, Data),
+ State1 = try Mod:handle_send(Pkt, Result, State)
+ catch _:undef -> State
+ end,
+ case Result of
+ _ when is_record(Pkt, stream_error) ->
+ process_stream_end({error, {stream, Pkt}}, State1);
+ ok ->
+ State1;
+ {error, Why} ->
+ process_stream_close({error, {socket, Why}}, State1)
end.
-send_error(State, Pkt, Err) when ?is_stanza(Pkt) ->
- case xmpp:get_type(Pkt) of
- result -> {noreply, State};
- error -> {noreply, State};
- _ ->
- ErrPkt = xmpp:make_error(Pkt, Err),
- send_element(State, ErrPkt)
- end;
-send_error(State, _, _) ->
- {noreply, State}.
+-spec send_error(state(), xmpp_element(), stanza_error()) -> state().
+send_error(State, Pkt, Err) ->
+ case xmpp:is_stanza(Pkt) of
+ true ->
+ case xmpp:get_type(Pkt) of
+ result -> State;
+ error -> State;
+ <<"result">> -> State;
+ <<"error">> -> State;
+ _ ->
+ ErrPkt = xmpp:make_error(Pkt, Err),
+ send_element(State, ErrPkt)
+ end;
+ false ->
+ State
+ end.
+
+-spec send_trailer(state()) -> state().
+send_trailer(State) ->
+ send_text(State, <<"</stream:stream>">>),
+ close_socket(State).
-send_text(#{socket := Sock, sockmod := SockMod}, Data) ->
- SockMod:send(Sock, Data).
+-spec send_text(state(), binary()) -> ok | {error, inet:posix()}.
+send_text(#{socket := Sock, sockmod := SockMod,
+ stream_state := StateName,
+ stream_header_sent := true}, Data) when StateName /= disconnected ->
+ SockMod:send(Sock, Data);
+send_text(_, _) ->
+ {error, einval}.
-choose_lang(Lang, <<"">>) -> Lang;
-choose_lang(_, Lang) -> Lang.
+-spec close_socket(state()) -> state().
+close_socket(#{sockmod := SockMod, socket := Socket} = State) ->
+ SockMod:close(Socket),
+ State#{stream_timeout => infinity,
+ stream_state => disconnected}.
+-spec select_lang(binary(), binary()) -> binary().
+select_lang(Lang, <<"">>) -> Lang;
+select_lang(_, Lang) -> Lang.
+
+-spec set_lang(xmpp_element(), state()) -> xmpp_element().
set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) ->
HisLang = xmpp:get_lang(Pkt),
- Lang = choose_lang(MyLang, HisLang),
+ Lang = select_lang(MyLang, HisLang),
xmpp:set_lang(Pkt, Lang);
set_lang(Pkt, _) ->
Pkt.
+-spec format_inet_error(atom()) -> string().
+format_inet_error(Reason) ->
+ case inet:format_error(Reason) of
+ "unknown POSIX error" -> atom_to_list(Reason);
+ Txt -> Txt
+ end.
+
+-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
+format_stream_error(Reason, Txt) ->
+ Slogan = case Reason of
+ #'see-other-host'{} -> "see-other-host";
+ _ -> atom_to_list(Reason)
+ end,
+ case Txt of
+ undefined -> Slogan;
+ #text{data = <<"">>} -> Slogan;
+ #text{data = Data} ->
+ binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
+ end.
+
+-spec format(io:format(), list()) -> binary().
+format(Fmt, Args) ->
+ iolist_to_binary(io_lib:format(Fmt, Args)).
+
+-spec lists_intersection(list(), list()) -> list().
lists_intersection(L1, L2) ->
lists:filter(
fun(E) ->
lists:member(E, L2)
end, L1).
+-spec identity([cyrsasl:sasl_property()]) -> binary().
identity(Props) ->
case proplists:get_value(authzid, Props, <<>>) of
<<>> -> proplists:get_value(username, Props, <<>>);
--- /dev/null
+%%%-------------------------------------------------------------------
+%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%% @copyright (C) 2016, Evgeny Khramtsov
+%%% @doc
+%%%
+%%% @end
+%%% Created : 14 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%-------------------------------------------------------------------
+-module(xmpp_stream_out).
+-behaviour(gen_server).
+
+-protocol({rfc, 6120}).
+
+%% API
+-export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1,
+ stop/1, send/2, close/1, close/2, establish/1, format_error/1,
+ set_timeout/2, get_transport/1, change_shaper/2]).
+%% gen_server callbacks
+-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
+ terminate/2, code_change/3]).
+
+%%-define(DBGFSM, true).
+-ifdef(DBGFSM).
+-define(FSMOPTS, [{debug, [trace]}]).
+-else.
+-define(FSMOPTS, []).
+-endif.
+
+-define(TCP_SEND_TIMEOUT, 15000).
+
+-include("xmpp.hrl").
+-include("logger.hrl").
+-include_lib("kernel/include/inet.hrl").
+
+-type state() :: map().
+-type host_port() :: {inet:hostname(), inet:port_number()}.
+-type ip_port() :: {inet:ip_address(), inet:port_number()}.
+-type network_error() :: {error, inet:posix() | inet_res:res_error()}.
+-type stop_reason() :: {idna, bad_string} |
+ {dns, inet:posix() | inet_res:res_error()} |
+ {stream, reset | stream_error()} |
+ {tls, term()} |
+ {pkix, binary()} |
+ {auth, atom() | binary() | string()} |
+ {socket, inet:posix() | closed | timeout}.
+
+-callback init(list()) -> {ok, state()} | {stop, term()} | ignore.
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+start(Mod, Args, Opts) ->
+ gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+
+start_link(Mod, Args, Opts) ->
+ gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
+
+call(Ref, Msg, Timeout) ->
+ gen_server:call(Ref, Msg, Timeout).
+
+cast(Ref, Msg) ->
+ gen_server:cast(Ref, Msg).
+
+reply(Ref, Reply) ->
+ gen_server:reply(Ref, Reply).
+
+-spec connect(pid()) -> ok.
+connect(Ref) ->
+ cast(Ref, connect).
+
+-spec stop(pid()) -> ok;
+ (state()) -> no_return().
+stop(Pid) when is_pid(Pid) ->
+ cast(Pid, stop);
+stop(#{owner := Owner} = State) when Owner == self() ->
+ terminate(normal, State),
+ exit(normal);
+stop(_) ->
+ erlang:error(badarg).
+
+-spec send(pid(), xmpp_element()) -> ok;
+ (state(), xmpp_element()) -> state().
+send(Pid, Pkt) when is_pid(Pid) ->
+ cast(Pid, {send, Pkt});
+send(#{owner := Owner} = State, Pkt) when Owner == self() ->
+ send_element(State, Pkt);
+send(_, _) ->
+ erlang:error(badarg).
+
+-spec close(pid()) -> ok;
+ (state()) -> state().
+close(Ref) ->
+ close(Ref, true).
+
+-spec close(pid(), boolean()) -> ok;
+ (state(), boolean()) -> state().
+close(Pid, SendTrailer) when is_pid(Pid) ->
+ cast(Pid, {close, SendTrailer});
+close(#{owner := Owner} = State, SendTrailer) when Owner == self() ->
+ if SendTrailer -> send_trailer(State);
+ true -> close_socket(State)
+ end;
+close(_, _) ->
+ erlang:error(badarg).
+
+-spec establish(state()) -> state().
+establish(State) ->
+ process_stream_established(State).
+
+-spec set_timeout(state(), non_neg_integer() | infinity) -> state().
+set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
+ case Timeout of
+ infinity -> State#{stream_timeout => infinity};
+ _ ->
+ Time = p1_time_compat:monotonic_time(milli_seconds),
+ State#{stream_timeout => {Timeout, Time}}
+ end;
+set_timeout(_, _) ->
+ erlang:error(badarg).
+
+get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner})
+ when Owner == self() ->
+ SockMod:get_transport(Socket);
+get_transport(_) ->
+ erlang:error(badarg).
+
+-spec change_shaper(state(), shaper:shaper()) -> ok.
+change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper)
+ when Owner == self() ->
+ SockMod:change_shaper(Socket, Shaper);
+change_shaper(_, _) ->
+ erlang:error(badarg).
+
+-spec format_error(stop_reason()) -> binary().
+format_error({idna, _}) ->
+ <<"Not an IDN hostname">>;
+format_error({dns, Reason}) ->
+ format("DNS lookup failed: ~s", [format_inet_error(Reason)]);
+format_error({socket, Reason}) ->
+ format("Connection failed: ~s", [format_inet_error(Reason)]);
+format_error({pkix, Reason}) ->
+ format("Peer certificate rejected: ~s", [Reason]);
+format_error({stream, reset}) ->
+ <<"Stream reset by peer">>;
+format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
+ format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
+format_error({tls, Reason}) ->
+ format("TLS failed: ~w", [Reason]);
+format_error({auth, Reason}) ->
+ format("Authentication failed: ~s", [Reason]);
+format_error(Err) ->
+ format("Unrecognized error: ~w", [Err]).
+
+%%%===================================================================
+%%% gen_server callbacks
+%%%===================================================================
+init([Mod, SockMod, From, To, Opts]) ->
+ Time = p1_time_compat:monotonic_time(milli_seconds),
+ State = #{owner => self(),
+ mod => Mod,
+ sockmod => SockMod,
+ server => From,
+ user => <<"">>,
+ resource => <<"">>,
+ lang => <<"">>,
+ remote_server => To,
+ xmlns => ?NS_SERVER,
+ stream_direction => out,
+ stream_timeout => {timer:seconds(30), Time},
+ stream_id => new_id(),
+ stream_encrypted => false,
+ stream_verified => false,
+ stream_authenticated => false,
+ stream_restarted => false,
+ stream_state => connecting},
+ case try Mod:init([State, Opts])
+ catch _:undef -> {ok, State}
+ end of
+ {ok, State1} ->
+ {_, State2, Timeout} = noreply(State1),
+ {ok, State2, Timeout};
+ Err ->
+ Err
+ end.
+
+handle_call(Call, From, #{mod := Mod} = State) ->
+ noreply(try Mod:handle_call(Call, From, State)
+ catch _:undef -> State
+ end).
+
+handle_cast(connect, #{remote_server := RemoteServer,
+ sockmod := SockMod,
+ stream_state := connecting} = State) ->
+ case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of
+ false ->
+ noreply(process_stream_close({error, {idna, bad_string}}, State));
+ ASCIIName ->
+ case resolve(binary_to_list(ASCIIName), State) of
+ {ok, AddrPorts} ->
+ case connect(AddrPorts, State) of
+ {ok, Socket, AddrPort} ->
+ SocketMonitor = SockMod:monitor(Socket),
+ State1 = State#{ip => AddrPort,
+ socket => Socket,
+ socket_monitor => SocketMonitor},
+ State2 = State1#{stream_state => wait_for_stream},
+ noreply(send_header(State2));
+ {error, Why} ->
+ Err = {error, {socket, Why}},
+ noreply(process_stream_close(Err, State))
+ end;
+ {error, Why} ->
+ noreply(process_stream_close({error, {dns, Why}}, State))
+ end
+ end;
+handle_cast(connect, State) ->
+ %% Ignoring connection attempts in other states
+ noreply(State);
+handle_cast({send, Pkt}, State) ->
+ noreply(send_element(State, Pkt));
+handle_cast(stop, State) ->
+ {stop, normal, State};
+handle_cast(Cast, #{mod := Mod} = State) ->
+ noreply(try Mod:handle_cast(Cast, State)
+ catch _:undef -> State
+ end).
+
+handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
+ #{stream_state := wait_for_stream,
+ xmlns := XMLNS, lang := MyLang} = State) ->
+ El = #xmlel{name = Name, attrs = Attrs},
+ try xmpp:decode(El, XMLNS, []) of
+ #stream_start{} = Pkt ->
+ noreply(process_stream(Pkt, State));
+ _ ->
+ noreply(send_element(State, xmpp:serr_invalid_xml()))
+ catch _:{xmpp_codec, Why} ->
+ Txt = xmpp:io_format_error(Why),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ Err = xmpp:serr_invalid_xml(Txt, Lang),
+ noreply(send_element(State, Err))
+ end;
+handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
+ State1 = send_header(State),
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ Err = case Reason of
+ <<"XML stanza is too big">> ->
+ xmpp:serr_policy_violation(Reason, Lang);
+ _ ->
+ xmpp:serr_not_well_formed()
+ end,
+ noreply(send_element(State1, Err))
+ end;
+handle_info({'$gen_event', {xmlstreamelement, El}},
+ #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
+ try xmpp:decode(El, NS, [ignore_els]) of
+ Pkt ->
+ State1 = try Mod:handle_recv(El, Pkt, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false -> noreply(process_element(Pkt, State1))
+ end
+ catch _:{xmpp_codec, Why} ->
+ State1 = try Mod:handle_recv(El, undefined, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ Txt = xmpp:io_format_error(Why),
+ Lang = select_lang(MyLang, xmpp:get_lang(El)),
+ noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang)))
+ end
+ end;
+handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
+ #{mod := Mod} = State) ->
+ noreply(try Mod:handle_cdata(Data, State)
+ catch _:undef -> State
+ end);
+handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
+ noreply(process_stream_end({error, {stream, reset}}, State));
+handle_info({'$gen_event', closed}, State) ->
+ noreply(process_stream_close({error, {socket, closed}}, State));
+handle_info(timeout, #{mod := Mod} = State) ->
+ Disconnected = is_disconnected(State),
+ noreply(try Mod:handle_timeout(State)
+ catch _:undef when not Disconnected ->
+ send_element(State, xmpp:serr_connection_timeout());
+ _:undef ->
+ stop(State)
+ end);
+handle_info({'DOWN', MRef, _Type, _Object, _Info},
+ #{socket_monitor := MRef} = State) ->
+ noreply(process_stream_close({error, {socket, closed}}, State));
+handle_info(Info, #{mod := Mod} = State) ->
+ noreply(try Mod:handle_info(Info, State)
+ catch _:undef -> State
+ end).
+
+terminate(Reason, #{mod := Mod} = State) ->
+ case get(already_terminated) of
+ true ->
+ State;
+ _ ->
+ put(already_terminated, true),
+ try Mod:terminate(Reason, State)
+ catch _:undef -> ok
+ end,
+ send_trailer(State)
+ end.
+
+code_change(OldVsn, #{mod := Mod} = State, Extra) ->
+ Mod:code_change(OldVsn, State, Extra).
+
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
+-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
+noreply(#{stream_timeout := infinity} = State) ->
+ {noreply, State, infinity};
+noreply(#{stream_timeout := {MSecs, OldTime}} = State) ->
+ NewTime = p1_time_compat:monotonic_time(milli_seconds),
+ Timeout = max(0, MSecs - NewTime + OldTime),
+ {noreply, State, Timeout}.
+
+-spec new_id() -> binary().
+new_id() ->
+ randoms:get_string().
+
+-spec is_disconnected(state()) -> boolean().
+is_disconnected(#{stream_state := StreamState}) ->
+ StreamState == disconnected.
+
+-spec process_stream_close(stop_reason(), state()) -> state().
+process_stream_close(_, #{stream_state := disconnected} = State) ->
+ State;
+process_stream_close(Reason, #{mod := Mod} = State) ->
+ State1 = send_trailer(State),
+ try Mod:handle_stream_close(Reason, State1)
+ catch _:undef -> stop(State1)
+ end.
+
+-spec process_stream_end(stop_reason(), state()) -> state().
+process_stream_end(_, #{stream_state := disconnected} = State) ->
+ State;
+process_stream_end(Reason, #{mod := Mod} = State) ->
+ State1 = send_trailer(State),
+ try Mod:handle_stream_end(Reason, State1)
+ catch _:undef -> stop(State1)
+ end.
+
+-spec process_stream(stream_start(), state()) -> state().
+process_stream(#stream_start{xmlns = XML_NS,
+ stream_xmlns = STREAM_NS},
+ #{xmlns := NS} = State)
+ when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
+ send_element(State, xmpp:serr_invalid_namespace());
+process_stream(#stream_start{lang = Lang, id = ID,
+ version = Version} = StreamStart,
+ #{mod := Mod} = State) ->
+ State1 = State#{stream_remote_id => ID, lang => Lang},
+ State2 = try Mod:handle_stream_start(StreamStart, State1)
+ catch _:undef -> State1
+ end,
+ case is_disconnected(State2) of
+ true -> State2;
+ false ->
+ case Version of
+ {1,0} -> State2#{stream_state => wait_for_features};
+ _ -> process_stream_downgrade(StreamStart, State)
+ end
+ end.
+
+-spec process_element(xmpp_element(), state()) -> state().
+process_element(Pkt, #{stream_state := StateName} = State) ->
+ case Pkt of
+ #stream_features{} when StateName == wait_for_features ->
+ process_features(Pkt, State);
+ #starttls_proceed{} when StateName == wait_for_starttls_response ->
+ process_starttls(State);
+ #sasl_success{} when StateName == wait_for_sasl_response ->
+ process_sasl_success(State);
+ #sasl_failure{} when StateName == wait_for_sasl_response ->
+ process_sasl_failure(Pkt, State);
+ #stream_error{} ->
+ process_stream_end({error, {stream, Pkt}}, State);
+ _ when is_record(Pkt, stream_features);
+ is_record(Pkt, starttls_proceed);
+ is_record(Pkt, starttls);
+ is_record(Pkt, sasl_auth);
+ is_record(Pkt, sasl_success);
+ is_record(Pkt, sasl_failure);
+ is_record(Pkt, sasl_response);
+ is_record(Pkt, sasl_abort);
+ is_record(Pkt, compress);
+ is_record(Pkt, handshake) ->
+ %% Do not pass this crap upstream
+ State;
+ _ ->
+ process_packet(Pkt, State)
+ end.
+
+-spec process_features(stream_features(), state()) -> state().
+process_features(StreamFeatures,
+ #{stream_authenticated := true, mod := Mod} = State) ->
+ State1 = try Mod:handle_authenticated_features(StreamFeatures, State)
+ catch _:undef -> State
+ end,
+ process_stream_established(State1);
+process_features(#stream_features{sub_els = Els} = StreamFeatures,
+ #{stream_encrypted := Encrypted,
+ mod := Mod, lang := Lang} = State) ->
+ State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ TLSRequired = is_starttls_required(State1),
+ %% TODO: improve xmpp.erl
+ Msg = #message{sub_els = Els},
+ case xmpp:get_subtag(Msg, #starttls{}) of
+ false when TLSRequired and not Encrypted ->
+ Txt = <<"Use of STARTTLS required">>,
+ send_element(State1, xmpp:err_policy_violation(Txt, Lang));
+ #starttls{} when not Encrypted ->
+ State2 = State1#{stream_state => wait_for_starttls_response},
+ send_element(State2, #starttls{});
+ _ ->
+ State2 = process_cert_verification(State1),
+ case is_disconnected(State2) of
+ true -> State2;
+ false ->
+ case xmpp:get_subtag(Msg, #sasl_mechanisms{}) of
+ #sasl_mechanisms{list = Mechs} ->
+ process_sasl_mechanisms(Mechs, State2);
+ false ->
+ process_sasl_failure(
+ #sasl_failure{reason = 'invalid-mechanism'},
+ State2)
+ end
+ end
+ end
+ end.
+
+-spec process_stream_established(state()) -> state().
+process_stream_established(#{stream_state := StateName} = State)
+ when StateName == disconnected; StateName == established ->
+ State;
+process_stream_established(#{mod := Mod} = State) ->
+ State1 = State#{stream_authenticated := true,
+ stream_state => established,
+ stream_timeout => infinity},
+ try Mod:handle_stream_established(State1)
+ catch _:undef -> State1
+ end.
+
+-spec process_sasl_mechanisms([binary()], state()) -> state().
+process_sasl_mechanisms(Mechs, #{user := User, server := Server} = State) ->
+ %% TODO: support other mechanisms
+ Mech = <<"EXTERNAL">>,
+ case lists:member(<<"EXTERNAL">>, Mechs) of
+ true ->
+ State1 = State#{stream_state => wait_for_sasl_response},
+ Authzid = jid:to_string(jid:make(User, Server)),
+ send_element(State1, #sasl_auth{mechanism = Mech, text = Authzid});
+ false ->
+ process_sasl_failure(
+ #sasl_failure{reason = 'invalid-mechanism'}, State)
+ end.
+
+-spec process_starttls(state()) -> state().
+process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) ->
+ TLSOpts = try Mod:tls_options(State)
+ catch _:undef -> []
+ end,
+ case SockMod:starttls(Socket, [connect|TLSOpts]) of
+ {ok, TLSSocket} ->
+ State1 = State#{socket => TLSSocket,
+ stream_id => new_id(),
+ stream_restarted => true,
+ stream_state => wait_for_stream,
+ stream_encrypted => true},
+ send_header(State1);
+ {error, Why} ->
+ process_stream_close({error, {tls, Why}}, State)
+ end.
+
+-spec process_stream_downgrade(stream_start(), state()) -> state().
+process_stream_downgrade(StreamStart, #{mod := Mod} = State) ->
+ try Mod:downgrade_stream(StreamStart, State)
+ catch _:undef ->
+ send_element(State, xmpp:serr_unsupported_version())
+ end.
+
+-spec process_cert_verification(state()) -> state().
+process_cert_verification(#{stream_encrypted := true,
+ stream_verified := false,
+ mod := Mod} = State) ->
+ case try Mod:tls_verify(State)
+ catch _:undef -> true
+ end of
+ true ->
+ case xmpp_stream_pkix:authenticate(State) of
+ {ok, _} ->
+ State#{stream_verified => true};
+ {error, Why, _Peer} ->
+ process_stream_close({error, {pkix, Why}}, State)
+ end;
+ false ->
+ State#{stream_verified => true}
+ end;
+process_cert_verification(State) ->
+ State.
+
+-spec process_sasl_success(state()) -> state().
+process_sasl_success(#{mod := Mod,
+ sockmod := SockMod,
+ socket := Socket} = State) ->
+ State1 = try Mod:handle_auth_success(<<"EXTERNAL">>, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ SockMod:reset_stream(Socket),
+ State2 = State1#{stream_id => new_id(),
+ stream_restarted => true,
+ stream_state => wait_for_stream,
+ stream_authenticated => true},
+ send_header(State2)
+ end.
+
+-spec process_sasl_failure(sasl_failure(), state()) -> state().
+process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) ->
+ try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State)
+ catch _:undef -> process_stream_close({error, {auth, Reason}}, State)
+ end.
+
+-spec process_packet(xmpp_element(), state()) -> state().
+process_packet(Pkt, #{mod := Mod} = State) ->
+ try Mod:handle_packet(Pkt, State)
+ catch _:undef -> State
+ end.
+
+-spec is_starttls_required(state()) -> boolean().
+is_starttls_required(#{mod := Mod} = State) ->
+ try Mod:tls_required(State)
+ catch _:undef -> false
+ end.
+
+-spec send_header(state()) -> state().
+send_header(#{remote_server := RemoteServer,
+ stream_encrypted := Encrypted,
+ lang := Lang,
+ xmlns := NS,
+ user := User,
+ resource := Resource,
+ server := Server} = State) ->
+ NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
+ true -> <<"">>
+ end,
+ From = if Encrypted ->
+ jid:make(User, Server, Resource);
+ NS == ?NS_SERVER ->
+ jid:make(Server);
+ true ->
+ undefined
+ end,
+ Header = xmpp:encode(
+ #stream_start{xmlns = NS,
+ lang = Lang,
+ stream_xmlns = ?NS_STREAM,
+ db_xmlns = NS_DB,
+ from = From,
+ to = jid:make(RemoteServer),
+ version = {1,0}}),
+ case send_text(State, fxml:element_to_header(Header)) of
+ ok -> State;
+ {error, Why} -> process_stream_close({error, {socket, Why}}, State)
+ end.
+
+-spec send_element(state(), xmpp_element()) -> state().
+send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
+ El = xmpp:encode(Pkt, NS),
+ Data = fxml:element_to_binary(El),
+ State1 = try Mod:handle_send(Pkt, El, Data, State)
+ catch _:undef -> State
+ end,
+ case is_disconnected(State1) of
+ true -> State1;
+ false ->
+ case send_text(State1, Data) of
+ _ when is_record(Pkt, stream_error) ->
+ process_stream_end({error, {stream, Pkt}}, State1);
+ ok ->
+ State1;
+ {error, Why} ->
+ process_stream_close({error, {socket, Why}}, State1)
+ end
+ end.
+
+-spec send_error(state(), xmpp_element(), stanza_error()) -> state().
+send_error(State, Pkt, Err) ->
+ case xmpp:is_stanza(Pkt) of
+ true ->
+ case xmpp:get_type(Pkt) of
+ result -> State;
+ error -> State;
+ <<"result">> -> State;
+ <<"error">> -> State;
+ _ ->
+ ErrPkt = xmpp:make_error(Pkt, Err),
+ send_element(State, ErrPkt)
+ end;
+ false ->
+ State
+ end.
+
+-spec send_text(state(), binary()) -> ok | {error, inet:posix()}.
+send_text(#{sockmod := SockMod, socket := Socket,
+ stream_state := StateName}, Data) when StateName /= disconnected ->
+ SockMod:send(Socket, Data);
+send_text(_, _) ->
+ {error, einval}.
+
+-spec send_trailer(state()) -> state().
+send_trailer(State) ->
+ send_text(State, <<"</stream:stream>">>),
+ close_socket(State).
+
+-spec close_socket(state()) -> state().
+close_socket(State) ->
+ case State of
+ #{sockmod := SockMod, socket := Socket} ->
+ SockMod:close(Socket);
+ _ ->
+ ok
+ end,
+ State#{stream_timeout => infinity,
+ stream_state => disconnected}.
+
+-spec select_lang(binary(), binary()) -> binary().
+select_lang(Lang, <<"">>) -> Lang;
+select_lang(_, Lang) -> Lang.
+
+-spec format_inet_error(atom()) -> string().
+format_inet_error(Reason) ->
+ case inet:format_error(Reason) of
+ "unknown POSIX error" -> atom_to_list(Reason);
+ Txt -> Txt
+ end.
+
+-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
+format_stream_error(Reason, Txt) ->
+ Slogan = case Reason of
+ #'see-other-host'{} -> "see-other-host";
+ _ -> atom_to_list(Reason)
+ end,
+ case Txt of
+ undefined -> Slogan;
+ #text{data = <<"">>} -> Slogan;
+ #text{data = Data} ->
+ binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
+ end.
+
+-spec format(io:format(), list()) -> binary().
+format(Fmt, Args) ->
+ iolist_to_binary(io_lib:format(Fmt, Args)).
+
+%%%===================================================================
+%%% Connection stuff
+%%%===================================================================
+-spec resolve(string(), state()) -> {ok, [host_port()]} | network_error().
+resolve(Host, State) ->
+ case srv_lookup(Host, State) of
+ {error, _Reason} ->
+ DefaultPort = get_default_port(State),
+ a_lookup([{Host, DefaultPort}], State);
+ {ok, HostPorts} ->
+ a_lookup(HostPorts, State)
+ end.
+
+-spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error().
+srv_lookup(Host, State) ->
+ %% Only perform SRV lookups for FQDN names
+ case string:chr(Host, $.) of
+ 0 ->
+ {error, nxdomain};
+ _ ->
+ case inet_parse:address(Host) of
+ {ok, _} ->
+ {error, nxdomain};
+ {error, _} ->
+ Timeout = get_dns_timeout(State),
+ Retries = get_dns_retries(State),
+ srv_lookup(Host, Timeout, Retries)
+ end
+ end.
+
+-spec srv_lookup(string(), non_neg_integer(), integer()) ->
+ {ok, [host_port()]} | network_error().
+srv_lookup(_Host, _Timeout, Retries) when Retries < 1 ->
+ {error, timeout};
+srv_lookup(Host, Timeout, Retries) ->
+ SRVName = "_xmpp-server._tcp." ++ Host,
+ case inet_res:getbyname(SRVName, srv, Timeout) of
+ {ok, HostEntry} ->
+ host_entry_to_host_ports(HostEntry);
+ {error, _} ->
+ LegacySRVName = "_jabber._tcp." ++ Host,
+ case inet_res:getbyname(LegacySRVName, srv, Timeout) of
+ {error, timeout} ->
+ srv_lookup(Host, Timeout, Retries - 1);
+ {error, _} = Err ->
+ Err;
+ {ok, HostEntry} ->
+ host_entry_to_host_ports(HostEntry)
+ end
+ end.
+
+-spec a_lookup([{inet:hostname(), inet:port_number()}], state()) ->
+ {ok, [ip_port()]} | network_error().
+a_lookup(HostPorts, State) ->
+ HostPortFamilies = [{Host, Port, Family}
+ || {Host, Port} <- HostPorts,
+ Family <- get_address_families(State)],
+ a_lookup(HostPortFamilies, State, {error, nxdomain}).
+
+-spec a_lookup([{inet:hostname(), inet:port_number(), inet:address_family()}],
+ state(), network_error()) -> {ok, [ip_port()]} | network_error().
+a_lookup([{Host, Port, Family}|HostPortFamilies], State, _) ->
+ Timeout = get_dns_timeout(State),
+ Retries = get_dns_retries(State),
+ case a_lookup(Host, Port, Family, Timeout, Retries) of
+ {error, _} = Err ->
+ a_lookup(HostPortFamilies, State, Err);
+ {ok, AddrPorts} ->
+ {ok, AddrPorts}
+ end;
+a_lookup([], _State, Err) ->
+ Err.
+
+-spec a_lookup(inet:hostname(), inet:port_number(), inet:address_family(),
+ non_neg_integer(), integer()) -> {ok, [ip_port()]} | network_error().
+a_lookup(_Host, _Port, _Family, _Timeout, Retries) when Retries < 1 ->
+ {error, timeout};
+a_lookup(Host, Port, Family, Timeout, Retries) ->
+ case inet:gethostbyname(Host, Family, Timeout) of
+ {error, timeout} ->
+ a_lookup(Host, Port, Family, Timeout, Retries - 1);
+ {error, _} = Err ->
+ Err;
+ {ok, HostEntry} ->
+ host_entry_to_addr_ports(HostEntry, Port)
+ end.
+
+-spec host_entry_to_host_ports(inet:hostent()) -> {ok, [host_port()]} |
+ {error, nxdomain}.
+host_entry_to_host_ports(#hostent{h_addr_list = AddrList}) ->
+ PrioHostPorts = lists:flatmap(
+ fun({Priority, Weight, Port, Host}) ->
+ N = case Weight of
+ 0 -> 0;
+ _ -> (Weight + 1) * randoms:uniform()
+ end,
+ [{Priority * 65536 - N, Host, Port}];
+ (_) ->
+ []
+ end, AddrList),
+ HostPorts = [{Host, Port}
+ || {_Priority, Host, Port} <- lists:usort(PrioHostPorts)],
+ case HostPorts of
+ [] -> {error, nxdomain};
+ _ -> {ok, HostPorts}
+ end.
+
+-spec host_entry_to_addr_ports(inet:hostent(), inet:port_number()) ->
+ {ok, [ip_port()]} | {error, nxdomain}.
+host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port) ->
+ AddrPorts = lists:flatmap(
+ fun(Addr) ->
+ try get_addr_type(Addr) of
+ _ -> [{Addr, Port}]
+ catch _:_ ->
+ []
+ end
+ end, AddrList),
+ case AddrPorts of
+ [] -> {error, nxdomain};
+ _ -> {ok, AddrPorts}
+ end.
+
+-spec connect([ip_port()], state()) -> {ok, term(), ip_port()} | network_error().
+connect(AddrPorts, #{sockmod := SockMod} = State) ->
+ Timeout = get_connect_timeout(State),
+ connect(AddrPorts, SockMod, Timeout, {error, nxdomain}).
+
+-spec connect([ip_port()], module(), non_neg_integer(), network_error()) ->
+ {ok, term(), ip_port()} | network_error().
+connect([{Addr, Port}|AddrPorts], SockMod, Timeout, _) ->
+ Type = get_addr_type(Addr),
+ case SockMod:connect(Addr, Port,
+ [binary, {packet, 0},
+ {send_timeout, ?TCP_SEND_TIMEOUT},
+ {send_timeout_close, true},
+ {active, false}, Type],
+ Timeout) of
+ {ok, Socket} ->
+ {ok, Socket, {Addr, Port}};
+ Err ->
+ connect(AddrPorts, SockMod, Timeout, Err)
+ end;
+connect([], _SockMod, _Timeout, Err) ->
+ Err.
+
+-spec get_addr_type(inet:ip_address()) -> inet:address_family().
+get_addr_type({_, _, _, _}) -> inet;
+get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
+
+-spec get_dns_timeout(state()) -> non_neg_integer().
+get_dns_timeout(#{mod := Mod} = State) ->
+ timer:seconds(
+ try Mod:dns_timeout(State)
+ catch _:undef -> 10
+ end).
+
+-spec get_dns_retries(state()) -> non_neg_integer().
+get_dns_retries(#{mod := Mod} = State) ->
+ try Mod:dns_retries(State)
+ catch _:undef -> 2
+ end.
+
+-spec get_default_port(state()) -> inet:port_number().
+get_default_port(#{mod := Mod, xmlns := NS} = State) ->
+ try Mod:default_port(State)
+ catch _:undef when NS == ?NS_SERVER -> 5269;
+ _:undef when NS == ?NS_CLIENT -> 5222
+ end.
+
+-spec get_address_families(state()) -> [inet:address_family()].
+get_address_families(#{mod := Mod} = State) ->
+ try Mod:address_families(State)
+ catch _:undef -> [inet, inet6]
+ end.
+
+-spec get_connect_timeout(state()) -> non_neg_integer().
+get_connect_timeout(#{mod := Mod} = State) ->
+ timer:seconds(
+ try Mod:connect_timeout(State)
+ catch _:undef -> 10
+ end).
--- /dev/null
+%%%-------------------------------------------------------------------
+%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%% @copyright (C) 2016, Evgeny Khramtsov
+%%% @doc
+%%%
+%%% @end
+%%% Created : 13 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%-------------------------------------------------------------------
+-module(xmpp_stream_pkix).
+
+%% API
+-export([authenticate/1, authenticate/2]).
+
+-include("xmpp.hrl").
+-include_lib("public_key/include/public_key.hrl").
+-include("XmppAddr.hrl").
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state())
+ -> {ok, binary()} | {error, binary(), binary()}.
+authenticate(State) ->
+ authenticate(State, <<"">>).
+
+-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state(), binary())
+ -> {ok, binary()} | {error, binary(), binary()}.
+authenticate(#{xmlns := ?NS_SERVER, remote_server := Peer,
+ sockmod := SockMod, socket := Socket}, _Authzid) ->
+ case SockMod:get_peer_certificate(Socket) of
+ {ok, Cert} ->
+ case SockMod:get_verify_result(Socket) of
+ 0 ->
+ case ejabberd_idna:domain_utf8_to_ascii(Peer) of
+ false ->
+ {error, <<"Cannot decode remote server name">>, Peer};
+ AsciiPeer ->
+ case lists:any(
+ fun(D) -> match_domain(AsciiPeer, D) end,
+ get_cert_domains(Cert)) of
+ true ->
+ {ok, Peer};
+ false ->
+ {error, <<"Certificate host name mismatch">>, Peer}
+ end
+ end;
+ VerifyRes ->
+ {error, fast_tls:get_cert_verify_string(VerifyRes, Cert), Peer}
+ end;
+ {error, _Reason} ->
+ {error, <<"Cannot get peer certificate">>, Peer};
+ error ->
+ {error, <<"Cannot get peer certificate">>, Peer}
+ end;
+authenticate(_State, _Authzid) ->
+ %% TODO: client PKIX authentication
+ {error, <<"Client certificate verification not implemented">>, <<"">>}.
+
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
+get_cert_domains(Cert) ->
+ TBSCert = Cert#'Certificate'.tbsCertificate,
+ Subject = case TBSCert#'TBSCertificate'.subject of
+ {rdnSequence, Subj} -> lists:flatten(Subj);
+ _ -> []
+ end,
+ Extensions = case TBSCert#'TBSCertificate'.extensions of
+ Exts when is_list(Exts) -> Exts;
+ _ -> []
+ end,
+ lists:flatmap(
+ fun(#'AttributeTypeAndValue'{type = ?'id-at-commonName',value = Val}) ->
+ case 'OTP-PUB-KEY':decode('X520CommonName', Val) of
+ {ok, {_, D1}} ->
+ D = if is_binary(D1) -> D1;
+ is_list(D1) -> list_to_binary(D1);
+ true -> error
+ end,
+ if D /= error ->
+ case jid:from_string(D) of
+ #jid{luser = <<"">>, lserver = LD,
+ lresource = <<"">>} ->
+ [LD];
+ _ -> []
+ end;
+ true -> []
+ end;
+ _ -> []
+ end;
+ (_) -> []
+ end, Subject) ++
+ lists:flatmap(
+ fun(#'Extension'{extnID = ?'id-ce-subjectAltName',
+ extnValue = Val}) ->
+ BVal = if is_list(Val) -> list_to_binary(Val);
+ true -> Val
+ end,
+ case 'OTP-PUB-KEY':decode('SubjectAltName', BVal) of
+ {ok, SANs} ->
+ lists:flatmap(
+ fun({otherName, #'AnotherName'{'type-id' = ?'id-on-xmppAddr',
+ value = XmppAddr}}) ->
+ case 'XmppAddr':decode('XmppAddr', XmppAddr) of
+ {ok, D} when is_binary(D) ->
+ case jid:from_string(D) of
+ #jid{luser = <<"">>,
+ lserver = LD,
+ lresource = <<"">>} ->
+ case ejabberd_idna:domain_utf8_to_ascii(LD) of
+ false ->
+ [];
+ PCLD ->
+ [PCLD]
+ end;
+ _ -> []
+ end;
+ _ -> []
+ end;
+ ({dNSName, D}) when is_list(D) ->
+ case jid:from_string(list_to_binary(D)) of
+ #jid{luser = <<"">>,
+ lserver = LD,
+ lresource = <<"">>} ->
+ [LD];
+ _ -> []
+ end;
+ (_) -> []
+ end, SANs);
+ _ -> []
+ end;
+ (_) -> []
+ end, Extensions).
+
+match_domain(Domain, Domain) -> true;
+match_domain(Domain, Pattern) ->
+ DLabels = str:tokens(Domain, <<".">>),
+ PLabels = str:tokens(Pattern, <<".">>),
+ match_labels(DLabels, PLabels).
+
+match_labels([], []) -> true;
+match_labels([], [_ | _]) -> false;
+match_labels([_ | _], []) -> false;
+match_labels([DL | DLabels], [PL | PLabels]) ->
+ case lists:all(fun (C) ->
+ $a =< C andalso C =< $z orelse
+ $0 =< C andalso C =< $9 orelse
+ C == $- orelse C == $*
+ end,
+ binary_to_list(PL))
+ of
+ true ->
+ Regexp = ejabberd_regexp:sh_to_awk(PL),
+ case ejabberd_regexp:run(DL, Regexp) of
+ match -> match_labels(DLabels, PLabels);
+ nomatch -> false
+ end;
+ false -> false
+ end.