From 7f653cfe762ecf33ae9522b8df25f2902d5546df Mon Sep 17 00:00:00 2001 From: Evgeniy Khramtsov Date: Sun, 11 Dec 2016 18:24:51 +0300 Subject: [PATCH] Rewrite ejabberd_service to use new XMPP stream API --- src/cyrsasl.erl | 1 + src/ejabberd_c2s.erl | 137 +++++++------- src/ejabberd_service.erl | 398 ++++++++++++--------------------------- src/xmpp_stream_in.erl | 104 ++++++++-- 4 files changed, 286 insertions(+), 354 deletions(-) diff --git a/src/cyrsasl.erl b/src/cyrsasl.erl index e23196475..c49f8a3cb 100644 --- a/src/cyrsasl.erl +++ b/src/cyrsasl.erl @@ -71,6 +71,7 @@ mech_state }). -type sasl_state() :: #sasl_state{}. +-export_type([sasl_state/0]). -callback mech_new(binary(), fun(), fun(), fun()) -> any(). -callback mech_step(any(), binary()) -> {ok, props()} | diff --git a/src/ejabberd_c2s.erl b/src/ejabberd_c2s.erl index 1568d5db6..b5113c34b 100644 --- a/src/ejabberd_c2s.erl +++ b/src/ejabberd_c2s.erl @@ -21,20 +21,24 @@ %%%------------------------------------------------------------------- -module(ejabberd_c2s). -behaviour(xmpp_stream_in). +-behaviour(ejabberd_config). -protocol({rfc, 6121}). %% ejabberd_socket callbacks -export([start/2, socket_type/0]). +%% ejabberd_config callbacks +-export([opt_type/1, transform_listen_option/2]). %% xmpp_stream_in callbacks -export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]). --export([tls_options/1, tls_required/1, sasl_mechanisms/1, init_sasl/1, bind/2, +-export([tls_options/1, tls_required/1, compress_methods/1, + sasl_mechanisms/1, init_sasl/1, bind/2, handshake/2, unauthenticated_stream_features/1, authenticated_stream_features/1, handle_stream_start/1, handle_stream_end/1, handle_stream_close/1, handle_unauthenticated_packet/2, handle_authenticated_packet/2, - handle_auth_success/4, handle_auth_failure/4, - handle_unbinded_packet/2]). + handle_auth_success/4, handle_auth_failure/4, handle_send/5, + handle_unbinded_packet/2, handle_cdata/2]). %% API -export([get_presence/1, get_subscription/2, get_subscribed/1, send/2, close/1]). @@ -99,8 +103,7 @@ send(State, Pkt) -> %%%=================================================================== %%% xmpp_stream_in callbacks %%%=================================================================== -tls_options(#{server := Server, tls_options := TLSOpts}) -> - LServer = jid:nameprep(Server), +tls_options(#{lserver := LServer, tls_options := TLSOpts}) -> case ejabberd_config:get_option({domain_certfile, LServer}, fun iolist_to_binary/1) of undefined -> @@ -112,19 +115,21 @@ tls_options(#{server := Server, tls_options := TLSOpts}) -> tls_required(#{tls_required := TLSRequired}) -> TLSRequired. -unauthenticated_stream_features(#{server := Server}) -> - LServer = jid:nameprep(Server), +compress_methods(#{zlib := true}) -> + [<<"zlib">>]; +compress_methods(_) -> + []. + +sasl_mechanisms(#{lserver := LServer}) -> + cyrsasl:listmech(LServer). + +unauthenticated_stream_features(#{lserver := LServer}) -> ejabberd_hooks:run_fold(c2s_pre_auth_features, LServer, [], [LServer]). -authenticated_stream_features(#{server := Server}) -> - LServer = jid:nameprep(Server), +authenticated_stream_features(#{lserver := LServer}) -> ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]). -sasl_mechanisms(#{server := Server}) -> - cyrsasl:listmech(jid:nameprep(Server)). - -init_sasl(#{server := Server}) -> - LServer = jid:nameprep(Server), +init_sasl(#{lserver := LServer}) -> cyrsasl:server_new( <<"jabber">>, LServer, <<"">>, [], fun(U) -> @@ -147,8 +152,11 @@ bind(R, #{user := U, server := S} = State) -> open_session(State, Resource) end. -handle_stream_start(#{server := Server, ip := IP, lang := Lang} = State) -> - LServer = jid:nameprep(Server), +handshake(_Data, State) -> + %% This is only for jabber component + {ok, State}. + +handle_stream_start(#{lserver := LServer, ip := IP, lang := Lang} = State) -> case lists:member(LServer, ?MYHOSTS) of false -> xmpp_stream_in:send(State, xmpp:serr_host_unknown()); @@ -172,8 +180,7 @@ handle_stream_close(State) -> {stop, normal, State}. handle_auth_success(User, Mech, AuthModule, - #{socket := Socket, ip := IP, server := Server} = State) -> - LServer = jid:nameprep(Server), + #{socket := Socket, ip := IP, lserver := LServer} = State) -> ?INFO_MSG("(~w) Accepted ~s authentication for ~s@~s by ~p from ~s", [Socket, Mech, User, LServer, AuthModule, ejabberd_config:may_hide_data(jlib:ip_to_list(IP))]), @@ -182,8 +189,7 @@ handle_auth_success(User, Mech, AuthModule, {noreply, State1}, [true, User]). handle_auth_failure(User, Mech, Reason, - #{socket := Socket, ip := IP, server := Server} = State) -> - LServer = jid:nameprep(Server), + #{socket := Socket, ip := IP, lserver := LServer} = State) -> ?INFO_MSG("(~w) Failed ~s authentication ~sfrom ~s: ~s", [Socket, Mech, if User /= <<"">> -> ["for ", User, "@", LServer, " "]; @@ -193,22 +199,18 @@ handle_auth_failure(User, Mech, Reason, ejabberd_hooks:run_fold(c2s_auth_result, LServer, {noreply, State}, [false, User]). -handle_unbinded_packet(Pkt, #{server := Server} = State) -> - LServer = jid:nameprep(Server), +handle_unbinded_packet(Pkt, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_unbinded_packet, LServer, {noreply, State}, [Pkt]). -handle_unauthenticated_packet(Pkt, #{server := Server} = State) -> - LServer = jid:nameprep(Server), +handle_unauthenticated_packet(Pkt, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_unauthenticated_packet, LServer, {noreply, State}, [Pkt]). -handle_authenticated_packet(Pkt, #{server := Server} = State) when not ?is_stanza(Pkt) -> - LServer = jid:nameprep(Server), +handle_authenticated_packet(Pkt, #{lserver := LServer} = State) when not ?is_stanza(Pkt) -> ejabberd_hooks:run_fold(c2s_authenticated_packet, LServer, {noreply, State}, [Pkt]); -handle_authenticated_packet(Pkt, #{server := Server} = State) -> - LServer = jid:nameprep(Server), +handle_authenticated_packet(Pkt, #{lserver := LServer} = State) -> case ejabberd_hooks:run_fold(c2s_authenticated_packet, LServer, {noreply, State}, [Pkt]) of {noreply, State1} -> @@ -228,6 +230,14 @@ handle_authenticated_packet(Pkt, #{server := Server} = State) -> Err end. +handle_cdata(Data, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_handle_cdata, LServer, + {noreply, State}, [Data]). + +handle_send(Reason, Pkt, El, Data, #{lserver := LServer} = State) -> + ejabberd_hooks:run_fold(c2s_handle_send, LServer, + {noreply, State}, [Reason, Pkt, El, Data]). + init([State, Opts]) -> Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all), Shaper = gen_mod:get_opt(shaper, Opts, fun acl:shaper_rules_validator/1, none), @@ -239,6 +249,7 @@ init([State, Opts]) -> end, Opts), TLSRequired = proplists:get_bool(starttls_required, Opts), TLSVerify = proplists:get_bool(tls_verify, Opts), + Zlib = proplists:get_bool(zlib, Opts), State1 = State#{tls_options => TLSOpts, tls_required => TLSRequired, tls_verify => TLSVerify, @@ -246,19 +257,18 @@ init([State, Opts]) -> pres_f => ?SETS:new(), pres_t => ?SETS:new(), sid => ejabberd_sm:make_sid(), + zlib => Zlib, lang => ?MYLANG, server => ?MYNAME, access => Access, shaper => Shaper}, ejabberd_hooks:run_fold(c2s_init, {ok, State1}, []). -handle_call(get_presence, _From, - #{user := U, server := S, resource := R} = State) -> +handle_call(get_presence, _From, #{jid := JID} = State) -> Pres = case maps:get(pres_last, State, undefined) of undefined -> - From = jid:make(U, S, R), - To = jid:remove_resource(From), - #presence{from = From, to = To, type = unavailable}; + BareJID = jid:remove_resource(JID), + #presence{from = JID, to = BareJID, type = unavailable}; P -> P end, @@ -266,20 +276,17 @@ handle_call(get_presence, _From, handle_call(get_subscribed, _From, #{pres_f := PresF} = State) -> Subscribed = ?SETS:to_list(PresF), {reply, Subscribed, State}; -handle_call(Request, From, #{server := Server} = State) -> - LServer = jid:nameprep(Server), +handle_call(Request, From, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold( c2s_handle_call, LServer, {noreply, State}, [Request, From]). handle_cast(closed, State) -> handle_stream_close(State); -handle_cast(Msg, #{server := Server} = State) -> - LServer = jid:nameprep(Server), +handle_cast(Msg, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_handle_cast, LServer, {noreply, State}, [Msg]). -handle_info({route, From, To, Packet0}, #{server := Server} = State) -> +handle_info({route, From, To, Packet0}, #{lserver := LServer} = State) -> Packet = xmpp:set_from_to(Packet0, From, To), - LServer = jid:nameprep(Server), {Pass, NewState} = case Packet of #presence{} -> process_presence_in(State, Packet); @@ -289,7 +296,6 @@ handle_info({route, From, To, Packet0}, #{server := Server} = State) -> process_iq_in(State, Packet) end, if Pass -> - LServer = jid:nameprep(Server), Packet1 = ejabberd_hooks:run_fold( user_receive_packet, LServer, Packet, [NewState]), ejabberd_hooks:run(c2s_loop_debug, [{route, From, To, Packet}]), @@ -300,8 +306,7 @@ handle_info({route, From, To, Packet0}, #{server := Server} = State) -> end; handle_info(system_shutdown, State) -> xmpp_stream_in:send(State, xmpp:serr_system_shutdown()); -handle_info(Info, #{server := Server} = State) -> - LServer = jid:nameprep(Server), +handle_info(Info, #{lserver := LServer} = State) -> ejabberd_hooks:run_fold(c2s_handle_info, LServer, {noreply, State}, [Info]). terminate(_Reason, _State) -> @@ -319,11 +324,10 @@ check_bl_c2s({IP, _Port}, Lang) -> ejabberd_hooks:run_fold(check_bl_c2s, false, [IP, Lang]). -spec open_session(state(), binary()) -> {ok, state()} | {error, stanza_error(), state()}. -open_session(#{user := U, server := S, sid := SID, +open_session(#{user := U, server := S, lserver := LServer, sid := SID, socket := Socket, ip := IP, auth_module := AuthMod, access := Access, lang := Lang} = State, R) -> JID = jid:make(U, S, R), - LServer = JID#jid.lserver, case acl:access_matches(Access, #{usr => jid:split(JID), ip => IP}, LServer) of @@ -374,9 +378,8 @@ process_message_in(State, #message{type = T} = Msg) -> end. -spec process_presence_in(state(), presence()) -> {boolean(), state()}. -process_presence_in(#{server := Server, pres_a := PresA} = State0, +process_presence_in(#{lserver := LServer, pres_a := PresA} = State0, #presence{from = From, to = To, type = T} = Pres) -> - LServer = jid:nameprep(Server), State = ejabberd_hooks:run_fold(c2s_presence_in, LServer, State0, [Pres]), case T of probe -> @@ -399,7 +402,7 @@ process_presence_in(#{server := Server, pres_a := PresA} = State0, end. -spec route_probe_reply(jid(), jid(), state()) -> ok. -route_probe_reply(From, To, #{server := Server, pres_f := PresF, +route_probe_reply(From, To, #{lserver := LServer, pres_f := PresF, pres_last := LastPres, pres_timestamp := TS} = State) -> LFrom = jid:tolower(From), @@ -413,7 +416,6 @@ route_probe_reply(From, To, #{server := Server, pres_f := PresF, deny -> ok; allow -> - LServer = jid:nameprep(Server), ejabberd_hooks:run(presence_probe_hook, LServer, [From, To, self()]), @@ -432,10 +434,9 @@ route_probe_reply(_, _, _) -> ok. -spec process_presence_out(state(), presence()) -> next_state(). -process_presence_out(#{user := User, server := Server, - lang := Lang, pres_a := PresA} = State, +process_presence_out(#{user := User, server := Server, lserver := LServer, + jid := JID, lang := Lang, pres_a := PresA} = State, #presence{from = From, to = To, type = Type} = Pres) -> - LServer = jid:nameprep(Server), LTo = jid:tolower(To), case privacy_check_packet(State, Pres, out) of deny -> @@ -448,7 +449,7 @@ process_presence_out(#{user := User, server := Server, Access = gen_mod:get_module_opt(LServer, mod_roster, access, fun(A) when is_atom(A) -> A end, all), - MyBareJID = jid:make(User, Server, <<"">>), + MyBareJID = jid:remove_resource(JID), case acl:match_rule(LServer, Access, MyBareJID) of deny -> ErrText = <<"Denied by ACL">>, @@ -485,9 +486,8 @@ process_self_presence(#{ip := IP, conn := Conn, State1 = broadcast_presence_unavailable(State, Pres), State2 = maps:remove(pres_last, maps:remove(pres_timestamp, State1)), {noreply, State2}; -process_self_presence(#{server := Server} = State, +process_self_presence(#{lserver := LServer} = State, #presence{type = available} = Pres) -> - LServer = jid:nameprep(Server), PreviousPres = maps:get(pres_last, State, undefined), update_priority(State, Pres), State1 = ejabberd_hooks:run_fold(user_available_hook, LServer, State, [Pres]), @@ -543,8 +543,7 @@ check_privacy_then_route(#{lang := Lang} = State, Pkt) -> end. -spec privacy_check_packet(state(), stanza(), in | out) -> allow | deny. -privacy_check_packet(#{server := Server} = State, Pkt, Dir) -> - LServer = jid:nameprep(Server), +privacy_check_packet(#{lserver := LServer} = State, Pkt, Dir) -> ejabberd_hooks:run_fold(privacy_check_packet, LServer, allow, [State, Pkt, Dir]). -spec get_priority_from_presence(presence()) -> integer(). @@ -555,9 +554,7 @@ get_priority_from_presence(#presence{priority = Prio}) -> end. -spec filter_blocked(state(), presence(), ?SETS:set()) -> [jid()]. -filter_blocked(#{user := U, server := S, resource := R} = State, - Pres, LJIDSet) -> - From = jid:make(U, S, R), +filter_blocked(#{jid := From} = State, Pres, LJIDSet) -> ?SETS:fold( fun(LJID, Acc) -> To = jid:make(LJID), @@ -581,8 +578,7 @@ route_error(Pkt, Err) -> ejabberd_router:route_error(To, From, Pkt, Err). -spec route_multiple(state(), [jid()], stanza()) -> ok. -route_multiple(#{server := Server}, JIDs, Pkt) -> - LServer = jid:nameprep(Server), +route_multiple(#{lserver := LServer}, JIDs, Pkt) -> From = xmpp:get_from(Pkt), ejabberd_router_multicast:route_multicast(From, LServer, JIDs, Pkt). @@ -636,9 +632,9 @@ get_conn_type(State) -> end. -spec change_shaper(state()) -> ok. -change_shaper(#{shaper := ShaperName, ip := IP, +change_shaper(#{shaper := ShaperName, ip := IP, lserver := LServer, user := U, server := S, resource := R} = State) -> - #jid{lserver = LServer} = JID = jid:make(U, S, R), + JID = jid:make(U, S, R), Shaper = acl:access_matches(ShaperName, #{usr => jid:split(JID), ip => IP}, LServer), @@ -680,3 +676,18 @@ fsm_limit_opts(Opts) -> N -> [{max_queue, N}] end end. + +transform_listen_option(Opt, Opts) -> + [Opt|Opts]. + +opt_type(domain_certfile) -> fun iolist_to_binary/1; +opt_type(max_fsm_queue) -> + fun (I) when is_integer(I), I > 0 -> I end; +opt_type(resource_conflict) -> + fun (setresource) -> setresource; + (closeold) -> closeold; + (closenew) -> closenew; + (acceptnew) -> acceptnew + end; +opt_type(_) -> + [domain_certfile, max_fsm_queue, resource_conflict]. diff --git a/src/ejabberd_service.erl b/src/ejabberd_service.erl index 35cfe15af..c48cd536c 100644 --- a/src/ejabberd_service.erl +++ b/src/ejabberd_service.erl @@ -1,8 +1,5 @@ -%%%---------------------------------------------------------------------- -%%% File : ejabberd_service.erl -%%% Author : Alexey Shchepin -%%% Purpose : External component management (XEP-0114) -%%% Created : 6 Dec 2002 by Alexey Shchepin +%%%------------------------------------------------------------------- +%%% Created : 11 Dec 2016 by Evgeny Khramtsov %%% %%% %%% ejabberd, Copyright (C) 2002-2016 ProcessOne @@ -21,77 +18,60 @@ %%% with this program; if not, write to the Free Software Foundation, Inc., %%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. %%% -%%%---------------------------------------------------------------------- - +%%%------------------------------------------------------------------- -module(ejabberd_service). - +-behaviour(xmpp_stream_in). -behaviour(ejabberd_config). --author('alexey@process-one.net'). - -protocol({xep, 114, '1.6'}). --define(GEN_FSM, p1_fsm). - --behaviour(?GEN_FSM). - -%% External exports --export([start/2, start_link/2, send_text/2, - send_element/2, socket_type/0, transform_listen_option/2]). - --export([init/1, wait_for_stream/2, - wait_for_handshake/2, stream_established/2, - handle_event/3, handle_sync_event/4, code_change/4, - handle_info/3, terminate/3, print_state/1, opt_type/1]). +%% ejabberd_socket callbacks +-export([start/2, socket_type/0]). +%% ejabberd_config callbacks +-export([opt_type/1, transform_listen_option/2]). +%% xmpp_stream_in callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, + terminate/2, code_change/3]). +-export([handshake/2, handle_stream_start/1, handle_authenticated_packet/2]). +%% API +-export([send/2]). -include("ejabberd.hrl"). --include("logger.hrl"). -include("xmpp.hrl"). +-include("logger.hrl"). --record(state, - {socket :: ejabberd_socket:socket_state(), - sockmod = ejabberd_socket :: ejabberd_socket | ejabberd_frontend_socket, - streamid = <<"">> :: binary(), - host_opts = dict:new() :: ?TDICT, - host = <<"">> :: binary(), - access :: atom(), - check_from = true :: boolean()}). - --type state_name() :: wait_for_stream | wait_for_handshake | stream_established. --type state() :: #state{}. --type fsm_next() :: {next_state, state_name(), state()}. --type fsm_stop() :: {stop, normal, state()}. --type fsm_transition() :: fsm_stop() | fsm_next(). - -%-define(DBGFSM, true). +%%-define(DBGFSM, true). -ifdef(DBGFSM). -define(FSMOPTS, [{debug, [trace]}]). -else. -define(FSMOPTS, []). -endif. -%%%---------------------------------------------------------------------- +-type state() :: map(). +-type next_state() :: {noreply, state()} | {stop, term(), state()}. +-export_type([state/0, next_state/0]). + +%%%=================================================================== %%% API -%%%---------------------------------------------------------------------- +%%%=================================================================== start(SockData, Opts) -> - supervisor:start_child(ejabberd_service_sup, - [SockData, Opts]). + xmpp_stream_in:start(?MODULE, [SockData, Opts], + fsm_limit_opts(Opts) ++ ?FSMOPTS). -start_link(SockData, Opts) -> - (?GEN_FSM):start_link(ejabberd_service, - [SockData, Opts], fsm_limit_opts(Opts) ++ (?FSMOPTS)). +socket_type() -> + xml_stream. -socket_type() -> xml_stream. +-spec send(state(), xmpp_element()) -> next_state(). +send(State, Pkt) -> + xmpp_stream_in:send(State, Pkt). -%%%---------------------------------------------------------------------- -%%% Callback functions from gen_fsm -%%%---------------------------------------------------------------------- -init([{SockMod, Socket}, Opts]) -> +%%%=================================================================== +%%% xmpp_stream_in callbacks +%%%=================================================================== +init([#{socket := Socket} = State, Opts]) -> ?INFO_MSG("(~w) External service connected", [Socket]), - Access = case lists:keysearch(access, 1, Opts) of - {value, {_, A}} -> A; - _ -> all - end, + Access = gen_mod:get_opt(access, Opts, fun acl:access_rules_validator/1, all), + Shaper = gen_mod:get_opt(shaper_rule, Opts, fun acl:shaper_rules_validator/1, none), HostOpts = case lists:keyfind(hosts, 1, Opts) of {hosts, HOpts} -> lists:foldl( @@ -107,252 +87,120 @@ init([{SockMod, Socket}, Opts]) -> p1_sha:sha(randoms:bytes(20))), dict:from_list([{global, Pass}]) end, - Shaper = case lists:keysearch(shaper_rule, 1, Opts) of - {value, {_, S}} -> S; - _ -> none - end, - CheckFrom = case lists:keysearch(service_check_from, 1, - Opts) - of - {value, {_, CF}} -> CF; - _ -> true - end, - SockMod:change_shaper(Socket, Shaper), - {ok, wait_for_stream, - #state{socket = Socket, sockmod = SockMod, - streamid = new_id(), host_opts = HostOpts, - access = Access, check_from = CheckFrom}}. - -wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) -> - try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of - #stream_start{xmlns = NS_COMPONENT, stream_xmlns = NS_STREAM} - when NS_COMPONENT /= ?NS_COMPONENT; NS_STREAM /= ?NS_STREAM -> - send_header(StateData, ?MYNAME), - send_element(StateData, xmpp:serr_invalid_namespace()), - {stop, normal, StateData}; - #stream_start{to = To} when is_record(To, jid) -> - Host = To#jid.lserver, - send_header(StateData, Host), - HostOpts = case dict:is_key(Host, StateData#state.host_opts) of - true -> - StateData#state.host_opts; - false -> - case dict:find(global, StateData#state.host_opts) of - {ok, GlobalPass} -> - dict:from_list([{Host, GlobalPass}]); - error -> - StateData#state.host_opts - end - end, - {next_state, wait_for_handshake, - StateData#state{host = Host, host_opts = HostOpts}}; - #stream_start{} -> - send_header(StateData, ?MYNAME), - send_element(StateData, xmpp:serr_improper_addressing()), - {stop, normal, StateData}; - _ -> - send_header(StateData, ?MYNAME), - send_element(StateData, xmpp:serr_invalid_xml()), - {stop, normal, StateData} - catch _:{xmpp_codec, Why} -> - Txt = xmpp:format_error(Why), - send_header(StateData, ?MYNAME), - send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)), - {stop, normal, StateData} - end; -wait_for_stream({xmlstreamerror, _}, StateData) -> - send_header(StateData, ?MYNAME), - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_stream(closed, StateData) -> - {stop, normal, StateData}. - -wait_for_handshake({xmlstreamelement, El}, StateData) -> - decode_element(El, wait_for_handshake, StateData); -wait_for_handshake(#handshake{data = Digest}, StateData) -> - case dict:find(StateData#state.host, StateData#state.host_opts) of + CheckFrom = gen_mod:get_opt(check_from, Opts, + fun(Flag) when is_boolean(Flag) -> Flag end), + xmpp_stream_in:change_shaper(State, Shaper), + State1 = State#{access => Access, + xmlns => ?NS_COMPONENT, + lang => ?MYLANG, + server => ?MYNAME, + host_opts => HostOpts, + check_from => CheckFrom}, + ejabberd_hooks:run_fold(component_init, {ok, State1}, []). + +handle_stream_start(#{remote_server := RemoteServer, + host_opts := HostOpts} = State) -> + NewHostOpts = case dict:is_key(RemoteServer, HostOpts) of + true -> + HostOpts; + false -> + case dict:find(global, HostOpts) of + {ok, GlobalPass} -> + dict:from_list([{RemoteServer, GlobalPass}]); + error -> + HostOpts + end + end, + {noreply, State#{host_opts => NewHostOpts}}. + +handshake(Digest, #{remote_server := RemoteServer, + stream_id := StreamID, + host_opts := HostOpts} = State) -> + case dict:find(RemoteServer, HostOpts) of {ok, Password} -> - case p1_sha:sha(<<(StateData#state.streamid)/binary, - Password/binary>>) of + case p1_sha:sha(<>) of Digest -> - send_element(StateData, #handshake{}), lists:foreach( fun (H) -> ejabberd_router:register_route(H, ?MYNAME), - ?INFO_MSG("Route registered for service ~p~n", - [H]), + ?INFO_MSG("Route registered for service ~p~n", [H]), ejabberd_hooks:run(component_connected, [H]) - end, dict:fetch_keys(StateData#state.host_opts)), - {next_state, stream_established, StateData}; - _ -> - send_element(StateData, xmpp:serr_not_authorized()), - {stop, normal, StateData} - end; - _ -> - send_element(StateData, xmpp:serr_not_authorized()), - {stop, normal, StateData} - end; -wait_for_handshake({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -wait_for_handshake({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -wait_for_handshake(closed, StateData) -> - {stop, normal, StateData}; -wait_for_handshake(_Pkt, StateData) -> - {next_state, wait_for_handshake, StateData}. - -stream_established({xmlstreamelement, El}, StateData) -> - decode_element(El, stream_established, StateData); -stream_established(El, StateData) when ?is_stanza(El) -> - From = xmpp:get_from(El), - To = xmpp:get_to(El), - Lang = xmpp:get_lang(El), - if From == undefined orelse To == undefined -> - Txt = <<"Missing 'from' or 'to' attribute">>, - send_error(StateData, El, xmpp:err_jid_malformed(Txt, Lang)); - true -> - case check_from(From, StateData) of - true -> - ejabberd_router:route(From, To, El); - false -> - Txt = <<"Improper domain part of 'from' attribute">>, - send_error(StateData, El, xmpp:err_not_allowed(Txt, Lang)) - end - end, - {next_state, stream_established, StateData}; -stream_established({xmlstreamend, _Name}, StateData) -> - {stop, normal, StateData}; -stream_established({xmlstreamerror, _}, StateData) -> - send_element(StateData, xmpp:serr_not_well_formed()), - {stop, normal, StateData}; -stream_established(closed, StateData) -> - {stop, normal, StateData}; -stream_established(_Event, StateData) -> - {next_state, stream_established, StateData}. + end, dict:fetch_keys(HostOpts)), + {ok, State}; + _ -> + ?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]), + {error, xmpp:serr_not_authorized(), State} + end; + _ -> + ?ERROR_MSG("Failed authentication for service ~s", [RemoteServer]), + {error, xmpp:serr_not_authorized(), State} + end. -handle_event(_Event, StateName, StateData) -> - {next_state, StateName, StateData}. +handle_authenticated_packet(Pkt, #{lang := Lang} = State) -> + From = xmpp:get_from(Pkt), + case check_from(From, State) of + true -> + To = xmpp:get_to(Pkt), + ejabberd_router:route(From, To, Pkt), + {noreply, State}; + false -> + Txt = <<"Improper domain part of 'from' attribute">>, + Err = xmpp:serr_invalid_from(Txt, Lang), + xmpp_stream_in:send(State, Err) + end. -handle_sync_event(_Event, _From, StateName, - StateData) -> - Reply = ok, {reply, Reply, StateName, StateData}. +handle_call(_Request, _From, State) -> + Reply = ok, + {reply, Reply, State}. -code_change(_OldVsn, StateName, StateData, _Extra) -> - {ok, StateName, StateData}. +handle_cast(_Msg, State) -> + {noreply, State}. -handle_info({send_text, Text}, StateName, StateData) -> - send_text(StateData, Text), - {next_state, StateName, StateData}; -handle_info({send_element, El}, StateName, StateData) -> - send_element(StateData, El), - {next_state, StateName, StateData}; -handle_info({route, From, To, Packet}, StateName, - StateData) -> - case acl:match_rule(global, StateData#state.access, From) of - allow -> +handle_info({route, From, To, Packet}, #{access := Access} = State) -> + case acl:match_rule(global, Access, From) of + allow -> Pkt = xmpp:set_from_to(Packet, From, To), - send_element(StateData, Pkt); + xmpp_stream_in:send(State, Pkt); deny -> Lang = xmpp:get_lang(Packet), Err = xmpp:err_not_allowed(<<"Denied by ACL">>, Lang), - ejabberd_router:route_error(To, From, Packet, Err) - end, - {next_state, StateName, StateData}; -handle_info(Info, StateName, StateData) -> + ejabberd_router:route_error(To, From, Packet, Err), + {noreply, State} + end; +handle_info(Info, State) -> ?ERROR_MSG("Unexpected info: ~p", [Info]), - {next_state, StateName, StateData}. - -terminate(Reason, StateName, StateData) -> - ?INFO_MSG("terminated: ~p", [Reason]), - case StateName of - stream_established -> - lists:foreach(fun (H) -> - ejabberd_router:unregister_route(H), - ejabberd_hooks:run(component_disconnected, - [H, Reason]) - end, - dict:fetch_keys(StateData#state.host_opts)); - _ -> ok - end, - catch send_trailer(StateData), - (StateData#state.sockmod):close(StateData#state.socket), - ok. - -%%---------------------------------------------------------------------- -%% Func: print_state/1 -%% Purpose: Prepare the state to be printed on error log -%% Returns: State to print -%%---------------------------------------------------------------------- -print_state(State) -> State. - -%%%---------------------------------------------------------------------- -%%% Internal functions -%%%---------------------------------------------------------------------- - --spec send_text(state(), iodata()) -> ok. -send_text(StateData, Text) -> - (StateData#state.sockmod):send(StateData#state.socket, - Text). - --spec send_element(state(), xmpp_element()) -> ok. -send_element(StateData, El) -> - El1 = xmpp:encode(El, ?NS_COMPONENT), - send_text(StateData, fxml:element_to_binary(El1)). - --spec send_error(state(), xmlel() | stanza(), stanza_error()) -> ok. -send_error(StateData, Stanza, Error) -> - Type = xmpp:get_type(Stanza), - if Type == error; Type == result; - Type == <<"error">>; Type == <<"result">> -> - ok; - true -> - send_element(StateData, xmpp:make_error(Stanza, Error)) + {noreply, State}. + +terminate(Reason, #{stream_state := StreamState, host_opts := HostOpts}) -> + ?INFO_MSG("External service disconnected: ~p", [Reason]), + case StreamState of + session_established -> + lists:foreach( + fun(H) -> + ejabberd_router:unregister_route(H), + ejabberd_hooks:run(component_disconnected, [H, Reason]) + end, dict:fetch_keys(HostOpts)); + _ -> + ok end. --spec send_header(state(), binary()) -> ok. -send_header(StateData, Host) -> - Header = xmpp:encode( - #stream_start{xmlns = ?NS_COMPONENT, - stream_xmlns = ?NS_STREAM, - from = jid:make(Host), - id = StateData#state.streamid}), - send_text(StateData, fxml:element_to_header(Header)). - --spec send_trailer(state()) -> ok. -send_trailer(StateData) -> - send_text(StateData, <<"">>). - --spec decode_element(xmlel(), state_name(), state()) -> fsm_transition(). -decode_element(#xmlel{} = El, StateName, StateData) -> - try xmpp:decode(El, ?NS_COMPONENT, [ignore_els]) of - Pkt -> ?MODULE:StateName(Pkt, StateData) - catch error:{xmpp_codec, Why} -> - case xmpp:is_stanza(El) of - true -> - Lang = xmpp:get_lang(El), - Txt = xmpp:format_error(Why), - send_error(StateData, El, xmpp:err_bad_request(Txt, Lang)); - false -> - ok - end, - {next_state, StateName, StateData} - end. +code_change(_OldVsn, State, _Extra) -> + {ok, State}. +%%%=================================================================== +%%% Internal functions +%%%=================================================================== -spec check_from(jid(), state()) -> boolean(). -check_from(_From, #state{check_from = false}) -> +check_from(_From, #{check_from := false}) -> %% If the admin does not want to check the from field %% when accept packets from any address. %% In this case, the component can send packet of %% behalf of the server users. true; -check_from(From, StateData) -> +check_from(From, #{host_opts := HostOpts}) -> %% The default is the standard behaviour in XEP-0114 Server = From#jid.lserver, - dict:is_key(Server, StateData#state.host_opts). - --spec new_id() -> binary(). -new_id() -> randoms:get_string(). + dict:is_key(Server, HostOpts). transform_listen_option({hosts, Hosts, O}, Opts) -> case lists:keyfind(hosts, 1, Opts) of diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl index 6294a7893..1307f9da4 100644 --- a/src/xmpp_stream_in.erl +++ b/src/xmpp_stream_in.erl @@ -37,7 +37,48 @@ -type next_state() :: {noreply, state()} | {stop, term(), state()}. -callback init(list()) -> {ok, state()} | {stop, term()} | ignore. +-callback handle_stream_start(state()) -> next_state(). +-callback handle_stream_end(state()) -> next_state(). +-callback handle_stream_close(state()) -> next_state(). +-callback handle_cdata(binary(), state()) -> next_state(). +-callback handle_unauthenticated_packet(xmpp_element(), state()) -> next_state(). -callback handle_authenticated_packet(xmpp_element(), state()) -> next_state(). +-callback handle_unbinded_packet(xmpp_element(), state()) -> next_state(). +-callback handle_auth_success(binary(), binary(), module(), state()) -> next_state(). +-callback handle_auth_failure(binary(), binary(), atom(), state()) -> next_state(). +-callback handle_send(ok | {error, atom()}, + xmpp_element(), fxml:xmlel(), binary(), state()) -> next_state(). +-callback init_sasl(state()) -> cyrsasl:sasl_state(). +-callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}. +-callback handshake(binary(), state()) -> {ok, state()} | {error, stream_error(), state()}. +-callback compress_methods(state()) -> [binary()]. +-callback tls_options(state()) -> [proplists:property()]. +-callback tls_required(state()) -> boolean(). +-callback sasl_mechanisms(state()) -> [binary()]. +-callback unauthenticated_stream_features(state()) -> [xmpp_element()]. +-callback authenticated_stream_features(state()) -> [xmpp_element()]. + +%% All callbacks are optional +-optional_callbacks([init/1, + handle_stream_start/1, + handle_stream_end/1, + handle_stream_close/1, + handle_cdata/2, + handle_authenticated_packet/2, + handle_unauthenticated_packet/2, + handle_unbinded_packet/2, + handle_auth_success/4, + handle_auth_failure/4, + handle_send/5, + init_sasl/1, + bind/2, + handshake/2, + compress_methods/1, + tls_options/1, + tls_required/1, + sasl_mechanisms/1, + unauthenticated_stream_features/1, + authenticated_stream_features/1]). %%%=================================================================== %%% API @@ -94,21 +135,28 @@ init([Module, {SockMod, Socket}, Opts]) -> user => <<"">>, server => <<"">>, resource => <<"">>, + lserver => <<"">>, ip => IP}, - Module:init([State, Opts]); + try Module:init([State, Opts]) + catch _:undef -> {ok, State} + end; {error, Reason} -> {stop, Reason} end. handle_cast(Cast, #{mod := Mod} = State) -> - Mod:handle_cast(Cast, State). + try Mod:handle_cast(Cast, State) + catch _:undef -> {noreply, State} + end. handle_call(Call, From, #{mod := Mod} = State) -> - Mod:handle_call(Call, From, State). + try Mod:handle_call(Call, From, State) + catch _:undef -> {reply, ok, State} + end. handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}}, - #{stream_state := wait_for_stream} = State) -> - try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of + #{stream_state := wait_for_stream, xmlns := XMLNS} = State) -> + try xmpp:decode(#xmlel{name = Name, attrs = Attrs}, XMLNS, []) of #stream_start{} = Pkt -> case send_header(State, Pkt) of {noreply, State1} -> @@ -169,11 +217,15 @@ handle_info({'DOWN', MRef, _Type, _Object, _Info}, catch _:undef -> {stop, normal, State} end; handle_info(Info, #{mod := Mod} = State) -> - Mod:handle_info(Info, State). + try Mod:handle_info(Info, State) + catch _:undef -> {noreply, State} + end. terminate(Reason, #{mod := Mod, socket := Socket, sockmod := SockMod} = State) -> - Mod:terminate(Reason, State), + try Mod:terminate(Reason, State) + catch _:undef -> ok + end, send_text(State, <<"">>), SockMod:close(Socket). @@ -234,13 +286,14 @@ process_stream(#stream_start{to = #jid{lserver = RemoteServer}}, Err -> Err end; -process_stream(#stream_start{to = #jid{server = Server}, from = From}, +process_stream(#stream_start{to = #jid{server = Server, lserver = LServer}, + from = From}, #{stream_authenticated := Authenticated, stream_restarted := StreamWasRestarted, mod := Mod, xmlns := NS, resource := Resource, stream_tlsed := TLSEnabled} = State) -> case if not StreamWasRestarted -> - State1 = State#{server => Server}, + State1 = State#{server => Server, lserver => LServer}, try Mod:handle_stream_start(State1) catch _:undef -> {noreply, State1} end; @@ -342,10 +395,18 @@ process_authenticated_packet(Pkt, #{xmlns := NS, mod := Mod} = State) -> #xmpp_session{} -> send_element(State, xmpp:make_iq_result(Pkt2)); _ -> - Mod:handle_authenticated_packet(Pkt2, State) + try Mod:handle_authenticated_packet(Pkt2, State) + catch _:undef -> + Err = xmpp:err_service_unavailable(), + send_error(State, Pkt, Err) + end end; {ok, Pkt2} -> - Mod:handle_authenticated_packet(Pkt2, State); + try Mod:handle_authenticated_packet(Pkt2, State) + catch _:undef -> + Err = xmpp:err_service_unavailable(), + send_error(State, Pkt, Err) + end; {error, Err} -> send_element(State, Err) end. @@ -385,8 +446,15 @@ process_bind(Pkt, #{mod := Mod} = State) -> send_error(State, Pkt, Err) end. -process_handshake(#handshake{} = Pkt, #{mod := Mod} = State) -> - Mod:handle_handshake(Pkt, State). +process_handshake(#handshake{data = Data}, #{mod := Mod} = State) -> + case Mod:handshake(Data, State) of + {ok, State1} -> + State2 = State1#{stream_state => session_established, + stream_authenticated => true}, + send_element(State2, #handshake{}); + {error, #stream_error{} = Err, State1} -> + send_element(State1, Err) + end. process_compress(#compress{}, #{stream_compressed := true} = State) -> send_element(State, #compress_failure{reason = 'setup-failed'}); @@ -436,9 +504,13 @@ process_sasl_request(#sasl_auth{mechanism = <<"EXTERNAL">>}, process_sasl_failure('encryption-required', <<"">>, State); process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, #{mod := Mod} = State) -> - SASLState = Mod:init_sasl(State), - SASLResult = cyrsasl:server_start(SASLState, Mech, ClientIn), - process_sasl_result(SASLResult, State). + try Mod:init_sasl(State) of + SASLState -> + SASLResult = cyrsasl:server_start(SASLState, Mech, ClientIn), + process_sasl_result(SASLResult, State) + catch _:undef -> + process_sasl_failure('temporary-auth-failure', <<"">>, State) + end. process_sasl_response(#sasl_response{text = ClientIn}, #{sasl_state := SASLState} = State) -> -- 2.40.0