From: Evgeniy Khramtsov Date: Wed, 28 Dec 2016 06:47:11 +0000 (+0300) Subject: Add xmpp_stream_out behaviour and rewrite s2s/SM code X-Git-Tag: 17.03-beta~91^2~33 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=309bdfbe285c82726d2ce1406fc26c19a6b37bd9;p=ejabberd Add xmpp_stream_out behaviour and rewrite s2s/SM code --- diff --git a/include/ejabberd.hrl b/include/ejabberd.hrl index 391089a0e..ddf41f094 100644 --- a/include/ejabberd.hrl +++ b/include/ejabberd.hrl @@ -41,8 +41,6 @@ -define(COPYRIGHT, "Copyright (c) 2002-2016 ProcessOne"). --define(S2STIMEOUT, timer:minutes(10)). - %%-define(DBGFSM, true). -record(scram, diff --git a/src/cyrsasl.erl b/src/cyrsasl.erl index c49f8a3cb..874a417a1 100644 --- a/src/cyrsasl.erl +++ b/src/cyrsasl.erl @@ -36,13 +36,6 @@ -include("ejabberd.hrl"). -include("logger.hrl"). -%% --export_type([ - mechanism/0, - mechanisms/0, - sasl_mechanism/0 -]). - -record(sasl_mechanism, {mechanism = <<"">> :: mechanism() | '$1', module :: atom(), @@ -51,10 +44,15 @@ -type(mechanism() :: binary()). -type(mechanisms() :: [mechanism(),...]). -type(password_type() :: plain | digest | scram). --type(props() :: [{username, binary()} | - {authzid, binary()} | - {mechanism, binary()} | - {auth_module, atom()}]). +-type sasl_property() :: {username, binary()} | + {authzid, binary()} | + {mechanism, binary()} | + {auth_module, atom()}. +-type sasl_return() :: {ok, [sasl_property()]} | + {ok, [sasl_property()], binary()} | + {continue, binary(), any()} | + {error, atom()} | + {error, atom(), binary()}. -type(sasl_mechanism() :: #sasl_mechanism{}). @@ -71,14 +69,11 @@ mech_state }). -type sasl_state() :: #sasl_state{}. --export_type([sasl_state/0]). +-export_type([mechanism/0, mechanisms/0, sasl_mechanism/0, + sasl_state/0, sasl_return/0, sasl_property/0]). -callback mech_new(binary(), fun(), fun(), fun()) -> any(). --callback mech_step(any(), binary()) -> {ok, props()} | - {ok, props(), binary()} | - {continue, binary(), any()} | - {error, atom()} | - {error, atom(), binary()}. +-callback mech_step(any(), binary()) -> sasl_return(). start() -> ets:new(sasl_mechanism, diff --git a/src/ejabberd_app.erl b/src/ejabberd_app.erl index e4333c816..eb25fe656 100644 --- a/src/ejabberd_app.erl +++ b/src/ejabberd_app.erl @@ -169,7 +169,7 @@ broadcast_c2s_shutdown() -> Children = ejabberd_sm:get_all_pids(), lists:foreach( fun(C2SPid) when node(C2SPid) == node() -> - C2SPid ! system_shutdown; + ejabberd_c2s:send(C2SPid, xmpp:serr_system_shutdown()); (_) -> ok end, Children). diff --git a/src/ejabberd_auth.erl b/src/ejabberd_auth.erl index 74c8009c2..eba0a4038 100644 --- a/src/ejabberd_auth.erl +++ b/src/ejabberd_auth.erl @@ -42,7 +42,7 @@ get_password_s/2, get_password_with_authmodule/2, is_user_exists/2, is_user_exists_in_other_modules/3, remove_user/2, remove_user/3, plain_password_required/1, - store_type/1, entropy/1]). + store_type/1, entropy/1, backend_type/1]). -export([auth_modules/1, opt_type/1]). @@ -412,6 +412,13 @@ entropy(B) -> length(S) * math:log(lists:sum(Set)) / math:log(2) end. +-spec backend_type(atom()) -> atom(). +backend_type(Mod) -> + case atom_to_list(Mod) of + "ejabberd_auth_" ++ T -> list_to_atom(T); + _ -> Mod + end. + %%%---------------------------------------------------------------------- %%% Internal functions %%%---------------------------------------------------------------------- diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index b5113c34b..07d04fbc4 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -22,26 +22,32 @@ -module(ejabberd_c2s). -behaviour(xmpp_stream_in). -behaviour(ejabberd_config). +-behaviour(ejabberd_socket). -protocol({rfc, 6121}). %% ejabberd_socket callbacks --export([start/2, socket_type/0]). +-export([start/2, start_link/2, socket_type/0]). %% ejabberd_config callbacks -export([opt_type/1, transform_listen_option/2]). %% xmpp_stream_in callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([tls_options/1, tls_required/1, compress_methods/1, - sasl_mechanisms/1, init_sasl/1, bind/2, handshake/2, +-export([tls_options/1, tls_required/1, tls_verify/1, + compress_methods/1, bind/2, get_password_fun/1, + check_password_fun/1, check_password_digest_fun/1, unauthenticated_stream_features/1, authenticated_stream_features/1, - handle_stream_start/1, handle_stream_end/1, handle_stream_close/1, + handle_stream_start/2, handle_stream_end/2, handle_stream_close/2, handle_unauthenticated_packet/2, handle_authenticated_packet/2, - handle_auth_success/4, handle_auth_failure/4, handle_send/5, - handle_unbinded_packet/2, handle_cdata/2]). + handle_auth_success/4, handle_auth_failure/4, handle_send/3, + handle_recv/3, handle_cdata/2, handle_unbinded_packet/2]). +%% Hooks +-export([handle_unexpected_info/2, handle_unexpected_cast/2, + reject_unauthenticated_packet/2, process_closed/2]). %% API -export([get_presence/1, get_subscription/2, get_subscribed/1, - send/2, close/1]). + open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1, + copy_state/2, add_hooks/0]). -include("ejabberd.hrl"). -include("xmpp.hrl"). @@ -49,30 +55,30 @@ -define(SETS, gb_sets). -%%-define(DBGFSM, true). --ifdef(DBGFSM). --define(FSMOPTS, [{debug, [trace]}]). --else. --define(FSMOPTS, []). --endif. - -type state() :: map(). --type next_state() :: {noreply, state()} | {stop, term(), state()}. --export_type([state/0, next_state/0]). +-export_type([state/0]). %%%=================================================================== %%% ejabberd_socket API %%%=================================================================== start(SockData, Opts) -> xmpp_stream_in:start(?MODULE, [SockData, Opts], - fsm_limit_opts(Opts) ++ ?FSMOPTS). + ejabberd_config:fsm_limit_opts(Opts)). + +start_link(SockData, Opts) -> + xmpp_stream_in:start_link(?MODULE, [SockData, Opts], + ejabberd_config:fsm_limit_opts(Opts)). socket_type() -> xml_stream. +-spec call(pid(), term(), non_neg_integer() | infinity) -> term(). +call(Ref, Msg, Timeout) -> + xmpp_stream_in:call(Ref, Msg, Timeout). + -spec get_presence(pid()) -> presence(). get_presence(Ref) -> - xmpp_stream_in:call(Ref, get_presence, 1000). + call(Ref, get_presence, 1000). -spec get_subscription(jid() | ljid(), state()) -> both | from | to | none. get_subscription(#jid{} = From, State) -> @@ -90,15 +96,85 @@ get_subscription(LFrom, #{pres_f := PresF, pres_t := PresT}) -> -spec get_subscribed(pid()) -> [ljid()]. %% Return list of all available resources of contacts get_subscribed(Ref) -> - xmpp_stream_in:call(Ref, get_subscribed, 1000). + call(Ref, get_subscribed, 1000). --spec close(pid()) -> ok. close(Ref) -> - xmpp_stream_in:cast(Ref, closed). + xmpp_stream_in:close(Ref). + +close(Ref, SendTrailer) -> + xmpp_stream_in:close(Ref, SendTrailer). + +stop(Ref) -> + xmpp_stream_in:stop(Ref). + +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Pid, Pkt) when is_pid(Pid) -> + xmpp_stream_in:send(Pid, Pkt); +send(#{lserver := LServer} = State, Pkt) -> + case ejabberd_hooks:run_fold(c2s_filter_send, LServer, Pkt, [State]) of + drop -> State; + Pkt1 -> xmpp_stream_in:send(State, Pkt1) + end. + +-spec establish(state()) -> state(). +establish(State) -> + xmpp_stream_in:establish(State). + +-spec add_hooks() -> ok. +add_hooks() -> + lists:foreach( + fun(Host) -> + ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100), + ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE, + reject_unauthenticated_packet, 100), + ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, + handle_unexpected_info, 100), + ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE, + handle_unexpected_cast, 100) + + end, ?MYHOSTS). + +%% Copies content of one c2s state to another. +%% This is needed for session migration from one pid to another. +-spec copy_state(state(), state()) -> state(). +copy_state(#{owner := Owner} = NewState, + #{jid := JID, resource := Resource, sid := {Time, _}, + auth_module := AuthModule, lserver := LServer, + pres_t := PresT, pres_a := PresA, + pres_f := PresF} = OldState) -> + State1 = case OldState of + #{pres_last := Pres, pres_timestamp := PresTS} -> + NewState#{pres_last => Pres, pres_timestamp => PresTS}; + _ -> + NewState + end, + Conn = get_conn_type(State1), + State2 = State1#{jid => JID, resource => Resource, + conn => Conn, + sid => {Time, Owner}, + auth_module => AuthModule, + pres_t => PresT, pres_a => PresA, + pres_f => PresF}, + ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]). + +%%%=================================================================== +%%% Hooks +%%%=================================================================== +handle_unexpected_info(State, Info) -> + ?WARNING_MSG("got unexpected info: ~p", [Info]), + State. --spec send(state(), xmpp_element()) -> next_state(). -send(State, Pkt) -> - xmpp_stream_in:send(State, Pkt). +handle_unexpected_cast(State, Msg) -> + ?WARNING_MSG("got unexpected cast: ~p", [Msg]), + State. + +reject_unauthenticated_packet(State, Pkt) -> + Err = xmpp:err_not_authorized(), + xmpp_stream_in:send_error(State, Pkt, Err). + +process_closed(State, _Reason) -> + stop(State). %%%=================================================================== %%% xmpp_stream_in callbacks @@ -115,128 +191,158 @@ tls_options(#{lserver := LServer, tls_options := TLSOpts}) -> tls_required(#{tls_required := TLSRequired}) -> TLSRequired. +tls_verify(#{tls_verify := TLSVerify}) -> + TLSVerify. + compress_methods(#{zlib := true}) -> [<<"zlib">>]; compress_methods(_) -> []. -sasl_mechanisms(#{lserver := LServer}) -> - cyrsasl:listmech(LServer). - unauthenticated_stream_features(#{lserver := LServer}) -> ejabberd_hooks:run_fold(c2s_pre_auth_features, LServer, [], [LServer]). authenticated_stream_features(#{lserver := LServer}) -> ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]). -init_sasl(#{lserver := LServer}) -> - cyrsasl:server_new( - <<"jabber">>, LServer, <<"">>, [], - fun(U) -> - ejabberd_auth:get_password_with_authmodule(U, LServer) - end, - fun(U, AuthzId, P) -> - ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P) - end, - fun(U, AuthzId, P, D, DG) -> - ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG) - end). +get_password_fun(#{lserver := LServer}) -> + fun(U) -> + ejabberd_auth:get_password_with_authmodule(U, LServer) + end. + +check_password_fun(#{lserver := LServer}) -> + fun(U, AuthzId, P) -> + ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P) + end. + +check_password_digest_fun(#{lserver := LServer}) -> + fun(U, AuthzId, P, D, DG) -> + ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG) + end. bind(<<"">>, State) -> bind(new_uniq_id(), State); -bind(R, #{user := U, server := S} = State) -> +bind(R, #{user := U, server := S, access := Access, lang := Lang, + lserver := LServer, socket := Socket, ip := IP} = State) -> case resource_conflict_action(U, S, R) of closenew -> {error, xmpp:err_conflict(), State}; {accept_resource, Resource} -> - open_session(State, Resource) + JID = jid:make(U, S, Resource), + case acl:access_matches(Access, + #{usr => jid:split(JID), ip => IP}, + LServer) of + allow -> + State1 = open_session(State#{resource => Resource}), + State2 = ejabberd_hooks:run_fold( + c2s_session_opened, LServer, State1, []), + ?INFO_MSG("(~s) Opened session for ~s", + [ejabberd_socket:pp(Socket), jid:to_string(JID)]), + {ok, State2}; + deny -> + ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]), + ?INFO_MSG("(~s) Forbidden session for ~s", + [ejabberd_socket:pp(Socket), jid:to_string(JID)]), + Txt = <<"Denied by ACL">>, + {error, xmpp:err_not_allowed(Txt, Lang), State} + end end. -handshake(_Data, State) -> - %% This is only for jabber component - {ok, State}. +-spec open_session(state()) -> {ok, state()} | state(). +open_session(#{user := U, server := S, resource := R, + sid := SID, ip := IP, auth_module := AuthModule} = State) -> + JID = jid:make(U, S, R), + change_shaper(State), + Conn = get_conn_type(State), + State1 = State#{conn => Conn, resource => R, jid => JID}, + Prio = try maps:get(pres_last, State) of + Pres -> get_priority_from_presence(Pres) + catch _:{badkey, _} -> + undefined + end, + Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}], + ejabberd_sm:open_session(SID, U, S, R, Prio, Info), + State1. -handle_stream_start(#{lserver := LServer, ip := IP, lang := Lang} = State) -> +handle_stream_start(StreamStart, + #{lserver := LServer, ip := IP, lang := Lang} = State) -> case lists:member(LServer, ?MYHOSTS) of false -> - xmpp_stream_in:send(State, xmpp:serr_host_unknown()); + send(State, xmpp:serr_host_unknown()); true -> case check_bl_c2s(IP, Lang) of false -> change_shaper(State), - {noreply, State}; + ejabberd_hooks:run_fold( + c2s_stream_started, LServer, State, [StreamStart]); {true, LogReason, ReasonT} -> ?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s", [jlib:ip_to_list(IP), LogReason]), Err = xmpp:serr_policy_violation(ReasonT, Lang), - xmpp_stream_in:send(State, Err) + send(State, Err) end end. -handle_stream_end(State) -> - {stop, normal, State}. +handle_stream_end(Reason, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_closed, LServer, State, [Reason]). -handle_stream_close(State) -> - {stop, normal, State}. +handle_stream_close(_Reason, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_closed, LServer, State, [normal]). handle_auth_success(User, Mech, AuthModule, #{socket := Socket, ip := IP, lserver := LServer} = State) -> - ?INFO_MSG("(~w) Accepted ~s authentication for ~s@~s by ~p from ~s", - [Socket, Mech, User, LServer, AuthModule, + ?INFO_MSG("(~s) Accepted c2s ~s authentication for ~s@~s by ~s backend from ~s", + [ejabberd_socket:pp(Socket), Mech, User, LServer, + ejabberd_auth:backend_type(AuthModule), ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), State1 = State#{auth_module => AuthModule}, ejabberd_hooks:run_fold(c2s_auth_result, LServer, - {noreply, State1}, [true, User]). + State1, [true, User]). handle_auth_failure(User, Mech, Reason, #{socket := Socket, ip := IP, lserver := LServer} = State) -> - ?INFO_MSG("(~w) Failed ~s authentication ~sfrom ~s: ~s", - [Socket, Mech, + ?INFO_MSG("(~s) Failed c2s ~s authentication ~sfrom ~s: ~s", + [ejabberd_socket:pp(Socket), Mech, if User /= <<"">> -> ["for ", User, "@", LServer, " "]; true -> "" end, ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), ejabberd_hooks:run_fold(c2s_auth_result, LServer, - {noreply, State}, [false, User]). + State, [false, User]). handle_unbinded_packet(Pkt, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer, - {noreply, State}, [Pkt]). + State, [Pkt]). handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_unauthenticated_packet, - LServer, {noreply, State}, [Pkt]). + LServer, State, [Pkt]). handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) -> ejabberd_hooks:run_fold(c2s_authenticated_packet, - LServer, {noreply, State}, [Pkt]); + LServer, State, [Pkt]); handle_authenticated_packet(Pkt, #{lserver := LServer} = State) -> - case ejabberd_hooks:run_fold(c2s_authenticated_packet, - LServer, {noreply, State}, [Pkt]) of - {noreply, State1} -> - Pkt1 = ejabberd_hooks:run_fold(user_send_packet, LServer, Pkt, [State1]), - Res = case Pkt1 of - #presence{to = #jid{lresource = <<"">>}} -> - process_self_presence(State1, Pkt1); - #presence{} -> - process_presence_out(State1, Pkt1); - _ -> - check_privacy_then_route(State1, Pkt1) - end, - ejabberd_hooks:run(c2s_loop_debug, [{xmlstreamelement, Pkt}]), - Res; - Err -> - ejabberd_hooks:run(c2s_loop_debug, [{xmlstreamelement, Pkt}]), - Err + State1 = ejabberd_hooks:run_fold(c2s_authenticated_packet, + LServer, State, [Pkt]), + Pkt1 = ejabberd_hooks:run_fold(user_send_packet, LServer, Pkt, [State1]), + case Pkt1 of + #presence{to = #jid{lresource = <<"">>}} -> + process_self_presence(State1, Pkt1); + #presence{} -> + process_presence_out(State1, Pkt1); + _ -> + check_privacy_then_route(State1, Pkt1) end. handle_cdata(Data, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_handle_cdata, LServer, - {noreply, State}, [Data]). + State, [Data]). + +handle_recv(El, Pkt, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_handle_recv, LServer, State, [El, Pkt]). -handle_send(Reason, Pkt, El, Data, #{lserver := LServer} = State) -> - ejabberd_hooks:run_fold(c2s_handle_send, LServer, - {noreply, State}, [Reason, Pkt, El, Data]). +handle_send(Pkt, Result, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_handle_send, LServer, State, [Pkt, Result]). init([State, Opts]) -> Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all), @@ -262,15 +368,13 @@ init([State, Opts]) -> server => ?MYNAME, access => Access, shaper => Shaper}, - ejabberd_hooks:run_fold(c2s_init, {ok, State1}, []). + ejabberd_hooks:run_fold(c2s_init, {ok, State1}, [Opts]). handle_call(get_presence, _From, #{jid := JID} = State) -> - Pres = case maps:get(pres_last, State, undefined) of - undefined -> + Pres = try maps:get(pres_last, State) + catch _:{badkey, _} -> BareJID = jid:remove_resource(JID), - #presence{from = JID, to = BareJID, type = unavailable}; - P -> - P + #presence{from = JID, to = BareJID, type = unavailable} end, {reply, Pres, State}; handle_call(get_subscribed, _From, #{pres_f := PresF} = State) -> @@ -278,12 +382,10 @@ handle_call(get_subscribed, _From, #{pres_f := PresF} = State) -> {reply, Subscribed, State}; handle_call(Request, From, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold( - c2s_handle_call, LServer, {noreply, State}, [Request, From]). + c2s_handle_call, LServer, State, [Request, From]). -handle_cast(closed, State) -> - handle_stream_close(State); handle_cast(Msg, #{lserver := LServer} = State) -> - ejabberd_hooks:run_fold(c2s_handle_cast, LServer, {noreply, State}, [Msg]). + ejabberd_hooks:run_fold(c2s_handle_cast, LServer, State, [Msg]). handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) -> Packet = xmpp:set_from_to(Packet0, From, To), @@ -299,15 +401,13 @@ handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) -> Packet1 = ejabberd_hooks:run_fold( user_receive_packet, LServer, Packet, [NewState]), ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), - xmpp_stream_in:send(NewState, Packet1); + send(NewState, Packet1); true -> ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), - {noreply, NewState} + NewState end; -handle_info(system_shutdown, State) -> - xmpp_stream_in:send(State, xmpp:serr_system_shutdown()); handle_info(Info, #{lserver := LServer} = State) -> - ejabberd_hooks:run_fold(c2s_handle_info, LServer, {noreply, State}, [Info]). + ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]). terminate(_Reason, _State) -> ok. @@ -323,33 +423,6 @@ code_change(_OldVsn, State, _Extra) -> check_bl_c2s({IP, _Port}, Lang) -> ejabberd_hooks:run_fold(check_bl_c2s, false, [IP, Lang]). --spec open_session(state(), binary()) -> {ok, state()} | {error, stanza_error(), state()}. -open_session(#{user := U, server := S, lserver := LServer, sid := SID, - socket := Socket, ip := IP, auth_module := AuthMod, - access := Access, lang := Lang} = State, R) -> - JID = jid:make(U, S, R), - case acl:access_matches(Access, - #{usr => jid:split(JID), ip => IP}, - LServer) of - allow -> - ?INFO_MSG("(~w) Opened session for ~s", - [Socket, jid:to_string(JID)]), - change_shaper(State), - Conn = get_conn_type(State), - Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}], - ejabberd_sm:open_session(SID, U, LServer, R, Info), - State1 = State#{conn => Conn, resource => R, jid => JID}, - State2 = ejabberd_hooks:run_fold( - c2s_session_opened, LServer, State1, []), - {ok, State2}; - deny -> - ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]), - ?INFO_MSG("(~w) Forbidden session for ~s", - [Socket, jid:to_string(JID)]), - Txt = <<"Denied by ACL">>, - {error, xmpp:err_not_allowed(Txt, Lang), State} - end. - -spec process_iq_in(state(), iq()) -> {boolean(), state()}. process_iq_in(State, #iq{} = IQ) -> case privacy_check_packet(State, IQ, in) of @@ -433,7 +506,7 @@ route_probe_reply(From, To, #{lserver := LServer, pres_f := PresF, route_probe_reply(_, _, _) -> ok. --spec process_presence_out(state(), presence()) -> next_state(). +-spec process_presence_out(state(), presence()) -> state(). process_presence_out(#{user := User, server := Server, lserver := LServer, jid := JID, lang := Lang, pres_a := PresA} = State, #presence{from = From, to = To, type = Type} = Pres) -> @@ -461,21 +534,21 @@ process_presence_out(#{user := User, server := Server, lserver := LServer, [User, Server, To, Type]), BareFrom = jid:remove_resource(From), route(xmpp:set_from_to(Pres, BareFrom, To)), - {noreply, State} + State end; allow when Type == error; Type == probe -> route(Pres), - {noreply, State}; + State; allow -> route(Pres), A = case Type of available -> ?SETS:add_element(LTo, PresA); unavailable -> ?SETS:del_element(LTo, PresA) end, - {noreply, State#{pres_a => A}} + State#{pres_a => A} end. --spec process_self_presence(state(), presence()) -> {noreply, state()}. +-spec process_self_presence(state(), presence()) -> state(). process_self_presence(#{ip := IP, conn := Conn, auth_module := AuthMod, sid := SID, user := U, server := S, resource := R} = State, @@ -484,8 +557,7 @@ process_self_presence(#{ip := IP, conn := Conn, Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthMod}], ejabberd_sm:unset_presence(SID, U, S, R, Status, Info), State1 = broadcast_presence_unavailable(State, Pres), - State2 = maps:remove(pres_last, maps:remove(pres_timestamp, State1)), - {noreply, State2}; + maps:remove(pres_last, maps:remove(pres_timestamp, State1)); process_self_presence(#{lserver := LServer} = State, #presence{type = available} = Pres) -> PreviousPres = maps:get(pres_last, State, undefined), @@ -494,10 +566,9 @@ process_self_presence(#{lserver := LServer} = State, State2 = State1#{pres_last => Pres, pres_timestamp => p1_time_compat:timestamp()}, FromUnavailable = PreviousPres == undefined, - State3 = broadcast_presence_available(State2, Pres, FromUnavailable), - {noreply, State3}; + broadcast_presence_available(State2, Pres, FromUnavailable); process_self_presence(State, _Pres) -> - {noreply, State}. + State. -spec update_priority(state(), presence()) -> ok. update_priority(#{ip := IP, conn := Conn, auth_module := AuthMod, @@ -529,7 +600,7 @@ broadcast_presence_available(#{pres_a := PresA, pres_f := PresF} = State, route_multiple(State, JIDs, Pres), State. --spec check_privacy_then_route(state(), stanza()) -> next_state(). +-spec check_privacy_then_route(state(), stanza()) -> state(). check_privacy_then_route(#{lang := Lang} = State, Pkt) -> case privacy_check_packet(State, Pkt, out) of deny -> @@ -539,7 +610,7 @@ check_privacy_then_route(#{lang := Lang} = State, Pkt) -> xmpp_stream_in:send_error(State, Pkt, Err); allow -> route(Pkt), - {noreply, State} + State end. -spec privacy_check_packet(state(), stanza(), in | out) -> allow | deny. @@ -664,25 +735,10 @@ do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) -> end end. --spec fsm_limit_opts([proplists:property()]) -> [proplists:property()]. -fsm_limit_opts(Opts) -> - case lists:keysearch(max_fsm_queue, 1, Opts) of - {value, {_, N}} when is_integer(N) -> [{max_queue, N}]; - _ -> - case ejabberd_config:get_option( - max_fsm_queue, - fun(I) when is_integer(I), I > 0 -> I end) of - undefined -> []; - N -> [{max_queue, N}] - end - end. - transform_listen_option(Opt, Opts) -> [Opt|Opts]. opt_type(domain_certfile) -> fun iolist_to_binary/1; -opt_type(max_fsm_queue) -> - fun (I) when is_integer(I), I > 0 -> I end; opt_type(resource_conflict) -> fun (setresource) -> setresource; (closeold) -> closeold; @@ -690,4 +746,4 @@ opt_type(resource_conflict) -> (acceptnew) -> acceptnew end; opt_type(_) -> - [domain_certfile, max_fsm_queue, resource_conflict]. + [domain_certfile, resource_conflict]. diff --git a/src/ejabberd_config.erl b/src/ejabberd_config.erl index e930e36b1..9014bfabd 100644 --- a/src/ejabberd_config.erl +++ b/src/ejabberd_config.erl @@ -38,7 +38,8 @@ transform_options/1, collect_options/1, default_db/2, convert_to_yaml/1, convert_to_yaml/2, v_db/2, env_binary_to_list/2, opt_type/1, may_hide_data/1, - is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1]). + is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1, + fsm_limit_opts/1]). -export([start/2]). @@ -1403,6 +1404,8 @@ opt_type(hosts) -> end; opt_type(language) -> fun iolist_to_binary/1; +opt_type(max_fsm_queue) -> + fun (I) when is_integer(I), I > 0 -> I end; opt_type(_) -> [hide_sensitive_log_data, hosts, language]. @@ -1421,3 +1424,17 @@ may_hide_data(Data) -> true -> "hidden_by_ejabberd" end. + +-spec fsm_limit_opts([proplists:property()]) -> [{max_queue, pos_integer()}]. +fsm_limit_opts(Opts) -> + case lists:keyfind(max_fsm_queue, 1, Opts) of + {_, I} when is_integer(I), I>0 -> + [{max_queue, I}]; + false -> + case get_option( + max_fsm_queue, + fun(I) when is_integer(I), I>0 -> I end) of + undefined -> []; + N -> [{max_queue, N}] + end + end. diff --git a/src/ejabberd_hooks.erl b/src/ejabberd_hooks.erl index c1daa4c0e..612d5afe5 100644 --- a/src/ejabberd_hooks.erl +++ b/src/ejabberd_hooks.erl @@ -376,8 +376,11 @@ run_fold1([{_Seq, Module, Function} | Ls], Hook, Val, Args) -> end. safe_apply(Module, Function, Args) -> - if is_function(Function) -> - catch apply(Function, Args); - true -> - catch apply(Module, Function, Args) + try if is_function(Function) -> + apply(Function, Args); + true -> + apply(Module, Function, Args) + end + catch E:R when E /= exit, R /= normal -> + {'EXIT', {E, {R, erlang:get_stacktrace()}}} end. diff --git a/src/ejabberd_listener.erl b/src/ejabberd_listener.erl index a9cc441e9..f720fc585 100644 --- a/src/ejabberd_listener.erl +++ b/src/ejabberd_listener.erl @@ -330,9 +330,9 @@ accept(ListenSocket, Module, Opts, Interval) -> {ok, Socket} -> case {inet:sockname(Socket), inet:peername(Socket)} of {{ok, {Addr, Port}}, {ok, {PAddr, PPort}}} -> - ?INFO_MSG("(~w) Accepted connection ~s:~p -> ~s:~p", - [Socket, ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)), PPort, - inet_parse:ntoa(Addr), Port]); + ?INFO_MSG("Accepted connection ~s:~p -> ~s:~p", + [ejabberd_config:may_hide_data(inet_parse:ntoa(PAddr)), + PPort, inet_parse:ntoa(Addr), Port]); _ -> ok end, diff --git a/src/ejabberd_router.erl b/src/ejabberd_router.erl index 33093abb0..5ce8a8afb 100644 --- a/src/ejabberd_router.erl +++ b/src/ejabberd_router.erl @@ -43,7 +43,9 @@ unregister_route/1, unregister_routes/1, dirty_get_all_routes/0, - dirty_get_all_domains/0 + dirty_get_all_domains/0, + is_my_route/1, + is_my_host/1 ]). -export([start_link/0]). @@ -110,12 +112,12 @@ register_route(Domain) -> [?MODULE, ?MODULE]), register_route(Domain, ?MYNAME). --spec register_route(binary(), binary()) -> term(). +-spec register_route(binary(), binary()) -> ok. register_route(Domain, ServerHost) -> register_route(Domain, ServerHost, undefined). --spec register_route(binary(), binary(), local_hint()) -> term(). +-spec register_route(binary(), binary(), local_hint()) -> ok. register_route(Domain, ServerHost, LocalHint) -> case {jid:nameprep(Domain), jid:nameprep(ServerHost)} of @@ -165,6 +167,11 @@ register_route(Domain, ServerHost, LocalHint) -> end end, mnesia:transaction(F) + end, + if LocalHint == undefined -> + ?INFO_MSG("Route registered: ~s", [LDomain]); + true -> + ok end end. @@ -175,7 +182,7 @@ register_routes(Domains) -> end, Domains). --spec unregister_route(binary()) -> term(). +-spec unregister_route(binary()) -> ok. unregister_route(Domain) -> case jid:nameprep(Domain) of @@ -210,7 +217,8 @@ unregister_route(Domain) -> end end, mnesia:transaction(F) - end + end, + ?INFO_MSG("Route unregistered: ~s", [LDomain]) end. -spec unregister_routes([binary()]) -> ok. @@ -245,6 +253,29 @@ host_of_route(Domain) -> end end. +-spec is_my_route(binary()) -> boolean(). +is_my_route(Domain) -> + case jid:nameprep(Domain) of + error -> + erlang:error({invalid_domain, Domain}); + LDomain -> + mnesia:dirty_read(route, LDomain) /= [] + end. + +-spec is_my_host(binary()) -> boolean(). +is_my_host(Domain) -> + case jid:nameprep(Domain) of + error -> + erlang:error({invalid_domain, Domain}); + LDomain -> + case mnesia:dirty_read(route, LDomain) of + [#route{server_host = Host}|_] -> + Host == LDomain; + [] -> + false + end + end. + -spec process_iq(jid(), jid(), iq() | xmlel()) -> any(). process_iq(From, To, #iq{} = IQ) -> if To#jid.luser == <<"">> -> diff --git a/src/ejabberd_s2s.erl b/src/ejabberd_s2s.erl index 4df1761cb..af4d6a662 100644 --- a/src/ejabberd_s2s.erl +++ b/src/ejabberd_s2s.erl @@ -35,16 +35,16 @@ %% API -export([start_link/0, route/3, have_connection/1, - make_key/2, get_connections_pids/1, try_register/1, - remove_connection/2, find_connection/2, + get_connections_pids/1, try_register/1, + remove_connection/2, start_connection/2, start_connection/3, dirty_get_connections/0, allow_host/2, incoming_s2s_number/0, outgoing_s2s_number/0, stop_all_connections/0, clean_temporarily_blocked_table/0, list_temporarily_blocked_hosts/0, external_host_overloaded/1, is_temporarly_blocked/1, - check_peer_certificate/3, - get_commands_spec/0]). + get_commands_spec/0, zlib_enabled/1, get_idle_timeout/1, + tls_required/1, tls_verify/1, tls_enabled/1, tls_options/2]). %% gen_server callbacks -export([init/1, handle_call/3, handle_cast/2, @@ -196,39 +196,94 @@ try_register(FromTo) -> dirty_get_connections() -> mnesia:dirty_all_keys(s2s). -check_peer_certificate(SockMod, Sock, Peer) -> - case SockMod:get_peer_certificate(Sock) of - {ok, Cert} -> - case SockMod:get_verify_result(Sock) of - 0 -> - case ejabberd_idna:domain_utf8_to_ascii(Peer) of - false -> - {error, <<"Cannot decode remote server name">>}; - AsciiPeer -> - case - lists:any(fun(D) -> match_domain(AsciiPeer, D) end, - get_cert_domains(Cert)) of - true -> - {ok, <<"Verification successful">>}; - false -> - {error, <<"Certificate host name mismatch">>} - end - end; - VerifyRes -> - {error, fast_tls:get_cert_verify_string(VerifyRes, Cert)} - end; - {error, _Reason} -> - {error, <<"Cannot get peer certificate">>}; - error -> - {error, <<"Cannot get peer certificate">>} +-spec tls_options(binary(), [proplists:property()]) -> [proplists:property()]. +tls_options(LServer, DefaultOpts) -> + TLSOpts1 = case ejabberd_config:get_option( + {s2s_certfile, LServer}, + fun iolist_to_binary/1, + ejabberd_config:get_option( + {domain_certfile, LServer}, + fun iolist_to_binary/1)) of + undefined -> []; + CertFile -> lists:keystore(certfile, 1, DefaultOpts, + {certfile, CertFile}) + end, + TLSOpts2 = case ejabberd_config:get_option( + {s2s_ciphers, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts1; + Ciphers -> lists:keystore(ciphers, 1, TLSOpts1, + {ciphers, Ciphers}) + end, + TLSOpts3 = case ejabberd_config:get_option( + {s2s_protocol_options, LServer}, + fun (Options) -> str:join(Options, <<$|>>) end) of + undefined -> TLSOpts2; + ProtoOpts -> lists:keystore(protocol_options, 1, TLSOpts2, + {protocol_options, ProtoOpts}) + end, + TLSOpts4 = case ejabberd_config:get_option( + {s2s_dhfile, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts3; + DHFile -> lists:keystore(dhfile, 1, TLSOpts3, + {dhfile, DHFile}) + end, + TLSOpts5 = case ejabberd_config:get_option( + {s2s_cafile, LServer}, + fun iolist_to_binary/1) of + undefined -> TLSOpts4; + CAFile -> lists:keystore(cafile, 1, TLSOpts4, + {cafile, CAFile}) + end, + case ejabberd_config:get_option( + {s2s_tls_compression, LServer}, + fun(B) when is_boolean(B) -> B end) of + undefined -> TLSOpts5; + false -> [compression_none | TLSOpts5]; + true -> lists:delete(compression_none, TLSOpts5) end. --spec make_key({binary(), binary()}, binary()) -> binary(). -make_key({From, To}, StreamID) -> - Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end), - p1_sha:to_hexlist( - crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)), - [To, " ", From, " ", StreamID])). +-spec tls_required(binary()) -> boolean(). +tls_required(LServer) -> + TLS = use_starttls(LServer), + TLS == required orelse TLS == required_trusted. + +-spec tls_verify(binary()) -> boolean(). +tls_verify(LServer) -> + TLS = use_starttls(LServer), + TLS == required_trusted. + +-spec tls_enabled(binary()) -> boolean(). +tls_enabled(LServer) -> + TLS = use_starttls(LServer), + TLS == true orelse TLS == optional. + +-spec zlib_enabled(binary()) -> boolean(). +zlib_enabled(LServer) -> + ejabberd_config:get_option( + {s2s_zlib, LServer}, + fun(B) when is_boolean(B) -> B end, + false). + +-spec use_starttls(binary()) -> boolean() | optional | required | required_trusted. +use_starttls(LServer) -> + ejabberd_config:get_option( + {s2s_use_starttls, LServer}, + fun(true) -> true; + (false) -> false; + (optional) -> optional; + (required) -> required; + (required_trusted) -> required_trusted + end, false). + +-spec get_idle_timeout(binary()) -> non_neg_integer() | infinity. +get_idle_timeout(LServer) -> + ejabberd_config:get_option( + {s2s_timeout, LServer}, + fun(I) when is_integer(I), I >= 0 -> timer:seconds(I); + (infinity) -> infinity + end, timer:minutes(10)). %%==================================================================== %% gen_server callbacks @@ -246,6 +301,8 @@ init([]) -> ejabberd_mnesia:create(?MODULE, temporarily_blocked, [{ram_copies, [node()]}, {attributes, record_info(fields, temporarily_blocked)}]), + ejabberd_s2s_in:add_hooks(), + ejabberd_s2s_out:add_hooks(), {ok, #state{}}. handle_call(_Request, _From, State) -> @@ -291,30 +348,36 @@ clean_table_from_bad_node(Node) -> end, mnesia:async_dirty(F). --spec do_route(jid(), jid(), stanza()) -> ok | false. +-spec do_route(jid(), jid(), stanza()) -> ok. do_route(From, To, Packet) -> ?DEBUG("s2s manager~n\tfrom ~p~n\tto ~p~n\tpacket " "~P~n", [From, To, Packet, 8]), - case find_connection(From, To) of - {atomic, Pid} when is_pid(Pid) -> - ?DEBUG("sending to process ~p~n", [Pid]), - #jid{lserver = MyServer} = From, - ejabberd_hooks:run(s2s_send_packet, MyServer, - [From, To, Packet]), - send_element(Pid, xmpp:set_from_to(Packet, From, To)), - ok; - {aborted, _Reason} -> - Lang = xmpp:get_lang(Packet), - Txt = <<"No s2s connection found">>, - Err = xmpp:err_service_unavailable(Txt, Lang), - ejabberd_router:route_error(To, From, Packet, Err), - false + case start_connection(From, To) of + {ok, Pid} when is_pid(Pid) -> + ?DEBUG("sending to process ~p~n", [Pid]), + #jid{lserver = MyServer} = From, + ejabberd_hooks:run(s2s_send_packet, MyServer, [From, To, Packet]), + ejabberd_s2s_out:route(Pid, xmpp:set_from_to(Packet, From, To)); + {error, Reason} -> + Err = case Reason of + forbidden -> + Lang = xmpp:get_lang(Packet), + xmpp:err_forbidden(<<"Denied by ACL">>, Lang); + internal_server_error -> + xmpp:err_internal_server_error() + end, + ejabberd_router:route_error(To, From, Packet, Err) end. --spec find_connection(jid(), jid()) -> {aborted, any()} | {atomic, pid()}. +-spec start_connection(jid(), jid()) -> {ok, pid()} | + {error, forbidden | internal_server_error}. +start_connection(From, To) -> + start_connection(From, To, []). -find_connection(From, To) -> +-spec start_connection(jid(), jid(), [proplists:property()]) + -> {ok, pid()} | {error, forbidden | internal_server_error}. +start_connection(From, To, Opts) -> #jid{lserver = MyServer} = From, #jid{lserver = Server} = To, FromTo = {MyServer, Server}, @@ -323,15 +386,13 @@ find_connection(From, To) -> MaxS2SConnectionsNumberPerNode = max_s2s_connections_number_per_node(FromTo), ?DEBUG("Finding connection for ~p~n", [FromTo]), - case catch mnesia:dirty_read(s2s, FromTo) of - {'EXIT', Reason} -> {aborted, Reason}; + case mnesia:dirty_read(s2s, FromTo) of [] -> %% We try to establish all the connections if the host is not a %% service and if the s2s host is not blacklisted or %% is in whitelist: - case not is_service(From, To) andalso - allow_host(MyServer, Server) - of + LServer = ejabberd_router:host_of_route(MyServer), + case not is_service(From, To) andalso allow_host(LServer, Server) of true -> NeededConnections = needed_connections_number([], MaxS2SConnectionsNumber, @@ -339,8 +400,8 @@ find_connection(From, To) -> open_several_connections(NeededConnections, MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode); - false -> {aborted, error} + MaxS2SConnectionsNumberPerNode, Opts); + false -> {error, forbidden} end; L when is_list(L) -> NeededConnections = needed_connections_number(L, @@ -351,10 +412,10 @@ find_connection(From, To) -> open_several_connections(NeededConnections, MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode); + MaxS2SConnectionsNumberPerNode, Opts); true -> %% We choose a connexion from the pool of opened ones. - {atomic, choose_connection(From, L)} + {ok, choose_connection(From, L)} end end. @@ -377,20 +438,22 @@ choose_pid(From, Pids) -> open_several_connections(N, MyServer, Server, From, FromTo, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode) -> - ConnectionsResult = [new_connection(MyServer, Server, - From, FromTo, MaxS2SConnectionsNumber, - MaxS2SConnectionsNumberPerNode) - || _N <- lists:seq(1, N)], - case [PID || {atomic, PID} <- ConnectionsResult] of - [] -> hd(ConnectionsResult); - PIDs -> {atomic, choose_pid(From, PIDs)} + MaxS2SConnectionsNumberPerNode, Opts) -> + case lists:flatmap( + fun(_) -> + new_connection(MyServer, Server, + From, FromTo, MaxS2SConnectionsNumber, + MaxS2SConnectionsNumberPerNode, Opts) + end, lists:seq(1, N)) of + [] -> + {error, internal_server_error}; + PIDs -> + {ok, choose_pid(From, PIDs)} end. new_connection(MyServer, Server, From, FromTo, - MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode) -> - {ok, Pid} = ejabberd_s2s_out:start( - MyServer, Server, new), + MaxS2SConnectionsNumber, MaxS2SConnectionsNumberPerNode, Opts) -> + {ok, Pid} = ejabberd_s2s_out:start(MyServer, Server, Opts), F = fun() -> L = mnesia:read({s2s, FromTo}), NeededConnections = needed_connections_number(L, @@ -398,17 +461,21 @@ new_connection(MyServer, Server, From, FromTo, MaxS2SConnectionsNumberPerNode), if NeededConnections > 0 -> mnesia:write(#s2s{fromto = FromTo, pid = Pid}), - ?INFO_MSG("New s2s connection started ~p", [Pid]), Pid; true -> choose_connection(From, L) end end, TRes = mnesia:transaction(F), case TRes of - {atomic, Pid} -> ejabberd_s2s_out:start_connection(Pid); - _ -> ejabberd_s2s_out:stop_connection(Pid) - end, - TRes. + {atomic, Pid} -> + ejabberd_s2s_out:connect(Pid), + [Pid]; + {aborted, Reason} -> + ?ERROR_MSG("failed to register connection ~s -> ~s: ~p", + [MyServer, Server, Reason]), + ejabberd_s2s_out:stop(Pid), + [] + end. -spec max_s2s_connections_number({binary(), binary()}) -> integer(). max_s2s_connections_number({From, To}) -> @@ -459,9 +526,6 @@ parent_domains(Domain) -> end, [], lists:reverse(str:tokens(Domain, <<".">>))). -send_element(Pid, El) -> - Pid ! {send_element, El}. - %%%---------------------------------------------------------------------- %%% ejabberd commands @@ -536,24 +600,13 @@ update_tables() -> %% Check if host is in blacklist or white list allow_host(MyServer, S2SHost) -> - allow_host2(MyServer, S2SHost) andalso + allow_host1(MyServer, S2SHost) andalso not is_temporarly_blocked(S2SHost). -allow_host2(MyServer, S2SHost) -> - Hosts = (?MYHOSTS), - case lists:dropwhile(fun (ParentDomain) -> - not lists:member(ParentDomain, Hosts) - end, - parent_domains(MyServer)) - of - [MyHost | _] -> allow_host1(MyHost, S2SHost); - [] -> allow_host1(MyServer, S2SHost) - end. - allow_host1(MyHost, S2SHost) -> Rule = ejabberd_config:get_option( - s2s_access, - fun(A) -> A end, + {s2s_access, MyHost}, + fun acl:access_rules_validator/1, all), JID = jid:make(S2SHost), case acl:match_rule(MyHost, Rule, JID) of @@ -624,133 +677,34 @@ get_s2s_state(S2sPid) -> end, [{s2s_pid, S2sPid} | Infos]. -get_cert_domains(Cert) -> - TBSCert = Cert#'Certificate'.tbsCertificate, - Subject = case TBSCert#'TBSCertificate'.subject of - {rdnSequence, Subj} -> lists:flatten(Subj); - _ -> [] - end, - Extensions = case TBSCert#'TBSCertificate'.extensions of - Exts when is_list(Exts) -> Exts; - _ -> [] - end, - lists:flatmap(fun (#'AttributeTypeAndValue'{type = - ?'id-at-commonName', - value = Val}) -> - case 'OTP-PUB-KEY':decode('X520CommonName', Val) of - {ok, {_, D1}} -> - D = if is_binary(D1) -> D1; - is_list(D1) -> list_to_binary(D1); - true -> error - end, - if D /= error -> - case jid:from_string(D) of - #jid{luser = <<"">>, lserver = LD, - lresource = <<"">>} -> - [LD]; - _ -> [] - end; - true -> [] - end; - _ -> [] - end; - (_) -> [] - end, - Subject) - ++ - lists:flatmap(fun (#'Extension'{extnID = - ?'id-ce-subjectAltName', - extnValue = Val}) -> - BVal = if is_list(Val) -> list_to_binary(Val); - true -> Val - end, - case 'OTP-PUB-KEY':decode('SubjectAltName', BVal) - of - {ok, SANs} -> - lists:flatmap(fun ({otherName, - #'AnotherName'{'type-id' = - ?'id-on-xmppAddr', - value = - XmppAddr}}) -> - case - 'XmppAddr':decode('XmppAddr', - XmppAddr) - of - {ok, D} - when - is_binary(D) -> - case - jid:from_string((D)) - of - #jid{luser = - <<"">>, - lserver = - LD, - lresource = - <<"">>} -> - case - ejabberd_idna:domain_utf8_to_ascii(LD) - of - false -> - []; - PCLD -> - [PCLD] - end; - _ -> [] - end; - _ -> [] - end; - ({dNSName, D}) - when is_list(D) -> - case - jid:from_string(list_to_binary(D)) - of - #jid{luser = <<"">>, - lserver = LD, - lresource = - <<"">>} -> - [LD]; - _ -> [] - end; - (_) -> [] - end, - SANs); - _ -> [] - end; - (_) -> [] - end, - Extensions). - -match_domain(Domain, Domain) -> true; -match_domain(Domain, Pattern) -> - DLabels = str:tokens(Domain, <<".">>), - PLabels = str:tokens(Pattern, <<".">>), - match_labels(DLabels, PLabels). - -match_labels([], []) -> true; -match_labels([], [_ | _]) -> false; -match_labels([_ | _], []) -> false; -match_labels([DL | DLabels], [PL | PLabels]) -> - case lists:all(fun (C) -> - $a =< C andalso C =< $z orelse - $0 =< C andalso C =< $9 orelse - C == $- orelse C == $* - end, - binary_to_list(PL)) - of - true -> - Regexp = ejabberd_regexp:sh_to_awk(PL), - case ejabberd_regexp:run(DL, Regexp) of - match -> match_labels(DLabels, PLabels); - nomatch -> false - end; - false -> false - end. - opt_type(route_subdomains) -> fun (s2s) -> s2s; (local) -> local end; opt_type(s2s_access) -> fun acl:access_rules_validator/1; -opt_type(_) -> [route_subdomains, s2s_access]. +opt_type(domain_certfile) -> fun iolist_to_binary/1; +opt_type(s2s_certfile) -> fun iolist_to_binary/1; +opt_type(s2s_ciphers) -> fun iolist_to_binary/1; +opt_type(s2s_dhfile) -> fun iolist_to_binary/1; +opt_type(s2s_protocol_options) -> + fun (Options) -> str:join(Options, <<"|">>) end; +opt_type(s2s_tls_compression) -> + fun (true) -> true; + (false) -> false + end; +opt_type(s2s_use_starttls) -> + fun (true) -> true; + (false) -> false; + (optional) -> optional; + (required) -> required; + (required_trusted) -> required_trusted + end; +opt_type(s2s_timeout) -> + fun(I) when is_integer(I), I>=0 -> I; + (infinity) -> infinity + end; +opt_type(_) -> + [route_subdomains, s2s_access, s2s_certfile, + s2s_ciphers, s2s_dhfile, s2s_protocol_options, + s2s_tls_compression, s2s_use_starttls, s2s_timeout]. diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index 395a0fce7..93f75bfcf 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -1,8 +1,5 @@ -%%%---------------------------------------------------------------------- -%%% File : ejabberd_s2s_in.erl -%%% Author : Alexey Shchepin -%%% Purpose : Serve incoming s2s connection -%%% Created : 6 Dec 2002 by Alexey Shchepin +%%%------------------------------------------------------------------- +%%% Created : 12 Dec 2016 by Evgeny Khramtsov %%% %%% %%% ejabberd, Copyright (C) 2002-2016 ProcessOne @@ -21,645 +18,280 @@ %%% with this program; if not, write to the Free Software Foundation, Inc., %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. %%% -%%%---------------------------------------------------------------------- - +%%%------------------------------------------------------------------- -module(ejabberd_s2s_in). - +-behaviour(xmpp_stream_in). -behaviour(ejabberd_config). +-behaviour(ejabberd_socket). --author('alexey@process-one.net'). - --behaviour(p1_fsm). - -%% External exports +%% ejabberd_socket callbacks -export([start/2, start_link/2, socket_type/0]). - --export([init/1, wait_for_stream/2, - wait_for_feature_request/2, stream_established/2, - handle_event/3, handle_sync_event/4, code_change/4, - handle_info/3, print_state/1, terminate/3, opt_type/1]). +%% ejabberd_config callbacks +-export([opt_type/1]). +%% xmpp_stream_in callbacks +-export([init/1, handle_call/3, handle_cast/2, + handle_info/2, terminate/2, code_change/3]). +-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, + compress_methods/1, + unauthenticated_stream_features/1, authenticated_stream_features/1, + handle_stream_start/2, handle_stream_end/2, handle_stream_close/2, + handle_stream_established/1, handle_auth_success/4, + handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2, + handle_unauthenticated_packet/2, handle_authenticated_packet/2]). +%% Hooks +-export([handle_unexpected_info/2, handle_unexpected_cast/2, + reject_unauthenticated_packet/2, process_closed/2]). +%% API +-export([stop/1, close/1, send/2, update_state/2, establish/1, add_hooks/0]). -include("ejabberd.hrl"). --include("logger.hrl"). - -include("xmpp.hrl"). +-include("logger.hrl"). --define(DICT, dict). - --record(state, - {socket :: ejabberd_socket:socket_state(), - sockmod = ejabberd_socket :: ejabberd_socket | ejabberd_frontend_socket, - streamid = <<"">> :: binary(), - shaper = none :: shaper:shaper(), - tls = false :: boolean(), - tls_enabled = false :: boolean(), - tls_required = false :: boolean(), - tls_certverify = false :: boolean(), - tls_options = [] :: list(), - server = <<"">> :: binary(), - authenticated = false :: boolean(), - auth_domain = <<"">> :: binary(), - connections = (?DICT):new() :: ?TDICT, - timer = make_ref() :: reference()}). - --type state_name() :: wait_for_stream | wait_for_feature_request | stream_established. --type state() :: #state{}. --type fsm_next() :: {next_state, state_name(), state()}. --type fsm_stop() :: {stop, normal, state()}. --type fsm_transition() :: fsm_stop() | fsm_next(). - -%%-define(DBGFSM, true). --ifdef(DBGFSM). --define(FSMOPTS, [{debug, [trace]}]). --else. --define(FSMOPTS, []). --endif. +-type state() :: map(). +-export_type([state/0]). +%%%=================================================================== +%%% API +%%%=================================================================== start(SockData, Opts) -> - supervisor:start_child(ejabberd_s2s_in_sup, - [SockData, Opts]). + xmpp_stream_in:start(?MODULE, [SockData, Opts], + ejabberd_config:fsm_limit_opts(Opts)). start_link(SockData, Opts) -> - p1_fsm:start_link(ejabberd_s2s_in, [SockData, Opts], - ?FSMOPTS ++ fsm_limit_opts(Opts)). + xmpp_stream_in:start_link(?MODULE, [SockData, Opts], + ejabberd_config:fsm_limit_opts(Opts)). + +close(Ref) -> + xmpp_stream_in:close(Ref). + +stop(Ref) -> + xmpp_stream_in:stop(Ref). + +socket_type() -> + xml_stream. + +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Stream, Pkt) -> + xmpp_stream_in:send(Stream, Pkt). + +-spec establish(state()) -> state(). +establish(State) -> + xmpp_stream_in:establish(State). + +-spec update_state(pid(), fun((state()) -> state()) | + {module(), atom(), list()}) -> ok. +update_state(Ref, Callback) -> + xmpp_stream_in:cast(Ref, {update_state, Callback}). + +-spec add_hooks() -> ok. +add_hooks() -> + lists:foreach( + fun(Host) -> + ejabberd_hooks:add(s2s_in_closed, Host, ?MODULE, + process_closed, 100), + ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE, + reject_unauthenticated_packet, 100), + ejabberd_hooks:add(s2s_in_handle_info, Host, ?MODULE, + handle_unexpected_info, 100), + ejabberd_hooks:add(s2s_in_handle_cast, Host, ?MODULE, + handle_unexpected_cast, 100) + end, ?MYHOSTS). + +%%%=================================================================== +%%% Hooks +%%%=================================================================== +handle_unexpected_info(State, Info) -> + ?WARNING_MSG("got unexpected info: ~p", [Info]), + State. + +handle_unexpected_cast(State, Msg) -> + ?WARNING_MSG("got unexpected cast: ~p", [Msg]), + State. + +reject_unauthenticated_packet(State, Pkt) -> + Err = xmpp:err_not_authorized(), + xmpp_stream_in:send_error(State, Pkt, Err). + +process_closed(State, _Reason) -> + stop(State). + +%%%=================================================================== +%%% xmpp_stream_in callbacks +%%%=================================================================== +tls_options(#{tls_compression := Compression, server_host := LServer}) -> + Opts = case Compression of + false -> [compression_none]; + true -> [] + end, + ejabberd_s2s:tls_options(LServer, Opts). + +tls_required(#{server_host := LServer}) -> + ejabberd_s2s:tls_required(LServer). -socket_type() -> xml_stream. +tls_verify(#{server_host := LServer}) -> + ejabberd_s2s:tls_verify(LServer). -%%%---------------------------------------------------------------------- -%%% Callback functions from gen_fsm -%%%---------------------------------------------------------------------- +tls_enabled(#{server_host := LServer}) -> + ejabberd_s2s:tls_enabled(LServer). -init([{SockMod, Socket}, Opts]) -> - ?DEBUG("started: ~p", [{SockMod, Socket}]), - Shaper = case lists:keysearch(shaper, 1, Opts) of - {value, {_, S}} -> S; - _ -> none +compress_methods(#{server_host := LServer}) -> + case ejabberd_s2s:zlib_enabled(LServer) of + true -> [<<"zlib">>]; + false -> [] + end. + +unauthenticated_stream_features(#{server_host := LServer}) -> + ejabberd_hooks:run_fold(s2s_in_pre_auth_features, LServer, [], [LServer]). + +authenticated_stream_features(#{server_host := LServer}) -> + ejabberd_hooks:run_fold(s2s_in_post_auth_features, LServer, [], [LServer]). + +handle_stream_start(_StreamStart, #{lserver := LServer} = State) -> + case check_to(jid:make(LServer), State) of + false -> + send(State, xmpp:serr_host_unknown()); + true -> + ServerHost = ejabberd_router:host_of_route(LServer), + State#{server_host => ServerHost} + end. + +handle_stream_end(Reason, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [Reason]). + +handle_stream_close(_Reason, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [normal]). + +handle_stream_established(State) -> + set_idle_timeout(State#{established => true}). + +handle_auth_success(RServer, Mech, _AuthModule, + #{socket := Socket, ip := IP, + auth_domains := AuthDomains, + server_host := ServerHost, + lserver := LServer} = State) -> + ?INFO_MSG("(~s) Accepted inbound s2s ~s authentication ~s -> ~s (~s)", + [ejabberd_socket:pp(Socket), Mech, RServer, LServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + State1 = case ejabberd_s2s:allow_host(ServerHost, RServer) of + true -> + AuthDomains1 = sets:add_element(RServer, AuthDomains), + State#{auth_domains => AuthDomains1}; + false -> + State end, - {StartTLS, TLSRequired, TLSCertverify} = - case ejabberd_config:get_option( - s2s_use_starttls, - fun(false) -> false; - (true) -> true; - (optional) -> optional; - (required) -> required; - (required_trusted) -> required_trusted - end, - false) of - UseTls - when (UseTls == undefined) or - (UseTls == false) -> - {false, false, false}; - UseTls - when (UseTls == true) or - (UseTls == - optional) -> - {true, false, false}; - required -> {true, true, false}; - required_trusted -> - {true, true, true} - end, - TLSOpts1 = case ejabberd_config:get_option( - s2s_certfile, - fun iolist_to_binary/1) of - undefined -> []; - CertFile -> [{certfile, CertFile}] - end, - TLSOpts2 = case ejabberd_config:get_option( - s2s_ciphers, fun iolist_to_binary/1) of - undefined -> TLSOpts1; - Ciphers -> [{ciphers, Ciphers} | TLSOpts1] - end, - TLSOpts3 = case ejabberd_config:get_option( - s2s_protocol_options, - fun (Options) -> - [_|O] = lists:foldl( - fun(X, Acc) -> X ++ Acc end, [], - [["|" | binary_to_list(Opt)] || Opt <- Options, is_binary(Opt)] - ), - iolist_to_binary(O) - end) of - undefined -> TLSOpts2; - ProtocolOpts -> [{protocol_options, ProtocolOpts} | TLSOpts2] - end, - TLSOpts4 = case ejabberd_config:get_option( - s2s_dhfile, fun iolist_to_binary/1) of - undefined -> TLSOpts3; - DHFile -> [{dhfile, DHFile} | TLSOpts3] - end, - TLSOpts = case proplists:get_bool(tls_compression, Opts) of - false -> [compression_none | TLSOpts4]; - true -> TLSOpts4 - end, - Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - {ok, wait_for_stream, - #state{socket = Socket, sockmod = SockMod, - streamid = new_id(), shaper = Shaper, tls = StartTLS, - tls_enabled = false, tls_required = TLSRequired, - tls_certverify = TLSCertverify, tls_options = TLSOpts, - timer = Timer}}. - -%%---------------------------------------------------------------------- -%% Func: StateName/2 -%% Returns: {next_state, NextStateName, NextStateData} | -%% {next_state, NextStateName, NextStateData, Timeout} | -%% {stop, Reason, NewStateData} -%%---------------------------------------------------------------------- -wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) -> - try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of - #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM} - when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM -> - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_invalid_namespace()), - {stop, normal, StateData}; - #stream_start{to = #jid{lserver = Server}, - from = From, version = {1,0}} - when StateData#state.tls and not StateData#state.authenticated -> - send_header(StateData, {1,0}), - Auth = if StateData#state.tls_enabled -> - case From of - #jid{} -> - {Result, Message} = - ejabberd_s2s:check_peer_certificate( - StateData#state.sockmod, - StateData#state.socket, - From#jid.lserver), - {Result, From#jid.lserver, Message}; - undefined -> - {error, <<"(unknown)">>, - <<"Got no valid 'from' attribute">>} - end; - true -> - {no_verify, <<"(unknown)">>, <<"TLS not (yet) enabled">>} - end, - StartTLS = if StateData#state.tls_enabled -> []; - not StateData#state.tls_enabled and - not StateData#state.tls_required -> - [#starttls{required = false}]; - not StateData#state.tls_enabled and - StateData#state.tls_required -> - [#starttls{required = true}] - end, - case Auth of - {error, RemoteServer, CertError} - when StateData#state.tls_certverify -> - ?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)", - [StateData#state.server, RemoteServer, CertError]), - send_element(StateData, - xmpp:serr_policy_violation(CertError, ?MYLANG)), - {stop, normal, StateData}; - {VerifyResult, RemoteServer, Msg} -> - {SASL, NewStateData} = - case VerifyResult of - ok -> - {[#sasl_mechanisms{list = [<<"EXTERNAL">>]}], - StateData#state{auth_domain = RemoteServer}}; - error -> - ?DEBUG("Won't accept certificate of ~s: ~s", - [RemoteServer, Msg]), - {[], StateData}; - no_verify -> - {[], StateData} - end, - send_element(NewStateData, - #stream_features{ - sub_els = SASL ++ StartTLS ++ - ejabberd_hooks:run_fold( - s2s_stream_features, Server, [], - [Server])}), - {next_state, wait_for_feature_request, - NewStateData#state{server = Server}} - end; - #stream_start{to = #jid{lserver = Server}, - version = {1,0}} when StateData#state.authenticated -> - send_header(StateData, {1,0}), - send_element(StateData, - #stream_features{ - sub_els = ejabberd_hooks:run_fold( - s2s_stream_features, Server, [], - [Server])}), - {next_state, stream_established, StateData}; - #stream_start{db_xmlns = ?NS_SERVER_DIALBACK} - when (StateData#state.tls_required and StateData#state.tls_enabled) - or (not StateData#state.tls_required) -> - send_header(StateData, undefined), - {next_state, stream_established, StateData}; - #stream_start{} -> - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_undefined_condition()), - {stop, normal, StateData}; - _ -> - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_invalid_xml()), - {stop, normal, StateData} - catch _:{xmpp_codec, Why} -> - Txt = xmpp:format_error(Why), - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)), - {stop, normal, StateData} - end; -wait_for_stream({xmlstreamerror, _}, StateData) -> - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_stream(timeout, StateData) -> - send_header(StateData, {1,0}), - send_element(StateData, xmpp:serr_connection_timeout()), - {stop, normal, StateData}; -wait_for_stream(closed, StateData) -> - {stop, normal, StateData}. - -wait_for_feature_request({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_feature_request, StateData); -wait_for_feature_request(#starttls{}, - #state{tls = true, tls_enabled = false} = StateData) -> - case (StateData#state.sockmod):get_sockmod(StateData#state.socket) of - gen_tcp -> - ?DEBUG("starttls", []), - Socket = StateData#state.socket, - TLSOpts1 = case - ejabberd_config:get_option( - {domain_certfile, StateData#state.server}, - fun iolist_to_binary/1) of - undefined -> StateData#state.tls_options; - CertFile -> - lists:keystore(certfile, 1, - StateData#state.tls_options, - {certfile, CertFile}) - end, - TLSOpts2 = case ejabberd_config:get_option( - {s2s_cafile, StateData#state.server}, - fun iolist_to_binary/1) of - undefined -> TLSOpts1; - CAFile -> - lists:keystore(cafile, 1, TLSOpts1, - {cafile, CAFile}) - end, - TLSOpts = case ejabberd_config:get_option( - {s2s_tls_compression, StateData#state.server}, - fun(true) -> true; - (false) -> false - end, false) of - true -> lists:delete(compression_none, TLSOpts2); - false -> [compression_none | TLSOpts2] - end, - TLSSocket = (StateData#state.sockmod):starttls( - Socket, TLSOpts, - fxml:element_to_binary( - xmpp:encode(#starttls_proceed{}))), - {next_state, wait_for_stream, - StateData#state{socket = TLSSocket, streamid = new_id(), - tls_enabled = true, tls_options = TLSOpts}}; - _ -> - send_element(StateData, #starttls_failure{}), - {stop, normal, StateData} - end; -wait_for_feature_request(#sasl_auth{mechanism = Mech}, - #state{tls_enabled = true} = StateData) -> - case Mech of - <<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> -> - AuthDomain = StateData#state.auth_domain, - AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>, AuthDomain), - if AllowRemoteHost -> - (StateData#state.sockmod):reset_stream(StateData#state.socket), - send_element(StateData, #sasl_success{}), - ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)", - [AuthDomain, StateData#state.tls_enabled]), - change_shaper(StateData, <<"">>, jid:make(AuthDomain)), - {next_state, wait_for_stream, - StateData#state{streamid = new_id(), - authenticated = true}}; - true -> - Txt = xmpp:mk_text(<<"Denied by ACL">>, ?MYLANG), - send_element(StateData, - #sasl_failure{reason = 'not-authorized', - text = Txt}), - {stop, normal, StateData} - end; - _ -> - send_element(StateData, #sasl_failure{reason = 'invalid-mechanism'}), - {stop, normal, StateData} - end; -wait_for_feature_request({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -wait_for_feature_request({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_feature_request(closed, StateData) -> - {stop, normal, StateData}; -wait_for_feature_request(_Pkt, #state{tls_required = TLSRequired, - tls_enabled = TLSEnabled} = StateData) - when TLSRequired and not TLSEnabled -> - Txt = <<"Use of STARTTLS required">>, - send_element(StateData, xmpp:serr_policy_violation(Txt, ?MYLANG)), - {stop, normal, StateData}; -wait_for_feature_request(El, StateData) -> - stream_established({xmlstreamelement, El}, StateData). - -stream_established({xmlstreamelement, El}, StateData) -> - cancel_timer(StateData#state.timer), - Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - decode_element(El, stream_established, StateData#state{timer = Timer}); -stream_established(#db_result{to = To, from = From, key = Key}, - StateData) -> - ?DEBUG("GET KEY: ~p", [{To, From, Key}]), - case {ejabberd_s2s:allow_host(To, From), - lists:member(To, ejabberd_router:dirty_get_all_domains())} of - {true, true} -> - ejabberd_s2s_out:terminate_if_waiting_delay(To, From), - ejabberd_s2s_out:start(To, From, - {verify, self(), Key, - StateData#state.streamid}), - Conns = (?DICT):store({From, To}, - wait_for_verification, - StateData#state.connections), - change_shaper(StateData, To, jid:make(From)), - {next_state, stream_established, - StateData#state{connections = Conns}}; - {_, false} -> - send_element(StateData, xmpp:serr_host_unknown()), - {stop, normal, StateData}; - {false, _} -> - send_element(StateData, xmpp:serr_invalid_from()), - {stop, normal, StateData} - end; -stream_established(#db_verify{to = To, from = From, id = Id, key = Key}, - StateData) -> - ?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]), - Type = case ejabberd_s2s:make_key({To, From}, Id) of - Key -> valid; - _ -> invalid - end, - send_element(StateData, - #db_verify{from = To, to = From, id = Id, type = Type}), - {next_state, stream_established, StateData}; -stream_established(Pkt, StateData) when ?is_stanza(Pkt) -> + ejabberd_hooks:run_fold(s2s_in_auth_result, ServerHost, State1, [true, RServer]). + +handle_auth_failure(RServer, Mech, Reason, + #{socket := Socket, ip := IP, + server_host := ServerHost, + lserver := LServer} = State) -> + ?INFO_MSG("(~s) Failed inbound s2s ~s authentication ~s -> ~s (~s): ~s", + [ejabberd_socket:pp(Socket), Mech, RServer, LServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), + ejabberd_hooks:run_fold(s2s_in_auth_result, + ServerHost, State, [false, RServer]). + +handle_unauthenticated_packet(Pkt, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_unauthenticated_packet, + LServer, State, [Pkt]). + +handle_authenticated_packet(Pkt, #{server_host := LServer} = State) when not ?is_stanza(Pkt) -> + ejabberd_hooks:run_fold(s2s_in_authenticated_packet, LServer, State, [Pkt]); +handle_authenticated_packet(Pkt, State) -> From = xmpp:get_from(Pkt), To = xmpp:get_to(Pkt), - if To /= undefined, From /= undefined -> - LFrom = From#jid.lserver, - LTo = To#jid.lserver, - if StateData#state.authenticated -> - case LFrom == StateData#state.auth_domain andalso - lists:member(LTo, ejabberd_router:dirty_get_all_domains()) of - true -> - ejabberd_hooks:run(s2s_receive_packet, LTo, - [From, To, Pkt]), - ejabberd_router:route(From, To, Pkt); - false -> - send_error(StateData, Pkt, xmpp:err_not_authorized()) - end; - true -> - case (?DICT):find({LFrom, LTo}, StateData#state.connections) of - {ok, established} -> - ejabberd_hooks:run(s2s_receive_packet, LTo, - [From, To, Pkt]), - ejabberd_router:route(From, To, Pkt); - _ -> - send_error(StateData, Pkt, xmpp:err_not_authorized()) - end - end; - true -> - send_error(StateData, Pkt, xmpp:err_jid_malformed()) - end, - ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]), - {next_state, stream_established, StateData}; -stream_established({valid, From, To}, StateData) -> - send_element(StateData, - #db_result{from = To, to = From, type = valid}), - ?INFO_MSG("Accepted s2s dialback authentication for ~s (TLS=~p)", - [From, StateData#state.tls_enabled]), - NSD = StateData#state{connections = - (?DICT):store({From, To}, established, - StateData#state.connections)}, - {next_state, stream_established, NSD}; -stream_established({invalid, From, To}, StateData) -> - send_element(StateData, - #db_result{from = To, to = From, type = invalid}), - NSD = StateData#state{connections = - (?DICT):erase({From, To}, - StateData#state.connections)}, - {next_state, stream_established, NSD}; -stream_established({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -stream_established({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -stream_established(timeout, StateData) -> - send_element(StateData, xmpp:serr_connection_timeout()), - {stop, normal, StateData}; -stream_established(closed, StateData) -> - {stop, normal, StateData}; -stream_established(Pkt, StateData) -> - ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]), - {next_state, stream_established, StateData}. - -%%---------------------------------------------------------------------- -%% Func: StateName/3 -%% Returns: {next_state, NextStateName, NextStateData} | -%% {next_state, NextStateName, NextStateData, Timeout} | -%% {reply, Reply, NextStateName, NextStateData} | -%% {reply, Reply, NextStateName, NextStateData, Timeout} | -%% {stop, Reason, NewStateData} | -%% {stop, Reason, Reply, NewStateData} -%%---------------------------------------------------------------------- -%state_name(Event, From, StateData) -> -% Reply = ok, -% {reply, Reply, state_name, StateData}. - -handle_event(_Event, StateName, StateData) -> - {next_state, StateName, StateData}. - -handle_sync_event(get_state_infos, _From, StateName, - StateData) -> - SockMod = StateData#state.sockmod, - {Addr, Port} = try - SockMod:peername(StateData#state.socket) - of - {ok, {A, P}} -> {A, P}; - {error, _} -> {unknown, unknown} - catch - _:_ -> {unknown, unknown} - end, - Domains = get_external_hosts(StateData), - Infos = [{direction, in}, {statename, StateName}, - {addr, Addr}, {port, Port}, - {streamid, StateData#state.streamid}, - {tls, StateData#state.tls}, - {tls_enabled, StateData#state.tls_enabled}, - {tls_options, StateData#state.tls_options}, - {authenticated, StateData#state.authenticated}, - {shaper, StateData#state.shaper}, {sockmod, SockMod}, - {domains, Domains}], - Reply = {state_infos, Infos}, - {reply, Reply, StateName, StateData}; -%%---------------------------------------------------------------------- -%% Func: handle_sync_event/4 -%% Returns: {next_state, NextStateName, NextStateData} | -%% {next_state, NextStateName, NextStateData, Timeout} | -%% {reply, Reply, NextStateName, NextStateData} | -%% {reply, Reply, NextStateName, NextStateData, Timeout} | -%% {stop, Reason, NewStateData} | -%% {stop, Reason, Reply, NewStateData} -%%---------------------------------------------------------------------- -handle_sync_event(_Event, _From, StateName, - StateData) -> - Reply = ok, {reply, Reply, StateName, StateData}. - -code_change(_OldVsn, StateName, StateData, _Extra) -> - {ok, StateName, StateData}. - -handle_info({send_text, Text}, StateName, StateData) -> - send_text(StateData, Text), - {next_state, StateName, StateData}; -handle_info({timeout, Timer, _}, StateName, - #state{timer = Timer} = StateData) -> - if StateName == wait_for_stream -> - send_header(StateData, undefined); - true -> - ok - end, - send_element(StateData, xmpp:serr_connection_timeout()), - {stop, normal, StateData}; -handle_info(_, StateName, StateData) -> - {next_state, StateName, StateData}. - -terminate(Reason, _StateName, StateData) -> - ?DEBUG("terminated: ~p", [Reason]), - case Reason of - {process_limit, _} -> - [ejabberd_s2s:external_host_overloaded(Host) - || Host <- get_external_hosts(StateData)]; - _ -> ok - end, - catch send_trailer(StateData), - (StateData#state.sockmod):close(StateData#state.socket), - ok. - -get_external_hosts(StateData) -> - case StateData#state.authenticated of - true -> [StateData#state.auth_domain]; - false -> - Connections = StateData#state.connections, - [D - || {{D, _}, established} <- dict:to_list(Connections)] + case check_from_to(From, To, State) of + ok -> + LServer = ejabberd_router:host_of_route(To#jid.lserver), + State1 = ejabberd_hooks:run_fold(s2s_in_authenticated_packet, + LServer, State, [Pkt]), + Pkt1 = ejabberd_hooks:run_fold(s2s_receive_packet, LServer, + Pkt, [State1]), + ejabberd_router:route(From, To, Pkt1), + State1; + {error, Err} -> + send(State, Err) end. -print_state(State) -> State. +handle_cdata(Data, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_handle_cdata, LServer, State, [Data]). + +handle_recv(El, Pkt, #{server_host := LServer} = State) -> + State1 = set_idle_timeout(State), + ejabberd_hooks:run_fold(s2s_in_handle_recv, LServer, State1, [El, Pkt]). + +handle_send(Pkt, Result, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_handle_send, LServer, + State, [Pkt, Result]). + +init([State, Opts]) -> + Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none), + TLSCompression = proplists:get_bool(tls_compression, Opts), + State1 = State#{tls_compression => TLSCompression, + auth_domains => sets:new(), + xmlns => ?NS_SERVER, + lang => ?MYLANG, + server => ?MYNAME, + lserver => ?MYNAME, + server_host => ?MYNAME, + established => false, + shaper => Shaper}, + ejabberd_hooks:run_fold(s2s_in_init, {ok, State1}, [Opts]). + +handle_call(Request, From, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_handle_call, LServer, State, [Request, From]). + +handle_cast({update_state, Fun}, State) -> + case Fun of + {M, F, A} -> erlang:apply(M, F, [State|A]); + _ when is_function(Fun) -> Fun(State) + end; +handle_cast(Msg, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_handle_cast, LServer, State, [Msg]). + +handle_info(Info, #{server_host := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_in_handle_info, LServer, State, [Info]). + +terminate(_Reason, _State) -> + ok. + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. -%%%---------------------------------------------------------------------- +%%%=================================================================== %%% Internal functions -%%%---------------------------------------------------------------------- - --spec send_text(state(), iodata()) -> ok. -send_text(StateData, Text) -> - (StateData#state.sockmod):send(StateData#state.socket, - Text). - --spec send_element(state(), xmpp_element()) -> ok. -send_element(StateData, El) -> - El1 = xmpp:encode(El, ?NS_SERVER), - send_text(StateData, fxml:element_to_binary(El1)). - --spec send_error(state(), xmlel() | stanza(), stanza_error()) -> ok. -send_error(StateData, Stanza, Error) -> - Type = xmpp:get_type(Stanza), - if Type == error; Type == result; - Type == <<"error">>; Type == <<"result">> -> - ok; - true -> - send_element(StateData, xmpp:make_error(Stanza, Error)) +%%%=================================================================== +-spec check_from_to(jid(), jid(), state()) -> ok | {error, stream_error()}. +check_from_to(From, To, State) -> + case check_from(From, State) of + true -> + case check_to(To, State) of + true -> + ok; + false -> + {error, xmpp:serr_improper_addressing()} + end; + false -> + {error, xmpp:serr_invalid_from()} end. --spec send_trailer(state()) -> ok. -send_trailer(StateData) -> - send_text(StateData, <<"">>). - --spec send_header(state(), undefined | {integer(), integer()}) -> ok. -send_header(StateData, Version) -> - Header = xmpp:encode( - #stream_start{xmlns = ?NS_SERVER, - stream_xmlns = ?NS_STREAM, - db_xmlns = ?NS_SERVER_DIALBACK, - id = StateData#state.streamid, - version = Version}), - send_text(StateData, fxml:element_to_header(Header)). - --spec change_shaper(state(), binary(), jid()) -> ok. -change_shaper(StateData, Host, JID) -> - Shaper = acl:match_rule(Host, StateData#state.shaper, - JID), - (StateData#state.sockmod):change_shaper(StateData#state.socket, - Shaper). - --spec new_id() -> binary(). -new_id() -> randoms:get_string(). - --spec cancel_timer(reference()) -> ok. -cancel_timer(Timer) -> - erlang:cancel_timer(Timer), - receive {timeout, Timer, _} -> ok after 0 -> ok end. - -fsm_limit_opts(Opts) -> - case lists:keysearch(max_fsm_queue, 1, Opts) of - {value, {_, N}} when is_integer(N) -> [{max_queue, N}]; - _ -> - case ejabberd_config:get_option( - max_fsm_queue, - fun(I) when is_integer(I), I > 0 -> I end) of - undefined -> []; - N -> [{max_queue, N}] - end - end. +-spec check_from(jid(), state()) -> boolean(). +check_from(#jid{lserver = S1}, #{auth_domains := AuthDomains}) -> + sets:is_element(S1, AuthDomains). + +-spec check_to(jid(), state()) -> boolean(). +check_to(#jid{lserver = LServer}, _State) -> + ejabberd_router:is_my_route(LServer). + +-spec set_idle_timeout(state()) -> state(). +set_idle_timeout(#{server_host := LServer, + established := true} = State) -> + Timeout = ejabberd_s2s:get_idle_timeout(LServer), + xmpp_stream_in:set_timeout(State, Timeout); +set_idle_timeout(State) -> + State. --spec decode_element(xmlel() | xmpp_element(), state_name(), state()) -> fsm_transition(). -decode_element(#xmlel{} = El, StateName, StateData) -> - Opts = if StateName == stream_established -> - [ignore_els]; - true -> - [] - end, - try xmpp:decode(El, ?NS_SERVER, Opts) of - Pkt -> ?MODULE:StateName(Pkt, StateData) - catch error:{xmpp_codec, Why} -> - case xmpp:is_stanza(El) of - true -> - Lang = xmpp:get_lang(El), - Txt = xmpp:format_error(Why), - send_error(StateData, El, xmpp:err_bad_request(Txt, Lang)); - false -> - ok - end, - {next_state, StateName, StateData} - end; -decode_element(Pkt, StateName, StateData) -> - ?MODULE:StateName(Pkt, StateData). - -opt_type(domain_certfile) -> fun iolist_to_binary/1; -opt_type(max_fsm_queue) -> - fun (I) when is_integer(I), I > 0 -> I end; -opt_type(s2s_certfile) -> fun iolist_to_binary/1; -opt_type(s2s_cafile) -> fun iolist_to_binary/1; -opt_type(s2s_ciphers) -> fun iolist_to_binary/1; -opt_type(s2s_dhfile) -> fun iolist_to_binary/1; -opt_type(s2s_protocol_options) -> - fun (Options) -> - [_ | O] = lists:foldl(fun (X, Acc) -> X ++ Acc end, [], - [["|" | binary_to_list(Opt)] - || Opt <- Options, is_binary(Opt)]), - iolist_to_binary(O) - end; -opt_type(s2s_tls_compression) -> - fun (true) -> true; - (false) -> false - end; -opt_type(s2s_use_starttls) -> - fun (false) -> false; - (true) -> true; - (optional) -> optional; - (required) -> required; - (required_trusted) -> required_trusted - end; opt_type(_) -> - [domain_certfile, max_fsm_queue, s2s_certfile, s2s_cafile, - s2s_ciphers, s2s_dhfile, s2s_protocol_options, - s2s_tls_compression, s2s_use_starttls]. + []. diff --git a/src/ejabberd_s2s_out.erl b/src/ejabberd_s2s_out.erl index b9ce47830..72d9dfea8 100644 --- a/src/ejabberd_s2s_out.erl +++ b/src/ejabberd_s2s_out.erl @@ -1,973 +1,313 @@ -%%%---------------------------------------------------------------------- -%%% File : ejabberd_s2s_out.erl -%%% Author : Alexey Shchepin -%%% Purpose : Manage outgoing server-to-server connections -%%% Created : 6 Dec 2002 by Alexey Shchepin +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2016, Evgeny Khramtsov +%%% @doc %%% -%%% -%%% ejabberd, Copyright (C) 2002-2016 ProcessOne -%%% -%%% This program is free software; you can redistribute it and/or -%%% modify it under the terms of the GNU General Public License as -%%% published by the Free Software Foundation; either version 2 of the -%%% License, or (at your option) any later version. -%%% -%%% This program is distributed in the hope that it will be useful, -%%% but WITHOUT ANY WARRANTY; without even the implied warranty of -%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -%%% General Public License for more details. -%%% -%%% You should have received a copy of the GNU General Public License along -%%% with this program; if not, write to the Free Software Foundation, Inc., -%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -%%% -%%%---------------------------------------------------------------------- - +%%% @end +%%% Created : 16 Dec 2016 by Evgeny Khramtsov +%%%------------------------------------------------------------------- -module(ejabberd_s2s_out). - +-behaviour(xmpp_stream_out). -behaviour(ejabberd_config). --author('alexey@process-one.net'). - --behaviour(p1_fsm). - -%% External exports --export([start/3, - start_link/3, - start_connection/1, - terminate_if_waiting_delay/2, - stop_connection/1, - transform_options/1]). - --export([init/1, open_socket/2, wait_for_stream/2, - wait_for_validation/2, wait_for_features/2, - wait_for_auth_result/2, wait_for_starttls_proceed/2, - relay_to_bridge/2, reopen_socket/2, wait_before_retry/2, - stream_established/2, handle_event/3, - handle_sync_event/4, handle_info/3, terminate/3, - print_state/1, code_change/4, test_get_addr_port/1, - get_addr_port/1, opt_type/1]). +%% ejabberd_config callbacks +-export([opt_type/1, transform_options/1]). +%% xmpp_stream_out callbacks +-export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, + handle_auth_success/2, handle_auth_failure/3, handle_packet/2, + handle_stream_end/2, handle_stream_close/2, + handle_recv/3, handle_send/4, handle_cdata/2, + handle_stream_established/1, handle_timeout/1]). +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). +%% Hooks +-export([process_auth_result/2, process_closed/2, handle_unexpected_info/2, + handle_unexpected_cast/2]). +%% API +-export([start/3, start_link/3, connect/1, close/1, stop/1, send/2, + route/2, establish/1, update_state/2, add_hooks/0]). -include("ejabberd.hrl"). --include("logger.hrl"). -include("xmpp.hrl"). +-include("logger.hrl"). --record(state, - {socket :: ejabberd_socket:socket_state(), - streamid = <<"">> :: binary(), - remote_streamid = <<"">> :: binary(), - use_v10 = true :: boolean(), - tls = false :: boolean(), - tls_required = false :: boolean(), - tls_certverify = false :: boolean(), - tls_enabled = false :: boolean(), - tls_options = [connect] :: list(), - authenticated = false :: boolean(), - db_enabled = true :: boolean(), - try_auth = true :: boolean(), - myname = <<"">> :: binary(), - server = <<"">> :: binary(), - queue = queue:new() :: ?TQUEUE, - delay_to_retry = undefined_delay :: undefined_delay | non_neg_integer(), - new = false :: boolean(), - verify = false :: false | {pid(), binary(), binary()}, - bridge :: {atom(), atom()}, - timer = make_ref() :: reference()}). - --type state_name() :: open_socket | wait_for_stream | - wait_for_validation | wait_for_features | - wait_for_auth_result | wait_for_starttls_proceed | - relay_to_bridge | reopen_socket | wait_before_retry | - stream_established. --type state() :: #state{}. --type fsm_stop() :: {stop, normal, state()}. --type fsm_next() :: {next_state, state_name(), state(), non_neg_integer()} | - {next_state, state_name(), state()}. --type fsm_transition() :: fsm_stop() | fsm_next(). - -%%-define(DBGFSM, true). - --ifdef(DBGFSM). - --define(FSMOPTS, [{debug, [trace]}]). - --else. - --define(FSMOPTS, []). - --endif. - --define(FSMTIMEOUT, 30000). - -%% We do not block on send anymore. --define(TCP_SEND_TIMEOUT, 15000). - -%% Maximum delay to wait before retrying to connect after a failed attempt. -%% Specified in miliseconds. Default value is 5 minutes. --define(MAX_RETRY_DELAY, 300000). - --define(SOCKET_DEFAULT_RESULT, {error, badarg}). +-type state() :: map(). +-export_type([state/0]). -%%%---------------------------------------------------------------------- +%%%=================================================================== %%% API -%%%---------------------------------------------------------------------- -start(From, Host, Type) -> - supervisor:start_child(ejabberd_s2s_out_sup, - [From, Host, Type]). - -start_link(From, Host, Type) -> - p1_fsm:start_link(ejabberd_s2s_out, [From, Host, Type], - fsm_limit_opts() ++ (?FSMOPTS)). - -start_connection(Pid) -> p1_fsm:send_event(Pid, init). - -stop_connection(Pid) -> p1_fsm:send_event(Pid, closed). - -%%%---------------------------------------------------------------------- -%%% Callback functions from p1_fsm -%%%---------------------------------------------------------------------- - -init([From, Server, Type]) -> - process_flag(trap_exit, true), - ?DEBUG("started: ~p", [{From, Server, Type}]), - {TLS, TLSRequired, TLSCertverify} = - case ejabberd_config:get_option( - s2s_use_starttls, - fun(true) -> true; - (false) -> false; - (optional) -> optional; - (required) -> required; - (required_trusted) -> required_trusted - end) - of - UseTls - when (UseTls == undefined) or (UseTls == false) -> - {false, false, false}; - UseTls - when (UseTls == true) or (UseTls == optional) -> - {true, false, false}; - required -> - {true, true, false}; - required_trusted -> - {true, true, true} - end, - UseV10 = TLS, - TLSOpts1 = case - ejabberd_config:get_option( - s2s_certfile, fun iolist_to_binary/1) - of - undefined -> [connect]; - CertFile -> [{certfile, CertFile}, connect] - end, - TLSOpts2 = case ejabberd_config:get_option( - s2s_ciphers, fun iolist_to_binary/1) of - undefined -> TLSOpts1; - Ciphers -> [{ciphers, Ciphers} | TLSOpts1] - end, - TLSOpts3 = case ejabberd_config:get_option( - s2s_protocol_options, - fun (Options) -> - [_|O] = lists:foldl( - fun(X, Acc) -> X ++ Acc end, [], - [["|" | binary_to_list(Opt)] || Opt <- Options, is_binary(Opt)] - ), - iolist_to_binary(O) - end) of - undefined -> TLSOpts2; - ProtocolOpts -> [{protocol_options, ProtocolOpts} | TLSOpts2] - end, - TLSOpts4 = case ejabberd_config:get_option( - s2s_dhfile, fun iolist_to_binary/1) of - undefined -> TLSOpts3; - DHFile -> [{dhfile, DHFile} | TLSOpts3] - end, - TLSOpts = case ejabberd_config:get_option( - {s2s_tls_compression, From}, - fun(true) -> true; - (false) -> false - end, false) of - false -> [compression_none | TLSOpts4]; - true -> TLSOpts4 - end, - {New, Verify} = case Type of - new -> {true, false}; - {verify, Pid, Key, SID} -> - start_connection(self()), {false, {Pid, Key, SID}} - end, - Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - {ok, open_socket, - #state{use_v10 = UseV10, tls = TLS, - tls_required = TLSRequired, tls_certverify = TLSCertverify, - tls_options = TLSOpts, queue = queue:new(), myname = From, - server = Server, new = New, verify = Verify, timer = Timer}}. - -open_socket(init, StateData) -> - log_s2s_out(StateData#state.new, StateData#state.myname, - StateData#state.server, StateData#state.tls), - ?DEBUG("open_socket: ~p", - [{StateData#state.myname, StateData#state.server, - StateData#state.new, StateData#state.verify}]), - AddrList = case - ejabberd_idna:domain_utf8_to_ascii(StateData#state.server) - of - false -> []; - ASCIIAddr -> get_addr_port(ASCIIAddr) +%%%=================================================================== +start(From, To, Opts) -> + xmpp_stream_out:start(?MODULE, [ejabberd_socket, From, To, Opts], + ejabberd_config:fsm_limit_opts([])). + +start_link(From, To, Opts) -> + xmpp_stream_out:start_link(?MODULE, [ejabberd_socket, From, To, Opts], + ejabberd_config:fsm_limit_opts([])). + +connect(Ref) -> + xmpp_stream_out:connect(Ref). + +close(Ref) -> + xmpp_stream_out:close(Ref). + +stop(Ref) -> + xmpp_stream_out:stop(Ref). + +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Stream, Pkt) -> + xmpp_stream_out:send(Stream, Pkt). + +-spec route(pid(), xmpp_element()) -> ok. +route(Ref, Pkt) -> + Ref ! {route, Pkt}. + +-spec establish(state()) -> state(). +establish(State) -> + xmpp_stream_out:establish(State). + +-spec update_state(pid(), fun((state()) -> state()) | + {module(), atom(), list()}) -> ok. +update_state(Ref, Callback) -> + xmpp_stream_out:cast(Ref, {update_state, Callback}). + +-spec add_hooks() -> ok. +add_hooks() -> + lists:foreach( + fun(Host) -> + ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE, + process_auth_result, 100), + ejabberd_hooks:add(s2s_out_closed, Host, ?MODULE, + process_closed, 100), + ejabberd_hooks:add(s2s_out_handle_info, Host, ?MODULE, + handle_unexpected_info, 100), + ejabberd_hooks:add(s2s_out_handle_cast, Host, ?MODULE, + handle_unexpected_cast, 100) + end, ?MYHOSTS). + +%%%=================================================================== +%%% Hooks +%%%=================================================================== +process_auth_result(#{server := LServer, remote_server := RServer} = State, + false) -> + Delay = get_delay(), + ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: authentication failed;" + " bouncing for ~p seconds", + [LServer, RServer, Delay]), + State1 = close(State), + State2 = bounce_queue(State1), + xmpp_stream_out:set_timeout(State2, timer:seconds(Delay)); +process_auth_result(State, true) -> + State. + +process_closed(#{server := LServer, remote_server := RServer} = State, + _Reason) -> + Delay = get_delay(), + ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s; " + "bouncing for ~p seconds", + [LServer, RServer, + try maps:get(stop_reason, State) of + {error, Why} -> xmpp_stream_out:format_error(Why) + catch _:undef -> <<"unexplained reason">> end, - case lists:foldl(fun ({Addr, Port}, Acc) -> - case Acc of - {ok, Socket} -> {ok, Socket}; - _ -> open_socket1(Addr, Port) - end - end, - ?SOCKET_DEFAULT_RESULT, AddrList) - of - {ok, Socket} -> - Version = if StateData#state.use_v10 -> {1,0}; - true -> undefined - end, - NewStateData = StateData#state{socket = Socket, - tls_enabled = false, - streamid = new_id()}, - send_header(NewStateData, Version), - {next_state, wait_for_stream, NewStateData, - ?FSMTIMEOUT}; - {error, Reason} -> - ?INFO_MSG("s2s connection: ~s -> ~s (remote server " - "not found: ~p)", - [StateData#state.myname, StateData#state.server, Reason]), - case ejabberd_hooks:run_fold(find_s2s_bridge, undefined, - [StateData#state.myname, - StateData#state.server]) - of - {Mod, Fun, Type} -> - ?INFO_MSG("found a bridge to ~s for: ~s -> ~s", - [Type, StateData#state.myname, - StateData#state.server]), - NewStateData = StateData#state{bridge = {Mod, Fun}}, - {next_state, relay_to_bridge, NewStateData}; - _ -> wait_before_reconnect(StateData) - end - end; -open_socket(Event, StateData) -> - handle_unexpected_event(Event, open_socket, StateData). - -open_socket1({_, _, _, _} = Addr, Port) -> - open_socket2(inet, Addr, Port); -%% IPv6 -open_socket1({_, _, _, _, _, _, _, _} = Addr, Port) -> - open_socket2(inet6, Addr, Port); -%% Hostname -open_socket1(Host, Port) -> - lists:foldl(fun (_Family, {ok, _Socket} = R) -> R; - (Family, _) -> - Addrs = get_addrs(Host, Family), - lists:foldl(fun (_Addr, {ok, _Socket} = R) -> R; - (Addr, _) -> open_socket1(Addr, Port) - end, - ?SOCKET_DEFAULT_RESULT, Addrs) - end, - ?SOCKET_DEFAULT_RESULT, outgoing_s2s_families()). - -open_socket2(Type, Addr, Port) -> - ?DEBUG("s2s_out: connecting to ~p:~p~n", [Addr, Port]), - Timeout = outgoing_s2s_timeout(), - case catch ejabberd_socket:connect(Addr, Port, - [binary, {packet, 0}, - {send_timeout, ?TCP_SEND_TIMEOUT}, - {send_timeout_close, true}, - {active, false}, Type], - Timeout) - of - {ok, _Socket} = R -> R; - {error, Reason} = R -> - ?DEBUG("s2s_out: connect return ~p~n", [Reason]), R; - {'EXIT', Reason} -> - ?DEBUG("s2s_out: connect crashed ~p~n", [Reason]), - {error, Reason} + Delay]), + State1 = bounce_queue(State), + xmpp_stream_out:set_timeout(State1, timer:seconds(Delay)). + +handle_unexpected_info(State, Info) -> + ?WARNING_MSG("got unexpected info: ~p", [Info]), + State. + +handle_unexpected_cast(State, Msg) -> + ?WARNING_MSG("got unexpected cast: ~p", [Msg]), + State. + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +tls_options(#{server := LServer}) -> + ejabberd_s2s:tls_options(LServer, []). + +tls_required(#{server := LServer}) -> + ejabberd_s2s:tls_required(LServer). + +tls_verify(#{server := LServer}) -> + ejabberd_s2s:tls_verify(LServer). + +tls_enabled(#{server := LServer}) -> + ejabberd_s2s:tls_enabled(LServer). + +handle_auth_success(Mech, #{socket := Socket, ip := IP, + remote_server := RServer, + server := LServer} = State) -> + ?INFO_MSG("(~s) Accepted outbound s2s ~s authentication ~s -> ~s (~s)", + [ejabberd_socket:pp(Socket), Mech, LServer, RServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State, [true]). + +handle_auth_failure(Mech, Reason, + #{socket := Socket, ip := IP, + remote_server := RServer, + server := LServer} = State) -> + ?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s", + [ejabberd_socket:pp(Socket), Mech, LServer, RServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), + State1 = State#{on_route => bounce, + stop_reason => {error, {auth, Reason}}}, + ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State1, [false]). + +handle_packet(Pkt, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]). + +handle_stream_end(Reason, #{server := LServer} = State) -> + State1 = State#{on_route => bounce, stop_reason => Reason}, + ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [normal]). + +handle_stream_close(Reason, #{server := LServer} = State) -> + State1 = State#{on_route => bounce, stop_reason => Reason}, + ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [Reason]). + +handle_stream_established(State) -> + State1 = State#{on_route => send}, + State2 = resend_queue(State1), + set_idle_timeout(State2). + +handle_cdata(Data, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_cdata, LServer, State, [Data]). + +handle_recv(El, Pkt, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_recv, LServer, State, [El, Pkt]). + +handle_send(Pkt, El, Data, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_send, LServer, + State, [Pkt, El, Data]). + +handle_timeout(#{server := LServer, remote_server := RServer, + on_route := Action} = State) -> + case Action of + bounce -> stop(State); + queue -> send(State, xmpp:serr_connection_timeout()); + send -> + ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: inactive", + [LServer, RServer]), + stop(State) end. -%%---------------------------------------------------------------------- - -wait_for_stream({xmlstreamstart, Name, Attrs}, StateData0) -> - {CertCheckRes, CertCheckMsg, StateData} = - if StateData0#state.tls_certverify, StateData0#state.tls_enabled -> - {Res, Msg} = - ejabberd_s2s:check_peer_certificate(ejabberd_socket, - StateData0#state.socket, - StateData0#state.server), - ?DEBUG("Certificate verification result for ~s: ~s", - [StateData0#state.server, Msg]), - {Res, Msg, StateData0#state{tls_certverify = false}}; - true -> - {no_verify, <<"Not verified">>, StateData0} - end, - try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of - _ when CertCheckRes == error -> - send_element(StateData, - xmpp:serr_policy_violation(CertCheckMsg, ?MYLANG)), - ?INFO_MSG("Closing s2s connection: ~s -> ~s (~s)", - [StateData#state.myname, StateData#state.server, - CertCheckMsg]), - {stop, normal, StateData}; - #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM} - when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM -> - send_element(StateData, xmpp:serr_invalid_namespace()), - {stop, normal, StateData}; - #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID, - version = V} when V /= {1,0} -> - send_db_request(StateData#state{remote_streamid = ID}); - #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID} - when StateData#state.use_v10 -> - {next_state, wait_for_features, - StateData#state{remote_streamid = ID}, ?FSMTIMEOUT}; - #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID} - when not StateData#state.use_v10 -> - %% Handle Tigase's workaround for an old ejabberd bug: - send_db_request(StateData#state{remote_streamid = ID}); - #stream_start{id = ID} when StateData#state.use_v10 -> - {next_state, wait_for_features, - StateData#state{db_enabled = false, remote_streamid = ID}, - ?FSMTIMEOUT}; - #stream_start{} -> - send_element(StateData, xmpp:serr_invalid_namespace()), - {stop, normal, StateData}; - _ -> - send_element(StateData, xmpp:serr_invalid_xml()), - {stop, normal, StateData} - catch _:{xmpp_codec, Why} -> - Txt = xmpp:format_error(Why), - send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)), - {stop, normal, StateData} - end; -wait_for_stream(Event, StateData) -> - handle_unexpected_event(Event, wait_for_stream, StateData). - -wait_for_validation({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_validation, StateData); -wait_for_validation(#db_result{to = To, from = From, type = Type}, StateData) -> - ?DEBUG("recv result: ~p", [{From, To, Type}]), - case {Type, StateData#state.tls_enabled, StateData#state.tls_required} of - {valid, Enabled, Required} when (Enabled == true) or (Required == false) -> - send_queue(StateData, StateData#state.queue), - ?INFO_MSG("Connection established: ~s -> ~s with " - "TLS=~p", - [StateData#state.myname, StateData#state.server, - StateData#state.tls_enabled]), - ejabberd_hooks:run(s2s_connect_hook, - [StateData#state.myname, - StateData#state.server]), - {next_state, stream_established, StateData#state{queue = queue:new()}}; - {valid, Enabled, Required} when (Enabled == false) and (Required == true) -> - ?INFO_MSG("Closing s2s connection: ~s -> ~s (TLS " - "is required but unavailable)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; - _ -> - ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid " - "dialback key result)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData} - end; -wait_for_validation(#db_verify{to = To, from = From, id = Id, type = Type}, - StateData) -> - ?DEBUG("recv verify: ~p", [{From, To, Id, Type}]), - case StateData#state.verify of - false -> - NextState = wait_for_validation, - {next_state, NextState, StateData, get_timeout_interval(NextState)}; - {Pid, _Key, _SID} -> - case Type of - valid -> - p1_fsm:send_event(Pid, - {valid, StateData#state.server, - StateData#state.myname}); - _ -> - p1_fsm:send_event(Pid, - {invalid, StateData#state.server, - StateData#state.myname}) - end, - if StateData#state.verify == false -> - {stop, normal, StateData}; - true -> - NextState = wait_for_validation, - {next_state, NextState, StateData, get_timeout_interval(NextState)} - end - end; -wait_for_validation(timeout, - #state{verify = {VPid, VKey, SID}} = StateData) - when is_pid(VPid) and is_binary(VKey) and is_binary(SID) -> - ?DEBUG("wait_for_validation: ~s -> ~s (timeout in verify connection)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -wait_for_validation(Event, StateData) -> - handle_unexpected_event(Event, wait_for_validation, StateData). - -wait_for_features({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_features, StateData); -wait_for_features(#stream_features{sub_els = Els}, StateData) -> - {SASLEXT, StartTLS, StartTLSRequired} = - lists:foldl( - fun(#sasl_mechanisms{list = Mechs}, {_, STLS, STLSReq}) -> - {lists:member(<<"EXTERNAL">>, Mechs), STLS, STLSReq}; - (#starttls{required = Required}, {SEXT, _, _}) -> - {SEXT, true, Required}; - (_, Acc) -> - Acc - end, {false, false, false}, Els), - if not SASLEXT and not StartTLS and StateData#state.authenticated -> - send_queue(StateData, StateData#state.queue), - ?INFO_MSG("Connection established: ~s -> ~s with " - "SASL EXTERNAL and TLS=~p", - [StateData#state.myname, StateData#state.server, - StateData#state.tls_enabled]), - ejabberd_hooks:run(s2s_connect_hook, - [StateData#state.myname, - StateData#state.server]), - {next_state, stream_established, - StateData#state{queue = queue:new()}}; - SASLEXT and StateData#state.try_auth and - (StateData#state.new /= false) and - (StateData#state.tls_enabled or - not StateData#state.tls_required) -> - send_element(StateData, - #sasl_auth{mechanism = <<"EXTERNAL">>, - text = StateData#state.myname}), - {next_state, wait_for_auth_result, - StateData#state{try_auth = false}, ?FSMTIMEOUT}; - StartTLS and StateData#state.tls and - not StateData#state.tls_enabled -> - send_element(StateData, #starttls{}), - {next_state, wait_for_starttls_proceed, StateData, ?FSMTIMEOUT}; - StartTLSRequired and not StateData#state.tls -> - ?DEBUG("restarted: ~p", - [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:close(StateData#state.socket), - {next_state, reopen_socket, - StateData#state{socket = undefined, use_v10 = false}, - ?FSMTIMEOUT}; - StateData#state.db_enabled -> - send_db_request(StateData); - true -> - ?DEBUG("restarted: ~p", - [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:close(StateData#state.socket), - {next_state, reopen_socket, - StateData#state{socket = undefined, use_v10 = false}, ?FSMTIMEOUT} +init([#{server := LServer, remote_server := RServer} = State, Opts]) -> + State1 = State#{on_route => queue, + queue => queue:new(), + xmlns => ?NS_SERVER, + lang => ?MYLANG, + shaper => none}, + ?INFO_MSG("Outbound s2s connection started: ~s -> ~s", + [LServer, RServer]), + ejabberd_hooks:run_fold(s2s_out_init, LServer, {ok, State1}, [Opts]). + +handle_call(Request, From, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_call, LServer, State, [Request, From]). + +handle_cast({update_state, Fun}, State) -> + case Fun of + {M, F, A} -> erlang:apply(M, F, [State|A]); + _ when is_function(Fun) -> Fun(State) end; -wait_for_features(Event, StateData) -> - handle_unexpected_event(Event, wait_for_features, StateData). - -wait_for_auth_result({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_auth_result, StateData); -wait_for_auth_result(#sasl_success{}, StateData) -> - ?DEBUG("auth: ~p", [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:reset_stream(StateData#state.socket), - send_header(StateData, {1,0}), - {next_state, wait_for_stream, - StateData#state{streamid = new_id(), authenticated = true}, - ?FSMTIMEOUT}; -wait_for_auth_result(#sasl_failure{}, StateData) -> - ?DEBUG("restarted: ~p", [{StateData#state.myname, StateData#state.server}]), - ejabberd_socket:close(StateData#state.socket), - {next_state, reopen_socket, - StateData#state{socket = undefined}, ?FSMTIMEOUT}; -wait_for_auth_result(Event, StateData) -> - handle_unexpected_event(Event, wait_for_auth_result, StateData). - -wait_for_starttls_proceed({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_starttls_proceed, StateData); -wait_for_starttls_proceed(#starttls_proceed{}, StateData) -> - ?DEBUG("starttls: ~p", [{StateData#state.myname, StateData#state.server}]), - Socket = StateData#state.socket, - TLSOpts = case ejabberd_config:get_option( - {domain_certfile, StateData#state.myname}, - fun iolist_to_binary/1) of - undefined -> StateData#state.tls_options; - CertFile -> - [{certfile, CertFile} - | lists:keydelete(certfile, 1, - StateData#state.tls_options)] - end, - TLSSocket = ejabberd_socket:starttls(Socket, TLSOpts), - NewStateData = StateData#state{socket = TLSSocket, - streamid = new_id(), - tls_enabled = true, - tls_options = TLSOpts}, - send_header(NewStateData, {1,0}), - {next_state, wait_for_stream, NewStateData, ?FSMTIMEOUT}; -wait_for_starttls_proceed(Event, StateData) -> - handle_unexpected_event(Event, wait_for_starttls_proceed, StateData). - -reopen_socket({xmlstreamelement, _El}, StateData) -> - {next_state, reopen_socket, StateData, ?FSMTIMEOUT}; -reopen_socket({xmlstreamend, _Name}, StateData) -> - {next_state, reopen_socket, StateData, ?FSMTIMEOUT}; -reopen_socket({xmlstreamerror, _}, StateData) -> - {next_state, reopen_socket, StateData, ?FSMTIMEOUT}; -reopen_socket(timeout, StateData) -> - ?INFO_MSG("reopen socket: timeout", []), - {stop, normal, StateData}; -reopen_socket(closed, StateData) -> - p1_fsm:send_event(self(), init), - {next_state, open_socket, StateData, ?FSMTIMEOUT}. - -%% This state is use to avoid reconnecting to often to bad sockets -wait_before_retry(_Event, StateData) -> - {next_state, wait_before_retry, StateData, ?FSMTIMEOUT}. - -relay_to_bridge(stop, StateData) -> - wait_before_reconnect(StateData); -relay_to_bridge(closed, StateData) -> - ?INFO_MSG("relay to bridge: ~s -> ~s (closed)", - [StateData#state.myname, StateData#state.server]), - {stop, normal, StateData}; -relay_to_bridge(_Event, StateData) -> - {next_state, relay_to_bridge, StateData}. - -stream_established({xmlstreamelement, El}, StateData) -> - decode_element(El, stream_established, StateData); -stream_established(#db_verify{to = VTo, from = VFrom, id = VId, type = VType}, - StateData) -> - ?DEBUG("recv verify: ~p", [{VFrom, VTo, VId, VType}]), - case StateData#state.verify of - {VPid, _VKey, _SID} -> - case VType of - valid -> - p1_fsm:send_event(VPid, - {valid, StateData#state.server, - StateData#state.myname}); - _ -> - p1_fsm:send_event(VPid, - {invalid, StateData#state.server, - StateData#state.myname}) - end; - _ -> ok - end, - {next_state, stream_established, StateData}; -stream_established(Event, StateData) -> - handle_unexpected_event(Event, stream_established, StateData). - --spec handle_unexpected_event(term(), state_name(), state()) -> fsm_transition(). -handle_unexpected_event(Event, StateName, StateData) -> - case Event of - {xmlstreamerror, _} -> - send_element(StateData, xmpp:serr_not_well_formed()), - ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " - "got invalid XML from peer", - [StateData#state.myname, StateData#state.server, - StateName]), - {stop, normal, StateData}; - {xmlstreamend, _} -> - ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " - "XML stream closed by peer", - [StateData#state.myname, StateData#state.server, - StateName]), - {stop, normal, StateData}; - timeout -> - send_element(StateData, xmpp:serr_connection_timeout()), - ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " - "timed out during establishing an XML stream", - [StateData#state.myname, StateData#state.server, - StateName]), - {stop, normal, StateData}; - closed -> - ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " - "connection socket closed", - [StateData#state.myname, StateData#state.server, - StateName]), - {stop, normal, StateData}; - Pkt when StateName == wait_for_stream; - StateName == wait_for_features; - StateName == wait_for_auth_result; - StateName == wait_for_starttls_proceed -> - send_element(StateData, xmpp:serr_bad_format()), - ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: " - "got unexpected event ~p", - [StateData#state.myname, StateData#state.server, - StateName, Pkt]), - {stop, normal, StateData}; - _ -> - {next_state, StateName, StateData, get_timeout_interval(StateName)} - end. - -%%---------------------------------------------------------------------- -%% Func: StateName/3 -%% Returns: {next_state, NextStateName, NextStateData} | -%% {next_state, NextStateName, NextStateData, Timeout} | -%% {reply, Reply, NextStateName, NextStateData} | -%% {reply, Reply, NextStateName, NextStateData, Timeout} | -%% {stop, Reason, NewStateData} | -%% {stop, Reason, Reply, NewStateData} -%%---------------------------------------------------------------------- -%%state_name(Event, From, StateData) -> -%% Reply = ok, -%% {reply, Reply, state_name, StateData}. - -handle_event(_Event, StateName, StateData) -> - {next_state, StateName, StateData, - get_timeout_interval(StateName)}. - -handle_sync_event(get_state_infos, _From, StateName, - StateData) -> - {Addr, Port} = try - ejabberd_socket:peername(StateData#state.socket) - of - {ok, {A, P}} -> {A, P}; - {error, _} -> {unknown, unknown} - catch - _:_ -> {unknown, unknown} - end, - Infos = [{direction, out}, {statename, StateName}, - {addr, Addr}, {port, Port}, - {streamid, StateData#state.streamid}, - {use_v10, StateData#state.use_v10}, - {tls, StateData#state.tls}, - {tls_required, StateData#state.tls_required}, - {tls_enabled, StateData#state.tls_enabled}, - {tls_options, StateData#state.tls_options}, - {authenticated, StateData#state.authenticated}, - {db_enabled, StateData#state.db_enabled}, - {try_auth, StateData#state.try_auth}, - {myname, StateData#state.myname}, - {server, StateData#state.server}, - {delay_to_retry, StateData#state.delay_to_retry}, - {verify, StateData#state.verify}], - Reply = {state_infos, Infos}, - {reply, Reply, StateName, StateData}; -%%---------------------------------------------------------------------- -%% Func: handle_sync_event/4 -%% Returns: {next_state, NextStateName, NextStateData} | -%% {next_state, NextStateName, NextStateData, Timeout} | -%% {reply, Reply, NextStateName, NextStateData} | -%% {reply, Reply, NextStateName, NextStateData, Timeout} | -%% {stop, Reason, NewStateData} | -%% {stop, Reason, Reply, NewStateData} -%%---------------------------------------------------------------------- -handle_sync_event(_Event, _From, StateName, - StateData) -> - Reply = ok, - {reply, Reply, StateName, StateData, - get_timeout_interval(StateName)}. - -code_change(_OldVsn, StateName, StateData, _Extra) -> - {ok, StateName, StateData}. - -handle_info({send_text, Text}, StateName, StateData) -> - send_text(StateData, Text), - cancel_timer(StateData#state.timer), - Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - {next_state, StateName, StateData#state{timer = Timer}, - get_timeout_interval(StateName)}; -handle_info({send_element, El}, StateName, StateData) -> - case StateName of - stream_established -> - cancel_timer(StateData#state.timer), - Timer = erlang:start_timer(?S2STIMEOUT, self(), []), - send_element(StateData, El), - {next_state, StateName, StateData#state{timer = Timer}}; - %% In this state we bounce all message: We are waiting before - %% trying to reconnect - wait_before_retry -> - bounce_element(El, xmpp:err_remote_server_not_found()), - {next_state, StateName, StateData}; - relay_to_bridge -> - {Mod, Fun} = StateData#state.bridge, - ?DEBUG("relaying stanza via ~p:~p/1", [Mod, Fun]), - case catch Mod:Fun(El) of - {'EXIT', Reason} -> - ?ERROR_MSG("Error while relaying to bridge: ~p", - [Reason]), - bounce_element(El, xmpp:err_internal_server_error()), - wait_before_reconnect(StateData); - _ -> {next_state, StateName, StateData} - end; - _ -> - Q = queue:in(El, StateData#state.queue), - {next_state, StateName, StateData#state{queue = Q}, - get_timeout_interval(StateName)} +handle_cast(Msg, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_cast, LServer, State, [Msg]). + +handle_info({route, Pkt}, #{queue := Q, on_route := Action} = State) -> + case Action of + queue -> State#{queue => queue:in(Pkt, Q)}; + bounce -> bounce_packet(Pkt, State); + send -> set_idle_timeout(send(State, Pkt)) end; -handle_info({timeout, Timer, _}, wait_before_retry, - #state{timer = Timer} = StateData) -> - ?INFO_MSG("Reconnect delay expired: Will now retry " - "to connect to ~s when needed.", - [StateData#state.server]), - {stop, normal, StateData}; -handle_info({timeout, Timer, _}, _StateName, - #state{timer = Timer} = StateData) -> - ?INFO_MSG("Closing connection with ~s: timeout", - [StateData#state.server]), - {stop, normal, StateData}; -handle_info(terminate_if_waiting_before_retry, - wait_before_retry, StateData) -> - {stop, normal, StateData}; -handle_info(terminate_if_waiting_before_retry, - StateName, StateData) -> - {next_state, StateName, StateData, - get_timeout_interval(StateName)}; -handle_info(_, StateName, StateData) -> - {next_state, StateName, StateData, - get_timeout_interval(StateName)}. - -terminate(Reason, StateName, StateData) -> - ?DEBUG("terminated: ~p", [{Reason, StateName}]), - case StateData#state.new of - false -> ok; - true -> - ejabberd_s2s:remove_connection({StateData#state.myname, - StateData#state.server}, - self()) - end, - bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()), - bounce_messages(xmpp:err_remote_server_not_found()), - case StateData#state.socket of - undefined -> ok; - _Socket -> - catch send_trailer(StateData), - ejabberd_socket:close(StateData#state.socket) - end, - ok. - -print_state(State) -> State. - -%%%---------------------------------------------------------------------- +handle_info(Info, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_handle_info, LServer, State, [Info]). + +terminate(Reason, #{server := LServer, + remote_server := RServer} = State) -> + ejabberd_s2s:remove_connection({LServer, RServer}, self()), + State1 = case Reason of + normal -> State; + _ -> State#{stop_reason => {error, internal_failure}} + end, + bounce_queue(State1), + bounce_message_queue(State1). + +code_change(_OldVsn, State, _Extra) -> + {ok, State}. + +%%%=================================================================== %%% Internal functions -%%%---------------------------------------------------------------------- - --spec send_text(state(), iodata()) -> ok. -send_text(StateData, Text) -> - ?DEBUG("Send Text on stream = ~s", [Text]), - ejabberd_socket:send(StateData#state.socket, Text). - --spec send_element(state(), xmpp_element()) -> ok. -send_element(StateData, El) -> - El1 = xmpp:encode(El, ?NS_SERVER), - send_text(StateData, fxml:element_to_binary(El1)). - --spec send_header(state(), undefined | {integer(), integer()}) -> ok. -send_header(StateData, Version) -> - Header = xmpp:encode( - #stream_start{xmlns = ?NS_SERVER, - stream_xmlns = ?NS_STREAM, - db_xmlns = ?NS_SERVER_DIALBACK, - from = jid:make(StateData#state.myname), - to = jid:make(StateData#state.server), - version = Version}), - send_text(StateData, fxml:element_to_header(Header)). - --spec send_trailer(state()) -> ok. -send_trailer(StateData) -> - send_text(StateData, <<"">>). - --spec send_queue(state(), queue:queue()) -> ok. -send_queue(StateData, Q) -> - case queue:out(Q) of - {{value, El}, Q1} -> - send_element(StateData, El), send_queue(StateData, Q1); - {empty, _Q1} -> ok - end. - -%% Bounce a single message (xmlelement) --spec bounce_element(stanza(), stanza_error()) -> ok. -bounce_element(El, Error) -> - From = xmpp:get_from(El), - To = xmpp:get_to(El), - ejabberd_router:route_error(To, From, El, Error). - --spec bounce_queue(queue:queue(), stanza_error()) -> ok. -bounce_queue(Q, Error) -> - case queue:out(Q) of - {{value, El}, Q1} -> - bounce_element(El, Error), bounce_queue(Q1, Error); - {empty, _} -> ok - end. - --spec new_id() -> binary(). -new_id() -> randoms:get_string(). - --spec cancel_timer(reference()) -> ok. -cancel_timer(Timer) -> - erlang:cancel_timer(Timer), - receive {timeout, Timer, _} -> ok after 0 -> ok end. - --spec bounce_messages(stanza_error()) -> ok. -bounce_messages(Error) -> +%%%=================================================================== +-spec resend_queue(state()) -> state(). +resend_queue(#{queue := Q} = State) -> + State1 = State#{queue => queue:new()}, + jlib:queue_foldl( + fun(Pkt, AccState) -> + send(AccState, Pkt) + end, State1, Q). + +-spec bounce_queue(state()) -> state(). +bounce_queue(#{queue := Q} = State) -> + State1 = State#{queue => queue:new()}, + jlib:queue_foldl( + fun(Pkt, AccState) -> + bounce_packet(Pkt, AccState) + end, State1, Q). + +-spec bounce_message_queue(state()) -> state(). +bounce_message_queue(State) -> receive - {send_element, El} -> - bounce_element(El, Error), bounce_messages(Error) - after 0 -> ok - end. - --spec send_db_request(state()) -> fsm_transition(). -send_db_request(StateData) -> - Server = StateData#state.server, - New = case StateData#state.new of - false -> - ejabberd_s2s:try_register({StateData#state.myname, Server}); - true -> - true - end, - NewStateData = StateData#state{new = New}, - try case New of - false -> ok; - true -> - Key1 = ejabberd_s2s:make_key( - {StateData#state.myname, Server}, - StateData#state.remote_streamid), - send_element(StateData, - #db_result{from = StateData#state.myname, - to = Server, - key = Key1}) - end, - case StateData#state.verify of - false -> ok; - {_Pid, Key2, SID} -> - send_element(StateData, - #db_verify{from = StateData#state.myname, - to = StateData#state.server, - id = SID, - key = Key2}) - end, - {next_state, wait_for_validation, NewStateData, - (?FSMTIMEOUT) * 6} - catch - _:_ -> {stop, normal, NewStateData} + {route, Pkt} -> + State1 = bounce_packet(Pkt, State), + bounce_message_queue(State1) + after 0 -> + State end. -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -%% SRV support - --include_lib("kernel/include/inet.hrl"). - --spec get_addr_port(binary()) -> [{binary(), inet:port_number()}]. -get_addr_port(Server) -> - Res = srv_lookup(Server), - case Res of - {error, Reason} -> - ?DEBUG("srv lookup of '~s' failed: ~p~n", - [Server, Reason]), - [{Server, outgoing_s2s_port()}]; - {ok, HEnt} -> - ?DEBUG("srv lookup of '~s': ~p~n", - [Server, HEnt#hostent.h_addr_list]), - AddrList = HEnt#hostent.h_addr_list, - case catch lists:map(fun ({Priority, Weight, Port, - Host}) -> - N = case Weight of - 0 -> 0; - _ -> - (Weight + 1) * randoms:uniform() - end, - {Priority * 65536 - N, Host, Port} - end, - AddrList) - of - SortedList = [_ | _] -> - List = lists:map(fun ({_, Host, Port}) -> - {list_to_binary(Host), Port} - end, - lists:keysort(1, SortedList)), - ?DEBUG("srv lookup of '~s': ~p~n", [Server, List]), - List; - _ -> [{Server, outgoing_s2s_port()}] - end - end. - -srv_lookup(Server) -> - TimeoutMs = timer:seconds( - ejabberd_config:get_option( - s2s_dns_timeout, - fun(I) when is_integer(I), I>=0 -> I end, - 10)), - Retries = ejabberd_config:get_option( - s2s_dns_retries, - fun(I) when is_integer(I), I>=0 -> I end, - 2), - srv_lookup(binary_to_list(Server), TimeoutMs, Retries). - -%% XXX - this behaviour is suboptimal in the case that the domain -%% has a "_xmpp-server._tcp." but not a "_jabber._tcp." record and -%% we don't get a DNS reply for the "_xmpp-server._tcp." lookup. In this -%% case we'll give up when we get the "_jabber._tcp." nxdomain reply. -srv_lookup(_Server, _Timeout, Retries) - when Retries < 1 -> - {error, timeout}; -srv_lookup(Server, Timeout, Retries) -> - case inet_res:getbyname("_xmpp-server._tcp." ++ Server, - srv, Timeout) - of - {error, _Reason} -> - case inet_res:getbyname("_jabber._tcp." ++ Server, srv, - Timeout) - of - {error, timeout} -> - ?ERROR_MSG("The DNS servers~n ~p~ntimed out on " - "request for ~p IN SRV. You should check " - "your DNS configuration.", - [inet_db:res_option(nameserver), Server]), - srv_lookup(Server, Timeout, Retries - 1); - R -> R - end; - {ok, _HEnt} = R -> R - end. - -test_get_addr_port(Server) -> - lists:foldl(fun (_, Acc) -> - [HostPort | _] = get_addr_port(Server), - case lists:keysearch(HostPort, 1, Acc) of - false -> [{HostPort, 1} | Acc]; - {value, {_, Num}} -> - lists:keyreplace(HostPort, 1, Acc, - {HostPort, Num + 1}) - end - end, - [], lists:seq(1, 100000)). - -get_addrs(Host, Family) -> - Type = case Family of - inet4 -> inet; - ipv4 -> inet; - inet6 -> inet6; - ipv6 -> inet6 - end, - case inet:gethostbyname(binary_to_list(Host), Type) of - {ok, #hostent{h_addr_list = Addrs}} -> - ?DEBUG("~s of ~s resolved to: ~p~n", - [Type, Host, Addrs]), - Addrs; - {error, Reason} -> - ?DEBUG("~s lookup of '~s' failed: ~p~n", - [Type, Host, Reason]), - [] +-spec bounce_packet(xmpp_element(), state()) -> state(). +bounce_packet(Pkt, State) when ?is_stanza(Pkt) -> + From = xmpp:get_from(Pkt), + To = xmpp:get_to(Pkt), + Lang = xmpp:get_lang(Pkt), + Err = mk_bounce_error(Lang, State), + ejabberd_router:route_error(To, From, Pkt, Err), + State; +bounce_packet(_, State) -> + State. + +-spec mk_bounce_error(binary(), state()) -> stanza_error(). +mk_bounce_error(Lang, State) -> + try maps:get(stop_reason, State) of + {error, internal_failure} -> + xmpp:err_internal_server_error(); + {error, Why} -> + Reason = xmpp_stream_out:format_error(Why), + case Why of + {dns, _} -> + xmpp:err_remote_server_timeout(Reason, Lang); + _ -> + xmpp:err_remote_server_not_found(Reason, Lang) + end + catch _:{badkey, _} -> + xmpp:err_remote_server_not_found() end. --spec outgoing_s2s_port() -> pos_integer(). -outgoing_s2s_port() -> - ejabberd_config:get_option( - outgoing_s2s_port, - fun(I) when is_integer(I), I > 0, I =< 65536 -> I end, - 5269). - --spec outgoing_s2s_families() -> [ipv4 | ipv6]. -outgoing_s2s_families() -> - ejabberd_config:get_option( - outgoing_s2s_families, - fun(Families) -> - true = lists:all( - fun(ipv4) -> true; - (ipv6) -> true - end, Families), - Families - end, [ipv4, ipv6]). - --spec outgoing_s2s_timeout() -> pos_integer(). -outgoing_s2s_timeout() -> - ejabberd_config:get_option( - outgoing_s2s_timeout, - fun(TimeOut) when is_integer(TimeOut), TimeOut > 0 -> - TimeOut; - (infinity) -> - infinity - end, 10000). +-spec get_delay() -> non_neg_integer(). +get_delay() -> + MaxDelay = ejabberd_config:get_option( + s2s_max_retry_delay, + fun(I) when is_integer(I), I > 0 -> I end, + 300), + crypto:rand_uniform(0, MaxDelay). + +-spec set_idle_timeout(state()) -> state(). +set_idle_timeout(#{on_route := send, server := LServer} = State) -> + Timeout = ejabberd_s2s:get_idle_timeout(LServer), + xmpp_stream_out:set_timeout(State, Timeout); +set_idle_timeout(State) -> + State. transform_options(Opts) -> lists:foldl(fun transform_options/2, [], Opts). @@ -998,100 +338,6 @@ transform_options({s2s_dns_options, S2SDNSOpts}, AllOpts) -> transform_options(Opt, Opts) -> [Opt|Opts]. -%% Human readable S2S logging: Log only new outgoing connections as INFO -%% Do not log dialback -log_s2s_out(false, _, _, _) -> ok; -%% Log new outgoing connections: -log_s2s_out(_, Myname, Server, Tls) -> - ?INFO_MSG("Trying to open s2s connection: ~s -> " - "~s with TLS=~p", - [Myname, Server, Tls]). - -%% Calculate timeout depending on which state we are in: -%% Can return integer > 0 | infinity --spec get_timeout_interval(state_name()) -> pos_integer() | infinity. -get_timeout_interval(StateName) -> - case StateName of - %% Validation implies dialback: Networking can take longer: - wait_for_validation -> (?FSMTIMEOUT) * 6; - %% When stream is established, we only rely on S2S Timeout timer: - stream_established -> infinity; - relay_to_bridge -> infinity; - open_socket -> infinity; - _ -> ?FSMTIMEOUT - end. - -%% This function is intended to be called at the end of a state -%% function that want to wait for a reconnect delay before stopping. --spec wait_before_reconnect(state()) -> fsm_next(). -wait_before_reconnect(StateData) -> - bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()), - bounce_messages(xmpp:err_remote_server_not_found()), - cancel_timer(StateData#state.timer), - Delay = case StateData#state.delay_to_retry of - undefined_delay -> - {_, _, MicroSecs} = p1_time_compat:timestamp(), MicroSecs rem 14000 + 1000; - D1 -> lists:min([D1 * 2, get_max_retry_delay()]) - end, - Timer = erlang:start_timer(Delay, self(), []), - {next_state, wait_before_retry, - StateData#state{timer = Timer, delay_to_retry = Delay, - queue = queue:new()}}. - --spec get_max_retry_delay() -> pos_integer(). -get_max_retry_delay() -> - case ejabberd_config:get_option( - s2s_max_retry_delay, - fun(I) when is_integer(I), I > 0 -> I end) of - undefined -> ?MAX_RETRY_DELAY; - Seconds -> Seconds * 1000 - end. - -%% Terminate s2s_out connections that are in state wait_before_retry --spec terminate_if_waiting_delay(binary(), binary()) -> ok. -terminate_if_waiting_delay(From, To) -> - FromTo = {From, To}, - Pids = ejabberd_s2s:get_connections_pids(FromTo), - lists:foreach(fun (Pid) -> - Pid ! terminate_if_waiting_before_retry - end, - Pids). - --spec fsm_limit_opts() -> [{max_queue, pos_integer()}]. -fsm_limit_opts() -> - case ejabberd_config:get_option( - max_fsm_queue, - fun(I) when is_integer(I), I > 0 -> I end) of - undefined -> []; - N -> [{max_queue, N}] - end. - --spec decode_element(xmlel(), state_name(), state()) -> fsm_next(). -decode_element(#xmlel{} = El, StateName, StateData) -> - Opts = if StateName == stream_established -> - [ignore_els]; - true -> - [] - end, - try xmpp:decode(El, ?NS_SERVER, Opts) of - Pkt -> ?MODULE:StateName(Pkt, StateData) - catch error:{xmpp_codec, Why} -> - Type = xmpp:get_type(El), - case xmpp:is_stanza(El) of - true when Type /= <<"result">>, Type /= <<"error">> -> - Lang = xmpp:get_lang(El), - Txt = xmpp:format_error(Why), - Err = xmpp:make_error(El, xmpp:err_bad_request(Txt, Lang)), - send_element(StateData, Err); - false -> - ok - end, - {next_state, StateName, StateData, get_timeout_interval(StateName)} - end. - -opt_type(domain_certfile) -> fun iolist_to_binary/1; -opt_type(max_fsm_queue) -> - fun (I) when is_integer(I), I > 0 -> I end; opt_type(outgoing_s2s_families) -> fun (Families) -> true = lists:all(fun (ipv4) -> true; @@ -1107,36 +353,12 @@ opt_type(outgoing_s2s_timeout) -> TimeOut; (infinity) -> infinity end; -opt_type(s2s_certfile) -> fun iolist_to_binary/1; -opt_type(s2s_ciphers) -> fun iolist_to_binary/1; -opt_type(s2s_dhfile) -> fun iolist_to_binary/1; opt_type(s2s_dns_retries) -> fun (I) when is_integer(I), I >= 0 -> I end; opt_type(s2s_dns_timeout) -> fun (I) when is_integer(I), I >= 0 -> I end; opt_type(s2s_max_retry_delay) -> fun (I) when is_integer(I), I > 0 -> I end; -opt_type(s2s_protocol_options) -> - fun (Options) -> - [_ | O] = lists:foldl(fun (X, Acc) -> X ++ Acc end, [], - [["|" | binary_to_list(Opt)] - || Opt <- Options, is_binary(Opt)]), - iolist_to_binary(O) - end; -opt_type(s2s_tls_compression) -> - fun (true) -> true; - (false) -> false - end; -opt_type(s2s_use_starttls) -> - fun (true) -> true; - (false) -> false; - (optional) -> optional; - (required) -> required; - (required_trusted) -> required_trusted - end; opt_type(_) -> - [domain_certfile, max_fsm_queue, outgoing_s2s_families, - outgoing_s2s_port, outgoing_s2s_timeout, s2s_certfile, - s2s_ciphers, s2s_dhfile, s2s_dns_retries, s2s_dns_timeout, - s2s_max_retry_delay, s2s_protocol_options, - s2s_tls_compression, s2s_use_starttls]. + [outgoing_s2s_families, outgoing_s2s_port, outgoing_s2s_timeout, + s2s_dns_retries, s2s_dns_timeout, s2s_max_retry_delay]. diff --git a/src/ejabberd_service.erl b/src/ejabberd_service.erl index c48cd536c..13efd15e7 100644 --- a/src/ejabberd_service.erl +++ b/src/ejabberd_service.erl @@ -22,17 +22,18 @@ -module(ejabberd_service). -behaviour(xmpp_stream_in). -behaviour(ejabberd_config). +-behaviour(ejabberd_socket). -protocol({xep, 114, '1.6'}). %% ejabberd_socket callbacks --export([start/2, socket_type/0]). +-export([start/2, start_link/2, socket_type/0]). %% ejabberd_config callbacks -export([opt_type/1, transform_listen_option/2]). %% xmpp_stream_in callbacks --export([init/1, handle_call/3, handle_cast/2, handle_info/2, - terminate/2, code_change/3]). --export([handshake/2, handle_stream_start/1, handle_authenticated_packet/2]). +-export([init/1, handle_info/2, terminate/2, code_change/3]). +-export([handle_stream_start/2, handle_auth_success/4, handle_auth_failure/4, + handle_authenticated_packet/2, get_password_fun/1]). %% API -export([send/2]). @@ -40,36 +41,32 @@ -include("xmpp.hrl"). -include("logger.hrl"). -%%-define(DBGFSM, true). --ifdef(DBGFSM). --define(FSMOPTS, [{debug, [trace]}]). --else. --define(FSMOPTS, []). --endif. - -type state() :: map(). --type next_state() :: {noreply, state()} | {stop, term(), state()}. --export_type([state/0, next_state/0]). +-export_type([state/0]). %%%=================================================================== %%% API %%%=================================================================== start(SockData, Opts) -> xmpp_stream_in:start(?MODULE, [SockData, Opts], - fsm_limit_opts(Opts) ++ ?FSMOPTS). + ejabberd_config:fsm_limit_opts(Opts)). + +start_link(SockData, Opts) -> + xmpp_stream_in:start_link(?MODULE, [SockData, Opts], + ejabberd_config:fsm_limit_opts(Opts)). socket_type() -> xml_stream. --spec send(state(), xmpp_element()) -> next_state(). -send(State, Pkt) -> - xmpp_stream_in:send(State, Pkt). +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Stream, Pkt) -> + xmpp_stream_in:send(Stream, Pkt). %%%=================================================================== %%% xmpp_stream_in callbacks %%%=================================================================== -init([#{socket := Socket} = State, Opts]) -> - ?INFO_MSG("(~w) External service connected", [Socket]), +init([State, Opts]) -> Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all), Shaper = gen_mod:get_opt(shaper_rule, Opts, fun acl:shaper_rules_validator/1, none), HostOpts = case lists:keyfind(hosts, 1, Opts) of @@ -96,66 +93,85 @@ init([#{socket := Socket} = State, Opts]) -> server => ?MYNAME, host_opts => HostOpts, check_from => CheckFrom}, - ejabberd_hooks:run_fold(component_init, {ok, State1}, []). + ejabberd_hooks:run_fold(component_init, {ok, State1}, [Opts]). -handle_stream_start(#{remote_server := RemoteServer, +handle_stream_start(_StreamStart, + #{remote_server := RemoteServer, + lang := Lang, host_opts := HostOpts} = State) -> - NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of - true -> - HostOpts; - false -> - case dict:find(global, HostOpts) of - {ok, GlobalPass} -> - dict:from_list([{RemoteServer, GlobalPass}]); - error -> - HostOpts - end - end, - {noreply, State#{host_opts => NewHostOpts}}. - -handshake(Digest, #{remote_server := RemoteServer, - stream_id := StreamID, - host_opts := HostOpts} = State) -> - case dict:find(RemoteServer, HostOpts) of - {ok, Password} -> - case p1_sha:sha(<>) of - Digest -> - lists:foreach( - fun (H) -> - ejabberd_router:register_route(H, ?MYNAME), - ?INFO_MSG("Route registered for service ~p~n", [H]), - ejabberd_hooks:run(component_connected, [H]) - end, dict:fetch_keys(HostOpts)), - {ok, State}; - _ -> - ?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]), - {error, xmpp:serr_not_authorized(), State} - end; - _ -> - ?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]), - {error, xmpp:serr_not_authorized(), State} + case lists:member(RemoteServer, ?MYHOSTS) of + true -> + Txt = <<"Unable to register route on existing local domain">>, + xmpp_stream_in:send(State, xmpp:serr_conflict(Txt, Lang)); + false -> + NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of + true -> + HostOpts; + false -> + case dict:find(global, HostOpts) of + {ok, GlobalPass} -> + dict:from_list([{RemoteServer, GlobalPass}]); + error -> + HostOpts + end + end, + State#{host_opts => NewHostOpts} + end. + +get_password_fun(#{remote_server := RemoteServer, + socket := Socket, + ip := IP, + host_opts := HostOpts}) -> + fun(_) -> + case dict:find(RemoteServer, HostOpts) of + {ok, Password} -> + {Password, undefined}; + error -> + ?ERROR_MSG("(~s) Domain ~s is unconfigured for " + "external component from ~s", + [ejabberd_socket:pp(Socket), RemoteServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + {false, undefined} + end end. +handle_auth_success(_, Mech, _, + #{remote_server := RemoteServer, host_opts := HostOpts, + socket := Socket, ip := IP} = State) -> + ?INFO_MSG("(~s) Accepted external component ~s authentication " + "for ~s from ~s", + [ejabberd_socket:pp(Socket), Mech, RemoteServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + lists:foreach( + fun (H) -> + ejabberd_router:register_route(H, ?MYNAME), + ejabberd_hooks:run(component_connected, [H]) + end, dict:fetch_keys(HostOpts)), + State. + +handle_auth_failure(_, Mech, Reason, + #{remote_server := RemoteServer, + socket := Socket, ip := IP} = State) -> + ?ERROR_MSG("(~s) Failed external component ~s authentication " + "for ~s from ~s: ~s", + [ejabberd_socket:pp(Socket), Mech, RemoteServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), + Reason]), + State. + handle_authenticated_packet(Pkt, #{lang := Lang} = State) -> From = xmpp:get_from(Pkt), case check_from(From, State) of true -> To = xmpp:get_to(Pkt), ejabberd_router:route(From, To, Pkt), - {noreply, State}; + State; false -> Txt = <<"Improper domain part of 'from' attribute">>, Err = xmpp:serr_invalid_from(Txt, Lang), xmpp_stream_in:send(State, Err) end. -handle_call(_Request, _From, State) -> - Reply = ok, - {reply, Reply, State}. - -handle_cast(_Msg, State) -> - {noreply, State}. - handle_info({route, From, To, Packet}, #{access := Access} = State) -> case acl:match_rule(global, Access, From) of allow -> @@ -165,16 +181,15 @@ handle_info({route, From, To, Packet}, #{access := Access} = State) -> Lang = xmpp:get_lang(Packet), Err = xmpp:err_not_allowed(<<"Denied by ACL">>, Lang), ejabberd_router:route_error(To, From, Packet, Err), - {noreply, State} + State end; handle_info(Info, State) -> ?ERROR_MSG("Unexpected info: ~p", [Info]), - {noreply, State}. + State. terminate(Reason, #{stream_state := StreamState, host_opts := HostOpts}) -> - ?INFO_MSG("External service disconnected: ~p", [Reason]), case StreamState of - session_established -> + established -> lists:foreach( fun(H) -> ejabberd_router:unregister_route(H), @@ -220,19 +235,4 @@ transform_listen_option({host, Host, Os}, Opts) -> transform_listen_option(Opt, Opts) -> [Opt|Opts]. -fsm_limit_opts(Opts) -> - case lists:keysearch(max_fsm_queue, 1, Opts) of - {value, {_, N}} when is_integer(N) -> - [{max_queue, N}]; - _ -> - case ejabberd_config:get_option( - max_fsm_queue, - fun(I) when is_integer(I), I > 0 -> I end) of - undefined -> []; - N -> [{max_queue, N}] - end - end. - -opt_type(max_fsm_queue) -> - fun (I) when is_integer(I), I > 0 -> I end; -opt_type(_) -> [max_fsm_queue]. +opt_type(_) -> []. diff --git a/src/ejabberd_sm.erl b/src/ejabberd_sm.erl index 11b829a94..46008bec4 100644 --- a/src/ejabberd_sm.erl +++ b/src/ejabberd_sm.erl @@ -359,20 +359,20 @@ unregister_iq_handler(Host, XMLNS) -> ejabberd_sm ! {unregister_iq_handler, Host, XMLNS}. %% Why the hell do we have so many similar kicks? -c2s_handle_info({noreply, #{lang := Lang} = State}, replaced) -> +c2s_handle_info(#{lang := Lang} = State, replaced) -> State1 = State#{replaced => true}, Err = xmpp:serr_conflict(<<"Replaced by new connection">>, Lang), - ejabberd_c2s:send(State1, Err); -c2s_handle_info({noreply, #{lang := Lang} = State}, kick) -> + {stop, ejabberd_c2s:send(State1, Err)}; +c2s_handle_info(#{lang := Lang} = State, kick) -> Err = xmpp:serr_policy_violation(<<"has been kicked">>, Lang), - c2s_handle_info({noreply, State}, {kick, kicked_by_admin, Err}); -c2s_handle_info({noreply, State}, {kick, _Reason, Err}) -> - ejabberd_c2s:send(State, Err); -c2s_handle_info({noreply, #{lang := Lang} = State}, {exit, Reason}) -> + c2s_handle_info(State, {kick, kicked_by_admin, Err}); +c2s_handle_info(State, {kick, _Reason, Err}) -> + {stop, ejabberd_c2s:send(State, Err)}; +c2s_handle_info(#{lang := Lang} = State, {exit, Reason}) -> Err = xmpp:serr_conflict(Reason, Lang), - ejabberd_c2s:send(State, Err); -c2s_handle_info(Acc, _) -> - Acc. + {stop, ejabberd_c2s:send(State, Err)}; +c2s_handle_info(State, _) -> + State. %%==================================================================== %% gen_server callbacks diff --git a/src/ejabberd_socket.erl b/src/ejabberd_socket.erl index 3f01dae85..4e523a7e5 100644 --- a/src/ejabberd_socket.erl +++ b/src/ejabberd_socket.erl @@ -46,6 +46,7 @@ get_peer_certificate/1, get_verify_result/1, close/1, + pp/1, sockname/1, peername/1]). -include("ejabberd.hrl"). @@ -71,6 +72,11 @@ -export_type([socket/0, socket_state/0, sockmod/0]). +-callback start({module(), socket_state()}, + [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore. +-callback start_link({module(), socket_state()}, + [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore. +-callback socket_type() -> xml_stream | independent | raw. %%==================================================================== %% API @@ -109,7 +115,7 @@ start(Module, SockMod, Socket, Opts) -> {error, _Reason} -> SockMod:close(Socket) end, ReceiverMod:become_controller(Receiver, Pid); - {error, _Reason} -> + _ -> SockMod:close(Socket), case ReceiverMod of ejabberd_receiver -> ReceiverMod:close(Receiver); @@ -190,6 +196,7 @@ reset_stream(SocketData) -spec send(socket_state(), iodata()) -> ok. send(SocketData, Data) -> + ?DEBUG("Send XML on stream = ~p", [Data]), case catch (SocketData#socket_state.sockmod):send( SocketData#socket_state.socket, Data) of ok -> ok; @@ -238,8 +245,8 @@ get_transport(#socket_state{sockmod = SockMod, fast_tls -> tls; ezlib -> case ezlib:get_sockmod(Socket) of - tcp -> tcp_zlib; - tls -> tls_zlib + gen_tcp -> tcp_zlib; + fast_tls -> tls_zlib end; ejabberd_bosh -> http_bind; ejabberd_http_bind -> http_bind; @@ -268,3 +275,7 @@ peername(#socket_state{sockmod = SockMod, gen_tcp -> inet:peername(Socket); _ -> SockMod:peername(Socket) end. + +pp(#socket_state{receiver = Receiver} = State) -> + Transport = get_transport(State), + io_lib:format("~s|~w", [Transport, Receiver]). diff --git a/src/jlib.erl b/src/jlib.erl index 096ef4012..939baae84 100644 --- a/src/jlib.erl +++ b/src/jlib.erl @@ -38,8 +38,8 @@ -export([tolower/1, term_to_base64/1, base64_to_term/1, decode_base64/1, encode_base64/1, ip_to_list/1, atom_to_binary/1, binary_to_atom/1, tuple_to_binary/1, - l2i/1, i2l/1, i2l/2, queue_drop_while/2, - expr_to_term/1, term_to_expr/1]). + l2i/1, i2l/1, i2l/2, expr_to_term/1, term_to_expr/1, + queue_drop_while/2, queue_foldl/3, queue_foldr/3, queue_foreach/2]). %% The following functions are used by gen_iq_handler.erl for providing backward %% compatibility and must not be used in other parts of the code @@ -974,3 +974,33 @@ queue_drop_while(F, Q) -> empty -> Q end. + +-spec queue_foldl(fun((term(), T) -> T), T, ?TQUEUE) -> T. +queue_foldl(F, Acc, Q) -> + case queue:out(Q) of + {{value, Item}, Q1} -> + Acc1 = F(Item, Acc), + queue_foldl(F, Acc1, Q1); + {empty, _} -> + Acc + end. + +-spec queue_foldr(fun((term(), T) -> T), T, ?TQUEUE) -> T. +queue_foldr(F, Acc, Q) -> + case queue:out_r(Q) of + {{value, Item}, Q1} -> + Acc1 = F(Item, Acc), + queue_foldr(F, Acc1, Q1); + {empty, _} -> + Acc + end. + +-spec queue_foreach(fun((_) -> _), ?TQUEUE) -> ok. +queue_foreach(F, Q) -> + case queue:out(Q) of + {{value, Item}, Q1} -> + F(Item), + queue_foreach(F, Q1); + {empty, _} -> + ok + end. diff --git a/src/mod_blocking.erl b/src/mod_blocking.erl index 826a7bba3..45564daf4 100644 --- a/src/mod_blocking.erl +++ b/src/mod_blocking.erl @@ -54,8 +54,6 @@ start(Host, Opts) -> process_iq_set, 40), ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 40), - ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, - c2s_handle_info, 40), mod_disco:register_feature(Host, ?NS_BLOCKING), gen_iq_handler:add_iq_handler(ejabberd_sm, Host, ?NS_BLOCKING, ?MODULE, process_iq, IQDisc). @@ -65,6 +63,8 @@ stop(Host) -> process_iq_get, 40), ejabberd_hooks:delete(privacy_iq_set, Host, ?MODULE, process_iq_set, 40), + ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, + c2s_handle_info, 40), mod_disco:unregister_feature(Host, ?NS_BLOCKING), gen_iq_handler:remove_iq_handler(ejabberd_sm, Host, ?NS_BLOCKING). @@ -253,8 +253,8 @@ process_blocklist_get(LUser, LServer, Lang) -> {result, #block_list{items = Items}} end. --spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state(). -c2s_handle_info({noreply, #{user := U, server := S, resource := R} = State}, +-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state(). +c2s_handle_info(#{user := U, server := S, resource := R} = State, {blocking, Action}) -> SubEl = case Action of {block, JIDs} -> @@ -272,7 +272,9 @@ c2s_handle_info({noreply, #{user := U, server := S, resource := R} = State}, %% No need to replace active privacy list here, %% blocking pushes are always accompanied by %% Privacy List pushes - ejabberd_c2s:send(State, PushIQ). + {stop, ejabberd_c2s:send(State, PushIQ)}; +c2s_handle_info(State, _) -> + State. -spec db_mod(binary()) -> module(). db_mod(LServer) -> diff --git a/src/mod_legacy_auth.erl b/src/mod_legacy_auth.erl index 3e83680e5..f93b67e05 100644 --- a/src/mod_legacy_auth.erl +++ b/src/mod_legacy_auth.erl @@ -52,16 +52,16 @@ depends(_Host, _Opts) -> mod_opt_type(_) -> []. -c2s_unauthenticated_packet({noreply, State}, #iq{type = T, sub_els = [_]} = IQ) +c2s_unauthenticated_packet(State, #iq{type = T, sub_els = [_]} = IQ) when T == get; T == set -> case xmpp:get_subtag(IQ, #legacy_auth{}) of #legacy_auth{} = Auth -> {stop, authenticate(State, xmpp:set_els(IQ, [Auth]))}; false -> - {noreply, State} + State end; -c2s_unauthenticated_packet(Acc, _) -> - Acc. +c2s_unauthenticated_packet(State, _) -> + State. c2s_stream_features(Acc, LServer) -> case gen_mod:is_loaded(LServer, ?MODULE) of @@ -112,14 +112,10 @@ authenticate(#{stream_id := StreamID, server := Server, case ejabberd_auth:check_password_with_authmodule( U, U, JID#jid.lserver, P, D, DGen) of {true, AuthModule} -> - case ejabberd_c2s:handle_auth_success( - U, <<"legacy">>, AuthModule, State) of - {noreply, State1} -> - State2 = State1#{user := U}, - open_session(State2, IQ, R); - Err -> - Err - end; + State1 = ejabberd_c2s:handle_auth_success( + U, <<"legacy">>, AuthModule, State), + State2 = State1#{user := U}, + open_session(State2, IQ, R); _ -> Err = xmpp:make_error(IQ, xmpp:err_not_authorized()), process_auth_failure(State, U, Err, 'not-authorized') @@ -137,23 +133,13 @@ open_session(State, IQ, R) -> case ejabberd_c2s:bind(R, State) of {ok, State1} -> Res = xmpp:make_iq_result(IQ), - case ejabberd_c2s:send(State1, Res) of - {noreply, State2} -> - {noreply, State2#{stream_authenticated := true, - stream_state := session_established}}; - Err -> - Err - end; + State2 = ejabberd_c2s:send(State1, Res), + ejabberd_c2s:establish(State2); {error, Err, State1} -> Res = xmpp:make_error(IQ, Err), ejabberd_c2s:send(State1, Res) end. process_auth_failure(State, User, StanzaErr, Reason) -> - case ejabberd_c2s:send(State, StanzaErr) of - {noreply, State1} -> - ejabberd_c2s:handle_auth_failure( - User, <<"legacy">>, Reason, State1); - Err -> - Err - end. + State1 = ejabberd_c2s:send(State, StanzaErr), + ejabberd_c2s:handle_auth_failure(User, <<"legacy">>, Reason, State1). diff --git a/src/mod_offline.erl b/src/mod_offline.erl index e0e36a1da..8d58b14c9 100644 --- a/src/mod_offline.erl +++ b/src/mod_offline.erl @@ -309,11 +309,11 @@ get_info(_Acc, #jid{luser = U, lserver = S} = JID, get_info(Acc, _From, _To, _Node, _Lang) -> Acc. --spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state(). -c2s_handle_info({noreply, State}, {resend_offline, Flag}) -> - {noreply, State#{resend_offline => Flag}}; -c2s_handle_info(Acc, _) -> - Acc. +-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state(). +c2s_handle_info(State, {resend_offline, Flag}) -> + {stop, State#{resend_offline => Flag}}; +c2s_handle_info(State, _) -> + State. -spec handle_offline_query(iq()) -> iq(). handle_offline_query(#iq{from = #jid{luser = U1, lserver = S1}, diff --git a/src/mod_privacy.erl b/src/mod_privacy.erl index d76edc91d..b28bbcea2 100644 --- a/src/mod_privacy.erl +++ b/src/mod_privacy.erl @@ -535,8 +535,8 @@ remove_user(User, Server) -> Mod = gen_mod:db_mod(LServer, ?MODULE), Mod:remove_user(LUser, LServer). -c2s_handle_info({noreply, #{privacy_list := Old, - user := U, server := S, resource := R} = State}, +c2s_handle_info(#{privacy_list := Old, + user := U, server := S, resource := R} = State, {privacy_list, New, Name}) -> List = if Old#userlist.name == New#userlist.name -> New; true -> Old @@ -548,9 +548,9 @@ c2s_handle_info({noreply, #{privacy_list := Old, sub_els = [#privacy_query{ lists = [#privacy_list{name = Name}]}]}, State1 = State#{privacy_list => List}, - ejabberd_c2s:send(State1, PushIQ); -c2s_handle_info(Acc, _) -> - Acc. + {stop, ejabberd_c2s:send(State1, PushIQ)}; +c2s_handle_info(State, _) -> + State. -spec updated_list(userlist(), userlist(), userlist()) -> userlist(). updated_list(_, #userlist{name = OldName} = Old, diff --git a/src/mod_pubsub.erl b/src/mod_pubsub.erl index 98d50660c..8819e3a99 100644 --- a/src/mod_pubsub.erl +++ b/src/mod_pubsub.erl @@ -3026,8 +3026,8 @@ broadcast_stanza({LUser, LServer, LResource}, Publisher, Node, Nidx, Type, NodeO broadcast_stanza(Host, _Publisher, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM) -> broadcast_stanza(Host, Node, Nidx, Type, NodeOptions, SubsByDepth, NotifyType, BaseStanza, SHIM). --spec c2s_handle_info(ejabberd_c2s:next_state(), term()) -> ejabberd_c2s:next_state(). -c2s_handle_info({noreply, #{server := Server} = C2SState}, +-spec c2s_handle_info(ejabberd_c2s:state(), term()) -> ejabberd_c2s:state(). +c2s_handle_info(#{server := Server} = C2SState, {pep_message, Feature, From, Packet}) -> LServer = jid:nameprep(Server), lists:foreach( @@ -3042,8 +3042,8 @@ c2s_handle_info({noreply, #{server := Server} = C2SState}, ok end end, mod_caps:list_features(C2SState)), - {noreply, C2SState}; -c2s_handle_info({noreply, #{server := Server} = C2SState}, + {stop, C2SState}; +c2s_handle_info(#{server := Server} = C2SState, {send_filtered, {pep_message, Feature}, From, To, Packet}) -> LServer = jid:nameprep(Server), case mod_caps:get_user_caps(To, C2SState) of @@ -3059,9 +3059,9 @@ c2s_handle_info({noreply, #{server := Server} = C2SState}, error -> ok end, - {noreply, C2SState}; -c2s_handle_info(Acc, _) -> - Acc. + {stop, C2SState}; +c2s_handle_info(C2SState, _) -> + C2SState. subscribed_nodes_by_jid(NotifyType, SubsByDepth) -> NodesToDeliver = fun (Depth, Node, Subs, Acc) -> diff --git a/src/mod_register.erl b/src/mod_register.erl index 515cb1066..8917d4c5c 100644 --- a/src/mod_register.erl +++ b/src/mod_register.erl @@ -86,7 +86,7 @@ stream_feature_register(Acc, Host) -> Acc end. -c2s_unauthenticated_packet({noreply, #{ip := IP, server := Server} = State}, +c2s_unauthenticated_packet(#{ip := IP, server := Server} = State, #iq{type = T, sub_els = [_]} = IQ) when T == set; T == get -> case xmpp:get_subtag(IQ, #register{}) of @@ -97,10 +97,10 @@ c2s_unauthenticated_packet({noreply, #{ip := IP, server := Server} = State}, ResIQ1 = xmpp:set_from_to(ResIQ, jid:make(Server), undefined), {stop, ejabberd_c2s:send(State, ResIQ1)}; false -> - {noreply, State} + State end; -c2s_unauthenticated_packet(Acc, _) -> - Acc. +c2s_unauthenticated_packet(State, _) -> + State. process_iq(#iq{from = From} = IQ) -> process_iq(IQ, jid:tolower(From)). diff --git a/src/mod_roster.erl b/src/mod_roster.erl index a896ef055..5c207f3a4 100644 --- a/src/mod_roster.erl +++ b/src/mod_roster.erl @@ -464,10 +464,10 @@ push_item_version(Server, User, From, Item, end, ejabberd_sm:get_user_resources(User, Server)). -c2s_handle_info({noreply, State}, {item, JID, Sub}) -> - {noreply, roster_change(State, JID, Sub)}; -c2s_handle_info(Acc, _) -> - Acc. +c2s_handle_info(State, {item, JID, Sub}) -> + {stop, roster_change(State, JID, Sub)}; +c2s_handle_info(State, _) -> + State. -spec roster_change(ejabberd_c2s:state(), jid(), subscription()) -> ejabberd_c2s:state(). roster_change(#{user := U, server := S, resource := R} = State, diff --git a/src/mod_s2s_dialback.erl b/src/mod_s2s_dialback.erl new file mode 100644 index 000000000..ce9d2705b --- /dev/null +++ b/src/mod_s2s_dialback.erl @@ -0,0 +1,273 @@ +%%%------------------------------------------------------------------- +%%% Created : 16 Dec 2016 by Evgeny Khramtsov +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%------------------------------------------------------------------- +-module(mod_s2s_dialback). +-behaviour(gen_mod). + +-protocol({xep, 220, '1.1.1'}). +-protocol({xep, 185, '1.0'}). + +%% gen_mod API +-export([start/2, stop/1, depends/2, mod_opt_type/1]). +%% Hooks +-export([s2s_out_auth_result/2, s2s_in_packet/2, s2s_out_packet/2, + s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]). + +-include("ejabberd.hrl"). +-include("xmpp.hrl"). +-include("logger.hrl"). + +%%%=================================================================== +%%% API +%%%=================================================================== +start(Host, _Opts) -> + case ejabberd_s2s:tls_verify(Host) of + true -> + ?ERROR_MSG("disabling ~s for host ~s because option " + "'s2s_use_starttls' is set to 'required_trusted'", + [?MODULE, Host]); + false -> + ejabberd_hooks:add(s2s_out_init, Host, ?MODULE, s2s_out_init, 50), + ejabberd_hooks:add(s2s_out_closed, Host, ?MODULE, s2s_out_closed, 50), + ejabberd_hooks:add(s2s_in_pre_auth_features, Host, ?MODULE, + s2s_in_features, 50), + ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE, + s2s_in_features, 50), + ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE, + s2s_in_packet, 50), + ejabberd_hooks:add(s2s_in_authenticated_packet, Host, ?MODULE, + s2s_in_packet, 50), + ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE, + s2s_out_packet, 50), + ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE, + s2s_out_auth_result, 50) + end. + +stop(Host) -> + ejabberd_hooks:delete(s2s_out_init, Host, ?MODULE, s2s_out_init, 50), + ejabberd_hooks:delete(s2s_out_closed, Host, ?MODULE, s2s_out_closed, 50), + ejabberd_hooks:delete(s2s_in_pre_auth_features, Host, ?MODULE, + s2s_in_features, 50), + ejabberd_hooks:delete(s2s_in_post_auth_features, Host, ?MODULE, + s2s_in_features, 50), + ejabberd_hooks:delete(s2s_in_unauthenticated_packet, Host, ?MODULE, + s2s_in_packet, 50), + ejabberd_hooks:delete(s2s_in_authenticated_packet, Host, ?MODULE, + s2s_in_packet, 50), + ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE, + s2s_out_packet, 50), + ejabberd_hooks:delete(s2s_out_auth_result, Host, ?MODULE, + s2s_out_auth_result, 50). + +depends(_Host, _Opts) -> + []. + +mod_opt_type(_) -> + []. + +s2s_in_features(Acc, _) -> + [#db_feature{errors = true}|Acc]. + +s2s_out_init({ok, State}, Opts) -> + case proplists:get_value(db_verify, Opts) of + {StreamID, Key, Pid} -> + %% This is an outbound s2s connection created at step 1. + %% The purpose of this connection is to verify dialback key ONLY. + %% The connection is not registered in s2s table and thus is not + %% seen by anyone. + %% The connection will be closed immediately after receiving the + %% verification response (at step 3) + {ok, State#{db_verify => {StreamID, Key, Pid}}}; + undefined -> + {ok, State#{db_enabled => true}} + end; +s2s_out_init(Acc, _Opts) -> + Acc. + +s2s_out_closed(#{server := LServer, + remote_server := RServer, + db_verify := {StreamID, _Key, _Pid}} = State, _Reason) -> + %% Outbound s2s verificating connection (created at step 1) is + %% closed suddenly without receiving the response. + %% Building a response on our own + Response = #db_verify{from = RServer, to = LServer, + id = StreamID, type = error, + sub_els = [mk_error(internal_server_error)]}, + s2s_out_packet(State, Response); +s2s_out_closed(State, _Reason) -> + State. + +s2s_out_auth_result(#{server := LServer, + remote_server := RServer, + db_verify := {StreamID, Key, _Pid}} = State, + _) -> + %% The temporary outbound s2s connect (intended for verification) + %% has passed authentication state (either successfully or not, no matter) + %% and at this point we can send verification request as described + %% in section 2.1.2, step 2 + Request = #db_verify{from = LServer, to = RServer, + key = Key, id = StreamID}, + {stop, ejabberd_s2s_out:send(State, Request)}; +s2s_out_auth_result(#{db_enabled := true, + socket := Socket, ip := IP, + server := LServer, + remote_server := RServer, + stream_remote_id := StreamID} = State, false) -> + %% SASL authentication has failed, retrying with dialback + %% Sending dialback request, section 2.1.1, step 1 + ?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)", + [ejabberd_socket:pp(Socket), LServer, RServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + Key = make_key(LServer, RServer, StreamID), + State1 = maps:remove(stop_reason, State#{on_route => queue}), + State2 = ejabberd_s2s_out:send(State1, #db_result{from = LServer, + to = RServer, + key = Key}), + {stop, State2}; +s2s_out_auth_result(State, _) -> + State. + +s2s_in_packet(#{stream_id := StreamID} = State, + #db_result{from = From, to = To, key = Key, type = undefined}) -> + %% Received dialback request, section 2.2.1, step 1 + try + ok = check_from_to(From, To), + %% We're creating a temporary outbound s2s connection to + %% send verification request and to receive verification response + {ok, Pid} = ejabberd_s2s_out:start( + To, From, [{db_verify, {StreamID, Key, self()}}]), + ejabberd_s2s_out:connect(Pid), + State + catch _:{badmatch, {error, Reason}} -> + send_db_result(State, + #db_verify{from = From, to = To, type = error, + sub_els = [mk_error(Reason)]}) + end; +s2s_in_packet(State, #db_verify{to = To, from = From, key = Key, + id = StreamID, type = undefined}) -> + %% Received verification request, section 2.2.2, step 2 + Type = case make_key(To, From, StreamID) of + Key -> valid; + _ -> invalid + end, + Response = #db_verify{from = To, to = From, id = StreamID, type = Type}, + ejabberd_s2s_in:send(State, Response); +s2s_in_packet(State, Pkt) when is_record(Pkt, db_result); + is_record(Pkt, db_verify) -> + ?WARNING_MSG("Got stray dialback packet:~n~s", [xmpp:pp(Pkt)]), + State; +s2s_in_packet(State, _) -> + State. + +s2s_out_packet(#{server := LServer, + remote_server := RServer, + db_verify := {StreamID, _Key, Pid}} = State, + #db_verify{from = RServer, to = LServer, + id = StreamID, type = Type} = Response) + when Type /= undefined -> + %% Received verification response, section 2.1.2, step 3 + %% This is a response for the request sent at step 2 + ejabberd_s2s_in:update_state( + Pid, fun(S) -> send_db_result(S, Response) end), + %% At this point the connection is no longer needed and we can terminate it + ejabberd_s2s_out:stop(State); +s2s_out_packet(#{server := LServer, remote_server := RServer} = State, + #db_result{to = LServer, from = RServer, + type = Type} = Result) when Type /= undefined -> + %% Received dialback response, section 2.1.1, step 4 + %% This is a response to the request sent at step 1 + State1 = maps:remove(db_enabled, State), + case Type of + valid -> + State2 = ejabberd_s2s_out:handle_auth_success(<<"dialback">>, State1), + ejabberd_s2s_out:establish(State2); + _ -> + Reason = format_error(Result), + ejabberd_s2s_out:handle_auth_failure(<<"dialback">>, Reason, State1) + end; +s2s_out_packet(State, Pkt) when is_record(Pkt, db_result); + is_record(Pkt, db_verify) -> + ?WARNING_MSG("Got stray dialback packet:~n~s", [xmpp:pp(Pkt)]), + State; +s2s_out_packet(State, _) -> + State. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec make_key(binary(), binary(), binary()) -> binary(). +make_key(From, To, StreamID) -> + Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end), + p1_sha:to_hexlist( + crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)), + [To, " ", From, " ", StreamID])). + +-spec send_db_result(ejabberd_s2s_in:state(), db_verify()) -> ejabberd_s2s_in:state(). +send_db_result(State, #db_verify{from = From, to = To, + type = Type, sub_els = Els}) -> + %% Sending dialback response, section 2.2.1, step 4 + %% This is a response to the request received at step 1 + Response = #db_result{from = To, to = From, type = Type, sub_els = Els}, + State1 = ejabberd_s2s_in:send(State, Response), + case Type of + valid -> + State2 = ejabberd_s2s_in:handle_auth_success( + From, <<"dialback">>, undefined, State1), + ejabberd_s2s_in:establish(State2); + _ -> + Reason = format_error(Response), + ejabberd_s2s_in:handle_auth_failure( + From, <<"dialback">>, Reason, State1) + end. + +-spec check_from_to(binary(), binary()) -> ok | {error, forbidden | host_unknown}. +check_from_to(From, To) -> + case ejabberd_router:is_my_route(To) of + false -> {error, host_unknown}; + true -> + LServer = ejabberd_router:host_of_route(To), + case ejabberd_s2s:allow_host(LServer, From) of + true -> ok; + false -> {error, forbidden} + end + end. + +-spec mk_error(term()) -> stanza_error(). +mk_error(forbidden) -> + xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG); +mk_error(host_unknown) -> + xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG); +mk_error(_) -> + xmpp:err_internal_server_error(). + +-spec format_error(db_result()) -> binary(). +format_error(#db_result{type = invalid}) -> + <<"invalid dialback key">>; +format_error(#db_result{type = error, sub_els = Els}) -> + %% TODO: improve xmpp.erl + case xmpp:get_error(#message{sub_els = Els}) of + #stanza_error{reason = Reason} -> + erlang:atom_to_binary(Reason, latin1); + undefined -> + <<"unrecognized error">> + end; +format_error(_) -> + <<"unexpected dialback result">>. diff --git a/src/mod_sm.erl b/src/mod_sm.erl new file mode 100644 index 000000000..82d68702d --- /dev/null +++ b/src/mod_sm.erl @@ -0,0 +1,660 @@ +%%%------------------------------------------------------------------- +%%% Author : Holger Weiss +%%% Created : 25 Dec 2016 by Evgeny Khramtsov +%%% +%%% +%%% ejabberd, Copyright (C) 2002-2016 ProcessOne +%%% +%%% This program is free software; you can redistribute it and/or +%%% modify it under the terms of the GNU General Public License as +%%% published by the Free Software Foundation; either version 2 of the +%%% License, or (at your option) any later version. +%%% +%%% This program is distributed in the hope that it will be useful, +%%% but WITHOUT ANY WARRANTY; without even the implied warranty of +%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +%%% General Public License for more details. +%%% +%%% You should have received a copy of the GNU General Public License along +%%% with this program; if not, write to the Free Software Foundation, Inc., +%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +%%% +%%%------------------------------------------------------------------- +-module(mod_sm). +-behaviour(gen_mod). +-author('holger@zedat.fu-berlin.de'). +-protocol({xep, 198, '1.5.2'}). + +%% gen_mod API +-export([start/2, stop/1, depends/2, mod_opt_type/1]). +%% hooks +-export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2, + c2s_authenticated_packet/2, c2s_unauthenticated_packet/2, + c2s_unbinded_packet/2, c2s_closed/2, + c2s_handle_send/3, c2s_filter_send/2, c2s_handle_info/2]). + +-include("xmpp.hrl"). +-include("logger.hrl"). + +-define(is_sm_packet(Pkt), + is_record(Pkt, sm_enable) or + is_record(Pkt, sm_resume) or + is_record(Pkt, sm_a) or + is_record(Pkt, sm_r)). + +-type state() :: ejabberd_c2s:state(). +-type lqueue() :: {non_neg_integer(), queue:queue()}. + +%%%=================================================================== +%%% API +%%%=================================================================== +start(Host, _Opts) -> + ejabberd_hooks:add(c2s_init, ?MODULE, c2s_stream_init, 50), + ejabberd_hooks:add(c2s_stream_started, Host, ?MODULE, + c2s_stream_started, 50), + ejabberd_hooks:add(c2s_post_auth_features, Host, ?MODULE, + c2s_stream_features, 50), + ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE, + c2s_unauthenticated_packet, 50), + ejabberd_hooks:add(c2s_unbinded_packet, Host, ?MODULE, + c2s_unbinded_packet, 50), + ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE, + c2s_authenticated_packet, 50), + ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE, + c2s_handle_send, 50), + ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, + c2s_filter_send, 50), + ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, + c2s_handle_info, 50), + ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50). + +stop(Host) -> + %% TODO: do something with global 'c2s_init' hook + ejabberd_hooks:delete(c2s_stream_started, Host, ?MODULE, + c2s_stream_started, 50), + ejabberd_hooks:delete(c2s_post_auth_features, Host, ?MODULE, + c2s_stream_features, 50), + ejabberd_hooks:delete(c2s_unauthenticated_packet, Host, ?MODULE, + c2s_unauthenticated_packet, 50), + ejabberd_hooks:delete(c2s_unbinded_packet, Host, ?MODULE, + c2s_unbinded_packet, 50), + ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE, + c2s_authenticated_packet, 50), + ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE, + c2s_handle_send, 50), + ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, + c2s_filter_send, 50), + ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, + c2s_handle_info, 50), + ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50). + +depends(_Host, _Opts) -> + []. + +c2s_stream_init({ok, State}, Opts) -> + MgmtOpts = lists:filter( + fun({stream_management, _}) -> true; + ({max_ack_queue, _}) -> true; + ({resume_timeout, _}) -> true; + ({max_resume_timeout, _}) -> true; + ({ack_timeout, _}) -> true; + ({resend_on_timeout, _}) -> true; + (_) -> false + end, Opts), + {ok, State#{mgmt_options => MgmtOpts}}; +c2s_stream_init(Acc, _Opts) -> + Acc. + +c2s_stream_started(#{lserver := LServer, mgmt_options := Opts} = State, + _StreamStart) -> + State1 = maps:remove(mgmt_options, State), + ResumeTimeout = get_resume_timeout(LServer, Opts), + MaxResumeTimeout = get_max_resume_timeout(LServer, Opts, ResumeTimeout), + State1#{mgmt_state => inactive, + mgmt_max_queue => get_max_ack_queue(LServer, Opts), + mgmt_timeout => ResumeTimeout, + mgmt_max_timeout => MaxResumeTimeout, + mgmt_ack_timeout => get_ack_timeout(LServer, Opts), + mgmt_resend => get_resend_on_timeout(LServer, Opts)}; +c2s_stream_started(State, _StreamStart) -> + State. + +c2s_stream_features(Acc, Host) -> + case gen_mod:is_loaded(Host, ?MODULE) of + true -> + [#feature_sm{xmlns = ?NS_STREAM_MGMT_2}, + #feature_sm{xmlns = ?NS_STREAM_MGMT_3}|Acc]; + false -> + Acc + end. + +c2s_unauthenticated_packet(State, Pkt) when ?is_sm_packet(Pkt) -> + %% XEP-0198 says: "For client-to-server connections, the client MUST NOT + %% attempt to enable stream management until after it has completed Resource + %% Binding unless it is resuming a previous session". However, it also + %% says: "Stream management errors SHOULD be considered recoverable", so we + %% won't bail out. + Err = #sm_failed{reason = 'unexpected-request', xmlns = ?NS_STREAM_MGMT_3}, + {stop, send(State, Err)}; +c2s_unauthenticated_packet(State, _Pkt) -> + State. + +c2s_unbinded_packet(State, #sm_resume{} = Pkt) -> + case handle_resume(State, Pkt) of + {ok, ResumedState} -> + {stop, ResumedState}; + error -> + {stop, State} + end; +c2s_unbinded_packet(State, Pkt) when ?is_sm_packet(Pkt) -> + c2s_unauthenticated_packet(State, Pkt); +c2s_unbinded_packet(State, _Pkt) -> + State. + +c2s_authenticated_packet(#{mgmt_state := MgmtState} = State, Pkt) + when ?is_sm_packet(Pkt) -> + if MgmtState == pending; MgmtState == active -> + {stop, perform_stream_mgmt(Pkt, State)}; + true -> + {stop, negotiate_stream_mgmt(Pkt, State)} + end; +c2s_authenticated_packet(State, Pkt) -> + update_num_stanzas_in(State, Pkt). + +c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result) + when MgmtState == pending; MgmtState == active -> + State1 = mgmt_queue_add(State, Pkt), + case Result of + ok when ?is_stanza(Pkt) -> + send_ack(State1); + ok -> + State1; + {error, _} -> + transition_to_pending(State1) + end; +c2s_handle_send(State, _Pkt, _Result) -> + State. + +c2s_filter_send(Pkt, _State) -> + Pkt. + +c2s_handle_info(#{mgmt_ack_timer := T, jid := JID} = State, + {timeout, T, ack_timeout}) -> + ?DEBUG("Timeout waiting for stream management acknowledgement of ~s", + [jid:to_string(JID)]), + State1 = ejabberd_c2s:close(State, _SendTrailer = false), + c2s_closed(State1, ack_timeout); +c2s_handle_info(State, _) -> + State. + +c2s_closed(#{mgmt_state := active} = State, Reason) when Reason /= normal -> + {stop, transition_to_pending(State)}; +c2s_closed(State, _) -> + State. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec negotiate_stream_mgmt(xmpp_element(), state()) -> state(). +negotiate_stream_mgmt(Pkt, State) -> + Xmlns = xmpp:get_ns(Pkt), + case Pkt of + #sm_enable{} -> + handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt); + _ -> + Res = if is_record(Pkt, sm_a); + is_record(Pkt, sm_r); + is_record(Pkt, sm_resume) -> + #sm_failed{reason = 'unexpected-request', + xmlns = Xmlns}; + true -> + #sm_failed{reason = 'bad-request', + xmlns = Xmlns} + end, + send(State, Res) + end. + +-spec perform_stream_mgmt(xmpp_element(), state()) -> state(). +perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) -> + case xmpp:get_ns(Pkt) of + Xmlns -> + case Pkt of + #sm_r{} -> + handle_r(State); + #sm_a{} -> + handle_a(State, Pkt); + _ -> + Res = if is_record(Pkt, sm_enable); + is_record(Pkt, sm_resume) -> + #sm_failed{reason = 'unexpected-request', + xmlns = Xmlns}; + true -> + #sm_failed{reason = 'bad-request', + xmlns = Xmlns} + end, + send(State, Res) + end; + _ -> + send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns}) + end. + +-spec handle_enable(state(), sm_enable()) -> state(). +handle_enable(#{mgmt_timeout := DefaultTimeout, + mgmt_max_timeout := MaxTimeout, + xmlns := Xmlns, jid := JID} = State, + #sm_enable{resume = Resume, max = Max}) -> + Timeout = if Resume == false -> + 0; + Max /= undefined, Max > 0, Max =< MaxTimeout -> + Max; + true -> + DefaultTimeout + end, + Res = if Timeout > 0 -> + ?INFO_MSG("Stream management with resumption enabled for ~s", + [jid:to_string(JID)]), + #sm_enabled{xmlns = Xmlns, + id = make_resume_id(State), + resume = true, + max = Timeout}; + true -> + ?INFO_MSG("Stream management without resumption enabled for ~s", + [jid:to_string(JID)]), + #sm_enabled{xmlns = Xmlns} + end, + State1 = State#{mgmt_state => active, + mgmt_queue => queue_new(), + mgmt_timeout => Timeout * 1000}, + send(State1, Res). + +-spec handle_r(state()) -> state(). +handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) -> + Res = #sm_a{xmlns = Xmlns, h = H}, + send(State, Res). + +-spec handle_a(state(), sm_a()) -> state(). +handle_a(State, #sm_a{h = H}) -> + State1 = check_h_attribute(State, H), + resend_ack(State1). + +-spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}. +handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State, + #sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) -> + R = case inherit_session_state(State, PrevID) of + {ok, InheritedState} -> + {ok, InheritedState, H}; + {error, Err, InH} -> + {error, #sm_failed{reason = 'item-not-found', + h = InH, xmlns = Xmlns}, Err}; + {error, Err} -> + {error, #sm_failed{reason = 'item-not-found', + xmlns = Xmlns}, Err} + end, + case R of + {ok, ResumedState, NumHandled} -> + State1 = check_h_attribute(ResumedState, NumHandled), + #{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1, + AttrId = make_resume_id(State1), + State2 = send(State1, #sm_resumed{xmlns = AttrXmlns, + h = AttrH, + previd = AttrId}), + State3 = resend_unacked_stanzas(State2), + State4 = send(State3, #sm_r{xmlns = AttrXmlns}), + %% TODO: move this to mod_client_state + %% csi_flush_queue(State4), + State5 = ejabberd_hooks:run_fold(c2s_session_resumed, LServer, State4, []), + ?INFO_MSG("(~s) Resumed session for ~s", + [ejabberd_socket:pp(Socket), jid:to_string(JID)]), + {ok, State5}; + {error, El, Msg} -> + ?INFO_MSG("Cannot resume session for ~s: ~s", [jid:to_string(JID), Msg]), + {error, send(State, El)} + end. + +-spec transition_to_pending(state()) -> state(). +transition_to_pending(#{mgmt_state := active} = State) -> + %% TODO + State; +transition_to_pending(State) -> + State. + +-spec check_h_attribute(state(), non_neg_integer()) -> state(). +check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, jid := JID} = State, H) + when H > NumStanzasOut -> + ?DEBUG("~s acknowledged ~B stanzas, but only ~B were sent", + [jid:to_string(JID), H, NumStanzasOut]), + mgmt_queue_drop(State#{mgmt_stanzas_out => H}, NumStanzasOut); +check_h_attribute(#{mgmt_stanzas_out := NumStanzasOut, jid := JID} = State, H) -> + ?DEBUG("~s acknowledged ~B of ~B stanzas", + [jid:to_string(JID), H, NumStanzasOut]), + mgmt_queue_drop(State, H). + +-spec update_num_stanzas_in(state(), xmpp_element()) -> state(). +update_num_stanzas_in(#{mgmt_state := MgmtState, + mgmt_stanzas_in := NumStanzasIn} = State, El) + when MgmtState == active; MgmtState == pending -> + NewNum = case {xmpp:is_stanza(El), NumStanzasIn} of + {true, 4294967295} -> + 0; + {true, Num} -> + Num + 1; + {false, Num} -> + Num + end, + State#{mgmt_stanzas_in => NewNum}; +update_num_stanzas_in(State, _El) -> + State. + +send_ack(#{mgmt_ack_timer := _} = State) -> + State; +send_ack(#{mgmt_xmlns := Xmlns, + mgmt_stanzas_out := NumStanzasOut, + mgmt_ack_timeout := AckTimeout} = State) -> + State1 = send(State, #sm_r{xmlns = Xmlns}), + TRef = erlang:start_timer(AckTimeout, self(), ack_timeout), + State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}. + +resend_ack(#{mgmt_ack_timer := _, + mgmt_queue := Queue, + mgmt_stanzas_out := NumStanzasOut, + mgmt_stanzas_req := NumStanzasReq} = State) -> + State1 = cancel_ack_timer(State), + case NumStanzasReq < NumStanzasOut andalso not queue_is_empty(Queue) of + true -> send_ack(State1); + false -> State1 + end; +resend_ack(State) -> + State. + +-spec mgmt_queue_add(state(), xmpp_element()) -> state(). +mgmt_queue_add(#{mgmt_stanzas_out := NumStanzasOut, + mgmt_queue := Queue} = State, Stanza) when ?is_stanza(Stanza) -> + NewNum = case NumStanzasOut of + 4294967295 -> 0; + Num -> Num + 1 + end, + Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Stanza}, Queue), + State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum}, + check_queue_length(State1); +mgmt_queue_add(State, _Nonza) -> + State. + +-spec mgmt_queue_drop(state(), non_neg_integer()) -> state(). +mgmt_queue_drop(#{mgmt_queue := Queue} = State, NumHandled) -> + NewQueue = queue_dropwhile( + fun({N, _T, _E}) -> N =< NumHandled end, Queue), + State#{mgmt_queue => NewQueue}. + +-spec check_queue_length(state()) -> state(). +check_queue_length(#{mgmt_max_queue := Limit} = State) + when Limit == infinity; Limit == exceeded -> + State; +check_queue_length(#{mgmt_queue := Queue, mgmt_max_queue := Limit} = State) -> + case queue_len(Queue) > Limit of + true -> + State#{mgmt_max_queue => exceeded}; + false -> + State + end. + +-spec resend_unacked_stanzas(state()) -> state(). +resend_unacked_stanzas(#{mgmt_state := MgmtState, + mgmt_queue := {QueueLen, _} = Queue, + jid := JID} = State) + when (MgmtState == active orelse + MgmtState == pending orelse + MgmtState == timeout) andalso QueueLen > 0 -> + ?DEBUG("Resending ~B unacknowledged stanza(s) to ~s", + [QueueLen, jid:to_string(JID)]), + queue_foldl( + fun({_, Time, Pkt}, AccState) -> + NewPkt = add_resent_delay_info(AccState, Pkt, Time), + send(AccState, NewPkt) + end, State, Queue); +resend_unacked_stanzas(State) -> + State. + +-spec route_unacked_stanzas(state()) -> ok. +route_unacked_stanzas(#{mgmt_state := MgmtState, + mgmt_resend := MgmtResend, + lang := Lang, user := User, + jid := JID, lserver := LServer, + mgmt_queue := {QueueLen, _} = Queue, + resource := Resource} = State) + when (MgmtState == active orelse + MgmtState == pending orelse + MgmtState == timeout) andalso QueueLen > 0 -> + ResendOnTimeout = case MgmtResend of + Resend when is_boolean(Resend) -> + Resend; + if_offline -> + case ejabberd_sm:get_user_resources(User, Resource) of + [Resource] -> + %% Same resource opened new session + true; + [] -> true; + _ -> false + end + end, + ?DEBUG("Re-routing ~B unacknowledged stanza(s) to ~s", + [QueueLen, jid:to_string(JID)]), + queue_foreach( + fun({_, _Time, #presence{from = From}}) -> + ?DEBUG("Dropping presence stanza from ~s", [jid:to_string(From)]); + ({_, _Time, #iq{} = El}) -> + Txt = <<"User session terminated">>, + route_error(El, xmpp:err_service_unavailable(Txt, Lang)); + ({_, _Time, #message{from = From, meta = #{carbon_copy := true}}}) -> + %% XEP-0280 says: "When a receiving server attempts to deliver a + %% forked message, and that message bounces with an error for + %% any reason, the receiving server MUST NOT forward that error + %% back to the original sender." Resending such a stanza could + %% easily lead to unexpected results as well. + ?DEBUG("Dropping forwarded message stanza from ~s", + [jid:to_string(From)]); + ({_, Time, El}) -> + case ejabberd_hooks:run_fold(message_is_archived, + LServer, false, + [State, El]) of + true -> + ?DEBUG("Dropping archived message stanza from ~s", + [jid:to_string(xmpp:get_from(El))]); + false when ResendOnTimeout -> + NewEl = add_resent_delay_info(State, El, Time), + route(NewEl); + false -> + Txt = <<"User session terminated">>, + route_error(El, xmpp:err_service_unavailable(Txt, Lang)) + end + end, Queue); +route_unacked_stanzas(_State) -> + ok. + +-spec inherit_session_state(state(), binary()) -> {ok, state()} | + {error, binary()} | + {error, binary(), non_neg_integer()}. +inherit_session_state(#{user := U, server := S} = State, ResumeID) -> + case jlib:base64_to_term(ResumeID) of + {term, {R, Time}} -> + case ejabberd_sm:get_session_pid(U, S, R) of + none -> + case ejabberd_sm:get_offline_info(Time, U, S, R) of + none -> + {error, <<"Previous session PID not found">>}; + Info -> + case proplists:get_value(num_stanzas_in, Info) of + undefined -> + {error, <<"Previous session timed out">>}; + H -> + {error, <<"Previous session timed out">>, H} + end + end; + OldPID -> + OldSID = {Time, OldPID}, + try resume_session(OldSID, State) of + {resume, OldState} -> + State1 = ejabberd_c2s:copy_state(State, OldState), + State2 = ejabberd_c2s:open_session(State1), + {ok, State2}; + {error, Msg} -> + {error, Msg} + catch exit:{noproc, _} -> + {error, <<"Previous session PID is dead">>}; + exit:{timeout, _} -> + {error, <<"Session state copying timed out">>} + end + end; + _ -> + {error, <<"Invalid 'previd' value">>} + end. + +-spec resume_session({integer(), pid()}, state()) -> {resume, state()} | + {error, binary()}. +resume_session({Time, Pid}, _State) -> + ejabberd_c2s:call(Pid, {resume_session, Time}, timer:seconds(15)). + +-spec make_resume_id(state()) -> binary(). +make_resume_id(#{sid := {Time, _}, resource := Resource}) -> + jlib:term_to_base64({Resource, Time}). + +-spec add_resent_delay_info(state(), stanza(), erlang:timestamp()) -> stanza(). +add_resent_delay_info(_State, #iq{} = El, _Time) -> + El; +add_resent_delay_info(#{lserver := LServer}, El, Time) -> + xmpp_util:add_delay_info(El, jid:make(LServer), Time, <<"Resent">>). + +-spec route(stanza()) -> ok. +route(Pkt) -> + From = xmpp:get_from(Pkt), + To = xmpp:get_to(Pkt), + ejabberd_router:route(From, To, Pkt). + +-spec route_error(stanza(), stanza_error()) -> ok. +route_error(Pkt, Err) -> + From = xmpp:get_from(Pkt), + To = xmpp:get_to(Pkt), + ejabberd_router:route_error(To, From, Pkt, Err). + +-spec send(state(), xmpp_element()) -> state(). +send(#{mod := Mod} = State, Pkt) -> + Mod:send(State, Pkt). + +-spec queue_new() -> lqueue(). +queue_new() -> + {0, queue:new()}. + +-spec queue_in(term(), lqueue()) -> lqueue(). +queue_in(Elem, {N, Q}) -> + {N+1, queue:in(Elem, Q)}. + +-spec queue_len(lqueue()) -> non_neg_integer(). +queue_len({N, _}) -> + N. + +-spec queue_foldl(fun((term(), T) -> T), T, lqueue()) -> T. +queue_foldl(F, Acc, {_N, Q}) -> + jlib:queue_foldl(F, Acc, Q). + +-spec queue_foreach(fun((_) -> _), lqueue()) -> ok. +queue_foreach(F, {_N, Q}) -> + jlib:queue_foreach(F, Q). + +-spec queue_dropwhile(fun((term()) -> boolean()), lqueue()) -> lqueue(). +queue_dropwhile(F, {N, Q}) -> + case queue:peek(Q) of + {value, Item} -> + case F(Item) of + true -> + queue_dropwhile(F, {N-1, queue:drop(Q)}); + false -> + {N, Q} + end; + empty -> + {N, Q} + end. + +-spec queue_is_empty(lqueue()) -> boolean(). +queue_is_empty({N, _Q}) -> + N == 0. + +-spec cancel_ack_timer(state()) -> state(). +cancel_ack_timer(#{mgmt_ack_timer := TRef} = State) -> + case erlang:cancel_timer(TRef) of + false -> + receive {timeout, TRef, _} -> ok + after 0 -> ok + end; + _ -> + ok + end, + maps:remove(mgmt_ack_timer, State); +cancel_ack_timer(State) -> + State. + +%%%=================================================================== +%%% Configuration processing +%%%=================================================================== +get_max_ack_queue(Host, Opts) -> + VFun = mod_opt_type(max_ack_queue), + case gen_mod:get_module_opt(Host, ?MODULE, max_ack_queue, VFun) of + undefined -> gen_mod:get_opt(max_ack_queue, Opts, VFun, 1000); + Limit -> Limit + end. + +get_resume_timeout(Host, Opts) -> + VFun = mod_opt_type(resume_timeout), + case gen_mod:get_module_opt(Host, ?MODULE, resume_timeout, VFun) of + undefined -> gen_mod:get_opt(resume_timeout, Opts, VFun, 300); + Timeout -> Timeout + end. + +get_max_resume_timeout(Host, Opts, ResumeTimeout) -> + VFun = mod_opt_type(max_resume_timeout), + case gen_mod:get_module_opt(Host, ?MODULE, max_resume_timeout, VFun) of + undefined -> + case gen_mod:get_opt(max_resume_timeout, Opts, VFun) of + undefined -> ResumeTimeout; + Max when Max >= ResumeTimeout -> Max; + _ -> ResumeTimeout + end; + Max when Max >= ResumeTimeout -> Max; + _ -> ResumeTimeout + end. + +get_ack_timeout(Host, Opts) -> + VFun = mod_opt_type(ack_timeout), + T = case gen_mod:get_module_opt(Host, ?MODULE, ack_timeout, VFun) of + undefined -> gen_mod:get_opt(ack_timeout, Opts, VFun, 60); + AckTimeout -> AckTimeout + end, + case T of + infinity -> infinity; + _ -> timer:seconds(T) + end. + +get_resend_on_timeout(Host, Opts) -> + VFun = mod_opt_type(resend_on_timeout), + case gen_mod:get_module_opt(Host, ?MODULE, resend_on_timeout, VFun) of + undefined -> gen_mod:get_opt(resend_on_timeout, Opts, VFun, false); + Resend -> Resend + end. + +mod_opt_type(max_ack_queue) -> + fun(I) when is_integer(I), I > 0 -> I; + (infinity) -> infinity + end; +mod_opt_type(resume_timeout) -> + fun(I) when is_integer(I), I >= 0 -> I end; +mod_opt_type(max_resume_timeout) -> + fun(I) when is_integer(I), I >= 0 -> I end; +mod_opt_type(ack_timeout) -> + fun(I) when is_integer(I), I > 0 -> I; + (infinity) -> infinity + end; +mod_opt_type(resend_on_timeout) -> + fun(B) when is_boolean(B) -> B; + (if_offline) -> if_offline + end; +mod_opt_type(_) -> + [max_ack_queue, resume_timeout, max_resume_timeout, ack_timeout, + resend_on_timeout]. diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl index 1307f9da4..e9c1b3339 100644 --- a/src/xmpp_stream_in.erl +++ b/src/xmpp_stream_in.erl @@ -25,58 +25,81 @@ -protocol({rfc, 6120}). %% API --export([start/3, call/3, cast/2, reply/2, send/2, send_error/3, - get_transport/1, change_shaper/2]). +-export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1, + send/2, close/1, close/2, send_error/3, establish/1, + get_transport/1, change_shaper/2, set_timeout/2, format_error/1]). %% gen_server callbacks -export([init/1, handle_cast/2, handle_call/3, handle_info/2, terminate/2, code_change/3]). +%%-define(DBGFSM, true). +-ifdef(DBGFSM). +-define(FSMOPTS, [{debug, [trace]}]). +-else. +-define(FSMOPTS, []). +-endif. + -include("xmpp.hrl"). -type state() :: map(). --type next_state() :: {noreply, state()} | {stop, term(), state()}. +-type stop_reason() :: {stream, reset | stream_error()} | + {tls, term()} | + {socket, inet:posix() | closed | timeout}. -callback init(list()) -> {ok, state()} | {stop, term()} | ignore. --callback handle_stream_start(state()) -> next_state(). --callback handle_stream_end(state()) -> next_state(). --callback handle_stream_close(state()) -> next_state(). --callback handle_cdata(binary(), state()) -> next_state(). --callback handle_unauthenticated_packet(xmpp_element(), state()) -> next_state(). --callback handle_authenticated_packet(xmpp_element(), state()) -> next_state(). --callback handle_unbinded_packet(xmpp_element(), state()) -> next_state(). --callback handle_auth_success(binary(), binary(), module(), state()) -> next_state(). --callback handle_auth_failure(binary(), binary(), atom(), state()) -> next_state(). --callback handle_send(ok | {error, atom()}, - xmpp_element(), fxml:xmlel(), binary(), state()) -> next_state(). --callback init_sasl(state()) -> cyrsasl:sasl_state(). +-callback handle_cast(term(), state()) -> state(). +-callback handle_call(term(), term(), state()) -> state(). +-callback handle_info(term(), state()) -> state(). +-callback terminate(term(), state()) -> any(). +-callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. +-callback handle_stream_start(state()) -> state(). +-callback handle_stream_end(stop_reason(), state()) -> state(). +-callback handle_stream_close(stop_reason(), state()) -> state(). +-callback handle_cdata(binary(), state()) -> state(). +-callback handle_unauthenticated_packet(xmpp_element(), state()) -> state(). +-callback handle_authenticated_packet(xmpp_element(), state()) -> state(). +-callback handle_unbinded_packet(xmpp_element(), state()) -> state(). +-callback handle_auth_success(binary(), binary(), module(), state()) -> state(). +-callback handle_auth_failure(binary(), binary(), atom(), state()) -> state(). +-callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state(). +-callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state(). +-callback get_password_fun(state()) -> fun(). +-callback check_password_fun(state()) -> fun(). +-callback check_password_digest_fun(state()) -> fun(). -callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}. --callback handshake(binary(), state()) -> {ok, state()} | {error, stream_error(), state()}. -callback compress_methods(state()) -> [binary()]. -callback tls_options(state()) -> [proplists:property()]. -callback tls_required(state()) -> boolean(). --callback sasl_mechanisms(state()) -> [binary()]. +-callback tls_verify(state()) -> boolean(). -callback unauthenticated_stream_features(state()) -> [xmpp_element()]. -callback authenticated_stream_features(state()) -> [xmpp_element()]. %% All callbacks are optional -optional_callbacks([init/1, + handle_cast/2, + handle_call/3, + handle_info/2, + terminate/2, + code_change/3, handle_stream_start/1, - handle_stream_end/1, - handle_stream_close/1, + handle_stream_end/2, + handle_stream_close/2, handle_cdata/2, handle_authenticated_packet/2, handle_unauthenticated_packet/2, handle_unbinded_packet/2, handle_auth_success/4, handle_auth_failure/4, - handle_send/5, - init_sasl/1, + handle_send/3, + handle_recv/3, + get_password_fun/1, + check_password_fun/1, + check_password_digest_fun/1, bind/2, - handshake/2, compress_methods/1, tls_options/1, tls_required/1, - sasl_mechanisms/1, + tls_verify/1, unauthenticated_stream_features/1, authenticated_stream_features/1]). @@ -84,7 +107,10 @@ %%% API %%%=================================================================== start(Mod, Args, Opts) -> - gen_server:start(?MODULE, [Mod|Args], Opts). + gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + +start_link(Mod, Args, Opts) -> + gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). call(Ref, Msg, Timeout) -> gen_server:call(Ref, Msg, Timeout). @@ -95,16 +121,80 @@ cast(Ref, Msg) -> reply(Ref, Reply) -> gen_server:reply(Ref, Reply). --spec send(state(), xmpp_element()) -> next_state(). -send(State, Pkt) -> - send_element(State, Pkt). +-spec stop(pid()) -> ok; + (state()) -> no_return(). +stop(Pid) when is_pid(Pid) -> + cast(Pid, stop); +stop(#{owner := Owner} = State) when Owner == self() -> + terminate(normal, State), + exit(normal); +stop(_) -> + erlang:error(badarg). -get_transport(#{sockmod := SockMod, socket := Socket}) -> - SockMod:get_transport(Socket). +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Pid, Pkt) when is_pid(Pid) -> + cast(Pid, {send, Pkt}); +send(#{owner := Owner} = State, Pkt) when Owner == self() -> + send_element(State, Pkt); +send(_, _) -> + erlang:error(badarg). + +-spec close(pid()) -> ok; + (state()) -> state(). +close(Ref) -> + close(Ref, true). + +-spec close(pid(), boolean()) -> ok; + (state(), boolean()) -> state(). +close(Pid, SendTrailer) when is_pid(Pid) -> + cast(Pid, {close, SendTrailer}); +close(#{owner := Owner} = State, SendTrailer) when Owner == self() -> + if SendTrailer -> send_trailer(State); + true -> close_socket(State) + end; +close(_, _) -> + erlang:error(badarg). + +-spec establish(state()) -> state(). +establish(State) -> + process_stream_established(State). + +-spec set_timeout(state(), non_neg_integer() | infinity) -> state(). +set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> + case Timeout of + infinity -> State#{stream_timeout => infinity}; + _ -> + Time = p1_time_compat:monotonic_time(milli_seconds), + State#{stream_timeout => {Timeout, Time}} + end; +set_timeout(_, _) -> + erlang:error(badarg). + +get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner}) + when Owner == self() -> + SockMod:get_transport(Socket); +get_transport(_) -> + erlang:error(badarg). -spec change_shaper(state(), shaper:shaper()) -> ok. -change_shaper(#{sockmod := SockMod, socket := Socket}, Shaper) -> - SockMod:change_shaper(Socket, Shaper). +change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper) + when Owner == self() -> + SockMod:change_shaper(Socket, Shaper); +change_shaper(_, _) -> + erlang:error(badarg). + +-spec format_error(stop_reason()) -> binary(). +format_error({socket, Reason}) -> + format("Connection failed: ~s", [format_inet_error(Reason)]); +format_error({stream, reset}) -> + <<"Stream reset by peer">>; +format_error({stream, #stream_error{reason = Reason, text = Txt}}) -> + format("Stream failed: ~s", [format_stream_error(Reason, Txt)]); +format_error({tls, Reason}) -> + format("TLS failed: ~w", [Reason]); +format_error(Err) -> + format("Unrecognized error: ~w", [Err]). %%%=================================================================== %%% gen_server callbacks @@ -114,19 +204,24 @@ init([Module, {SockMod, Socket}, Opts]) -> {_, XS} -> XS; false -> false end, - TLSEnabled = proplists:get_bool(tls, Opts), + Encrypted = proplists:get_bool(tls, Opts), SocketMonitor = SockMod:monitor(Socket), case peername(SockMod, Socket) of {ok, IP} -> - State = #{mod => Module, + Time = p1_time_compat:monotonic_time(milli_seconds), + State = #{owner => self(), + mod => Module, socket => Socket, sockmod => SockMod, socket_monitor => SocketMonitor, + stream_timeout => {timer:seconds(30), Time}, + stream_direction => in, stream_id => new_id(), stream_state => wait_for_stream, + stream_header_sent => false, stream_restarted => false, stream_compressed => false, - stream_tlsed => TLSEnabled, + stream_encrypted => Encrypted, stream_version => {1,0}, stream_authenticated => false, xml_socket => XMLSocket, @@ -137,97 +232,133 @@ init([Module, {SockMod, Socket}, Opts]) -> resource => <<"">>, lserver => <<"">>, ip => IP}, - try Module:init([State, Opts]) - catch _:undef -> {ok, State} + case try Module:init([State, Opts]) + catch _:undef -> {ok, State} + end of + {ok, State1} -> + {_, State2, Timeout} = noreply(State1), + {ok, State2, Timeout}; + Err -> + Err end; {error, Reason} -> {stop, Reason} end. +handle_cast({send, Pkt}, State) -> + noreply(send_element(State, Pkt)); +handle_cast(stop, State) -> + {stop, normal, State}; handle_cast(Cast, #{mod := Mod} = State) -> - try Mod:handle_cast(Cast, State) - catch _:undef -> {noreply, State} - end. + noreply(try Mod:handle_cast(Cast, State) + catch _:undef -> State + end). handle_call(Call, From, #{mod := Mod} = State) -> - try Mod:handle_call(Call, From, State) - catch _:undef -> {reply, ok, State} - end. + noreply(try Mod:handle_call(Call, From, State) + catch _:undef -> State + end). handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, - #{stream_state := wait_for_stream, xmlns := XMLNS} = State) -> - try xmpp:decode(#xmlel{name = Name, attrs = Attrs}, XMLNS, []) of + #{stream_state := wait_for_stream, + xmlns := XMLNS, lang := MyLang} = State) -> + El = #xmlel{name = Name, attrs = Attrs}, + try xmpp:decode(El, XMLNS, []) of #stream_start{} = Pkt -> - case send_header(State, Pkt) of - {noreply, State1} -> - process_stream(Pkt, State1); - Err -> - Err + State1 = send_header(State, Pkt), + case is_disconnected(State1) of + true -> State1; + false -> noreply(process_stream(Pkt, State1)) end; _ -> - case send_header(State) of - {noreply, State1} -> - send_element(State1, xmpp:serr_invalid_xml()); - Err -> - Err + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> noreply(send_element(State1, xmpp:serr_invalid_xml())) end catch _:{xmpp_codec, Why} -> - case send_header(State) of - {noreply, State1} -> process_invalid_xml(Why, State1); - Err -> Err + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + Err = xmpp:serr_invalid_xml(Txt, Lang), + noreply(send_element(State1, Err)) end end; -handle_info({'$gen_event', {xmlstreamend, _}}, #{mod := Mod} = State) -> - try Mod:handle_stream_end(State) - catch _:undef -> {stop, normal, State} - end; handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> - case send_header(State) of - {noreply, State1} -> + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> Err = case Reason of <<"XML stanza is too big">> -> xmpp:serr_policy_violation(Reason, Lang); _ -> xmpp:serr_not_well_formed() end, - send_element(State1, Err); - Err -> - Err + noreply(send_element(State1, Err)) end; handle_info({'$gen_event', {xmlstreamelement, El}}, - #{xmlns := NS} = State) -> + #{xmlns := NS, lang := MyLang, mod := Mod} = State) -> try xmpp:decode(El, NS, [ignore_els]) of Pkt -> - process_element(Pkt, State) + State1 = try Mod:handle_recv(El, Pkt, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> noreply(process_element(Pkt, State1)) + end catch _:{xmpp_codec, Why} -> - process_invalid_xml(Why, State) + State1 = try Mod:handle_recv(El, {error, Why}, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang))) + end end; handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, #{mod := Mod} = State) -> - try Mod:handle_cdata(Data, State) - catch _:undef -> {noreply, State} - end; -handle_info(closed, #{mod := Mod} = State) -> - try Mod:handle_stream_close(State) - catch _:undef -> {stop, normal, State} - end; + noreply(try Mod:handle_cdata(Data, State) + catch _:undef -> State + end); +handle_info({'$gen_event', {xmlstreamend, _}}, State) -> + noreply(process_stream_end({error, {stream, reset}}, State)); +handle_info({'$gen_event', closed}, State) -> + noreply(process_stream_close({error, {socket, closed}}, State)); +handle_info(timeout, #{mod := Mod} = State) -> + Disconnected = is_disconnected(State), + noreply(try Mod:handle_timeout(State) + catch _:undef when not Disconnected -> + send_element(State, xmpp:serr_connection_timeout()); + _:undef -> + stop(State) + end); handle_info({'DOWN', MRef, _Type, _Object, _Info}, - #{socket_monitor := MRef, mod := Mod} = State) -> - try Mod:handle_stream_close(State) - catch _:undef -> {stop, normal, State} - end; + #{socket_monitor := MRef} = State) -> + noreply(process_stream_close({error, {socket, closed}}, State)); handle_info(Info, #{mod := Mod} = State) -> - try Mod:handle_info(Info, State) - catch _:undef -> {noreply, State} - end. + noreply(try Mod:handle_info(Info, State) + catch _:undef -> State + end). -terminate(Reason, #{mod := Mod, socket := Socket, - sockmod := SockMod} = State) -> - try Mod:terminate(Reason, State) - catch _:undef -> ok - end, - send_text(State, <<"">>), - SockMod:close(Socket). +terminate(Reason, #{mod := Mod} = State) -> + case get(already_terminated) of + true -> + State; + _ -> + put(already_terminated, true), + try Mod:terminate(Reason, State) + catch _:undef -> ok + end, + send_trailer(State) + end. code_change(OldVsn, #{mod := Mod} = State, Extra) -> Mod:code_change(OldVsn, State, Extra). @@ -235,20 +366,49 @@ code_change(OldVsn, #{mod := Mod} = State, Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== +-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. +noreply(#{stream_timeout := infinity} = State) -> + {noreply, State, infinity}; +noreply(#{stream_timeout := {MSecs, StartTime}} = State) -> + CurrentTime = p1_time_compat:monotonic_time(milli_seconds), + Timeout = max(0, MSecs - CurrentTime + StartTime), + {noreply, State, Timeout}. + -spec new_id() -> binary(). new_id() -> randoms:get_string(). +-spec is_disconnected(state()) -> boolean(). +is_disconnected(#{stream_state := StreamState}) -> + StreamState == disconnected. + +-spec peername(term(), term()) -> {ok, {inet:ip_address(), inet:port_number()}}| + {error, inet:posix()}. peername(SockMod, Socket) -> case SockMod of gen_tcp -> inet:peername(Socket); _ -> SockMod:peername(Socket) end. -process_invalid_xml(Reason, #{lang := Lang} = State) -> - Txt = xmpp:io_format_error(Reason), - send_element(State, xmpp:serr_invalid_xml(Txt, Lang)). +-spec process_stream_close(stop_reason(), state()) -> state(). +process_stream_close(_, #{stream_state := disconnected} = State) -> + State; +process_stream_close(Reason, #{mod := Mod} = State) -> + State1 = send_trailer(State), + try Mod:handle_stream_close(Reason, State1) + catch _:undef -> stop(State1) + end. + +-spec process_stream_end(stop_reason(), state()) -> state(). +process_stream_end(_, #{stream_state := disconnected} = State) -> + State; +process_stream_end(Reason, #{mod := Mod} = State) -> + State1 = send_trailer(State), + try Mod:handle_stream_end(Reason, State1) + catch _:undef -> stop(State1) + end. +-spec process_stream(stream_start(), state()) -> state(). process_stream(#stream_start{xmlns = XML_NS, stream_xmlns = STREAM_NS}, #{xmlns := NS} = State) @@ -268,73 +428,67 @@ process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) -> send_element(State, xmpp:serr_improper_addressing(Txt, Lang)); process_stream(#stream_start{from = undefined, version = {1,0}}, #{lang := Lang, xmlns := ?NS_SERVER, - stream_tlsed := true} = State) -> + stream_encrypted := true} = State) -> Txt = <<"Missing 'from' attribute">>, send_element(State, xmpp:serr_invalid_from(Txt, Lang)); process_stream(#stream_start{to = #jid{luser = U, lresource = R}}, #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> -> Txt = <<"Improper 'to' attribute">>, send_element(State, xmpp:serr_improper_addressing(Txt, Lang)); -process_stream(#stream_start{to = #jid{lserver = RemoteServer}}, +process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart, #{xmlns := ?NS_COMPONENT, mod := Mod} = State) -> - State1 = State#{remote_server => RemoteServer}, - case try Mod:handle_stream_start(State1) - catch _:undef -> {noreply, State1} - end of - {noreply, State2} -> - {noreply, State2#{stream_state => wait_for_handshake}}; - Err -> - Err + State1 = State#{remote_server => RemoteServer, + stream_state => wait_for_handshake}, + try Mod:handle_stream_start(StreamStart, State1) + catch _:undef -> State1 end; process_stream(#stream_start{to = #jid{server = Server, lserver = LServer}, - from = From}, + from = From} = StreamStart, #{stream_authenticated := Authenticated, stream_restarted := StreamWasRestarted, mod := Mod, xmlns := NS, resource := Resource, - stream_tlsed := TLSEnabled} = State) -> - case if not StreamWasRestarted -> - State1 = State#{server => Server, lserver => LServer}, - try Mod:handle_stream_start(State1) - catch _:undef -> {noreply, State1} - end; - true -> - {noreply, State} - end of - {noreply, State2} -> - State3 = if NS == ?NS_SERVER andalso TLSEnabled -> - State2#{remote_server => From#jid.lserver}; - true -> - State2 - end, - case send_features(State3) of - {noreply, State4} -> + stream_encrypted := Encrypted} = State) -> + State1 = if not StreamWasRestarted -> + State#{server => Server, lserver => LServer}; + true -> + State + end, + State2 = if NS == ?NS_SERVER andalso Encrypted -> + State1#{remote_server => From#jid.lserver}; + true -> + State1 + end, + State3 = try Mod:handle_stream_start(StreamStart, State2) + catch _:undef -> State2 + end, + case is_disconnected(State3) of + true -> State3; + false -> + State4 = send_features(State3), + case is_disconnected(State4) of + true -> State4; + false -> TLSRequired = is_starttls_required(State4), - NewStreamState = - if not Authenticated and - (not TLSEnabled and TLSRequired) -> - wait_for_starttls; - not Authenticated -> - wait_for_sasl_request; - (NS == ?NS_CLIENT) and (Resource == <<"">>) -> - wait_for_bind; - true -> - session_established - end, - {noreply, State4#{stream_state => NewStreamState}}; - Err -> - Err - end; - Err -> - Err + if not Authenticated and (TLSRequired and not Encrypted) -> + State4#{stream_state => wait_for_starttls}; + not Authenticated -> + State4#{stream_state => wait_for_sasl_request}; + (NS == ?NS_CLIENT) and (Resource == <<"">>) -> + State4#{stream_state => wait_for_bind}; + true -> + process_stream_established(State4) + end + end end. +-spec process_element(xmpp_element(), state()) -> state(). process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> case Pkt of #starttls{} when StateName == wait_for_starttls; StateName == wait_for_sasl_request -> process_starttls(State); #starttls{} -> - send_element(State, #starttls_failure{}); + process_starttls_failure(unexpected_starttls_request, State); #sasl_auth{} when StateName == wait_for_starttls -> send_element(State, #sasl_failure{reason = 'encryption-required'}); #sasl_auth{} when StateName == wait_for_sasl_request -> @@ -356,7 +510,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> #sasl_abort{} -> send_element(State, #sasl_failure{reason = 'aborted'}); #sasl_success{} -> - {noreply, State}; + State; #compress{} when StateName == wait_for_sasl_response -> send_element(State, #compress_failure{reason = 'setup-failed'}); #compress{} -> @@ -364,7 +518,9 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> #handshake{} when StateName == wait_for_handshake -> process_handshake(Pkt, State); #handshake{} -> - {noreply, State}; + State; + #stream_error{} -> + process_stream_end({error, {stream, Pkt}}, State); _ when StateName == wait_for_sasl_request; StateName == wait_for_handshake; StateName == wait_for_sasl_response -> @@ -375,10 +531,11 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> send_error(State, Pkt, Err); _ when StateName == wait_for_bind -> process_bind(Pkt, State); - _ when StateName == session_established -> + _ when StateName == established -> process_authenticated_packet(Pkt, State) end. +-spec process_unauthenticated_packet(xmpp_element(), state()) -> state(). process_unauthenticated_packet(Pkt, #{mod := Mod} = State) -> NewPkt = set_lang(Pkt, State), try Mod:handle_unauthenticated_packet(NewPkt, State) @@ -387,6 +544,7 @@ process_unauthenticated_packet(Pkt, #{mod := Mod} = State) -> send_error(State, Pkt, Err) end. +-spec process_authenticated_packet(xmpp_element(), state()) -> state(). process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) -> Pkt1 = set_lang(Pkt, State), case set_from_to(Pkt1, State) of @@ -411,6 +569,7 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) -> send_element(State, Err) end. +-spec process_bind(xmpp_element(), state()) -> state(). process_bind(#iq{type = set, sub_els = [_]} = Pkt, #{xmlns := ?NS_CLIENT, mod := Mod, lang := Lang} = State) -> case xmpp:get_subtag(Pkt, #bind{}) of @@ -426,8 +585,8 @@ process_bind(#iq{type = set, sub_els = [_]} = Pkt, server := S, resource := NewR} = State1} when NewR /= <<"">> -> Reply = #bind{jid = jid:make(U, S, NewR)}, - State2 = State1#{stream_state => session_established}, - send_element(State2, xmpp:make_iq_result(Pkt, Reply)); + State2 = send_element(State1, xmpp:make_iq_result(Pkt, Reply)), + process_stream_established(State2); {error, #stanza_error{}, State1} = Err -> send_error(State1, Pkt, Err) end @@ -446,16 +605,55 @@ process_bind(Pkt, #{mod := Mod} = State) -> send_error(State, Pkt, Err) end. -process_handshake(#handshake{data = Data}, #{mod := Mod} = State) -> - case Mod:handshake(Data, State) of - {ok, State1} -> - State2 = State1#{stream_state => session_established, - stream_authenticated => true}, - send_element(State2, #handshake{}); - {error, #stream_error{} = Err, State1} -> - send_element(State1, Err) +-spec process_handshake(handshake(), state()) -> state(). +process_handshake(#handshake{data = Digest}, + #{mod := Mod, stream_id := StreamID, + remote_server := RemoteServer} = State) -> + GetPW = try Mod:get_password_fun(State) + catch _:undef -> fun(_) -> {false, undefined} end + end, + AuthRes = case GetPW(<<"">>) of + {false, _} -> + false; + {Password, _} -> + p1_sha:sha(<>) == Digest + end, + case AuthRes of + true -> + State1 = try Mod:handle_auth_success( + RemoteServer, <<"handshake">>, undefined, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + State2 = send_element(State1, #handshake{}), + process_stream_established(State2) + end; + false -> + State1 = try Mod:handle_auth_failure( + RemoteServer, <<"handshake">>, 'not-authorized', State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> send_element(State1, xmpp:serr_not_authorized()) + end + end. + +-spec process_stream_established(state()) -> state(). +process_stream_established(#{stream_state := StateName} = State) + when StateName == disconnected; StateName == established -> + State; +process_stream_established(#{mod := Mod} = State) -> + State1 = State#{stream_authenticated := true, + stream_state => established, + stream_timeout => infinity}, + try Mod:handle_stream_established(State1) + catch _:undef -> State1 end. +-spec process_compress(compress(), state()) -> state(). process_compress(#compress{}, #{stream_compressed := true} = State) -> send_element(State, #compress_failure{reason = 'setup-failed'}); process_compress(#compress{methods = HisMethods}, @@ -468,16 +666,17 @@ process_compress(#compress{methods = HisMethods}, true -> BCompressed = fxml:element_to_binary(xmpp:encode(#compressed{})), ZlibSocket = SockMod:compress(Socket, BCompressed), - State1 = State#{socket => ZlibSocket, - stream_id => new_id(), - stream_restarted => true, - stream_state => wait_for_stream, - stream_compressed => true}, - {noreply, State1}; + State#{socket => ZlibSocket, + stream_id => new_id(), + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + stream_compressed => true}; false -> send_element(State, #compress_failure{reason = 'unsupported-method'}) end. +-spec process_starttls(state()) -> state(). process_starttls(#{socket := Socket, sockmod := SockMod, mod := Mod} = State) -> TLSOpts = try Mod:tls_options(State) @@ -485,38 +684,69 @@ process_starttls(#{socket := Socket, end, case SockMod:starttls(Socket, TLSOpts) of {ok, TLSSocket} -> - case send_element(State, #starttls_proceed{}) of - {noreply, State1} -> - {noreply, State1#{socket => TLSSocket, - stream_id => new_id(), - stream_restarted => true, - stream_state => wait_for_stream, - stream_tlsed => true}}; - Err -> - Err + State1 = send_element(State, #starttls_proceed{}), + case is_disconnected(State1) of + true -> State1; + false -> + State1#{socket => TLSSocket, + stream_id => new_id(), + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + stream_encrypted => true} end; - {error, _Reason} -> - send_element(State, #starttls_failure{}) + {error, Reason} -> + process_starttls_failure(Reason, State) end. -process_sasl_request(#sasl_auth{mechanism = <<"EXTERNAL">>}, - #{stream_tlsed := false} = State) -> - process_sasl_failure('encryption-required', <<"">>, State); -process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, - #{mod := Mod} = State) -> - try Mod:init_sasl(State) of - SASLState -> - SASLResult = cyrsasl:server_start(SASLState, Mech, ClientIn), - process_sasl_result(SASLResult, State) - catch _:undef -> - process_sasl_failure('temporary-auth-failure', <<"">>, State) +-spec process_starttls_failure(term(), state()) -> state(). +process_starttls_failure(Why, State) -> + State1 = send_element(State, #starttls_failure{}), + case is_disconnected(State1) of + true -> State1; + false -> process_stream_end({error, {tls, Why}}, State1) end. +-spec process_sasl_request(sasl_auth(), state()) -> state(). +process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, + #{mod := Mod, lserver := LServer} = State) -> + GetPW = try Mod:get_password_fun(State) + catch _:undef -> fun(_) -> false end + end, + CheckPW = try Mod:check_password_fun(State) + catch _:undef -> fun(_, _, _) -> false end + end, + CheckPWDigest = try Mod:check_password_digest_fun(State) + catch _:undef -> fun(_, _, _, _, _) -> false end + end, + SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [], + GetPW, CheckPW, CheckPWDigest), + State1 = State#{sasl_state => SASLState, sasl_mech => Mech}, + Mechs = get_sasl_mechanisms(State1), + SASLResult = case lists:member(Mech, Mechs) of + true when Mech == <<"EXTERNAL">> -> + case xmpp_stream_pkix:authenticate(State1, ClientIn) of + {ok, Peer} -> + {ok, [{auth_module, pkix}, + {username, Peer}]}; + {error, _Reason, Peer} -> + %% TODO: return meaningful error + {error, 'not-authorized', Peer} + end; + true -> + cyrsasl:server_start(SASLState, Mech, ClientIn); + false -> + {error, 'invalid-mechanism'} + end, + process_sasl_result(SASLResult, State1). + +-spec process_sasl_response(sasl_response(), state()) -> state(). process_sasl_response(#sasl_response{text = ClientIn}, #{sasl_state := SASLState} = State) -> SASLResult = cyrsasl:server_step(SASLState, ClientIn), process_sasl_result(SASLResult, State). +-spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state(). process_sasl_result({ok, Props}, State) -> process_sasl_success(Props, <<"">>, State); process_sasl_result({ok, Props, ServerOut}, State) -> @@ -528,58 +758,59 @@ process_sasl_result({error, Reason, User}, State) -> process_sasl_result({error, Reason}, State) -> process_sasl_failure(Reason, <<"">>, State). +-spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state(). process_sasl_success(Props, ServerOut, #{socket := Socket, sockmod := SockMod, - mod := Mod, sasl_state := SASLState} = State) -> - Mech = cyrsasl:get_mech(SASLState), + mod := Mod, sasl_mech := Mech} = State) -> User = identity(Props), AuthModule = proplists:get_value(auth_module, Props), - case try Mod:handle_auth_success(User, Mech, AuthModule, State) - catch _:undef -> {noreply, State} - end of - {noreply, State1} -> + State1 = try Mod:handle_auth_success(User, Mech, AuthModule, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> SockMod:reset_stream(Socket), - case send_element(State1, #sasl_success{text = ServerOut}) of - {noreply, State2} -> - State3 = maps:remove(sasl_state, State2), - {noreply, State3#{stream_id => new_id(), - stream_authenticated => true, - stream_restarted => true, - stream_state => wait_for_stream, - user => User}}; - Err -> - Err - end; - Err -> - Err + State2 = send_element(State1, #sasl_success{text = ServerOut}), + case is_disconnected(State2) of + true -> State2; + false -> + State3 = maps:remove(sasl_state, + maps:remove(sasl_mech, State2)), + State3#{stream_id => new_id(), + stream_authenticated => true, + stream_header_sent => false, + stream_restarted => true, + stream_state => wait_for_stream, + user => User} + end end. +-spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state(). process_sasl_continue(ServerOut, NewSASLState, State) -> - send_element(State, #sasl_challenge{text = ServerOut}), - {noreply, State#{sasl_state => NewSASLState, - stream_state => wait_for_sasl_response}}. + State1 = State#{sasl_state => NewSASLState, + stream_state => wait_for_sasl_response}, + send_element(State1, #sasl_challenge{text = ServerOut}). +-spec process_sasl_failure(atom(), binary(), state()) -> state(). process_sasl_failure(Reason, User, - #{mod := Mod, sasl_state := SASLState} = State) -> - Mech = cyrsasl:get_mech(SASLState), - case try Mod:handle_auth_failure(User, Mech, Reason, State) - catch _:undef -> {noreply, State} - end of - {noreply, State1} -> - State2 = maps:remove(sasl_state, State1), - State3 = State2#{stream_state => wait_for_sasl_request}, - send_element(State3, #sasl_failure{reason = Reason}); - Err -> - Err - end. + #{mod := Mod, sasl_mech := Mech} = State) -> + State1 = try Mod:handle_auth_failure(User, Mech, Reason, State) + catch _:undef -> State + end, + State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)), + State3 = State2#{stream_state => wait_for_sasl_request}, + send_element(State3, #sasl_failure{reason = Reason}). +-spec process_sasl_abort(state()) -> state(). process_sasl_abort(State) -> process_sasl_failure('aborted', <<"">>, State). +-spec send_features(state()) -> state(). send_features(#{stream_version := {1,0}, - stream_tlsed := TLSEnabled} = State) -> + stream_encrypted := Encrypted} = State) -> TLSRequired = is_starttls_required(State), - Features = if TLSRequired and not TLSEnabled -> + Features = if TLSRequired and not Encrypted -> get_tls_feature(State); true -> get_sasl_feature(State) ++ get_compress_feature(State) @@ -588,26 +819,38 @@ send_features(#{stream_version := {1,0}, end, send_element(State, #stream_features{sub_els = Features}); send_features(State) -> - %% clients from stone age - {noreply, State}. + %% clients and servers from stone age + State. +-spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()]. +get_sasl_mechanisms(#{stream_encrypted := Encrypted, mod := Mod, + xmlns := NS, lserver := LServer} = State) -> + Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer); + true -> [] + end, + TLSVerify = try Mod:tls_verify(State) + catch _:undef -> false + end, + if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) -> + [<<"EXTERNAL">>|Mechs]; + true -> + Mechs + end. + +-spec get_sasl_feature(state()) -> [sasl_mechanisms()]. get_sasl_feature(#{stream_authenticated := false, - mod := Mod, - stream_tlsed := TLSEnabled} = State) -> + stream_encrypted := Encrypted} = State) -> TLSRequired = is_starttls_required(State), - if TLSEnabled or not TLSRequired -> - try Mod:sasl_mechanisms(State) of - [] -> []; - List -> [#sasl_mechanisms{list = List}] - catch _:undef -> - [] - end; + if Encrypted or not TLSRequired -> + Mechs = get_sasl_mechanisms(State), + [#sasl_mechanisms{list = Mechs}]; true -> [] end; get_sasl_feature(_) -> []. +-spec get_compress_feature(state()) -> [compression()]. get_compress_feature(#{stream_compressed := false, mod := Mod} = State) -> try Mod:compress_methods(State) of [] -> []; @@ -618,23 +861,31 @@ get_compress_feature(#{stream_compressed := false, mod := Mod} = State) -> get_compress_feature(_) -> []. +-spec get_tls_feature(state()) -> [starttls()]. get_tls_feature(#{stream_authenticated := false, - stream_tlsed := false} = State) -> + stream_encrypted := false} = State) -> TLSRequired = is_starttls_required(State), [#starttls{required = TLSRequired}]; get_tls_feature(_) -> []. -get_bind_feature(#{stream_authenticated := true, resource := <<"">>}) -> +-spec get_bind_feature(state()) -> [bind()]. +get_bind_feature(#{xmlns := ?NS_CLIENT, + stream_authenticated := true, + resource := <<"">>}) -> [#bind{}]; get_bind_feature(_) -> []. -get_session_feature(#{stream_authenticated := true, resource := <<"">>}) -> +-spec get_session_feature(state()) -> [xmpp_session()]. +get_session_feature(#{xmlns := ?NS_CLIENT, + stream_authenticated := true, + resource := <<"">>}) -> [#xmpp_session{optional = true}]; get_session_feature(_) -> []. +-spec get_other_features(state()) -> [xmpp_element()]. get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) -> try if Auth -> Mod:authenticated_stream_features(State); @@ -644,15 +895,18 @@ get_other_features(#{stream_authenticated := Auth, mod := Mod} = State) -> [] end. +-spec is_starttls_required(state()) -> boolean(). is_starttls_required(#{mod := Mod} = State) -> try Mod:tls_required(State) catch _:undef -> false end. +-spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} | + {error, stream_error()}. set_from_to(Pkt, _State) when not ?is_stanza(Pkt) -> {ok, Pkt}; set_from_to(Pkt, #{user := U, server := S, resource := R, - xmlns := ?NS_CLIENT}) -> + lang := Lang, xmlns := ?NS_CLIENT}) -> JID = jid:make(U, S, R), From = case xmpp:get_from(Pkt) of undefined -> JID; @@ -668,7 +922,8 @@ set_from_to(Pkt, #{user := U, server := S, resource := R, end, {ok, xmpp:set_from_to(Pkt, JID, To)}; true -> - {error, xmpp:serr_invalid_from()} + Txt = <<"Improper 'from' attribute">>, + {error, xmpp:serr_invalid_from(Txt, Lang)} end; set_from_to(Pkt, #{lang := Lang}) -> From = xmpp:get_from(Pkt), @@ -683,17 +938,22 @@ set_from_to(Pkt, #{lang := Lang}) -> {ok, Pkt} end. +-spec send_header(state()) -> state(). send_header(State) -> send_header(State, #stream_start{}). -send_header(#{stream_state := wait_for_stream, - stream_id := StreamID, +-spec send_header(state(), stream_start()) -> state(). +send_header(#{stream_id := StreamID, stream_version := MyVersion, + stream_header_sent := false, lang := MyLang, xmlns := NS, server := DefaultServer} = State, #stream_start{to = To, lang = HisLang, version = HisVersion}) -> - Lang = choose_lang(MyLang, HisLang), + Lang = select_lang(MyLang, HisLang), + NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK; + true -> <<"">> + end, From = case To of #jid{} -> To; undefined -> jid:make(DefaultServer) @@ -706,63 +966,114 @@ send_header(#{stream_state := wait_for_stream, lang = Lang, xmlns = NS, stream_xmlns = ?NS_STREAM, + db_xmlns = NS_DB, id = StreamID, from = From}), - State1 = State#{lang => Lang}, + State1 = State#{lang => Lang, stream_header_sent => true}, case send_text(State1, fxml:element_to_header(Header)) of - ok -> {noreply, State1}; - {error, _} -> {stop, normal, State1} + ok -> State1; + {error, Why} -> process_stream_close({error, {socket, Why}}, State1) end; send_header(State, _) -> - {noreply, State}. + State. +-spec send_element(state(), xmpp_element()) -> state(). send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> El = xmpp:encode(Pkt, NS), Data = fxml:element_to_binary(El), - case send_text(State, Data) of - ok when is_record(Pkt, stream_error) -> - {stop, normal, State}; - ok when is_record(Pkt, starttls_failure) -> - {stop, normal, State}; - Res -> - try Mod:handle_send(Res, Pkt, El, Data, State) - catch _:undef when Res == ok -> - {noreply, State}; - _:undef -> - {stop, normal, State} - end + Result = send_text(State, Data), + State1 = try Mod:handle_send(Pkt, Result, State) + catch _:undef -> State + end, + case Result of + _ when is_record(Pkt, stream_error) -> + process_stream_end({error, {stream, Pkt}}, State1); + ok -> + State1; + {error, Why} -> + process_stream_close({error, {socket, Why}}, State1) end. -send_error(State, Pkt, Err) when ?is_stanza(Pkt) -> - case xmpp:get_type(Pkt) of - result -> {noreply, State}; - error -> {noreply, State}; - _ -> - ErrPkt = xmpp:make_error(Pkt, Err), - send_element(State, ErrPkt) - end; -send_error(State, _, _) -> - {noreply, State}. +-spec send_error(state(), xmpp_element(), stanza_error()) -> state(). +send_error(State, Pkt, Err) -> + case xmpp:is_stanza(Pkt) of + true -> + case xmpp:get_type(Pkt) of + result -> State; + error -> State; + <<"result">> -> State; + <<"error">> -> State; + _ -> + ErrPkt = xmpp:make_error(Pkt, Err), + send_element(State, ErrPkt) + end; + false -> + State + end. + +-spec send_trailer(state()) -> state(). +send_trailer(State) -> + send_text(State, <<"">>), + close_socket(State). -send_text(#{socket := Sock, sockmod := SockMod}, Data) -> - SockMod:send(Sock, Data). +-spec send_text(state(), binary()) -> ok | {error, inet:posix()}. +send_text(#{socket := Sock, sockmod := SockMod, + stream_state := StateName, + stream_header_sent := true}, Data) when StateName /= disconnected -> + SockMod:send(Sock, Data); +send_text(_, _) -> + {error, einval}. -choose_lang(Lang, <<"">>) -> Lang; -choose_lang(_, Lang) -> Lang. +-spec close_socket(state()) -> state(). +close_socket(#{sockmod := SockMod, socket := Socket} = State) -> + SockMod:close(Socket), + State#{stream_timeout => infinity, + stream_state => disconnected}. +-spec select_lang(binary(), binary()) -> binary(). +select_lang(Lang, <<"">>) -> Lang; +select_lang(_, Lang) -> Lang. + +-spec set_lang(xmpp_element(), state()) -> xmpp_element(). set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) -> HisLang = xmpp:get_lang(Pkt), - Lang = choose_lang(MyLang, HisLang), + Lang = select_lang(MyLang, HisLang), xmpp:set_lang(Pkt, Lang); set_lang(Pkt, _) -> Pkt. +-spec format_inet_error(atom()) -> string(). +format_inet_error(Reason) -> + case inet:format_error(Reason) of + "unknown POSIX error" -> atom_to_list(Reason); + Txt -> Txt + end. + +-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string(). +format_stream_error(Reason, Txt) -> + Slogan = case Reason of + #'see-other-host'{} -> "see-other-host"; + _ -> atom_to_list(Reason) + end, + case Txt of + undefined -> Slogan; + #text{data = <<"">>} -> Slogan; + #text{data = Data} -> + binary_to_list(Data) ++ " (" ++ Slogan ++ ")" + end. + +-spec format(io:format(), list()) -> binary(). +format(Fmt, Args) -> + iolist_to_binary(io_lib:format(Fmt, Args)). + +-spec lists_intersection(list(), list()) -> list(). lists_intersection(L1, L2) -> lists:filter( fun(E) -> lists:member(E, L2) end, L1). +-spec identity([cyrsasl:sasl_property()]) -> binary(). identity(Props) -> case proplists:get_value(authzid, Props, <<>>) of <<>> -> proplists:get_value(username, Props, <<>>); diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl new file mode 100644 index 000000000..fc373fff8 --- /dev/null +++ b/src/xmpp_stream_out.erl @@ -0,0 +1,856 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2016, Evgeny Khramtsov +%%% @doc +%%% +%%% @end +%%% Created : 14 Dec 2016 by Evgeny Khramtsov +%%%------------------------------------------------------------------- +-module(xmpp_stream_out). +-behaviour(gen_server). + +-protocol({rfc, 6120}). + +%% API +-export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1, + stop/1, send/2, close/1, close/2, establish/1, format_error/1, + set_timeout/2, get_transport/1, change_shaper/2]). +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). + +%%-define(DBGFSM, true). +-ifdef(DBGFSM). +-define(FSMOPTS, [{debug, [trace]}]). +-else. +-define(FSMOPTS, []). +-endif. + +-define(TCP_SEND_TIMEOUT, 15000). + +-include("xmpp.hrl"). +-include("logger.hrl"). +-include_lib("kernel/include/inet.hrl"). + +-type state() :: map(). +-type host_port() :: {inet:hostname(), inet:port_number()}. +-type ip_port() :: {inet:ip_address(), inet:port_number()}. +-type network_error() :: {error, inet:posix() | inet_res:res_error()}. +-type stop_reason() :: {idna, bad_string} | + {dns, inet:posix() | inet_res:res_error()} | + {stream, reset | stream_error()} | + {tls, term()} | + {pkix, binary()} | + {auth, atom() | binary() | string()} | + {socket, inet:posix() | closed | timeout}. + +-callback init(list()) -> {ok, state()} | {stop, term()} | ignore. + +%%%=================================================================== +%%% API +%%%=================================================================== +start(Mod, Args, Opts) -> + gen_server:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + +start_link(Mod, Args, Opts) -> + gen_server:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS). + +call(Ref, Msg, Timeout) -> + gen_server:call(Ref, Msg, Timeout). + +cast(Ref, Msg) -> + gen_server:cast(Ref, Msg). + +reply(Ref, Reply) -> + gen_server:reply(Ref, Reply). + +-spec connect(pid()) -> ok. +connect(Ref) -> + cast(Ref, connect). + +-spec stop(pid()) -> ok; + (state()) -> no_return(). +stop(Pid) when is_pid(Pid) -> + cast(Pid, stop); +stop(#{owner := Owner} = State) when Owner == self() -> + terminate(normal, State), + exit(normal); +stop(_) -> + erlang:error(badarg). + +-spec send(pid(), xmpp_element()) -> ok; + (state(), xmpp_element()) -> state(). +send(Pid, Pkt) when is_pid(Pid) -> + cast(Pid, {send, Pkt}); +send(#{owner := Owner} = State, Pkt) when Owner == self() -> + send_element(State, Pkt); +send(_, _) -> + erlang:error(badarg). + +-spec close(pid()) -> ok; + (state()) -> state(). +close(Ref) -> + close(Ref, true). + +-spec close(pid(), boolean()) -> ok; + (state(), boolean()) -> state(). +close(Pid, SendTrailer) when is_pid(Pid) -> + cast(Pid, {close, SendTrailer}); +close(#{owner := Owner} = State, SendTrailer) when Owner == self() -> + if SendTrailer -> send_trailer(State); + true -> close_socket(State) + end; +close(_, _) -> + erlang:error(badarg). + +-spec establish(state()) -> state(). +establish(State) -> + process_stream_established(State). + +-spec set_timeout(state(), non_neg_integer() | infinity) -> state(). +set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> + case Timeout of + infinity -> State#{stream_timeout => infinity}; + _ -> + Time = p1_time_compat:monotonic_time(milli_seconds), + State#{stream_timeout => {Timeout, Time}} + end; +set_timeout(_, _) -> + erlang:error(badarg). + +get_transport(#{sockmod := SockMod, socket := Socket, owner := Owner}) + when Owner == self() -> + SockMod:get_transport(Socket); +get_transport(_) -> + erlang:error(badarg). + +-spec change_shaper(state(), shaper:shaper()) -> ok. +change_shaper(#{sockmod := SockMod, socket := Socket, owner := Owner}, Shaper) + when Owner == self() -> + SockMod:change_shaper(Socket, Shaper); +change_shaper(_, _) -> + erlang:error(badarg). + +-spec format_error(stop_reason()) -> binary(). +format_error({idna, _}) -> + <<"Not an IDN hostname">>; +format_error({dns, Reason}) -> + format("DNS lookup failed: ~s", [format_inet_error(Reason)]); +format_error({socket, Reason}) -> + format("Connection failed: ~s", [format_inet_error(Reason)]); +format_error({pkix, Reason}) -> + format("Peer certificate rejected: ~s", [Reason]); +format_error({stream, reset}) -> + <<"Stream reset by peer">>; +format_error({stream, #stream_error{reason = Reason, text = Txt}}) -> + format("Stream failed: ~s", [format_stream_error(Reason, Txt)]); +format_error({tls, Reason}) -> + format("TLS failed: ~w", [Reason]); +format_error({auth, Reason}) -> + format("Authentication failed: ~s", [Reason]); +format_error(Err) -> + format("Unrecognized error: ~w", [Err]). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== +init([Mod, SockMod, From, To, Opts]) -> + Time = p1_time_compat:monotonic_time(milli_seconds), + State = #{owner => self(), + mod => Mod, + sockmod => SockMod, + server => From, + user => <<"">>, + resource => <<"">>, + lang => <<"">>, + remote_server => To, + xmlns => ?NS_SERVER, + stream_direction => out, + stream_timeout => {timer:seconds(30), Time}, + stream_id => new_id(), + stream_encrypted => false, + stream_verified => false, + stream_authenticated => false, + stream_restarted => false, + stream_state => connecting}, + case try Mod:init([State, Opts]) + catch _:undef -> {ok, State} + end of + {ok, State1} -> + {_, State2, Timeout} = noreply(State1), + {ok, State2, Timeout}; + Err -> + Err + end. + +handle_call(Call, From, #{mod := Mod} = State) -> + noreply(try Mod:handle_call(Call, From, State) + catch _:undef -> State + end). + +handle_cast(connect, #{remote_server := RemoteServer, + sockmod := SockMod, + stream_state := connecting} = State) -> + case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of + false -> + noreply(process_stream_close({error, {idna, bad_string}}, State)); + ASCIIName -> + case resolve(binary_to_list(ASCIIName), State) of + {ok, AddrPorts} -> + case connect(AddrPorts, State) of + {ok, Socket, AddrPort} -> + SocketMonitor = SockMod:monitor(Socket), + State1 = State#{ip => AddrPort, + socket => Socket, + socket_monitor => SocketMonitor}, + State2 = State1#{stream_state => wait_for_stream}, + noreply(send_header(State2)); + {error, Why} -> + Err = {error, {socket, Why}}, + noreply(process_stream_close(Err, State)) + end; + {error, Why} -> + noreply(process_stream_close({error, {dns, Why}}, State)) + end + end; +handle_cast(connect, State) -> + %% Ignoring connection attempts in other states + noreply(State); +handle_cast({send, Pkt}, State) -> + noreply(send_element(State, Pkt)); +handle_cast(stop, State) -> + {stop, normal, State}; +handle_cast(Cast, #{mod := Mod} = State) -> + noreply(try Mod:handle_cast(Cast, State) + catch _:undef -> State + end). + +handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, + #{stream_state := wait_for_stream, + xmlns := XMLNS, lang := MyLang} = State) -> + El = #xmlel{name = Name, attrs = Attrs}, + try xmpp:decode(El, XMLNS, []) of + #stream_start{} = Pkt -> + noreply(process_stream(Pkt, State)); + _ -> + noreply(send_element(State, xmpp:serr_invalid_xml())) + catch _:{xmpp_codec, Why} -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + Err = xmpp:serr_invalid_xml(Txt, Lang), + noreply(send_element(State, Err)) + end; +handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> + Err = case Reason of + <<"XML stanza is too big">> -> + xmpp:serr_policy_violation(Reason, Lang); + _ -> + xmpp:serr_not_well_formed() + end, + noreply(send_element(State1, Err)) + end; +handle_info({'$gen_event', {xmlstreamelement, El}}, + #{xmlns := NS, lang := MyLang, mod := Mod} = State) -> + try xmpp:decode(El, NS, [ignore_els]) of + Pkt -> + State1 = try Mod:handle_recv(El, Pkt, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> noreply(process_element(Pkt, State1)) + end + catch _:{xmpp_codec, Why} -> + State1 = try Mod:handle_recv(El, undefined, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang))) + end + end; +handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, + #{mod := Mod} = State) -> + noreply(try Mod:handle_cdata(Data, State) + catch _:undef -> State + end); +handle_info({'$gen_event', {xmlstreamend, _}}, State) -> + noreply(process_stream_end({error, {stream, reset}}, State)); +handle_info({'$gen_event', closed}, State) -> + noreply(process_stream_close({error, {socket, closed}}, State)); +handle_info(timeout, #{mod := Mod} = State) -> + Disconnected = is_disconnected(State), + noreply(try Mod:handle_timeout(State) + catch _:undef when not Disconnected -> + send_element(State, xmpp:serr_connection_timeout()); + _:undef -> + stop(State) + end); +handle_info({'DOWN', MRef, _Type, _Object, _Info}, + #{socket_monitor := MRef} = State) -> + noreply(process_stream_close({error, {socket, closed}}, State)); +handle_info(Info, #{mod := Mod} = State) -> + noreply(try Mod:handle_info(Info, State) + catch _:undef -> State + end). + +terminate(Reason, #{mod := Mod} = State) -> + case get(already_terminated) of + true -> + State; + _ -> + put(already_terminated, true), + try Mod:terminate(Reason, State) + catch _:undef -> ok + end, + send_trailer(State) + end. + +code_change(OldVsn, #{mod := Mod} = State, Extra) -> + Mod:code_change(OldVsn, State, Extra). + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +-spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. +noreply(#{stream_timeout := infinity} = State) -> + {noreply, State, infinity}; +noreply(#{stream_timeout := {MSecs, OldTime}} = State) -> + NewTime = p1_time_compat:monotonic_time(milli_seconds), + Timeout = max(0, MSecs - NewTime + OldTime), + {noreply, State, Timeout}. + +-spec new_id() -> binary(). +new_id() -> + randoms:get_string(). + +-spec is_disconnected(state()) -> boolean(). +is_disconnected(#{stream_state := StreamState}) -> + StreamState == disconnected. + +-spec process_stream_close(stop_reason(), state()) -> state(). +process_stream_close(_, #{stream_state := disconnected} = State) -> + State; +process_stream_close(Reason, #{mod := Mod} = State) -> + State1 = send_trailer(State), + try Mod:handle_stream_close(Reason, State1) + catch _:undef -> stop(State1) + end. + +-spec process_stream_end(stop_reason(), state()) -> state(). +process_stream_end(_, #{stream_state := disconnected} = State) -> + State; +process_stream_end(Reason, #{mod := Mod} = State) -> + State1 = send_trailer(State), + try Mod:handle_stream_end(Reason, State1) + catch _:undef -> stop(State1) + end. + +-spec process_stream(stream_start(), state()) -> state(). +process_stream(#stream_start{xmlns = XML_NS, + stream_xmlns = STREAM_NS}, + #{xmlns := NS} = State) + when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> + send_element(State, xmpp:serr_invalid_namespace()); +process_stream(#stream_start{lang = Lang, id = ID, + version = Version} = StreamStart, + #{mod := Mod} = State) -> + State1 = State#{stream_remote_id => ID, lang => Lang}, + State2 = try Mod:handle_stream_start(StreamStart, State1) + catch _:undef -> State1 + end, + case is_disconnected(State2) of + true -> State2; + false -> + case Version of + {1,0} -> State2#{stream_state => wait_for_features}; + _ -> process_stream_downgrade(StreamStart, State) + end + end. + +-spec process_element(xmpp_element(), state()) -> state(). +process_element(Pkt, #{stream_state := StateName} = State) -> + case Pkt of + #stream_features{} when StateName == wait_for_features -> + process_features(Pkt, State); + #starttls_proceed{} when StateName == wait_for_starttls_response -> + process_starttls(State); + #sasl_success{} when StateName == wait_for_sasl_response -> + process_sasl_success(State); + #sasl_failure{} when StateName == wait_for_sasl_response -> + process_sasl_failure(Pkt, State); + #stream_error{} -> + process_stream_end({error, {stream, Pkt}}, State); + _ when is_record(Pkt, stream_features); + is_record(Pkt, starttls_proceed); + is_record(Pkt, starttls); + is_record(Pkt, sasl_auth); + is_record(Pkt, sasl_success); + is_record(Pkt, sasl_failure); + is_record(Pkt, sasl_response); + is_record(Pkt, sasl_abort); + is_record(Pkt, compress); + is_record(Pkt, handshake) -> + %% Do not pass this crap upstream + State; + _ -> + process_packet(Pkt, State) + end. + +-spec process_features(stream_features(), state()) -> state(). +process_features(StreamFeatures, + #{stream_authenticated := true, mod := Mod} = State) -> + State1 = try Mod:handle_authenticated_features(StreamFeatures, State) + catch _:undef -> State + end, + process_stream_established(State1); +process_features(#stream_features{sub_els = Els} = StreamFeatures, + #{stream_encrypted := Encrypted, + mod := Mod, lang := Lang} = State) -> + State1 = try Mod:handle_unauthenticated_features(StreamFeatures, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + TLSRequired = is_starttls_required(State1), + %% TODO: improve xmpp.erl + Msg = #message{sub_els = Els}, + case xmpp:get_subtag(Msg, #starttls{}) of + false when TLSRequired and not Encrypted -> + Txt = <<"Use of STARTTLS required">>, + send_element(State1, xmpp:err_policy_violation(Txt, Lang)); + #starttls{} when not Encrypted -> + State2 = State1#{stream_state => wait_for_starttls_response}, + send_element(State2, #starttls{}); + _ -> + State2 = process_cert_verification(State1), + case is_disconnected(State2) of + true -> State2; + false -> + case xmpp:get_subtag(Msg, #sasl_mechanisms{}) of + #sasl_mechanisms{list = Mechs} -> + process_sasl_mechanisms(Mechs, State2); + false -> + process_sasl_failure( + #sasl_failure{reason = 'invalid-mechanism'}, + State2) + end + end + end + end. + +-spec process_stream_established(state()) -> state(). +process_stream_established(#{stream_state := StateName} = State) + when StateName == disconnected; StateName == established -> + State; +process_stream_established(#{mod := Mod} = State) -> + State1 = State#{stream_authenticated := true, + stream_state => established, + stream_timeout => infinity}, + try Mod:handle_stream_established(State1) + catch _:undef -> State1 + end. + +-spec process_sasl_mechanisms([binary()], state()) -> state(). +process_sasl_mechanisms(Mechs, #{user := User, server := Server} = State) -> + %% TODO: support other mechanisms + Mech = <<"EXTERNAL">>, + case lists:member(<<"EXTERNAL">>, Mechs) of + true -> + State1 = State#{stream_state => wait_for_sasl_response}, + Authzid = jid:to_string(jid:make(User, Server)), + send_element(State1, #sasl_auth{mechanism = Mech, text = Authzid}); + false -> + process_sasl_failure( + #sasl_failure{reason = 'invalid-mechanism'}, State) + end. + +-spec process_starttls(state()) -> state(). +process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) -> + TLSOpts = try Mod:tls_options(State) + catch _:undef -> [] + end, + case SockMod:starttls(Socket, [connect|TLSOpts]) of + {ok, TLSSocket} -> + State1 = State#{socket => TLSSocket, + stream_id => new_id(), + stream_restarted => true, + stream_state => wait_for_stream, + stream_encrypted => true}, + send_header(State1); + {error, Why} -> + process_stream_close({error, {tls, Why}}, State) + end. + +-spec process_stream_downgrade(stream_start(), state()) -> state(). +process_stream_downgrade(StreamStart, #{mod := Mod} = State) -> + try Mod:downgrade_stream(StreamStart, State) + catch _:undef -> + send_element(State, xmpp:serr_unsupported_version()) + end. + +-spec process_cert_verification(state()) -> state(). +process_cert_verification(#{stream_encrypted := true, + stream_verified := false, + mod := Mod} = State) -> + case try Mod:tls_verify(State) + catch _:undef -> true + end of + true -> + case xmpp_stream_pkix:authenticate(State) of + {ok, _} -> + State#{stream_verified => true}; + {error, Why, _Peer} -> + process_stream_close({error, {pkix, Why}}, State) + end; + false -> + State#{stream_verified => true} + end; +process_cert_verification(State) -> + State. + +-spec process_sasl_success(state()) -> state(). +process_sasl_success(#{mod := Mod, + sockmod := SockMod, + socket := Socket} = State) -> + State1 = try Mod:handle_auth_success(<<"EXTERNAL">>, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + SockMod:reset_stream(Socket), + State2 = State1#{stream_id => new_id(), + stream_restarted => true, + stream_state => wait_for_stream, + stream_authenticated => true}, + send_header(State2) + end. + +-spec process_sasl_failure(sasl_failure(), state()) -> state(). +process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) -> + try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State) + catch _:undef -> process_stream_close({error, {auth, Reason}}, State) + end. + +-spec process_packet(xmpp_element(), state()) -> state(). +process_packet(Pkt, #{mod := Mod} = State) -> + try Mod:handle_packet(Pkt, State) + catch _:undef -> State + end. + +-spec is_starttls_required(state()) -> boolean(). +is_starttls_required(#{mod := Mod} = State) -> + try Mod:tls_required(State) + catch _:undef -> false + end. + +-spec send_header(state()) -> state(). +send_header(#{remote_server := RemoteServer, + stream_encrypted := Encrypted, + lang := Lang, + xmlns := NS, + user := User, + resource := Resource, + server := Server} = State) -> + NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK; + true -> <<"">> + end, + From = if Encrypted -> + jid:make(User, Server, Resource); + NS == ?NS_SERVER -> + jid:make(Server); + true -> + undefined + end, + Header = xmpp:encode( + #stream_start{xmlns = NS, + lang = Lang, + stream_xmlns = ?NS_STREAM, + db_xmlns = NS_DB, + from = From, + to = jid:make(RemoteServer), + version = {1,0}}), + case send_text(State, fxml:element_to_header(Header)) of + ok -> State; + {error, Why} -> process_stream_close({error, {socket, Why}}, State) + end. + +-spec send_element(state(), xmpp_element()) -> state(). +send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> + El = xmpp:encode(Pkt, NS), + Data = fxml:element_to_binary(El), + State1 = try Mod:handle_send(Pkt, El, Data, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + case send_text(State1, Data) of + _ when is_record(Pkt, stream_error) -> + process_stream_end({error, {stream, Pkt}}, State1); + ok -> + State1; + {error, Why} -> + process_stream_close({error, {socket, Why}}, State1) + end + end. + +-spec send_error(state(), xmpp_element(), stanza_error()) -> state(). +send_error(State, Pkt, Err) -> + case xmpp:is_stanza(Pkt) of + true -> + case xmpp:get_type(Pkt) of + result -> State; + error -> State; + <<"result">> -> State; + <<"error">> -> State; + _ -> + ErrPkt = xmpp:make_error(Pkt, Err), + send_element(State, ErrPkt) + end; + false -> + State + end. + +-spec send_text(state(), binary()) -> ok | {error, inet:posix()}. +send_text(#{sockmod := SockMod, socket := Socket, + stream_state := StateName}, Data) when StateName /= disconnected -> + SockMod:send(Socket, Data); +send_text(_, _) -> + {error, einval}. + +-spec send_trailer(state()) -> state(). +send_trailer(State) -> + send_text(State, <<"">>), + close_socket(State). + +-spec close_socket(state()) -> state(). +close_socket(State) -> + case State of + #{sockmod := SockMod, socket := Socket} -> + SockMod:close(Socket); + _ -> + ok + end, + State#{stream_timeout => infinity, + stream_state => disconnected}. + +-spec select_lang(binary(), binary()) -> binary(). +select_lang(Lang, <<"">>) -> Lang; +select_lang(_, Lang) -> Lang. + +-spec format_inet_error(atom()) -> string(). +format_inet_error(Reason) -> + case inet:format_error(Reason) of + "unknown POSIX error" -> atom_to_list(Reason); + Txt -> Txt + end. + +-spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string(). +format_stream_error(Reason, Txt) -> + Slogan = case Reason of + #'see-other-host'{} -> "see-other-host"; + _ -> atom_to_list(Reason) + end, + case Txt of + undefined -> Slogan; + #text{data = <<"">>} -> Slogan; + #text{data = Data} -> + binary_to_list(Data) ++ " (" ++ Slogan ++ ")" + end. + +-spec format(io:format(), list()) -> binary(). +format(Fmt, Args) -> + iolist_to_binary(io_lib:format(Fmt, Args)). + +%%%=================================================================== +%%% Connection stuff +%%%=================================================================== +-spec resolve(string(), state()) -> {ok, [host_port()]} | network_error(). +resolve(Host, State) -> + case srv_lookup(Host, State) of + {error, _Reason} -> + DefaultPort = get_default_port(State), + a_lookup([{Host, DefaultPort}], State); + {ok, HostPorts} -> + a_lookup(HostPorts, State) + end. + +-spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error(). +srv_lookup(Host, State) -> + %% Only perform SRV lookups for FQDN names + case string:chr(Host, $.) of + 0 -> + {error, nxdomain}; + _ -> + case inet_parse:address(Host) of + {ok, _} -> + {error, nxdomain}; + {error, _} -> + Timeout = get_dns_timeout(State), + Retries = get_dns_retries(State), + srv_lookup(Host, Timeout, Retries) + end + end. + +-spec srv_lookup(string(), non_neg_integer(), integer()) -> + {ok, [host_port()]} | network_error(). +srv_lookup(_Host, _Timeout, Retries) when Retries < 1 -> + {error, timeout}; +srv_lookup(Host, Timeout, Retries) -> + SRVName = "_xmpp-server._tcp." ++ Host, + case inet_res:getbyname(SRVName, srv, Timeout) of + {ok, HostEntry} -> + host_entry_to_host_ports(HostEntry); + {error, _} -> + LegacySRVName = "_jabber._tcp." ++ Host, + case inet_res:getbyname(LegacySRVName, srv, Timeout) of + {error, timeout} -> + srv_lookup(Host, Timeout, Retries - 1); + {error, _} = Err -> + Err; + {ok, HostEntry} -> + host_entry_to_host_ports(HostEntry) + end + end. + +-spec a_lookup([{inet:hostname(), inet:port_number()}], state()) -> + {ok, [ip_port()]} | network_error(). +a_lookup(HostPorts, State) -> + HostPortFamilies = [{Host, Port, Family} + || {Host, Port} <- HostPorts, + Family <- get_address_families(State)], + a_lookup(HostPortFamilies, State, {error, nxdomain}). + +-spec a_lookup([{inet:hostname(), inet:port_number(), inet:address_family()}], + state(), network_error()) -> {ok, [ip_port()]} | network_error(). +a_lookup([{Host, Port, Family}|HostPortFamilies], State, _) -> + Timeout = get_dns_timeout(State), + Retries = get_dns_retries(State), + case a_lookup(Host, Port, Family, Timeout, Retries) of + {error, _} = Err -> + a_lookup(HostPortFamilies, State, Err); + {ok, AddrPorts} -> + {ok, AddrPorts} + end; +a_lookup([], _State, Err) -> + Err. + +-spec a_lookup(inet:hostname(), inet:port_number(), inet:address_family(), + non_neg_integer(), integer()) -> {ok, [ip_port()]} | network_error(). +a_lookup(_Host, _Port, _Family, _Timeout, Retries) when Retries < 1 -> + {error, timeout}; +a_lookup(Host, Port, Family, Timeout, Retries) -> + case inet:gethostbyname(Host, Family, Timeout) of + {error, timeout} -> + a_lookup(Host, Port, Family, Timeout, Retries - 1); + {error, _} = Err -> + Err; + {ok, HostEntry} -> + host_entry_to_addr_ports(HostEntry, Port) + end. + +-spec host_entry_to_host_ports(inet:hostent()) -> {ok, [host_port()]} | + {error, nxdomain}. +host_entry_to_host_ports(#hostent{h_addr_list = AddrList}) -> + PrioHostPorts = lists:flatmap( + fun({Priority, Weight, Port, Host}) -> + N = case Weight of + 0 -> 0; + _ -> (Weight + 1) * randoms:uniform() + end, + [{Priority * 65536 - N, Host, Port}]; + (_) -> + [] + end, AddrList), + HostPorts = [{Host, Port} + || {_Priority, Host, Port} <- lists:usort(PrioHostPorts)], + case HostPorts of + [] -> {error, nxdomain}; + _ -> {ok, HostPorts} + end. + +-spec host_entry_to_addr_ports(inet:hostent(), inet:port_number()) -> + {ok, [ip_port()]} | {error, nxdomain}. +host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port) -> + AddrPorts = lists:flatmap( + fun(Addr) -> + try get_addr_type(Addr) of + _ -> [{Addr, Port}] + catch _:_ -> + [] + end + end, AddrList), + case AddrPorts of + [] -> {error, nxdomain}; + _ -> {ok, AddrPorts} + end. + +-spec connect([ip_port()], state()) -> {ok, term(), ip_port()} | network_error(). +connect(AddrPorts, #{sockmod := SockMod} = State) -> + Timeout = get_connect_timeout(State), + connect(AddrPorts, SockMod, Timeout, {error, nxdomain}). + +-spec connect([ip_port()], module(), non_neg_integer(), network_error()) -> + {ok, term(), ip_port()} | network_error(). +connect([{Addr, Port}|AddrPorts], SockMod, Timeout, _) -> + Type = get_addr_type(Addr), + case SockMod:connect(Addr, Port, + [binary, {packet, 0}, + {send_timeout, ?TCP_SEND_TIMEOUT}, + {send_timeout_close, true}, + {active, false}, Type], + Timeout) of + {ok, Socket} -> + {ok, Socket, {Addr, Port}}; + Err -> + connect(AddrPorts, SockMod, Timeout, Err) + end; +connect([], _SockMod, _Timeout, Err) -> + Err. + +-spec get_addr_type(inet:ip_address()) -> inet:address_family(). +get_addr_type({_, _, _, _}) -> inet; +get_addr_type({_, _, _, _, _, _, _, _}) -> inet6. + +-spec get_dns_timeout(state()) -> non_neg_integer(). +get_dns_timeout(#{mod := Mod} = State) -> + timer:seconds( + try Mod:dns_timeout(State) + catch _:undef -> 10 + end). + +-spec get_dns_retries(state()) -> non_neg_integer(). +get_dns_retries(#{mod := Mod} = State) -> + try Mod:dns_retries(State) + catch _:undef -> 2 + end. + +-spec get_default_port(state()) -> inet:port_number(). +get_default_port(#{mod := Mod, xmlns := NS} = State) -> + try Mod:default_port(State) + catch _:undef when NS == ?NS_SERVER -> 5269; + _:undef when NS == ?NS_CLIENT -> 5222 + end. + +-spec get_address_families(state()) -> [inet:address_family()]. +get_address_families(#{mod := Mod} = State) -> + try Mod:address_families(State) + catch _:undef -> [inet, inet6] + end. + +-spec get_connect_timeout(state()) -> non_neg_integer(). +get_connect_timeout(#{mod := Mod} = State) -> + timer:seconds( + try Mod:connect_timeout(State) + catch _:undef -> 10 + end). diff --git a/src/xmpp_stream_pkix.erl b/src/xmpp_stream_pkix.erl new file mode 100644 index 000000000..59f5d820e --- /dev/null +++ b/src/xmpp_stream_pkix.erl @@ -0,0 +1,159 @@ +%%%------------------------------------------------------------------- +%%% @author Evgeny Khramtsov +%%% @copyright (C) 2016, Evgeny Khramtsov +%%% @doc +%%% +%%% @end +%%% Created : 13 Dec 2016 by Evgeny Khramtsov +%%%------------------------------------------------------------------- +-module(xmpp_stream_pkix). + +%% API +-export([authenticate/1, authenticate/2]). + +-include("xmpp.hrl"). +-include_lib("public_key/include/public_key.hrl"). +-include("XmppAddr.hrl"). + +%%%=================================================================== +%%% API +%%%=================================================================== +-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state()) + -> {ok, binary()} | {error, binary(), binary()}. +authenticate(State) -> + authenticate(State, <<"">>). + +-spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state(), binary()) + -> {ok, binary()} | {error, binary(), binary()}. +authenticate(#{xmlns := ?NS_SERVER, remote_server := Peer, + sockmod := SockMod, socket := Socket}, _Authzid) -> + case SockMod:get_peer_certificate(Socket) of + {ok, Cert} -> + case SockMod:get_verify_result(Socket) of + 0 -> + case ejabberd_idna:domain_utf8_to_ascii(Peer) of + false -> + {error, <<"Cannot decode remote server name">>, Peer}; + AsciiPeer -> + case lists:any( + fun(D) -> match_domain(AsciiPeer, D) end, + get_cert_domains(Cert)) of + true -> + {ok, Peer}; + false -> + {error, <<"Certificate host name mismatch">>, Peer} + end + end; + VerifyRes -> + {error, fast_tls:get_cert_verify_string(VerifyRes, Cert), Peer} + end; + {error, _Reason} -> + {error, <<"Cannot get peer certificate">>, Peer}; + error -> + {error, <<"Cannot get peer certificate">>, Peer} + end; +authenticate(_State, _Authzid) -> + %% TODO: client PKIX authentication + {error, <<"Client certificate verification not implemented">>, <<"">>}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +get_cert_domains(Cert) -> + TBSCert = Cert#'Certificate'.tbsCertificate, + Subject = case TBSCert#'TBSCertificate'.subject of + {rdnSequence, Subj} -> lists:flatten(Subj); + _ -> [] + end, + Extensions = case TBSCert#'TBSCertificate'.extensions of + Exts when is_list(Exts) -> Exts; + _ -> [] + end, + lists:flatmap( + fun(#'AttributeTypeAndValue'{type = ?'id-at-commonName',value = Val}) -> + case 'OTP-PUB-KEY':decode('X520CommonName', Val) of + {ok, {_, D1}} -> + D = if is_binary(D1) -> D1; + is_list(D1) -> list_to_binary(D1); + true -> error + end, + if D /= error -> + case jid:from_string(D) of + #jid{luser = <<"">>, lserver = LD, + lresource = <<"">>} -> + [LD]; + _ -> [] + end; + true -> [] + end; + _ -> [] + end; + (_) -> [] + end, Subject) ++ + lists:flatmap( + fun(#'Extension'{extnID = ?'id-ce-subjectAltName', + extnValue = Val}) -> + BVal = if is_list(Val) -> list_to_binary(Val); + true -> Val + end, + case 'OTP-PUB-KEY':decode('SubjectAltName', BVal) of + {ok, SANs} -> + lists:flatmap( + fun({otherName, #'AnotherName'{'type-id' = ?'id-on-xmppAddr', + value = XmppAddr}}) -> + case 'XmppAddr':decode('XmppAddr', XmppAddr) of + {ok, D} when is_binary(D) -> + case jid:from_string(D) of + #jid{luser = <<"">>, + lserver = LD, + lresource = <<"">>} -> + case ejabberd_idna:domain_utf8_to_ascii(LD) of + false -> + []; + PCLD -> + [PCLD] + end; + _ -> [] + end; + _ -> [] + end; + ({dNSName, D}) when is_list(D) -> + case jid:from_string(list_to_binary(D)) of + #jid{luser = <<"">>, + lserver = LD, + lresource = <<"">>} -> + [LD]; + _ -> [] + end; + (_) -> [] + end, SANs); + _ -> [] + end; + (_) -> [] + end, Extensions). + +match_domain(Domain, Domain) -> true; +match_domain(Domain, Pattern) -> + DLabels = str:tokens(Domain, <<".">>), + PLabels = str:tokens(Pattern, <<".">>), + match_labels(DLabels, PLabels). + +match_labels([], []) -> true; +match_labels([], [_ | _]) -> false; +match_labels([_ | _], []) -> false; +match_labels([DL | DLabels], [PL | PLabels]) -> + case lists:all(fun (C) -> + $a =< C andalso C =< $z orelse + $0 =< C andalso C =< $9 orelse + C == $- orelse C == $* + end, + binary_to_list(PL)) + of + true -> + Regexp = ejabberd_regexp:sh_to_awk(PL), + case ejabberd_regexp:run(DL, Regexp) of + match -> match_labels(DLabels, PLabels); + nomatch -> false + end; + false -> false + end.