From: Evgeniy Khramtsov Date: Thu, 29 Dec 2016 21:00:36 +0000 (+0300) Subject: More refactoring on session management X-Git-Tag: 17.03-beta~91^2~32 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=e7fe4dc474ed180a4200b2bdefc2ff58b12340c0;p=ejabberd More refactoring on session management --- diff --git a/src/ejabberd_auth.erl b/src/ejabberd_auth.erl index eba0a4038..17bc729cc 100644 --- a/src/ejabberd_auth.erl +++ b/src/ejabberd_auth.erl @@ -188,7 +188,7 @@ try_register(User, Server, Password) -> true -> {atomic, exists}; false -> LServer = jid:nameprep(Server), - case lists:member(LServer, ?MYHOSTS) of + case ejabberd_router:is_my_host(LServer) of true -> Res = lists:foldl(fun (_M, {atomic, ok} = Res) -> Res; (M, _) -> diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index 07d04fbc4..f22960c50 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -37,17 +37,18 @@ compress_methods/1, bind/2, get_password_fun/1, check_password_fun/1, check_password_digest_fun/1, unauthenticated_stream_features/1, authenticated_stream_features/1, - handle_stream_start/2, handle_stream_end/2, handle_stream_close/2, + handle_stream_start/2, handle_stream_end/2, handle_unauthenticated_packet/2, handle_authenticated_packet/2, handle_auth_success/4, handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2, handle_unbinded_packet/2]). %% Hooks --export([handle_unexpected_info/2, handle_unexpected_cast/2, - reject_unauthenticated_packet/2, process_closed/2]). +-export([handle_unexpected_cast/2, + reject_unauthenticated_packet/2, process_closed/2, + process_terminated/2, process_info/2]). %% API -export([get_presence/1, get_subscription/2, get_subscribed/1, open_session/1, call/3, send/2, close/1, close/2, stop/1, establish/1, - copy_state/2, add_hooks/0]). + reply/2, copy_state/2, set_timeout/2, add_hooks/1]). -include("ejabberd.hrl"). -include("xmpp.hrl"). @@ -76,6 +77,9 @@ socket_type() -> call(Ref, Msg, Timeout) -> xmpp_stream_in:call(Ref, Msg, Timeout). +reply(Ref, Reply) -> + xmpp_stream_in:reply(Ref, Reply). + -spec get_presence(pid()) -> presence(). get_presence(Ref) -> call(Ref, get_presence, 1000). @@ -112,37 +116,39 @@ stop(Ref) -> send(Pid, Pkt) when is_pid(Pid) -> xmpp_stream_in:send(Pid, Pkt); send(#{lserver := LServer} = State, Pkt) -> - case ejabberd_hooks:run_fold(c2s_filter_send, LServer, Pkt, [State]) of - drop -> State; - Pkt1 -> xmpp_stream_in:send(State, Pkt1) + case ejabberd_hooks:run_fold(c2s_filter_send, LServer, {Pkt, State}, []) of + {drop, State1} -> State1; + {Pkt1, State1} -> xmpp_stream_in:send(State1, Pkt1) end. +-spec set_timeout(state(), timeout()) -> state(). +set_timeout(State, Timeout) -> + xmpp_stream_in:set_timeout(State, Timeout). + -spec establish(state()) -> state(). establish(State) -> xmpp_stream_in:establish(State). --spec add_hooks() -> ok. -add_hooks() -> - lists:foreach( - fun(Host) -> - ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100), - ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE, - reject_unauthenticated_packet, 100), - ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, - handle_unexpected_info, 100), - ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE, - handle_unexpected_cast, 100) - - end, ?MYHOSTS). +-spec add_hooks(binary()) -> ok. +add_hooks(Host) -> + ejabberd_hooks:add(c2s_closed, Host, ?MODULE, process_closed, 100), + ejabberd_hooks:add(c2s_terminated, Host, ?MODULE, + process_terminated, 100), + ejabberd_hooks:add(c2s_unauthenticated_packet, Host, ?MODULE, + reject_unauthenticated_packet, 100), + ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, + process_info, 100), + ejabberd_hooks:add(c2s_handle_cast, Host, ?MODULE, + handle_unexpected_cast, 100). %% Copies content of one c2s state to another. %% This is needed for session migration from one pid to another. -spec copy_state(state(), state()) -> state(). copy_state(#{owner := Owner} = NewState, - #{jid := JID, resource := Resource, sid := {Time, _}, - auth_module := AuthModule, lserver := LServer, - pres_t := PresT, pres_a := PresA, - pres_f := PresF} = OldState) -> + #{jid := JID, resource := Resource, sid := {Time, _}, + auth_module := AuthModule, lserver := LServer, + pres_t := PresT, pres_a := PresA, + pres_f := PresF} = OldState) -> State1 = case OldState of #{pres_last := Pres, pres_timestamp := PresTS} -> NewState#{pres_last => Pres, pres_timestamp => PresTS}; @@ -158,10 +164,46 @@ copy_state(#{owner := Owner} = NewState, pres_f => PresF}, ejabberd_hooks:run_fold(c2s_copy_state, LServer, State2, [OldState]). +-spec open_session(state()) -> {ok, state()} | state(). +open_session(#{user := U, server := S, resource := R, + sid := SID, ip := IP, auth_module := AuthModule} = State) -> + JID = jid:make(U, S, R), + change_shaper(State), + Conn = get_conn_type(State), + State1 = State#{conn => Conn, resource => R, jid => JID}, + Prio = try maps:get(pres_last, State) of + Pres -> get_priority_from_presence(Pres) + catch _:{badkey, _} -> + undefined + end, + Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}], + ejabberd_sm:open_session(SID, U, S, R, Prio, Info), + xmpp_stream_in:establish(State1). + %%%=================================================================== %%% Hooks %%%=================================================================== -handle_unexpected_info(State, Info) -> +process_info(#{lserver := LServer} = State, + {route, From, To, Packet0}) -> + Packet = xmpp:set_from_to(Packet0, From, To), + {Pass, State1} = case Packet of + #presence{} -> + process_presence_in(State, Packet); + #message{} -> + process_message_in(State, Packet); + #iq{} -> + process_iq_in(State, Packet) + end, + if Pass -> + Packet1 = ejabberd_hooks:run_fold( + user_receive_packet, LServer, Packet, [State1]), + ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), + send(State1, Packet1); + true -> + ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), + State1 + end; +process_info(State, Info) -> ?WARNING_MSG("got unexpected info: ~p", [Info]), State. @@ -173,8 +215,22 @@ reject_unauthenticated_packet(State, Pkt) -> Err = xmpp:err_not_authorized(), xmpp_stream_in:send_error(State, Pkt, Err). -process_closed(State, _Reason) -> - stop(State). +process_closed(State, Reason) -> + stop(State#{stop_reason => Reason}). + +process_terminated(#{socket := Socket, jid := JID} = State, + Reason) -> + Status = format_reason(State, Reason), + ?INFO_MSG("(~s) Closing c2s connection for ~s: ~s", + [ejabberd_socket:pp(Socket), jid:to_string(JID), Status]), + Pres = #presence{type = unavailable, + status = xmpp:mk_text(Status), + from = JID, to = jid:remove_resource(JID)}, + State1 = broadcast_presence_unavailable(State, Pres), + bounce_message_queue(), + State1; +process_terminated(State, _Reason) -> + State. %%%=================================================================== %%% xmpp_stream_in callbacks @@ -248,25 +304,9 @@ bind(R, #{user := U, server := S, access := Access, lang := Lang, end end. --spec open_session(state()) -> {ok, state()} | state(). -open_session(#{user := U, server := S, resource := R, - sid := SID, ip := IP, auth_module := AuthModule} = State) -> - JID = jid:make(U, S, R), - change_shaper(State), - Conn = get_conn_type(State), - State1 = State#{conn => Conn, resource => R, jid => JID}, - Prio = try maps:get(pres_last, State) of - Pres -> get_priority_from_presence(Pres) - catch _:{badkey, _} -> - undefined - end, - Info = [{ip, IP}, {conn, Conn}, {auth_module, AuthModule}], - ejabberd_sm:open_session(SID, U, S, R, Prio, Info), - State1. - handle_stream_start(StreamStart, #{lserver := LServer, ip := IP, lang := Lang} = State) -> - case lists:member(LServer, ?MYHOSTS) of + case ejabberd_router:is_my_host(LServer) of false -> send(State, xmpp:serr_host_unknown()); true -> @@ -284,10 +324,8 @@ handle_stream_start(StreamStart, end. handle_stream_end(Reason, #{lserver := LServer} = State) -> - ejabberd_hooks:run_fold(c2s_closed, LServer, State, [Reason]). - -handle_stream_close(_Reason, #{lserver := LServer} = State) -> - ejabberd_hooks:run_fold(c2s_closed, LServer, State, [normal]). + State1 = State#{stop_reason => Reason}, + ejabberd_hooks:run_fold(c2s_closed, LServer, State1, [Reason]). handle_auth_success(User, Mech, AuthModule, #{socket := Socket, ip := IP, lserver := LServer} = State) -> @@ -296,8 +334,7 @@ handle_auth_success(User, Mech, AuthModule, ejabberd_auth:backend_type(AuthModule), ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), State1 = State#{auth_module => AuthModule}, - ejabberd_hooks:run_fold(c2s_auth_result, LServer, - State1, [true, User]). + ejabberd_hooks:run_fold(c2s_auth_result, LServer, State1, [true, User]). handle_auth_failure(User, Mech, Reason, #{socket := Socket, ip := IP, lserver := LServer} = State) -> @@ -307,16 +344,13 @@ handle_auth_failure(User, Mech, Reason, true -> "" end, ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), - ejabberd_hooks:run_fold(c2s_auth_result, LServer, - State, [false, User]). + ejabberd_hooks:run_fold(c2s_auth_result, LServer, State, [false, User]). handle_unbinded_packet(Pkt, #{lserver := LServer} = State) -> - ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer, - State, [Pkt]). + ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer, State, [Pkt]). handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) -> - ejabberd_hooks:run_fold(c2s_unauthenticated_packet, - LServer, State, [Pkt]). + ejabberd_hooks:run_fold(c2s_unauthenticated_packet, LServer, State, [Pkt]). handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) -> ejabberd_hooks:run_fold(c2s_authenticated_packet, @@ -366,20 +400,22 @@ init([State, Opts]) -> zlib => Zlib, lang => ?MYLANG, server => ?MYNAME, + lserver => ?MYNAME, access => Access, shaper => Shaper}, ejabberd_hooks:run_fold(c2s_init, {ok, State1}, [Opts]). -handle_call(get_presence, _From, #{jid := JID} = State) -> +handle_call(get_presence, From, #{jid := JID} = State) -> Pres = try maps:get(pres_last, State) catch _:{badkey, _} -> BareJID = jid:remove_resource(JID), #presence{from = JID, to = BareJID, type = unavailable} end, - {reply, Pres, State}; -handle_call(get_subscribed, _From, #{pres_f := PresF} = State) -> - Subscribed = ?SETS:to_list(PresF), - {reply, Subscribed, State}; + reply(From, Pres), + State; +handle_call(get_subscribed, From, #{pres_f := PresF} = State) -> + reply(From, ?SETS:to_list(PresF)), + State; handle_call(Request, From, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold( c2s_handle_call, LServer, State, [Request, From]). @@ -387,30 +423,22 @@ handle_call(Request, From, #{lserver := LServer} = State) -> handle_cast(Msg, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_handle_cast, LServer, State, [Msg]). -handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) -> - Packet = xmpp:set_from_to(Packet0, From, To), - {Pass, NewState} = case Packet of - #presence{} -> - process_presence_in(State, Packet); - #message{} -> - process_message_in(State, Packet); - #iq{} -> - process_iq_in(State, Packet) - end, - if Pass -> - Packet1 = ejabberd_hooks:run_fold( - user_receive_packet, LServer, Packet, [NewState]), - ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), - send(NewState, Packet1); - true -> - ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), - NewState - end; handle_info(Info, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_handle_info, LServer, State, [Info]). -terminate(_Reason, _State) -> - ok. +terminate(Reason, #{sid := SID, jid := _, + user := U, server := S, resource := R, + lserver := LServer} = State) -> + Status = format_reason(State, Reason), + case maps:is_key(pres_last, State) of + true -> + ejabberd_sm:close_session_unset_presence(SID, U, S, R, Status); + false -> + ejabberd_sm:close_session(SID, U, S, R) + end, + ejabberd_hooks:run_fold(c2s_terminated, LServer, State, [Reason]); +terminate(Reason, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_terminated, LServer, State, [Reason]). code_change(_OldVsn, State, _Extra) -> {ok, State}. @@ -684,6 +712,15 @@ resource_conflict_action(U, S, R) -> {accept_resource, Rnew} end. +-spec bounce_message_queue() -> ok. +bounce_message_queue() -> + receive {route, From, To, Pkt} -> + ejabberd_router:route(From, To, Pkt), + bounce_message_queue() + after 0 -> + ok + end. + -spec new_uniq_id() -> binary(). new_uniq_id() -> iolist_to_binary( @@ -735,6 +772,14 @@ do_some_magic(#{pres_a := PresA, pres_f := PresF} = State, From) -> end end. +-spec format_reason(state(), term()) -> binary(). +format_reason(#{stop_reason := Reason}, _) -> + xmpp_stream_in:format_error(Reason); +format_reason(_, Reason) when Reason /= normal -> + <<"internal server error">>; +format_reason(_, _) -> + <<"">>. + transform_listen_option(Opt, Opts) -> [Opt|Opts]. diff --git a/src/ejabberd_http.erl b/src/ejabberd_http.erl index c0c7bbbd6..1a33580ea 100644 --- a/src/ejabberd_http.erl +++ b/src/ejabberd_http.erl @@ -322,7 +322,7 @@ add_header(Name, Value, State)-> get_host_really_served(undefined, Provided) -> Provided; get_host_really_served(Default, Provided) -> - case lists:member(Provided, ?MYHOSTS) of + case ejabberd_router:is_my_host(Provided) of true -> Provided; false -> Default end. diff --git a/src/ejabberd_piefxis.erl b/src/ejabberd_piefxis.erl index b6f90ccf8..36d734004 100644 --- a/src/ejabberd_piefxis.erl +++ b/src/ejabberd_piefxis.erl @@ -350,7 +350,7 @@ process_el({xmlstreamelement, #xmlel{name = <<"host">>, JIDS = fxml:get_attr_s(<<"jid">>, Attrs), case jid:from_string(JIDS) of #jid{lserver = S} -> - case lists:member(S, ?MYHOSTS) of + case ejabberd_router:is_my_host(S) of true -> process_users(Els, State#state{server = S}); false -> diff --git a/src/ejabberd_s2s_in.erl b/src/ejabberd_s2s_in.erl index 93f75bfcf..a31af337e 100644 --- a/src/ejabberd_s2s_in.erl +++ b/src/ejabberd_s2s_in.erl @@ -34,7 +34,7 @@ -export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, compress_methods/1, unauthenticated_stream_features/1, authenticated_stream_features/1, - handle_stream_start/2, handle_stream_end/2, handle_stream_close/2, + handle_stream_start/2, handle_stream_end/2, handle_stream_established/1, handle_auth_success/4, handle_auth_failure/4, handle_send/3, handle_recv/3, handle_cdata/2, handle_unauthenticated_packet/2, handle_authenticated_packet/2]). @@ -160,9 +160,6 @@ handle_stream_start(_StreamStart, #{lserver := LServer} = State) -> handle_stream_end(Reason, #{server_host := LServer} = State) -> ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [Reason]). -handle_stream_close(_Reason, #{server_host := LServer} = State) -> - ejabberd_hooks:run_fold(s2s_in_closed, LServer, State, [normal]). - handle_stream_established(State) -> set_idle_timeout(State#{established => true}). diff --git a/src/ejabberd_s2s_out.erl b/src/ejabberd_s2s_out.erl index 72d9dfea8..6069c786c 100644 --- a/src/ejabberd_s2s_out.erl +++ b/src/ejabberd_s2s_out.erl @@ -15,14 +15,14 @@ %% xmpp_stream_out callbacks -export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1, handle_auth_success/2, handle_auth_failure/3, handle_packet/2, - handle_stream_end/2, handle_stream_close/2, + handle_stream_end/2, handle_stream_downgraded/2, handle_recv/3, handle_send/4, handle_cdata/2, handle_stream_established/1, handle_timeout/1]). -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). %% Hooks -export([process_auth_result/2, process_closed/2, handle_unexpected_info/2, - handle_unexpected_cast/2]). + handle_unexpected_cast/2, process_downgraded/2]). %% API -export([start/3, start_link/3, connect/1, close/1, stop/1, send/2, route/2, establish/1, update_state/2, add_hooks/0]). @@ -83,7 +83,9 @@ add_hooks() -> ejabberd_hooks:add(s2s_out_handle_info, Host, ?MODULE, handle_unexpected_info, 100), ejabberd_hooks:add(s2s_out_handle_cast, Host, ?MODULE, - handle_unexpected_cast, 100) + handle_unexpected_cast, 100), + ejabberd_hooks:add(s2s_out_downgraded, Host, ?MODULE, + process_downgraded, 100) end, ?MYHOSTS). %%%=================================================================== @@ -95,25 +97,28 @@ process_auth_result(#{server := LServer, remote_server := RServer} = State, ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: authentication failed;" " bouncing for ~p seconds", [LServer, RServer, Delay]), - State1 = close(State), - State2 = bounce_queue(State1), - xmpp_stream_out:set_timeout(State2, timer:seconds(Delay)); + State1 = State#{on_route => bounce}, + State2 = close(State1), + State3 = bounce_queue(State2), + xmpp_stream_out:set_timeout(State3, timer:seconds(Delay)); process_auth_result(State, true) -> State. +process_closed(#{server := LServer, remote_server := RServer, + on_route := send} = State, + Reason) -> + ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s", + [LServer, RServer, xmpp_stream_out:format_error(Reason)]), + stop(State); process_closed(#{server := LServer, remote_server := RServer} = State, - _Reason) -> + Reason) -> Delay = get_delay(), ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: ~s; " "bouncing for ~p seconds", - [LServer, RServer, - try maps:get(stop_reason, State) of - {error, Why} -> xmpp_stream_out:format_error(Why) - catch _:undef -> <<"unexplained reason">> - end, - Delay]), - State1 = bounce_queue(State), - xmpp_stream_out:set_timeout(State1, timer:seconds(Delay)). + [LServer, RServer, xmpp_stream_out:format_error(Reason), Delay]), + State1 = State#{on_route => bounce}, + State2 = bounce_queue(State1), + xmpp_stream_out:set_timeout(State2, timer:seconds(Delay)). handle_unexpected_info(State, Info) -> ?WARNING_MSG("got unexpected info: ~p", [Info]), @@ -123,6 +128,9 @@ handle_unexpected_cast(State, Msg) -> ?WARNING_MSG("got unexpected cast: ~p", [Msg]), State. +process_downgraded(State, _StreamStart) -> + send(State, xmpp:serr_unsupported_version()). + %%%=================================================================== %%% gen_server callbacks %%%=================================================================== @@ -153,21 +161,19 @@ handle_auth_failure(Mech, Reason, ?INFO_MSG("(~s) Failed outbound s2s ~s authentication ~s -> ~s (~s): ~s", [ejabberd_socket:pp(Socket), Mech, LServer, RServer, ejabberd_config:may_hide_data(jlib:ip_to_list(IP)), Reason]), - State1 = State#{on_route => bounce, - stop_reason => {error, {auth, Reason}}}, + State1 = State#{stop_reason => {auth, Reason}}, ejabberd_hooks:run_fold(s2s_out_auth_result, LServer, State1, [false]). handle_packet(Pkt, #{server := LServer} = State) -> ejabberd_hooks:run_fold(s2s_out_packet, LServer, State, [Pkt]). handle_stream_end(Reason, #{server := LServer} = State) -> - State1 = State#{on_route => bounce, stop_reason => Reason}, - ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [normal]). - -handle_stream_close(Reason, #{server := LServer} = State) -> - State1 = State#{on_route => bounce, stop_reason => Reason}, + State1 = State#{stop_reason => Reason}, ejabberd_hooks:run_fold(s2s_out_closed, LServer, State1, [Reason]). +handle_stream_downgraded(StreamStart, #{server := LServer} = State) -> + ejabberd_hooks:run_fold(s2s_out_downgraded, LServer, State, [StreamStart]). + handle_stream_established(State) -> State1 = State#{on_route => send}, State2 = resend_queue(State1), @@ -183,15 +189,10 @@ handle_send(Pkt, El, Data, #{server := LServer} = State) -> ejabberd_hooks:run_fold(s2s_out_handle_send, LServer, State, [Pkt, El, Data]). -handle_timeout(#{server := LServer, remote_server := RServer, - on_route := Action} = State) -> +handle_timeout(#{on_route := Action} = State) -> case Action of bounce -> stop(State); - queue -> send(State, xmpp:serr_connection_timeout()); - send -> - ?INFO_MSG("Closing outbound s2s connection ~s -> ~s: inactive", - [LServer, RServer]), - stop(State) + _ -> send(State, xmpp:serr_connection_timeout()) end. init([#{server := LServer, remote_server := RServer} = State, Opts]) -> @@ -229,7 +230,7 @@ terminate(Reason, #{server := LServer, ejabberd_s2s:remove_connection({LServer, RServer}, self()), State1 = case Reason of normal -> State; - _ -> State#{stop_reason => {error, internal_failure}} + _ -> State#{stop_reason => internal_failure} end, bounce_queue(State1), bounce_message_queue(State1). @@ -258,8 +259,7 @@ bounce_queue(#{queue := Q} = State) -> -spec bounce_message_queue(state()) -> state(). bounce_message_queue(State) -> - receive - {route, Pkt} -> + receive {route, Pkt} -> State1 = bounce_packet(Pkt, State), bounce_message_queue(State1) after 0 -> @@ -278,21 +278,19 @@ bounce_packet(_, State) -> State. -spec mk_bounce_error(binary(), state()) -> stanza_error(). -mk_bounce_error(Lang, State) -> - try maps:get(stop_reason, State) of - {error, internal_failure} -> +mk_bounce_error(Lang, #{stop_reason := Why}) -> + Reason = xmpp_stream_out:format_error(Why), + case Why of + internal_failure -> xmpp:err_internal_server_error(); - {error, Why} -> - Reason = xmpp_stream_out:format_error(Why), - case Why of - {dns, _} -> - xmpp:err_remote_server_timeout(Reason, Lang); - _ -> - xmpp:err_remote_server_not_found(Reason, Lang) - end - catch _:{badkey, _} -> - xmpp:err_remote_server_not_found() - end. + {dns, _} -> + xmpp:err_remote_server_not_found(Reason, Lang); + _ -> + xmpp:err_remote_server_timeout(Reason, Lang) + end; +mk_bounce_error(_Lang, _State) -> + %% We should not be here. Probably :) + xmpp:err_remote_server_not_found(). -spec get_delay() -> non_neg_integer(). get_delay() -> diff --git a/src/ejabberd_service.erl b/src/ejabberd_service.erl index 13efd15e7..6ecd03a4c 100644 --- a/src/ejabberd_service.erl +++ b/src/ejabberd_service.erl @@ -99,7 +99,7 @@ handle_stream_start(_StreamStart, #{remote_server := RemoteServer, lang := Lang, host_opts := HostOpts} = State) -> - case lists:member(RemoteServer, ?MYHOSTS) of + case ejabberd_router:is_my_host(RemoteServer) of true -> Txt = <<"Unable to register route on existing local domain">>, xmpp_stream_in:send(State, xmpp:serr_conflict(Txt, Lang)); diff --git a/src/ejabberd_sm.erl b/src/ejabberd_sm.erl index 46008bec4..a15d788d0 100644 --- a/src/ejabberd_sm.erl +++ b/src/ejabberd_sm.erl @@ -390,7 +390,8 @@ init([]) -> ejabberd_hooks:add(offline_message_hook, Host, ejabberd_sm, bounce_offline_message, 100), ejabberd_hooks:add(remove_user, Host, - ejabberd_sm, disconnect_removed_user, 100) + ejabberd_sm, disconnect_removed_user, 100), + ejabberd_c2s:add_hooks(Host) end, ?MYHOSTS), ejabberd_commands:register_commands(get_commands_spec()), {ok, #state{}}. diff --git a/src/ejabberd_web_admin.erl b/src/ejabberd_web_admin.erl index 3836beda7..e1c0760e9 100644 --- a/src/ejabberd_web_admin.erl +++ b/src/ejabberd_web_admin.erl @@ -192,7 +192,7 @@ process([<<"server">>, SHost | RPath] = Path, method = Method} = Request) -> Host = jid:nameprep(SHost), - case lists:member(Host, ?MYHOSTS) of + case ejabberd_router:is_my_host(Host) of true -> case get_auth_admin(Auth, HostHTTP, Path, Method) of {ok, {User, Server}} -> diff --git a/src/mod_legacy_auth.erl b/src/mod_legacy_auth.erl index f93b67e05..e9057b432 100644 --- a/src/mod_legacy_auth.erl +++ b/src/mod_legacy_auth.erl @@ -133,8 +133,7 @@ open_session(State, IQ, R) -> case ejabberd_c2s:bind(R, State) of {ok, State1} -> Res = xmpp:make_iq_result(IQ), - State2 = ejabberd_c2s:send(State1, Res), - ejabberd_c2s:establish(State2); + ejabberd_c2s:send(State1, Res); {error, Err, State1} -> Res = xmpp:make_error(IQ, Err), ejabberd_c2s:send(State1, Res) diff --git a/src/mod_s2s_dialback.erl b/src/mod_s2s_dialback.erl index ce9d2705b..4bdda2ca7 100644 --- a/src/mod_s2s_dialback.erl +++ b/src/mod_s2s_dialback.erl @@ -28,7 +28,8 @@ %% gen_mod API -export([start/2, stop/1, depends/2, mod_opt_type/1]). %% Hooks --export([s2s_out_auth_result/2, s2s_in_packet/2, s2s_out_packet/2, +-export([s2s_out_auth_result/2, s2s_out_downgraded/2, + s2s_in_packet/2, s2s_out_packet/2, s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]). -include("ejabberd.hrl"). @@ -57,6 +58,8 @@ start(Host, _Opts) -> s2s_in_packet, 50), ejabberd_hooks:add(s2s_out_packet, Host, ?MODULE, s2s_out_packet, 50), + ejabberd_hooks:add(s2s_out_downgraded, Host, ?MODULE, + s2s_out_downgraded, 50), ejabberd_hooks:add(s2s_out_auth_result, Host, ?MODULE, s2s_out_auth_result, 50) end. @@ -74,6 +77,8 @@ stop(Host) -> s2s_in_packet, 50), ejabberd_hooks:delete(s2s_out_packet, Host, ?MODULE, s2s_out_packet, 50), + ejabberd_hooks:delete(s2s_out_downgraded, Host, ?MODULE, + s2s_out_downgraded, 50), ejabberd_hooks:delete(s2s_out_auth_result, Host, ?MODULE, s2s_out_auth_result, 50). @@ -104,47 +109,56 @@ s2s_out_init(Acc, _Opts) -> s2s_out_closed(#{server := LServer, remote_server := RServer, - db_verify := {StreamID, _Key, _Pid}} = State, _Reason) -> + db_verify := {StreamID, _Key, _Pid}} = State, Reason) -> %% Outbound s2s verificating connection (created at step 1) is %% closed suddenly without receiving the response. %% Building a response on our own Response = #db_verify{from = RServer, to = LServer, id = StreamID, type = error, - sub_els = [mk_error(internal_server_error)]}, + sub_els = [mk_error(Reason)]}, s2s_out_packet(State, Response); s2s_out_closed(State, _Reason) -> State. -s2s_out_auth_result(#{server := LServer, - remote_server := RServer, - db_verify := {StreamID, Key, _Pid}} = State, - _) -> +s2s_out_auth_result(#{db_verify := _} = State, _) -> %% The temporary outbound s2s connect (intended for verification) %% has passed authentication state (either successfully or not, no matter) %% and at this point we can send verification request as described %% in section 2.1.2, step 2 - Request = #db_verify{from = LServer, to = RServer, - key = Key, id = StreamID}, - {stop, ejabberd_s2s_out:send(State, Request)}; + {stop, send_verify_request(State)}; s2s_out_auth_result(#{db_enabled := true, socket := Socket, ip := IP, server := LServer, - remote_server := RServer, - stream_remote_id := StreamID} = State, false) -> + remote_server := RServer} = State, false) -> %% SASL authentication has failed, retrying with dialback %% Sending dialback request, section 2.1.1, step 1 ?INFO_MSG("(~s) Retrying with s2s dialback authentication: ~s -> ~s (~s)", [ejabberd_socket:pp(Socket), LServer, RServer, ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), - Key = make_key(LServer, RServer, StreamID), State1 = maps:remove(stop_reason, State#{on_route => queue}), - State2 = ejabberd_s2s_out:send(State1, #db_result{from = LServer, - to = RServer, - key = Key}), - {stop, State2}; + {stop, send_db_request(State1)}; s2s_out_auth_result(State, _) -> State. +s2s_out_downgraded(#{db_verify := _} = State, _) -> + %% The verifying outbound s2s connection detected non-RFC compliant + %% server, send verification request immediately without auth phase, + %% section 2.1.2, step 2 + {stop, send_verify_request(State)}; +s2s_out_downgraded(#{db_enabled := true, + socket := Socket, ip := IP, + server := LServer, + remote_server := RServer} = State, _) -> + %% non-RFC compliant server detected, send dialback request instantly, + %% section 2.1.1, step 1 + ?INFO_MSG("(~s) Trying s2s dialback authentication with " + "non-RFC compliant server: ~s -> ~s (~s)", + [ejabberd_socket:pp(Socket), LServer, RServer, + ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), + {stop, send_db_request(State)}; +s2s_out_downgraded(State, _) -> + State. + s2s_in_packet(#{stream_id := StreamID} = State, #db_result{from = From, to = To, key = Key, type = undefined}) -> %% Received dialback request, section 2.2.1, step 1 @@ -220,6 +234,23 @@ make_key(From, To, StreamID) -> crypto:hmac(sha256, p1_sha:to_hexlist(crypto:hash(sha256, Secret)), [To, " ", From, " ", StreamID])). +-spec send_verify_request(ejabberd_s2s_out:state()) -> ejabberd_s2s_out:state(). +send_verify_request(#{server := LServer, + remote_server := RServer, + db_verify := {StreamID, Key, _Pid}} = State) -> + Request = #db_verify{from = LServer, to = RServer, + key = Key, id = StreamID}, + ejabberd_s2s_out:send(State, Request). + +-spec send_db_request(ejabberd_s2s_out:state()) -> ejabberd_s2s_out:state(). +send_db_request(#{server := LServer, + remote_server := RServer, + stream_remote_id := StreamID} = State) -> + Key = make_key(LServer, RServer, StreamID), + ejabberd_s2s_out:send(State, #db_result{from = LServer, + to = RServer, + key = Key}). + -spec send_db_result(ejabberd_s2s_in:state(), db_verify()) -> ejabberd_s2s_in:state(). send_db_result(State, #db_verify{from = From, to = To, type = Type, sub_els = Els}) -> @@ -255,6 +286,9 @@ mk_error(forbidden) -> xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG); mk_error(host_unknown) -> xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG); +mk_error({_Class, _Reason} = Why) -> + Txt = xmpp_stream_out:format_error(Why), + xmpp:err_remote_server_not_found(Txt, ?MYLANG); mk_error(_) -> xmpp:err_internal_server_error(). diff --git a/src/mod_sm.erl b/src/mod_sm.erl index 82d68702d..703234419 100644 --- a/src/mod_sm.erl +++ b/src/mod_sm.erl @@ -30,8 +30,9 @@ %% hooks -export([c2s_stream_init/2, c2s_stream_started/2, c2s_stream_features/2, c2s_authenticated_packet/2, c2s_unauthenticated_packet/2, - c2s_unbinded_packet/2, c2s_closed/2, - c2s_handle_send/3, c2s_filter_send/2, c2s_handle_info/2]). + c2s_unbinded_packet/2, c2s_closed/2, c2s_terminated/2, + c2s_handle_send/3, c2s_filter_send/1, c2s_handle_info/2, + c2s_handle_call/3, c2s_handle_recv/3]). -include("xmpp.hrl"). -include("logger.hrl"). @@ -60,13 +61,13 @@ start(Host, _Opts) -> c2s_unbinded_packet, 50), ejabberd_hooks:add(c2s_authenticated_packet, Host, ?MODULE, c2s_authenticated_packet, 50), - ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE, - c2s_handle_send, 50), - ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, - c2s_filter_send, 50), - ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, - c2s_handle_info, 50), - ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50). + ejabberd_hooks:add(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50), + ejabberd_hooks:add(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50), + ejabberd_hooks:add(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50), + ejabberd_hooks:add(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50), + ejabberd_hooks:add(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50), + ejabberd_hooks:add(c2s_closed, Host, ?MODULE, c2s_closed, 50), + ejabberd_hooks:add(c2s_terminated, Host, ?MODULE, c2s_terminated, 50). stop(Host) -> %% TODO: do something with global 'c2s_init' hook @@ -80,13 +81,13 @@ stop(Host) -> c2s_unbinded_packet, 50), ejabberd_hooks:delete(c2s_authenticated_packet, Host, ?MODULE, c2s_authenticated_packet, 50), - ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE, - c2s_handle_send, 50), - ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, - c2s_filter_send, 50), - ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, - c2s_handle_info, 50), - ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50). + ejabberd_hooks:delete(c2s_handle_send, Host, ?MODULE, c2s_handle_send, 50), + ejabberd_hooks:delete(c2s_handle_recv, Host, ?MODULE, c2s_handle_recv, 50), + ejabberd_hooks:delete(c2s_filter_send, Host, ?MODULE, c2s_filter_send, 50), + ejabberd_hooks:delete(c2s_handle_info, Host, ?MODULE, c2s_handle_info, 50), + ejabberd_hooks:delete(c2s_handle_call, Host, ?MODULE, c2s_handle_call, 50), + ejabberd_hooks:delete(c2s_closed, Host, ?MODULE, c2s_closed, 50), + ejabberd_hooks:delete(c2s_terminated, Host, ?MODULE, c2s_terminated, 50). depends(_Host, _Opts) -> []. @@ -115,7 +116,10 @@ c2s_stream_started(#{lserver := LServer, mgmt_options := Opts} = State, mgmt_timeout => ResumeTimeout, mgmt_max_timeout => MaxResumeTimeout, mgmt_ack_timeout => get_ack_timeout(LServer, Opts), - mgmt_resend => get_resend_on_timeout(LServer, Opts)}; + mgmt_resend => get_resend_on_timeout(LServer, Opts), + mgmt_stanzas_in => 0, + mgmt_stanzas_out => 0, + mgmt_stanzas_req => 0}; c2s_stream_started(State, _StreamStart) -> State. @@ -143,8 +147,8 @@ c2s_unbinded_packet(State, #sm_resume{} = Pkt) -> case handle_resume(State, Pkt) of {ok, ResumedState} -> {stop, ResumedState}; - error -> - {stop, State} + {error, State1} -> + {stop, State1} end; c2s_unbinded_packet(State, Pkt) when ?is_sm_packet(Pkt) -> c2s_unauthenticated_packet(State, Pkt); @@ -161,12 +165,26 @@ c2s_authenticated_packet(#{mgmt_state := MgmtState} = State, Pkt) c2s_authenticated_packet(State, Pkt) -> update_num_stanzas_in(State, Pkt). +c2s_handle_recv(#{lang := Lang} = State, El, {error, Why}) -> + Xmlns = xmpp:get_ns(El), + if Xmlns == ?NS_STREAM_MGMT_2; Xmlns == ?NS_STREAM_MGMT_3 -> + Txt = xmpp:io_format_error(Why), + Err = #sm_failed{reason = 'bad-request', + text = xmpp:mk_text(Txt, Lang), + xmlns = Xmlns}, + send(State, Err); + true -> + State + end; +c2s_handle_recv(State, _, _) -> + State. + c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result) when MgmtState == pending; MgmtState == active -> State1 = mgmt_queue_add(State, Pkt), case Result of ok when ?is_stanza(Pkt) -> - send_ack(State1); + send_rack(State1); ok -> State1; {error, _} -> @@ -175,21 +193,57 @@ c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result) c2s_handle_send(State, _Pkt, _Result) -> State. -c2s_filter_send(Pkt, _State) -> - Pkt. +c2s_filter_send({Pkt, State}) -> + {Pkt, State}. + +c2s_handle_call(#{sid := {Time, _}} = State, + {resume_session, Time}, From) -> + ejabberd_c2s:reply(From, {resume, State}), + {stop, State#{mgmt_state => resumed}}; +c2s_handle_call(State, {resume_session, _}, From) -> + ejabberd_c2s:reply(From, {error, <<"Previous session not found">>}), + {stop, State}; +c2s_handle_call(State, _Call, _From) -> + State. -c2s_handle_info(#{mgmt_ack_timer := T, jid := JID} = State, - {timeout, T, ack_timeout}) -> - ?DEBUG("Timeout waiting for stream management acknowledgement of ~s", +c2s_handle_info(#{mgmt_ack_timer := TRef, jid := JID} = State, + {timeout, TRef, ack_timeout}) -> + ?DEBUG("Timed out waiting for stream management acknowledgement of ~s", [jid:to_string(JID)]), State1 = ejabberd_c2s:close(State, _SendTrailer = false), - c2s_closed(State1, ack_timeout); + {stop, transition_to_pending(State1)}; +c2s_handle_info(#{mgmt_state := pending, jid := JID} = State, + {timeout, _, pending_timeout}) -> + ?DEBUG("Timed out waiting for resumption of stream for ~s", + [jid:to_string(JID)]), + ejabberd_c2s:stop(State#{mgmt_state => timeout}); c2s_handle_info(State, _) -> State. -c2s_closed(#{mgmt_state := active} = State, Reason) when Reason /= normal -> - {stop, transition_to_pending(State)}; -c2s_closed(State, _) -> +c2s_closed(State, {stream, _}) -> + State; +c2s_closed(#{mgmt_state := active} = State, Reason) -> + {stop, transition_to_pending(State#{stop_reason => Reason})}; +c2s_closed(State, _Reason) -> + State. + +c2s_terminated(#{mgmt_state := resumed, jid := JID} = State, _Reason) -> + ?INFO_MSG("Closing former stream of resumed session for ~s", + [jid:to_string(JID)]), + bounce_message_queue(), + {stop, State}; +c2s_terminated(#{mgmt_state := MgmtState, mgmt_stanzas_in := In, sid := SID, + user := U, server := S, resource := R} = State, _Reason) -> + case MgmtState of + timeout -> + Info = [{num_stanzas_in, In}], + ejabberd_sm:set_offline_info(SID, U, S, R, Info); + _ -> + ok + end, + route_unacked_stanzas(State), + State; +c2s_terminated(State, _Reason) -> State. %%%=================================================================== @@ -201,17 +255,14 @@ negotiate_stream_mgmt(Pkt, State) -> case Pkt of #sm_enable{} -> handle_enable(State#{mgmt_xmlns => Xmlns}, Pkt); + _ when is_record(Pkt, sm_a); + is_record(Pkt, sm_r); + is_record(Pkt, sm_resume) -> + Err = #sm_failed{reason = 'unexpected-request', xmlns = Xmlns}, + send(State, Err); _ -> - Res = if is_record(Pkt, sm_a); - is_record(Pkt, sm_r); - is_record(Pkt, sm_resume) -> - #sm_failed{reason = 'unexpected-request', - xmlns = Xmlns}; - true -> - #sm_failed{reason = 'bad-request', - xmlns = Xmlns} - end, - send(State, Res) + Err = #sm_failed{reason = 'bad-request', xmlns = Xmlns}, + send(State, Err) end. -spec perform_stream_mgmt(xmpp_element(), state()) -> state(). @@ -223,16 +274,13 @@ perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) -> handle_r(State); #sm_a{} -> handle_a(State, Pkt); + _ when is_record(Pkt, sm_enable); + is_record(Pkt, sm_resume) -> + send(State, #sm_failed{reason = 'unexpected-request', + xmlns = Xmlns}); _ -> - Res = if is_record(Pkt, sm_enable); - is_record(Pkt, sm_resume) -> - #sm_failed{reason = 'unexpected-request', - xmlns = Xmlns}; - true -> - #sm_failed{reason = 'bad-request', - xmlns = Xmlns} - end, - send(State, Res) + send(State, #sm_failed{reason = 'bad-request', + xmlns = Xmlns}) end; _ -> send(State, #sm_failed{reason = 'unsupported-version', xmlns = Xmlns}) @@ -241,7 +289,7 @@ perform_stream_mgmt(Pkt, #{mgmt_xmlns := Xmlns} = State) -> -spec handle_enable(state(), sm_enable()) -> state(). handle_enable(#{mgmt_timeout := DefaultTimeout, mgmt_max_timeout := MaxTimeout, - xmlns := Xmlns, jid := JID} = State, + mgmt_xmlns := Xmlns, jid := JID} = State, #sm_enable{resume = Resume, max = Max}) -> Timeout = if Resume == false -> 0; @@ -264,7 +312,7 @@ handle_enable(#{mgmt_timeout := DefaultTimeout, end, State1 = State#{mgmt_state => active, mgmt_queue => queue_new(), - mgmt_timeout => Timeout * 1000}, + mgmt_timeout => Timeout}, send(State1, Res). -spec handle_r(state()) -> state(). @@ -275,23 +323,26 @@ handle_r(#{mgmt_xmlns := Xmlns, mgmt_stanzas_in := H} = State) -> -spec handle_a(state(), sm_a()) -> state(). handle_a(State, #sm_a{h = H}) -> State1 = check_h_attribute(State, H), - resend_ack(State1). + resend_rack(State1). -spec handle_resume(state(), sm_resume()) -> {ok, state()} | {error, state()}. -handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State, +handle_resume(#{user := User, lserver := LServer, + lang := Lang, socket := Socket} = State, #sm_resume{h = H, previd = PrevID, xmlns = Xmlns}) -> R = case inherit_session_state(State, PrevID) of {ok, InheritedState} -> {ok, InheritedState, H}; {error, Err, InH} -> {error, #sm_failed{reason = 'item-not-found', + text = xmpp:mk_text(Err, Lang), h = InH, xmlns = Xmlns}, Err}; {error, Err} -> {error, #sm_failed{reason = 'item-not-found', + text = xmpp:mk_text(Err, Lang), xmlns = Xmlns}, Err} end, case R of - {ok, ResumedState, NumHandled} -> + {ok, #{jid := JID} = ResumedState, NumHandled} -> State1 = check_h_attribute(ResumedState, NumHandled), #{mgmt_xmlns := AttrXmlns, mgmt_stanzas_in := AttrH} = State1, AttrId = make_resume_id(State1), @@ -307,14 +358,20 @@ handle_resume(#{lserver := LServer, jid := JID, socket := Socket} = State, [ejabberd_socket:pp(Socket), jid:to_string(JID)]), {ok, State5}; {error, El, Msg} -> - ?INFO_MSG("Cannot resume session for ~s: ~s", [jid:to_string(JID), Msg]), + ?INFO_MSG("Cannot resume session for ~s@~s: ~s", + [User, LServer, Msg]), {error, send(State, El)} end. -spec transition_to_pending(state()) -> state(). -transition_to_pending(#{mgmt_state := active} = State) -> - %% TODO - State; +transition_to_pending(#{mgmt_state := active, jid := JID, + lserver := LServer, mgmt_timeout := Timeout} = State) -> + State1 = cancel_ack_timer(State), + ?INFO_MSG("Waiting for resumption of stream for ~s", [jid:to_string(JID)]), + State2 = ejabberd_hooks:run_fold(c2s_session_pending, LServer, State1, []), + State3 = ejabberd_c2s:close(State2, _SendTrailer = false), + erlang:start_timer(timer:seconds(Timeout), self(), pending_timeout), + State3#{mgmt_state => pending}; transition_to_pending(State) -> State. @@ -345,25 +402,25 @@ update_num_stanzas_in(#{mgmt_state := MgmtState, update_num_stanzas_in(State, _El) -> State. -send_ack(#{mgmt_ack_timer := _} = State) -> +send_rack(#{mgmt_ack_timer := _} = State) -> State; -send_ack(#{mgmt_xmlns := Xmlns, +send_rack(#{mgmt_xmlns := Xmlns, mgmt_stanzas_out := NumStanzasOut, mgmt_ack_timeout := AckTimeout} = State) -> State1 = send(State, #sm_r{xmlns = Xmlns}), TRef = erlang:start_timer(AckTimeout, self(), ack_timeout), State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}. -resend_ack(#{mgmt_ack_timer := _, - mgmt_queue := Queue, - mgmt_stanzas_out := NumStanzasOut, - mgmt_stanzas_req := NumStanzasReq} = State) -> +resend_rack(#{mgmt_ack_timer := _, + mgmt_queue := Queue, + mgmt_stanzas_out := NumStanzasOut, + mgmt_stanzas_req := NumStanzasReq} = State) -> State1 = cancel_ack_timer(State), case NumStanzasReq < NumStanzasOut andalso not queue_is_empty(Queue) of - true -> send_ack(State1); + true -> send_rack(State1); false -> State1 end; -resend_ack(State) -> +resend_rack(State) -> State. -spec mgmt_queue_add(state(), xmpp_element()) -> state(). @@ -492,10 +549,22 @@ inherit_session_state(#{user := U, server := S} = State, ResumeID) -> OldPID -> OldSID = {Time, OldPID}, try resume_session(OldSID, State) of - {resume, OldState} -> + {resume, #{mgmt_xmlns := Xmlns, + mgmt_queue := Queue, + mgmt_timeout := Timeout, + mgmt_stanzas_in := NumStanzasIn, + mgmt_stanzas_out := NumStanzasOut} = OldState} -> State1 = ejabberd_c2s:copy_state(State, OldState), - State2 = ejabberd_c2s:open_session(State1), - {ok, State2}; + State2 = State1#{mgmt_xmlns => Xmlns, + mgmt_queue => Queue, + mgmt_timeout => Timeout, + mgmt_stanzas_in => NumStanzasIn, + mgmt_stanzas_out => NumStanzasOut, + mgmt_state => active}, + ejabberd_sm:close_session(OldSID, U, S, R), + State3 = ejabberd_c2s:open_session(State2), + ejabberd_c2s:stop(OldPID), + {ok, State3}; {error, Msg} -> {error, Msg} catch exit:{noproc, _} -> @@ -591,6 +660,15 @@ cancel_ack_timer(#{mgmt_ack_timer := TRef} = State) -> cancel_ack_timer(State) -> State. +-spec bounce_message_queue() -> ok. +bounce_message_queue() -> + receive {route, From, To, Pkt} -> + ejabberd_router:route(From, To, Pkt), + bounce_message_queue() + after 0 -> + ok + end. + %%%=================================================================== %%% Configuration processing %%%=================================================================== diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl index e9c1b3339..a03870643 100644 --- a/src/xmpp_stream_in.erl +++ b/src/xmpp_stream_in.erl @@ -44,7 +44,8 @@ -type state() :: map(). -type stop_reason() :: {stream, reset | stream_error()} | {tls, term()} | - {socket, inet:posix() | closed | timeout}. + {socket, inet:posix() | closed | timeout} | + internal_failure. -callback init(list()) -> {ok, state()} | {stop, term()} | ignore. -callback handle_cast(term(), state()) -> state(). @@ -54,7 +55,6 @@ -callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}. -callback handle_stream_start(state()) -> state(). -callback handle_stream_end(stop_reason(), state()) -> state(). --callback handle_stream_close(stop_reason(), state()) -> state(). -callback handle_cdata(binary(), state()) -> state(). -callback handle_unauthenticated_packet(xmpp_element(), state()) -> state(). -callback handle_authenticated_packet(xmpp_element(), state()) -> state(). @@ -83,7 +83,6 @@ code_change/3, handle_stream_start/1, handle_stream_end/2, - handle_stream_close/2, handle_cdata/2, handle_authenticated_packet/2, handle_unauthenticated_packet/2, @@ -193,6 +192,8 @@ format_error({stream, #stream_error{reason = Reason, text = Txt}}) -> format("Stream failed: ~s", [format_stream_error(Reason, Txt)]); format_error({tls, Reason}) -> format("TLS failed: ~w", [Reason]); +format_error(internal_failure) -> + <<"Internal server error">>; format_error(Err) -> format("Unrecognized error: ~w", [Err]). @@ -263,75 +264,78 @@ handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, #{stream_state := wait_for_stream, xmlns := XMLNS, lang := MyLang} = State) -> El = #xmlel{name = Name, attrs = Attrs}, - try xmpp:decode(El, XMLNS, []) of - #stream_start{} = Pkt -> - State1 = send_header(State, Pkt), - case is_disconnected(State1) of - true -> State1; - false -> noreply(process_stream(Pkt, State1)) - end; - _ -> - State1 = send_header(State), - case is_disconnected(State1) of - true -> State1; - false -> noreply(send_element(State1, xmpp:serr_invalid_xml())) - end - catch _:{xmpp_codec, Why} -> - State1 = send_header(State), - case is_disconnected(State1) of - true -> State1; - false -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - Err = xmpp:serr_invalid_xml(Txt, Lang), - noreply(send_element(State1, Err)) - end - end; + noreply( + try xmpp:decode(El, XMLNS, []) of + #stream_start{} = Pkt -> + State1 = send_header(State, Pkt), + case is_disconnected(State1) of + true -> State1; + false -> process_stream(Pkt, State1) + end; + _ -> + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> send_element(State1, xmpp:serr_invalid_xml()) + end + catch _:{xmpp_codec, Why} -> + State1 = send_header(State), + case is_disconnected(State1) of + true -> State1; + false -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + Err = xmpp:serr_invalid_xml(Txt, Lang), + send_element(State1, Err) + end + end); handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> State1 = send_header(State), - case is_disconnected(State1) of - true -> State1; - false -> - Err = case Reason of - <<"XML stanza is too big">> -> - xmpp:serr_policy_violation(Reason, Lang); - _ -> - xmpp:serr_not_well_formed() - end, - noreply(send_element(State1, Err)) - end; + noreply( + case is_disconnected(State1) of + true -> State1; + false -> + Err = case Reason of + <<"XML stanza is too big">> -> + xmpp:serr_policy_violation(Reason, Lang); + _ -> + xmpp:serr_not_well_formed() + end, + send_element(State1, Err) + end); handle_info({'$gen_event', {xmlstreamelement, El}}, #{xmlns := NS, lang := MyLang, mod := Mod} = State) -> - try xmpp:decode(El, NS, [ignore_els]) of - Pkt -> - State1 = try Mod:handle_recv(El, Pkt, State) - catch _:undef -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> noreply(process_element(Pkt, State1)) - end - catch _:{xmpp_codec, Why} -> - State1 = try Mod:handle_recv(El, {error, Why}, State) - catch _:undef -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang))) - end - end; + noreply( + try xmpp:decode(El, NS, [ignore_els]) of + Pkt -> + State1 = try Mod:handle_recv(El, Pkt, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> process_element(Pkt, State1) + end + catch _:{xmpp_codec, Why} -> + State1 = try Mod:handle_recv(El, {error, Why}, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + send_error(State1, El, xmpp:err_bad_request(Txt, Lang)) + end + end); handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, #{mod := Mod} = State) -> noreply(try Mod:handle_cdata(Data, State) catch _:undef -> State end); handle_info({'$gen_event', {xmlstreamend, _}}, State) -> - noreply(process_stream_end({error, {stream, reset}}, State)); + noreply(process_stream_end({stream, reset}, State)); handle_info({'$gen_event', closed}, State) -> - noreply(process_stream_close({error, {socket, closed}}, State)); + noreply(process_stream_end({socket, closed}, State)); handle_info(timeout, #{mod := Mod} = State) -> Disconnected = is_disconnected(State), noreply(try Mod:handle_timeout(State) @@ -342,7 +346,7 @@ handle_info(timeout, #{mod := Mod} = State) -> end); handle_info({'DOWN', MRef, _Type, _Object, _Info}, #{socket_monitor := MRef} = State) -> - noreply(process_stream_close({error, {socket, closed}}, State)); + noreply(process_stream_end({socket, closed}, State)); handle_info(Info, #{mod := Mod} = State) -> noreply(try Mod:handle_info(Info, State) catch _:undef -> State @@ -390,15 +394,6 @@ peername(SockMod, Socket) -> _ -> SockMod:peername(Socket) end. --spec process_stream_close(stop_reason(), state()) -> state(). -process_stream_close(_, #{stream_state := disconnected} = State) -> - State; -process_stream_close(Reason, #{mod := Mod} = State) -> - State1 = send_trailer(State), - try Mod:handle_stream_close(Reason, State1) - catch _:undef -> stop(State1) - end. - -spec process_stream_end(stop_reason(), state()) -> state(). process_stream_end(_, #{stream_state := disconnected} = State) -> State; @@ -414,6 +409,8 @@ process_stream(#stream_start{xmlns = XML_NS, #{xmlns := NS} = State) when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> send_element(State, xmpp:serr_invalid_namespace()); +process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> + send_element(State, xmpp:serr_unsupported_version()); process_stream(#stream_start{lang = Lang}, #{xmlns := ?NS_CLIENT, lang := DefaultLang} = State) when size(Lang) > 35 -> @@ -520,7 +517,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) -> #handshake{} -> State; #stream_error{} -> - process_stream_end({error, {stream, Pkt}}, State); + process_stream_end({stream, Pkt}, State); _ when StateName == wait_for_sasl_request; StateName == wait_for_handshake; StateName == wait_for_sasl_response -> @@ -704,7 +701,7 @@ process_starttls_failure(Why, State) -> State1 = send_element(State, #starttls_failure{}), case is_disconnected(State1) of true -> State1; - false -> process_stream_end({error, {tls, Why}}, State1) + false -> process_stream_end({tls, Why}, State1) end. -spec process_sasl_request(sasl_auth(), state()) -> state(). @@ -939,8 +936,8 @@ set_from_to(Pkt, #{lang := Lang}) -> end. -spec send_header(state()) -> state(). -send_header(State) -> - send_header(State, #stream_start{}). +send_header(#{stream_version := Version} = State) -> + send_header(State, #stream_start{version = Version}). -spec send_header(state(), stream_start()) -> state(). send_header(#{stream_id := StreamID, @@ -959,8 +956,9 @@ send_header(#{stream_id := StreamID, undefined -> jid:make(DefaultServer) end, Version = case HisVersion of - undefined -> MyVersion; - _ -> HisVersion + undefined -> undefined; + {0,_} -> HisVersion; + _ -> MyVersion end, Header = xmpp:encode(#stream_start{version = Version, lang = Lang, @@ -969,10 +967,12 @@ send_header(#{stream_id := StreamID, db_xmlns = NS_DB, id = StreamID, from = From}), - State1 = State#{lang => Lang, stream_header_sent => true}, + State1 = State#{lang => Lang, + stream_version => Version, + stream_header_sent => true}, case send_text(State1, fxml:element_to_header(Header)) of ok -> State1; - {error, Why} -> process_stream_close({error, {socket, Why}}, State1) + {error, Why} -> process_stream_end({socket, Why}, State1) end; send_header(State, _) -> State. @@ -987,11 +987,11 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> end, case Result of _ when is_record(Pkt, stream_error) -> - process_stream_end({error, {stream, Pkt}}, State1); + process_stream_end({stream, Pkt}, State1); ok -> State1; {error, Why} -> - process_stream_close({error, {socket, Why}}, State1) + process_stream_end({socket, Why}, State1) end. -spec send_error(state(), xmpp_element(), stanza_error()) -> state(). @@ -1022,7 +1022,7 @@ send_text(#{socket := Sock, sockmod := SockMod, stream_header_sent := true}, Data) when StateName /= disconnected -> SockMod:send(Sock, Data); send_text(_, _) -> - {error, einval}. + {error, closed}. -spec close_socket(state()) -> state(). close_socket(#{sockmod := SockMod, socket := Socket} = State) -> diff --git a/src/xmpp_stream_out.erl b/src/xmpp_stream_out.erl index fc373fff8..08804e432 100644 --- a/src/xmpp_stream_out.erl +++ b/src/xmpp_stream_out.erl @@ -33,6 +33,7 @@ -include_lib("kernel/include/inet.hrl"). -type state() :: map(). +-type noreply() :: {noreply, state(), timeout()}. -type host_port() :: {inet:hostname(), inet:port_number()}. -type ip_port() :: {inet:ip_address(), inet:port_number()}. -type network_error() :: {error, inet:posix() | inet_res:res_error()}. @@ -42,7 +43,8 @@ {tls, term()} | {pkix, binary()} | {auth, atom() | binary() | string()} | - {socket, inet:posix() | closed | timeout}. + {socket, inet:posix() | closed | timeout} | + internal_failure. -callback init(list()) -> {ok, state()} | {stop, term()} | ignore. @@ -107,7 +109,7 @@ close(_, _) -> establish(State) -> process_stream_established(State). --spec set_timeout(state(), non_neg_integer() | infinity) -> state(). +-spec set_timeout(state(), timeout()) -> state(). set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() -> case Timeout of infinity -> State#{stream_timeout => infinity}; @@ -148,12 +150,15 @@ format_error({tls, Reason}) -> format("TLS failed: ~w", [Reason]); format_error({auth, Reason}) -> format("Authentication failed: ~s", [Reason]); +format_error(internal_failure) -> + <<"Internal server error">>; format_error(Err) -> format("Unrecognized error: ~w", [Err]). %%%=================================================================== %%% gen_server callbacks %%%=================================================================== +-spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore. init([Mod, SockMod, From, To, Opts]) -> Time = p1_time_compat:monotonic_time(milli_seconds), State = #{owner => self(), @@ -183,36 +188,38 @@ init([Mod, SockMod, From, To, Opts]) -> Err end. +-spec handle_call(term(), term(), state()) -> noreply(). handle_call(Call, From, #{mod := Mod} = State) -> noreply(try Mod:handle_call(Call, From, State) catch _:undef -> State end). +-spec handle_cast(term(), state()) -> noreply(). handle_cast(connect, #{remote_server := RemoteServer, sockmod := SockMod, stream_state := connecting} = State) -> - case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of - false -> - noreply(process_stream_close({error, {idna, bad_string}}, State)); - ASCIIName -> - case resolve(binary_to_list(ASCIIName), State) of - {ok, AddrPorts} -> - case connect(AddrPorts, State) of - {ok, Socket, AddrPort} -> - SocketMonitor = SockMod:monitor(Socket), - State1 = State#{ip => AddrPort, - socket => Socket, - socket_monitor => SocketMonitor}, - State2 = State1#{stream_state => wait_for_stream}, - noreply(send_header(State2)); - {error, Why} -> - Err = {error, {socket, Why}}, - noreply(process_stream_close(Err, State)) - end; - {error, Why} -> - noreply(process_stream_close({error, {dns, Why}}, State)) - end - end; + noreply( + case ejabberd_idna:domain_utf8_to_ascii(RemoteServer) of + false -> + process_stream_end({idna, bad_string}, State); + ASCIIName -> + case resolve(binary_to_list(ASCIIName), State) of + {ok, AddrPorts} -> + case connect(AddrPorts, State) of + {ok, Socket, AddrPort} -> + SocketMonitor = SockMod:monitor(Socket), + State1 = State#{ip => AddrPort, + socket => Socket, + socket_monitor => SocketMonitor}, + State2 = State1#{stream_state => wait_for_stream}, + send_header(State2); + {error, Why} -> + process_stream_end({socket, Why}, State) + end; + {error, Why} -> + process_stream_end({dns, Why}, State) + end + end); handle_cast(connect, State) -> %% Ignoring connection attempts in other states noreply(State); @@ -225,66 +232,70 @@ handle_cast(Cast, #{mod := Mod} = State) -> catch _:undef -> State end). +-spec handle_info(term(), state()) -> noreply(). handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, #{stream_state := wait_for_stream, xmlns := XMLNS, lang := MyLang} = State) -> El = #xmlel{name = Name, attrs = Attrs}, - try xmpp:decode(El, XMLNS, []) of - #stream_start{} = Pkt -> - noreply(process_stream(Pkt, State)); - _ -> - noreply(send_element(State, xmpp:serr_invalid_xml())) - catch _:{xmpp_codec, Why} -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - Err = xmpp:serr_invalid_xml(Txt, Lang), - noreply(send_element(State, Err)) - end; + noreply( + try xmpp:decode(El, XMLNS, []) of + #stream_start{} = Pkt -> + process_stream(Pkt, State); + _ -> + send_element(State, xmpp:serr_invalid_xml()) + catch _:{xmpp_codec, Why} -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + Err = xmpp:serr_invalid_xml(Txt, Lang), + send_element(State, Err) + end); handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) -> State1 = send_header(State), - case is_disconnected(State1) of - true -> State1; - false -> - Err = case Reason of - <<"XML stanza is too big">> -> - xmpp:serr_policy_violation(Reason, Lang); - _ -> - xmpp:serr_not_well_formed() - end, - noreply(send_element(State1, Err)) - end; + noreply( + case is_disconnected(State1) of + true -> State1; + false -> + Err = case Reason of + <<"XML stanza is too big">> -> + xmpp:serr_policy_violation(Reason, Lang); + _ -> + xmpp:serr_not_well_formed() + end, + send_element(State1, Err) + end); handle_info({'$gen_event', {xmlstreamelement, El}}, #{xmlns := NS, lang := MyLang, mod := Mod} = State) -> - try xmpp:decode(El, NS, [ignore_els]) of - Pkt -> - State1 = try Mod:handle_recv(El, Pkt, State) - catch _:undef -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> noreply(process_element(Pkt, State1)) - end - catch _:{xmpp_codec, Why} -> - State1 = try Mod:handle_recv(El, undefined, State) - catch _:undef -> State - end, - case is_disconnected(State1) of - true -> State1; - false -> - Txt = xmpp:io_format_error(Why), - Lang = select_lang(MyLang, xmpp:get_lang(El)), - noreply(send_error(State1, El, xmpp:err_bad_request(Txt, Lang))) - end - end; + noreply( + try xmpp:decode(El, NS, [ignore_els]) of + Pkt -> + State1 = try Mod:handle_recv(El, Pkt, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> process_element(Pkt, State1) + end + catch _:{xmpp_codec, Why} -> + State1 = try Mod:handle_recv(El, undefined, State) + catch _:undef -> State + end, + case is_disconnected(State1) of + true -> State1; + false -> + Txt = xmpp:io_format_error(Why), + Lang = select_lang(MyLang, xmpp:get_lang(El)), + send_error(State1, El, xmpp:err_bad_request(Txt, Lang)) + end + end); handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, #{mod := Mod} = State) -> noreply(try Mod:handle_cdata(Data, State) catch _:undef -> State end); handle_info({'$gen_event', {xmlstreamend, _}}, State) -> - noreply(process_stream_end({error, {stream, reset}}, State)); + noreply(process_stream_end({stream, reset}, State)); handle_info({'$gen_event', closed}, State) -> - noreply(process_stream_close({error, {socket, closed}}, State)); + noreply(process_stream_end({socket, closed}, State)); handle_info(timeout, #{mod := Mod} = State) -> Disconnected = is_disconnected(State), noreply(try Mod:handle_timeout(State) @@ -295,12 +306,13 @@ handle_info(timeout, #{mod := Mod} = State) -> end); handle_info({'DOWN', MRef, _Type, _Object, _Info}, #{socket_monitor := MRef} = State) -> - noreply(process_stream_close({error, {socket, closed}}, State)); + noreply(process_stream_end({socket, closed}, State)); handle_info(Info, #{mod := Mod} = State) -> noreply(try Mod:handle_info(Info, State) catch _:undef -> State end). +-spec terminate(term(), state()) -> any(). terminate(Reason, #{mod := Mod} = State) -> case get(already_terminated) of true -> @@ -319,7 +331,7 @@ code_change(OldVsn, #{mod := Mod} = State, Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== --spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}. +-spec noreply(state()) -> noreply(). noreply(#{stream_timeout := infinity} = State) -> {noreply, State, infinity}; noreply(#{stream_timeout := {MSecs, OldTime}} = State) -> @@ -335,15 +347,6 @@ new_id() -> is_disconnected(#{stream_state := StreamState}) -> StreamState == disconnected. --spec process_stream_close(stop_reason(), state()) -> state(). -process_stream_close(_, #{stream_state := disconnected} = State) -> - State; -process_stream_close(Reason, #{mod := Mod} = State) -> - State1 = send_trailer(State), - try Mod:handle_stream_close(Reason, State1) - catch _:undef -> stop(State1) - end. - -spec process_stream_end(stop_reason(), state()) -> state(). process_stream_end(_, #{stream_state := disconnected} = State) -> State; @@ -359,6 +362,8 @@ process_stream(#stream_start{xmlns = XML_NS, #{xmlns := NS} = State) when XML_NS /= NS; STREAM_NS /= ?NS_STREAM -> send_element(State, xmpp:serr_invalid_namespace()); +process_stream(#stream_start{version = {N, _}}, State) when N > 1 -> + send_element(State, xmpp:serr_unsupported_version()); process_stream(#stream_start{lang = Lang, id = ID, version = Version} = StreamStart, #{mod := Mod} = State) -> @@ -370,8 +375,10 @@ process_stream(#stream_start{lang = Lang, id = ID, true -> State2; false -> case Version of - {1,0} -> State2#{stream_state => wait_for_features}; - _ -> process_stream_downgrade(StreamStart, State) + {1, _} -> + State2#{stream_state => wait_for_features}; + _ -> + process_stream_downgrade(StreamStart, State2) end end. @@ -387,7 +394,7 @@ process_element(Pkt, #{stream_state := StateName} = State) -> #sasl_failure{} when StateName == wait_for_sasl_response -> process_sasl_failure(Pkt, State); #stream_error{} -> - process_stream_end({error, {stream, Pkt}}, State); + process_stream_end({stream, Pkt}, State); _ when is_record(Pkt, stream_features); is_record(Pkt, starttls_proceed); is_record(Pkt, starttls); @@ -487,14 +494,23 @@ process_starttls(#{sockmod := SockMod, socket := Socket, mod := Mod} = State) -> stream_encrypted => true}, send_header(State1); {error, Why} -> - process_stream_close({error, {tls, Why}}, State) + process_stream_end({tls, Why}, State) end. -spec process_stream_downgrade(stream_start(), state()) -> state(). -process_stream_downgrade(StreamStart, #{mod := Mod} = State) -> - try Mod:downgrade_stream(StreamStart, State) - catch _:undef -> - send_element(State, xmpp:serr_unsupported_version()) +process_stream_downgrade(StreamStart, + #{mod := Mod, lang := Lang, + stream_encrypted := Encrypted} = State) -> + TLSRequired = is_starttls_required(State), + if not Encrypted and TLSRequired -> + Txt = <<"Use of STARTTLS required">>, + send_element(State, xmpp:err_policy_violation(Txt, Lang)); + true -> + State1 = State#{stream_state => downgraded}, + try Mod:handle_stream_downgraded(StreamStart, State1) + catch _:undef -> + send_element(State1, xmpp:serr_unsupported_version()) + end end. -spec process_cert_verification(state()) -> state(). @@ -509,7 +525,7 @@ process_cert_verification(#{stream_encrypted := true, {ok, _} -> State#{stream_verified => true}; {error, Why, _Peer} -> - process_stream_close({error, {pkix, Why}}, State) + process_stream_end({pkix, Why}, State) end; false -> State#{stream_verified => true} @@ -538,7 +554,7 @@ process_sasl_success(#{mod := Mod, -spec process_sasl_failure(sasl_failure(), state()) -> state(). process_sasl_failure(#sasl_failure{reason = Reason}, #{mod := Mod} = State) -> try Mod:handle_auth_failure(<<"EXTERNAL">>, Reason, State) - catch _:undef -> process_stream_close({error, {auth, Reason}}, State) + catch _:undef -> process_stream_end({auth, Reason}, State) end. -spec process_packet(xmpp_element(), state()) -> state(). @@ -581,7 +597,7 @@ send_header(#{remote_server := RemoteServer, version = {1,0}}), case send_text(State, fxml:element_to_header(Header)) of ok -> State; - {error, Why} -> process_stream_close({error, {socket, Why}}, State) + {error, Why} -> process_stream_end({socket, Why}, State) end. -spec send_element(state(), xmpp_element()) -> state(). @@ -596,11 +612,11 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) -> false -> case send_text(State1, Data) of _ when is_record(Pkt, stream_error) -> - process_stream_end({error, {stream, Pkt}}, State1); + process_stream_end({stream, Pkt}, State1); ok -> State1; {error, Why} -> - process_stream_close({error, {socket, Why}}, State1) + process_stream_end({socket, Why}, State1) end end. @@ -626,7 +642,7 @@ send_text(#{sockmod := SockMod, socket := Socket, stream_state := StateName}, Data) when StateName /= disconnected -> SockMod:send(Socket, Data); send_text(_, _) -> - {error, einval}. + {error, closed}. -spec send_trailer(state()) -> state(). send_trailer(State) ->