]> granicus.if.org Git - ejabberd/commitdiff
Rewrite S2S and ejabberd_service code to use XML generator
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Wed, 27 Jul 2016 07:45:08 +0000 (10:45 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Wed, 27 Jul 2016 07:45:08 +0000 (10:45 +0300)
13 files changed:
include/ejabberd.hrl
include/ns.hrl
include/xmpp_codec.hrl
src/ejabberd_c2s.erl
src/ejabberd_http_bind.erl
src/ejabberd_router.erl
src/ejabberd_s2s.erl
src/ejabberd_s2s_in.erl
src/ejabberd_s2s_out.erl
src/ejabberd_service.erl
src/xmpp.erl
src/xmpp_codec.erl
tools/xmpp_codec.spec

index 6316d7813461ad8fab9141243ff1e73b73746d8b..a97474d2ba52c26d5aee0d8cee594e9b07db2d7a 100644 (file)
@@ -39,7 +39,7 @@
 
 -define(EJABBERD_URI, <<"http://www.process-one.net/en/ejabberd/">>).
 
--define(S2STIMEOUT, 600000).
+-define(S2STIMEOUT, timer:minutes(10)).
 
 %%-define(DBGFSM, true).
 
index a150746e73db63f6a6f3900a313f58664dd57405..b301615658f3eacdcda6d5f8b8bf75cf64274670 100644 (file)
 %%%
 %%%----------------------------------------------------------------------
 
+-define(NS_COMPONENT, <<"jabber:component:accept">>).
+-define(NS_SERVER, <<"jabber:server">>).
+-define(NS_SERVER_DIALBACK, <<"jabber:server:dialback">>).
+-define(NS_CLIENT, <<"jabber:client">>).
 -define(NS_DISCO_ITEMS,
        <<"http://jabber.org/protocol/disco#items">>).
 -define(NS_DISCO_INFO,
index b14c0d11fbe9022915be1be9a3d8e0a5dbec2e38..43bb6b098648efa7022cd876370b457e7c2dd7a8 100644 (file)
                xmlns :: binary()}).
 -type sm_a() :: #sm_a{}.
 
+-record(stream_start, {from :: any(),
+                       to :: any(),
+                       id = <<>> :: binary(),
+                       version = <<>> :: binary(),
+                       xmlns :: binary(),
+                       stream_xmlns = <<>> :: binary(),
+                       db_xmlns = <<>> :: binary(),
+                       lang = <<>> :: binary()}).
+-type stream_start() :: #stream_start{}.
+
 -record(muc_subscribe, {nick :: binary(),
                         events = [] :: [binary()]}).
 -type muc_subscribe() :: #muc_subscribe{}.
 -record(sasl_challenge, {text :: any()}).
 -type sasl_challenge() :: #sasl_challenge{}.
 
+-record(handshake, {data = <<>> :: binary()}).
+-type handshake() :: #handshake{}.
+
 -record(gone, {uri :: binary()}).
 -type gone() :: #gone{}.
 
                 text :: #text{}}).
 -type error() :: #error{}.
 
+-record(db_verify, {from :: any(),
+                    to :: any(),
+                    id :: binary(),
+                    type :: 'error' | 'invalid' | 'valid',
+                    key = <<>> :: binary(),
+                    error :: #error{}}).
+-type db_verify() :: #db_verify{}.
+
+-record(db_result, {from :: any(),
+                    to :: any(),
+                    type :: 'error' | 'invalid' | 'valid',
+                    key = <<>> :: binary(),
+                    error :: #error{}}).
+-type db_result() :: #db_result{}.
+
 -record(presence, {id :: binary(),
                    type = available :: 'available' | 'error' | 'probe' | 'subscribe' | 'subscribed' | 'unavailable' | 'unsubscribe' | 'unsubscribed',
                    lang :: binary(),
                utc :: any()}).
 -type time() :: #time{}.
 
--type xmpp_element() :: compression() |
+-type xmpp_element() :: muc_admin() |
+                        compression() |
                         pubsub_subscription() |
                         xdata_option() |
                         version() |
                         pubsub_affiliation() |
-                        muc_admin() |
                         mam_fin() |
                         sm_a() |
                         carbons_sent() |
                         compressed() |
                         block_list() |
                         rsm_set() |
+                        db_result() |
                         'see-other-host'() |
                         hint() |
+                        stream_start() |
                         stanza_id() |
                         starttls_proceed() |
                         client_id() |
                         pubsub() |
                         muc_owner() |
                         muc_actor() |
+                        vcard_name() |
                         adhoc_note() |
                         rosterver_feature() |
                         muc_invite() |
                         sm_enable() |
                         starttls_failure() |
                         sasl_challenge() |
+                        handshake() |
                         x_conference() |
                         private() |
                         compress_failure() |
                         sasl_failure() |
                         bookmark_storage() |
-                        vcard_name() |
                         muc_decline() |
                         sasl_auth() |
                         p1_push() |
                         csi() |
                         roster_query() |
                         mam_query() |
+                        db_verify() |
                         bookmark_url() |
                         vcard_email() |
                         vcard_label() |
index 8d217a354149b4717d23662de8106db189e71b55..1ae9a7c29cda95457bc416b60391d93befd8631c 100644 (file)
@@ -320,42 +320,46 @@ get_subscribed(FsmRef) ->
     (?GEN_FSM):sync_send_all_state_event(FsmRef,
                                         get_subscribed, 1000).
 
-wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) ->
-    DefaultLang = ?MYLANG,
-    case fxml:get_attr_s(<<"xmlns:stream">>, Attrs) of
-       ?NS_STREAM ->
-           Server =
-               case StateData#state.server of
-               <<"">> ->
-                   jid:nameprep(fxml:get_attr_s(<<"to">>, Attrs));
-               S -> S
-           end,
-           Lang = case fxml:get_attr_s(<<"xml:lang">>, Attrs) of
-               Lang1 when byte_size(Lang1) =< 35 ->
-                   %% As stated in BCP47, 4.4.1:
-                   %% Protocols or specifications that
-                   %% specify limited buffer sizes for
-                   %% language tags MUST allow for
-                   %% language tags of at least 35 characters.
-                   Lang1;
-               _ ->
-                   %% Do not store long language tag to
-                   %% avoid possible DoS/flood attacks
-                   <<"">>
-           end,
-           StreamVersion = case fxml:get_attr_s(<<"version">>, Attrs) of
-                             <<"1.0">> ->
-                                 <<"1.0">>;
-                             _ ->
-                                 <<"">>
-                         end,
+wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
+    try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
+       #stream_start{xmlns = NS_CLIENT, stream_xmlns = NS_STREAM, lang = Lang}
+          when NS_CLIENT /= ?NS_CLIENT; NS_STREAM /= ?NS_STREAM ->
+           send_header(StateData, ?MYNAME, <<"">>, Lang),
+            send_element(StateData, xmpp:serr_invalid_namespace()),
+            {stop, normal, StateData};
+       #stream_start{lang = Lang} when byte_size(Lang) > 35 ->
+           %% As stated in BCP47, 4.4.1:
+           %% Protocols or specifications that specify limited buffer sizes for
+           %% language tags MUST allow for language tags of at least 35 characters.
+           %% Do not store long language tag to avoid possible DoS/flood attacks
+           send_header(StateData, ?MYNAME, <<"">>, ?MYLANG),
+           Txt = <<"Too long value of 'xml:lang' attribute">>,
+           send_element(StateData,
+                        xmpp:serr_policy_violation(Txt, ?MYLANG)),
+           {stop, normal, StateData};
+       #stream_start{to = undefined, lang = Lang} ->
+           Txt = <<"Missing 'to' attribute">>,
+           send_header(StateData, ?MYNAME, <<"">>, Lang),
+           send_element(StateData,
+                        xmpp:serr_improper_addressing(Txt, Lang)),
+           {stop, normal, StateData};
+       #stream_start{to = #jid{lserver = To}, lang = Lang,
+                     version = Version} ->
+           Server = case StateData#state.server of
+                        <<"">> -> To;
+                        S -> S
+                    end,
+           StreamVersion = case Version of
+                               <<"1.0">> -> <<"1.0">>;
+                               _ -> <<"">>
+                           end,
            IsBlacklistedIP = is_ip_blacklisted(StateData#state.ip, Lang),
            case lists:member(Server, ?MYHOSTS) of
                true when IsBlacklistedIP == false ->
                    change_shaper(StateData, jid:make(<<"">>, Server, <<"">>)),
                    case StreamVersion of
                        <<"1.0">> ->
-                           send_header(StateData, Server, <<"1.0">>, DefaultLang),
+                           send_header(StateData, Server, <<"1.0">>, ?MYLANG),
                            case StateData#state.authenticated of
                                false ->
                                    TLS = StateData#state.tls,
@@ -458,7 +462,7 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) ->
                                    end
                            end;
                        _ ->
-                           send_header(StateData, Server, <<"">>, DefaultLang),
+                           send_header(StateData, Server, <<"">>, ?MYLANG),
                            if not StateData#state.tls_enabled and
                                        StateData#state.tls_required ->
                                    send_element(
@@ -477,17 +481,18 @@ wait_for_stream({xmlstreamstart, _Name, Attrs}, StateData) ->
                    {true, LogReason, ReasonT} = IsBlacklistedIP,
                    ?INFO_MSG("Connection attempt from blacklisted IP ~s: ~s",
                        [jlib:ip_to_list(IP), LogReason]),
-                   send_header(StateData, Server, StreamVersion, DefaultLang),
+                   send_header(StateData, Server, StreamVersion, ?MYLANG),
                    send_element(StateData, xmpp:serr_policy_violation(ReasonT, Lang)),
                    {stop, normal, StateData};
                _ ->
-                   send_header(StateData, ?MYNAME, StreamVersion, DefaultLang),
+                   send_header(StateData, ?MYNAME, StreamVersion, ?MYLANG),
                    send_element(StateData, xmpp:serr_host_unknown()),
                    {stop, normal, StateData}
-           end;
-       _ ->
-           send_header(StateData, ?MYNAME, <<"">>, DefaultLang),
-           send_element(StateData, xmpp:serr_invalid_namespace()),
+           end
+    catch _:{xmpp_codec, Why} ->
+           Txt = xmpp:format_error(Why),
+           send_header(StateData, ?MYNAME, <<"">>, ?MYLANG),
+           send_element(StateData, xmpp:serr_not_well_formed(Txt, ?MYLANG)),
            {stop, normal, StateData}
     end;
 wait_for_stream(timeout, StateData) ->
@@ -854,38 +859,36 @@ resource_conflict_action(U, S, R) ->
          {accept_resource, Rnew}
     end.
 
--spec decode_subels(stanza()) -> stanza().
-decode_subels(#iq{sub_els = [El], type = T} = IQ) when T == set; T == get ->
-    NewEl = case xmpp:get_ns(El) of
-               ?NS_BIND when T == set -> xmpp:decode(El);
-               ?NS_AUTH -> xmpp:decode(El);
-               ?NS_PRIVACY -> xmpp:decode(El);
-               ?NS_BLOCKING -> xmpp:decode(El);
-               _ -> El
-           end,
-    IQ#iq{sub_els = [NewEl]};
-decode_subels(Pkt) ->
-    Pkt.
-
--spec decode_element(xmlel(), state_name(), state()) -> fsm_next().
+-spec decode_element(xmlel(), state_name(), state()) -> fsm_transition().
 decode_element(#xmlel{} = El, StateName, StateData) ->
-    try
-       Pkt0 = xmpp:decode(El, [ignore_els]),
-       Pkt = decode_subels(Pkt0),
-       ?MODULE:StateName(Pkt, StateData)
+    try case xmpp:decode(El, [ignore_els]) of
+           #iq{sub_els = [_], type = T} = Pkt when T == set; T == get ->
+               NewPkt = xmpp:decode_els(
+                          Pkt,
+                          fun(SubEl) when StateName == session_established ->
+                                  case xmpp:get_ns(SubEl) of
+                                      ?NS_PRIVACY -> true;
+                                      ?NS_BLOCKING -> true;
+                                      _ -> false
+                                  end;
+                             (SubEl) ->
+                                  xmpp_codec:is_known_tag(SubEl)
+                          end),
+               ?MODULE:StateName(NewPkt, StateData);
+           Pkt ->
+               ?MODULE:StateName(Pkt, StateData)
+       end
     catch error:{xmpp_codec, Why} ->
-           Type = xmpp:get_type(El),
            NS = xmpp:get_ns(El),
            case xmpp:is_stanza(El) of
-               true when Type /= <<"result">>, Type /= <<"error">> ->
+               true ->
                    Lang = xmpp:get_lang(El),
                    Txt = xmpp:format_error(Why),
-                   Err = xmpp:make_error(El, xmpp:err_bad_request(Txt, Lang)),
-                   send_element(StateData, Err);
-               _ when NS == ?NS_STREAM_MGMT_2; NS == ?NS_STREAM_MGMT_3 ->
+                   send_error(StateData, El, xmpp:err_bad_request(Txt, Lang));
+               false when NS == ?NS_STREAM_MGMT_2; NS == ?NS_STREAM_MGMT_3 ->
                    Err = #sm_failed{reason = 'bad-request', xmlns = NS},
                    send_element(StateData, Err);
-               _ ->
+               false ->
                    ok
            end,
            fsm_next_state(StateName, StateData)
@@ -951,13 +954,7 @@ wait_for_bind(stop, StateData) ->
 wait_for_bind(Pkt, StateData) ->
     case xmpp:is_stanza(Pkt) of
        true ->
-           Type = xmpp:get_type(Pkt),
-           if Type /= error, Type /= result ->
-                   Err = xmpp:make_error(Pkt, xmpp:err_not_acceptable()),
-                   send_element(StateData, Err);
-              true ->
-                   ok
-           end;
+           send_error(StateData, Pkt, xmpp:err_not_acceptable());
        false ->
            ok
     end,
@@ -1046,7 +1043,7 @@ session_established(closed, StateData) ->
     {stop, normal, StateData};
 session_established(stop, StateData) ->
     {stop, normal, StateData};
-session_established(Pkt, StateData) ->
+session_established(Pkt, StateData) when ?is_stanza(Pkt) ->
     FromJID = StateData#state.jid,
     case check_from(Pkt, FromJID) of
        'invalid-from' ->
@@ -1055,11 +1052,13 @@ session_established(Pkt, StateData) ->
        _ ->
            NewStateData = update_num_stanzas_in(StateData, Pkt),
            session_established2(Pkt, NewStateData)
-    end.
+    end;
+session_established(_Pkt, StateData) ->
+    fsm_next_state(session_established, StateData).
 
 -spec session_established2(xmpp_element(), state()) -> fsm_next().
 %% Process packets sent by user (coming from user on c2s XMPP connection)
-session_established2(Pkt, StateData) when ?is_stanza(Pkt) ->
+session_established2(Pkt, StateData) ->
     User = StateData#state.user,
     Server = StateData#state.server,
     FromJID = StateData#state.jid,
@@ -1116,11 +1115,7 @@ session_established2(Pkt, StateData) when ?is_stanza(Pkt) ->
        end,
     ejabberd_hooks:run(c2s_loop_debug,
                       [{xmlstreamelement, Pkt}]),
-    fsm_next_state(session_established, NewState);
-session_established2(Pkt, StateData) ->
-    ejabberd_hooks:run(c2s_loop_debug,
-                      [{xmlstreamelement, Pkt}]),
-    fsm_next_state(session_established, StateData).
+    fsm_next_state(session_established, NewState).
 
 wait_for_resume({xmlstreamelement, _El} = Event, StateData) ->
     Result = session_established(Event, StateData),
@@ -1573,6 +1568,16 @@ send_element(StateData, #xmlel{} = El) ->
 send_element(StateData, Pkt) ->
     send_element(StateData, xmpp:encode(Pkt)).
 
+-spec send_error(state(), xmlel() | stanza(), error()) -> ok.
+send_error(StateData, Stanza, Error) ->
+    Type = xmpp:get_type(Stanza),
+    if Type == error; Type == result;
+       Type == <<"error">>; Type == <<"result">> ->
+           ok;
+       true ->
+           send_element(StateData, xmpp:make_error(Stanza, Error))
+    end.
+
 -spec send_stanza(state(), xmpp_element()) -> state().
 send_stanza(StateData, Stanza) when StateData#state.csi_state == inactive ->
     csi_filter_stanza(StateData, Stanza);
@@ -2136,28 +2141,24 @@ is_ip_blacklisted({IP, _Port}, Lang) ->
 
 %% Check from attributes
 %% returns invalid-from|NewElement
+-spec check_from(stanza(), jid()) -> 'invalid-from' | stanza().
 check_from(Pkt, FromJID) ->
-    case xmpp:is_stanza(Pkt) of
-       false ->
+    JID = xmpp:get_from(Pkt),
+    case JID of
+       undefined ->
            Pkt;
-       true ->
-           JID = xmpp:get_from(Pkt),
-           case JID of
-               undefined ->
+       #jid{} ->
+           if
+               (JID#jid.luser == FromJID#jid.luser) and
+               (JID#jid.lserver == FromJID#jid.lserver) and
+               (JID#jid.lresource == FromJID#jid.lresource) ->
                    Pkt;
-               #jid{} ->
-                   if
-                       (JID#jid.luser == FromJID#jid.luser) and
-                       (JID#jid.lserver == FromJID#jid.lserver) and
-                       (JID#jid.lresource == FromJID#jid.lresource) ->
-                           Pkt;
-                       (JID#jid.luser == FromJID#jid.luser) and
-                       (JID#jid.lserver == FromJID#jid.lserver) and
-                       (JID#jid.lresource == <<"">>) ->
-                           Pkt;
-                       true ->
-                           'invalid-from'
-                   end
+               (JID#jid.luser == FromJID#jid.luser) and
+               (JID#jid.lserver == FromJID#jid.lserver) and
+               (JID#jid.lresource == <<"">>) ->
+                   Pkt;
+               true ->
+                   'invalid-from'
            end
     end.
 
index ea8cd792f1e28ba7e93c4d715d68452da208be6a..6fa38110cf9d2143e958c1919490316df2cd6d0a 100644 (file)
 %% Wait 100ms before continue processing, to allow the client provide more related stanzas.
 -define(BOSH_VERSION, <<"1.8">>).
 
--define(NS_CLIENT, <<"jabber:client">>).
-
 -define(NS_BOSH, <<"urn:xmpp:xbosh">>).
 
 -define(NS_HTTP_BIND,
index 83ffd932b27fa884039840119e77fa629b64fb9f..d65edc6e327d26ceae3b3ad8fde667e50198503f 100644 (file)
@@ -389,7 +389,13 @@ do_route(OrigFrom, OrigTo, OrigPacket) ->
       {From, To, Packet} ->
          LDstDomain = To#jid.lserver,
          case mnesia:dirty_read(route, LDstDomain) of
-           [] -> ejabberd_s2s:route(From, To, Packet);
+           [] ->
+                 try xmpp:decode(Packet, [ignore_els]) of
+                     Pkt ->
+                         ejabberd_s2s:route(From, To, Pkt)
+                 catch _:{xmpp_codec, Why} ->
+                         log_decoding_error(From, To, Packet, Why)
+                 end;
            [R] ->
                do_route(From, To, Packet, R);
            Rs ->
@@ -425,15 +431,18 @@ do_route(From, To, Packet, #route{local_hint = LocalHint,
                    Pid ! {route, From, To, Pkt}
            end
     catch error:{xmpp_codec, Why} ->
-           ?ERROR_MSG("failed to decode xml element ~p when "
-                      "routing from ~s to ~s: ~s",
-                      [Packet, jid:to_string(From), jid:to_string(To),
-                       xmpp:format_error(Why)]),
-           drop
+           log_decoding_error(From, To, Packet, Why)
     end;
 do_route(_From, _To, _Packet, _Route) ->
     drop.
 
+-spec log_decoding_error(jid(), jid(), xmlel() | xmpp_element(), term()) -> ok.
+log_decoding_error(From, To, Packet, Reason) ->
+    ?ERROR_MSG("failed to decode xml element ~p when "
+              "routing from ~s to ~s: ~s",
+              [Packet, jid:to_string(From), jid:to_string(To),
+               xmpp:format_error(Reason)]).
+
 -spec get_component_number(binary()) -> pos_integer() | undefined.
 get_component_number(LDomain) ->
     ejabberd_config:get_option(
index 19de64adb0e1e2a1ccf24a2537804181958f76f5..e585257e88840b990f26a70910ac4d91566210f8 100644 (file)
@@ -55,7 +55,7 @@
 -include("ejabberd.hrl").
 -include("logger.hrl").
 
--include("jlib.hrl").
+-include("xmpp.hrl").
 
 -include("ejabberd_commands.hrl").
 
@@ -89,7 +89,7 @@ start_link() ->
     gen_server:start_link({local, ?MODULE}, ?MODULE, [],
                          []).
 
--spec route(jid(), jid(), xmlel()) -> ok.
+-spec route(jid(), jid(), xmpp_element()) -> ok.
 
 route(From, To, Packet) ->
     case catch do_route(From, To, Packet) of
@@ -222,6 +222,7 @@ check_peer_certificate(SockMod, Sock, Peer) ->
            {error, <<"Cannot get peer certificate">>}
     end.
 
+-spec make_key({binary(), binary()}, binary()) -> binary().
 make_key({From, To}, StreamID) ->
     Secret = ejabberd_config:get_option(shared_key, fun(V) -> V end),
     p1_sha:to_hexlist(
@@ -275,7 +276,7 @@ code_change(_OldVsn, State, _Extra) ->
 %%--------------------------------------------------------------------
 %%% Internal functions
 %%--------------------------------------------------------------------
-
+-spec clean_table_from_bad_node(node()) -> any().
 clean_table_from_bad_node(Node) ->
     F = fun() ->
                Es = mnesia:select(
@@ -289,6 +290,7 @@ clean_table_from_bad_node(Node) ->
        end,
     mnesia:async_dirty(F).
 
+-spec do_route(jid(), jid(), stanza()) -> ok | false.
 do_route(From, To, Packet) ->
     ?DEBUG("s2s manager~n\tfrom ~p~n\tto ~p~n\tpacket "
           "~P~n",
@@ -296,28 +298,16 @@ do_route(From, To, Packet) ->
     case find_connection(From, To) of
       {atomic, Pid} when is_pid(Pid) ->
          ?DEBUG("sending to process ~p~n", [Pid]),
-         #xmlel{name = Name, attrs = Attrs, children = Els} =
-             Packet,
-         NewAttrs =
-             jlib:replace_from_to_attrs(jid:to_string(From),
-                                        jid:to_string(To), Attrs),
          #jid{lserver = MyServer} = From,
          ejabberd_hooks:run(s2s_send_packet, MyServer,
                             [From, To, Packet]),
-         send_element(Pid,
-                      #xmlel{name = Name, attrs = NewAttrs, children = Els}),
+         send_element(Pid, xmpp:set_from_to(Packet, From, To)),
          ok;
       {aborted, _Reason} ->
-         case fxml:get_tag_attr_s(<<"type">>, Packet) of
-           <<"error">> -> ok;
-           <<"result">> -> ok;
-           _ ->
-               Lang = fxml:get_tag_attr_s(<<"xml:lang">>, Packet),
-               Txt = <<"No s2s connection found">>,
-               Err = jlib:make_error_reply(
-                       Packet, ?ERRT_SERVICE_UNAVAILABLE(Lang, Txt)),
-               ejabberd_router:route(To, From, Err)
-         end,
+         Lang = xmpp:get_lang(Packet),
+         Txt = <<"No s2s connection found">>,
+         Err = xmpp:err_service_unavailable(Txt, Lang),
+         ejabberd_router:route_error(To, From, Packet, Err),
          false
     end.
 
@@ -367,9 +357,11 @@ find_connection(From, To) ->
          end
     end.
 
+-spec choose_connection(jid(), [#s2s{}]) -> pid().
 choose_connection(From, Connections) ->
     choose_pid(From, [C#s2s.pid || C <- Connections]).
 
+-spec choose_pid(jid(), [pid()]) -> pid().
 choose_pid(From, Pids) ->
     Pids1 = case [P || P <- Pids, node(P) == node()] of
              [] -> Pids;
@@ -417,22 +409,21 @@ new_connection(MyServer, Server, From, FromTo,
     end,
     TRes.
 
+-spec max_s2s_connections_number({binary(), binary()}) -> integer().
 max_s2s_connections_number({From, To}) ->
-    case acl:match_rule(From, max_s2s_connections,
-                       jid:make(<<"">>, To, <<"">>))
-       of
+    case acl:match_rule(From, max_s2s_connections, jid:make(To)) of
       Max when is_integer(Max) -> Max;
       _ -> ?DEFAULT_MAX_S2S_CONNECTIONS_NUMBER
     end.
 
+-spec max_s2s_connections_number_per_node({binary(), binary()}) -> integer().
 max_s2s_connections_number_per_node({From, To}) ->
-    case acl:match_rule(From, max_s2s_connections_per_node,
-                       jid:make(<<"">>, To, <<"">>))
-       of
+    case acl:match_rule(From, max_s2s_connections_per_node, jid:make(To)) of
       Max when is_integer(Max) -> Max;
       _ -> ?DEFAULT_MAX_S2S_CONNECTIONS_NUMBER_PER_NODE
     end.
 
+-spec needed_connections_number([#s2s{}], integer(), integer()) -> integer().
 needed_connections_number(Ls, MaxS2SConnectionsNumber,
                          MaxS2SConnectionsNumberPerNode) ->
     LocalLs = [L || L <- Ls, node(L#s2s.pid) == node()],
@@ -444,6 +435,7 @@ needed_connections_number(Ls, MaxS2SConnectionsNumber,
 %% Description: Return true if the destination must be considered as a
 %% service.
 %% --------------------------------------------------------------------
+-spec is_service(jid(), jid()) -> boolean().
 is_service(From, To) ->
     LFromDomain = From#jid.lserver,
     case ejabberd_config:get_option(
@@ -541,7 +533,7 @@ allow_host1(MyHost, S2SHost) ->
              s2s_access,
              fun(A) -> A end,
              all),
-    JID = jid:make(<<"">>, S2SHost, <<"">>),
+    JID = jid:make(S2SHost),
     case acl:match_rule(MyHost, Rule, JID) of
         deny -> false;
         allow ->
index d8d0a400a03508e49709469539ce33e4e9b458be..04b961b3d1bf445fcba73a9bdde28223877fa2ca 100644 (file)
@@ -42,7 +42,7 @@
 -include("ejabberd.hrl").
 -include("logger.hrl").
 
--include("jlib.hrl").
+-include("xmpp.hrl").
 
 -define(DICT, dict).
 
         connections = (?DICT):new() :: ?TDICT,
          timer = make_ref()          :: reference()}).
 
-%-define(DBGFSM, true).
+-type state_name() :: wait_for_stream | wait_for_features | stream_established.
+-type state() :: #state{}.
+-type fsm_next() :: {next_state, state_name(), state()}.
+-type fsm_stop() :: {stop, normal, state()}.
+-type fsm_transition() :: fsm_stop() | fsm_next().
 
+%%-define(DBGFSM, true).
 -ifdef(DBGFSM).
-
 -define(FSMOPTS, [{debug, [trace]}]).
-
 -else.
-
 -define(FSMOPTS, []).
-
 -endif.
 
--define(STREAM_HEADER(Version),
-       <<"<?xml version='1.0'?><stream:stream "
-         "xmlns:stream='http://etherx.jabber.org/stream"
-         "s' xmlns='jabber:server' xmlns:db='jabber:ser"
-         "ver:dialback' id='",
-         (StateData#state.streamid)/binary, "'", Version/binary,
-         ">">>).
-
--define(STREAM_TRAILER, <<"</stream:stream>">>).
-
--define(INVALID_NAMESPACE_ERR,
-       fxml:element_to_binary(?SERR_INVALID_NAMESPACE)).
-
--define(HOST_UNKNOWN_ERR,
-       fxml:element_to_binary(?SERR_HOST_UNKNOWN)).
-
--define(INVALID_FROM_ERR,
-       fxml:element_to_binary(?SERR_INVALID_FROM)).
-
--define(INVALID_XML_ERR,
-       fxml:element_to_binary(?SERR_XML_NOT_WELL_FORMED)).
-
 start(SockData, Opts) ->
     supervisor:start_child(ejabberd_s2s_in_sup,
                             [SockData, Opts]).
@@ -185,319 +164,252 @@ init([{SockMod, Socket}, Opts]) ->
 %%          {next_state, NextStateName, NextStateData, Timeout} |
 %%          {stop, Reason, NewStateData}
 %%----------------------------------------------------------------------
-
-wait_for_stream({xmlstreamstart, _Name, Attrs},
-               StateData) ->
-    case {fxml:get_attr_s(<<"xmlns">>, Attrs),
-         fxml:get_attr_s(<<"xmlns:db">>, Attrs),
-         fxml:get_attr_s(<<"to">>, Attrs),
-         fxml:get_attr_s(<<"version">>, Attrs) == <<"1.0">>}
-       of
-      {<<"jabber:server">>, _, Server, true}
-         when StateData#state.tls and
-                not StateData#state.authenticated ->
-         send_text(StateData,
-                   ?STREAM_HEADER(<<" version='1.0'">>)),
-         Auth = if StateData#state.tls_enabled ->
-                       case jid:nameprep(fxml:get_attr_s(<<"from">>, Attrs)) of
-                         From when From /= <<"">>, From /= error ->
-                             {Result, Message} =
-                                 ejabberd_s2s:check_peer_certificate(StateData#state.sockmod,
-                                                                     StateData#state.socket,
-                                                                     From),
-                             {Result, From, Message};
-                         _ ->
-                             {error, <<"(unknown)">>,
-                              <<"Got no valid 'from' attribute">>}
-                       end;
-                   true ->
-                       {no_verify, <<"(unknown)">>,
-                        <<"TLS not (yet) enabled">>}
-                end,
-         StartTLS = if StateData#state.tls_enabled -> [];
-                       not StateData#state.tls_enabled and
+wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
+    try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
+       #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM}
+         when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM ->
+           send_header(StateData, <<" version='1.0'">>),
+           send_element(StateData, xmpp:serr_invalid_namespace()),
+           {stop, normal, StateData};
+       #stream_start{to = #jid{lserver = Server},
+                     from = #jid{lserver = From},
+                     version = <<"1.0">>}
+         when StateData#state.tls and not StateData#state.authenticated ->
+           send_header(StateData, <<" version='1.0'">>),
+           Auth = if StateData#state.tls_enabled ->
+                          {Result, Message} =
+                              ejabberd_s2s:check_peer_certificate(
+                                StateData#state.sockmod,
+                                StateData#state.socket,
+                                From),
+                          {Result, From, Message};
+                     true ->
+                          {no_verify, <<"(unknown)">>, <<"TLS not (yet) enabled">>}
+                  end,
+           StartTLS = if StateData#state.tls_enabled -> [];
+                         not StateData#state.tls_enabled and
                          not StateData#state.tls_required ->
-                           [#xmlel{name = <<"starttls">>,
-                                   attrs = [{<<"xmlns">>, ?NS_TLS}],
-                                   children = []}];
-                       not StateData#state.tls_enabled and
+                              [#starttls{required = false}];
+                         not StateData#state.tls_enabled and
                          StateData#state.tls_required ->
-                           [#xmlel{name = <<"starttls">>,
-                                   attrs = [{<<"xmlns">>, ?NS_TLS}],
-                                   children =
-                                       [#xmlel{name = <<"required">>,
-                                               attrs = [], children = []}]}]
-                    end,
-         case Auth of
-           {error, RemoteServer, CertError}
-               when StateData#state.tls_certverify ->
-               ?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)",
-                         [StateData#state.server, RemoteServer, CertError]),
-               send_text(StateData,
-                         <<(fxml:element_to_binary(?SERRT_POLICY_VIOLATION(<<"en">>,
-                                                                          CertError)))/binary,
-                           (?STREAM_TRAILER)/binary>>),
-               {stop, normal, StateData};
-           {VerifyResult, RemoteServer, Msg} ->
-               {SASL, NewStateData} = case VerifyResult of
-                                        ok ->
-                                            {[#xmlel{name = <<"mechanisms">>,
-                                                     attrs = [{<<"xmlns">>, ?NS_SASL}],
-                                                     children =
-                                                         [#xmlel{name = <<"mechanism">>,
-                                                                 attrs = [],
-                                                                 children =
-                                                                     [{xmlcdata,
-                                                                       <<"EXTERNAL">>}]}]}],
-                                             StateData#state{auth_domain = RemoteServer}};
-                                        error ->
-                                            ?DEBUG("Won't accept certificate of ~s: ~s",
-                                                   [RemoteServer, Msg]),
-                                            {[], StateData};
-                                        no_verify ->
-                                            {[], StateData}
-                                      end,
-               send_element(NewStateData,
-                            #xmlel{name = <<"stream:features">>, attrs = [],
-                                   children =
-                                       SASL ++
-                                         StartTLS ++
-                                           ejabberd_hooks:run_fold(s2s_stream_features,
-                                                                   Server, [],
-                                                                   [Server])}),
-               {next_state, wait_for_feature_request,
-                NewStateData#state{server = Server}}
-         end;
-      {<<"jabber:server">>, _, Server, true}
-         when StateData#state.authenticated ->
-         send_text(StateData,
-                   ?STREAM_HEADER(<<" version='1.0'">>)),
-         send_element(StateData,
-                      #xmlel{name = <<"stream:features">>, attrs = [],
-                             children =
-                                 ejabberd_hooks:run_fold(s2s_stream_features,
-                                                         Server, [],
-                                                         [Server])}),
-         {next_state, stream_established, StateData};
-      {<<"jabber:server">>, <<"jabber:server:dialback">>,
-       _Server, _} when
-             (StateData#state.tls_required and StateData#state.tls_enabled)
-             or (not StateData#state.tls_required) ->
-         send_text(StateData, ?STREAM_HEADER(<<"">>)),
-         {next_state, stream_established, StateData};
-      _ ->
-         send_text(StateData, ?INVALID_NAMESPACE_ERR),
-         {stop, normal, StateData}
+                              [#starttls{required = true}]
+                      end,
+           case Auth of
+               {error, RemoteServer, CertError}
+                 when StateData#state.tls_certverify ->
+                   ?INFO_MSG("Closing s2s connection: ~s <--> ~s (~s)",
+                             [StateData#state.server, RemoteServer, CertError]),
+                   send_element(StateData,
+                                xmpp:serr_policy_violation(CertError, ?MYLANG)),
+                   {stop, normal, StateData};
+               {VerifyResult, RemoteServer, Msg} ->
+                   {SASL, NewStateData} =
+                       case VerifyResult of
+                           ok ->
+                               {[#sasl_mechanisms{list = [<<"EXTERNAL">>]}],
+                                StateData#state{auth_domain = RemoteServer}};
+                           error ->
+                               ?DEBUG("Won't accept certificate of ~s: ~s",
+                                      [RemoteServer, Msg]),
+                               {[], StateData};
+                           no_verify ->
+                               {[], StateData}
+                       end,
+                   send_element(NewStateData,
+                                #stream_features{
+                                   sub_els = SASL ++ StartTLS ++
+                                       ejabberd_hooks:run_fold(
+                                         s2s_stream_features, Server, [],
+                                         [Server])}),
+                   {next_state, wait_for_feature_request,
+                    NewStateData#state{server = Server}}
+           end;
+       #stream_start{to = #jid{lserver = Server},
+                     version = <<"1.0">>} when StateData#state.authenticated ->
+           send_header(StateData, <<" version='1.0'">>),
+           send_element(StateData,
+                        #stream_features{
+                           sub_els = ejabberd_hooks:run_fold(
+                                       s2s_stream_features, Server, [],
+                                       [Server])}),
+           {next_state, stream_established, StateData};
+       #stream_start{db_xmlns = ?NS_SERVER_DIALBACK}
+         when (StateData#state.tls_required and StateData#state.tls_enabled)
+              or (not StateData#state.tls_required) ->
+           send_header(StateData, <<"">>),
+           {next_state, stream_established, StateData};
+       #stream_start{} ->
+           send_header(StateData, <<" version='1.0'">>),
+           send_element(StateData, xmpp:serr_undefined_condition()),
+           {stop, normal, StateData}
+    catch _:{xmpp_codec, Why} ->
+           Txt = xmpp:format_error(Why),
+           send_header(StateData, <<" version='1.0'">>),
+           send_element(StateData, xmpp:serr_not_well_formed(Txt, ?MYLANG)),
+           {stop, normal, StateData}
     end;
 wait_for_stream({xmlstreamerror, _}, StateData) ->
-    send_text(StateData,
-             <<(?STREAM_HEADER(<<"">>))/binary,
-               (?INVALID_XML_ERR)/binary, (?STREAM_TRAILER)/binary>>),
+    send_header(StateData, <<"">>),
+    send_element(StateData, xmpp:serr_not_well_formed()),
     {stop, normal, StateData};
 wait_for_stream(timeout, StateData) ->
+    send_header(StateData, <<"">>),
+    send_element(StateData, xmpp:serr_connection_timeout()),
     {stop, normal, StateData};
 wait_for_stream(closed, StateData) ->
     {stop, normal, StateData}.
 
-wait_for_feature_request({xmlstreamelement, El},
-                        StateData) ->
-    #xmlel{name = Name, attrs = Attrs} = El,
-    TLS = StateData#state.tls,
-    TLSEnabled = StateData#state.tls_enabled,
-    SockMod =
-       (StateData#state.sockmod):get_sockmod(StateData#state.socket),
-    case {fxml:get_attr_s(<<"xmlns">>, Attrs), Name} of
-      {?NS_TLS, <<"starttls">>}
-         when TLS == true, TLSEnabled == false,
-              SockMod == gen_tcp ->
-         ?DEBUG("starttls", []),
-         Socket = StateData#state.socket,
-         TLSOpts1 = case
-                     ejabberd_config:get_option(
-                        {domain_certfile, StateData#state.server},
-                        fun iolist_to_binary/1) of
-                     undefined -> StateData#state.tls_options;
-                     CertFile ->
-                         [{certfile, CertFile} | lists:keydelete(certfile, 1,
-                                                                 StateData#state.tls_options)]
-                   end,
-          TLSOpts = case ejabberd_config:get_option(
-                           {s2s_tls_compression, StateData#state.server},
-                           fun(true) -> true;
-                              (false) -> false
-                           end, false) of
-                        true -> lists:delete(compression_none, TLSOpts1);
-                        false -> [compression_none | TLSOpts1]
-                    end,
-         TLSSocket = (StateData#state.sockmod):starttls(Socket,
-                                                        TLSOpts,
-                                                        fxml:element_to_binary(#xmlel{name
-                                                                                         =
-                                                                                         <<"proceed">>,
-                                                                                     attrs
-                                                                                         =
-                                                                                         [{<<"xmlns">>,
-                                                                                           ?NS_TLS}],
-                                                                                     children
-                                                                                         =
-                                                                                         []})),
-         {next_state, wait_for_stream,
-          StateData#state{socket = TLSSocket, streamid = new_id(),
-                          tls_enabled = true, tls_options = TLSOpts}};
-      {?NS_SASL, <<"auth">>} when TLSEnabled ->
-         Mech = fxml:get_attr_s(<<"mechanism">>, Attrs),
-         case Mech of
-           <<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> ->
-               AuthDomain = StateData#state.auth_domain,
-               AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>,
-                                                         AuthDomain),
-               if AllowRemoteHost ->
-                      (StateData#state.sockmod):reset_stream(StateData#state.socket),
-                      send_element(StateData,
-                                   #xmlel{name = <<"success">>,
-                                          attrs = [{<<"xmlns">>, ?NS_SASL}],
-                                          children = []}),
-                      ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)",
-                                [AuthDomain, StateData#state.tls_enabled]),
-                      change_shaper(StateData, <<"">>,
-                                    jid:make(<<"">>, AuthDomain, <<"">>)),
-                      {next_state, wait_for_stream,
-                       StateData#state{streamid = new_id(),
-                                       authenticated = true}};
-                  true ->
-                      send_element(StateData,
-                                   #xmlel{name = <<"failure">>,
-                                          attrs = [{<<"xmlns">>, ?NS_SASL}],
-                                          children = []}),
-                      send_text(StateData, ?STREAM_TRAILER),
-                      {stop, normal, StateData}
-               end;
-           _ ->
-               send_element(StateData,
-                            #xmlel{name = <<"failure">>,
-                                   attrs = [{<<"xmlns">>, ?NS_SASL}],
-                                   children =
-                                       [#xmlel{name = <<"invalid-mechanism">>,
-                                               attrs = [], children = []}]}),
-               {stop, normal, StateData}
-         end;
-      _ ->
-         stream_established({xmlstreamelement, El}, StateData)
+wait_for_feature_request({xmlstreamelement, El}, StateData) ->
+    decode_element(El, wait_for_feature_request, StateData);
+wait_for_feature_request(#starttls{},
+                        #state{tls = true, tls_enabled = false} = StateData) ->
+    case (StateData#state.sockmod):get_sockmod(StateData#state.socket) of
+       gen_tcp ->
+           ?DEBUG("starttls", []),
+           Socket = StateData#state.socket,
+           TLSOpts1 = case
+                          ejabberd_config:get_option(
+                            {domain_certfile, StateData#state.server},
+                            fun iolist_to_binary/1) of
+                          undefined -> StateData#state.tls_options;
+                          CertFile ->
+                              lists:keystore(certfile, 1,
+                                             StateData#state.tls_options,
+                                             {certfile, CertFile})
+                      end,
+           TLSOpts = case ejabberd_config:get_option(
+                            {s2s_tls_compression, StateData#state.server},
+                            fun(true) -> true;
+                               (false) -> false
+                            end, false) of
+                         true -> lists:delete(compression_none, TLSOpts1);
+                         false -> [compression_none | TLSOpts1]
+                     end,
+           TLSSocket = (StateData#state.sockmod):starttls(
+                         Socket, TLSOpts,
+                         fxml:element_to_binary(#starttls_proceed{})),
+           {next_state, wait_for_stream,
+            StateData#state{socket = TLSSocket, streamid = new_id(),
+                            tls_enabled = true, tls_options = TLSOpts}};
+       _ ->
+            Txt = <<"Unsupported TLS transport">>,
+            send_element(StateData, xmpp:serr_policy_violation(Txt, ?MYLANG)),
+            {stop, normal, StateData}
+    end;
+wait_for_feature_request(#sasl_auth{mechanism = Mech},
+                        #state{tls_enabled = true} = StateData) ->
+    case Mech of
+       <<"EXTERNAL">> when StateData#state.auth_domain /= <<"">> ->
+           AuthDomain = StateData#state.auth_domain,
+           AllowRemoteHost = ejabberd_s2s:allow_host(<<"">>, AuthDomain),
+           if AllowRemoteHost ->
+                   (StateData#state.sockmod):reset_stream(StateData#state.socket),
+                   send_element(StateData, #sasl_success{}),
+                   ?INFO_MSG("Accepted s2s EXTERNAL authentication for ~s (TLS=~p)",
+                             [AuthDomain, StateData#state.tls_enabled]),
+                   change_shaper(StateData, <<"">>, jid:make(AuthDomain)),
+                   {next_state, wait_for_stream,
+                    StateData#state{streamid = new_id(),
+                                    authenticated = true}};
+              true ->
+                   send_element(StateData, #sasl_failure{}),
+                   {stop, normal, StateData}
+           end;
+       _ ->
+           send_element(StateData, #sasl_failure{reason = 'invalid-mechanism'}),
+           {stop, normal, StateData}
     end;
-wait_for_feature_request({xmlstreamend, _Name},
-                        StateData) ->
-    send_text(StateData, ?STREAM_TRAILER),
+wait_for_feature_request({xmlstreamend, _Name}, StateData) ->
     {stop, normal, StateData};
-wait_for_feature_request({xmlstreamerror, _},
-                        StateData) ->
-    send_text(StateData,
-             <<(?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
+wait_for_feature_request({xmlstreamerror, _}, StateData) ->
+    send_element(StateData, xmpp:serr_not_well_formed()),
     {stop, normal, StateData};
 wait_for_feature_request(closed, StateData) ->
-    {stop, normal, StateData}.
+    {stop, normal, StateData};
+wait_for_feature_request(_Pkt, #state{tls_required = TLSRequired,
+                                     tls_enabled = TLSEnabled} = StateData)
+  when TLSRequired and not TLSEnabled ->
+    Txt = <<"Use of STARTTLS required">>,
+    send_element(StateData, xmpp:serr_policy_violation(Txt, ?MYLANG)),
+    {stop, normal, StateData};
+wait_for_feature_request(El, StateData) ->
+    stream_established({xmlstreamelement, El}, StateData).
 
 stream_established({xmlstreamelement, El}, StateData) ->
     cancel_timer(StateData#state.timer),
     Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
-    case is_key_packet(El) of
-      {key, To, From, Id, Key} ->
-         ?DEBUG("GET KEY: ~p", [{To, From, Id, Key}]),
-         LTo = jid:nameprep(To),
-         LFrom = jid:nameprep(From),
-         case {ejabberd_s2s:allow_host(LTo, LFrom),
-               lists:member(LTo,
-                            ejabberd_router:dirty_get_all_domains())}
-             of
-           {true, true} ->
-               ejabberd_s2s_out:terminate_if_waiting_delay(LTo, LFrom),
-               ejabberd_s2s_out:start(LTo, LFrom,
-                                      {verify, self(), Key,
-                                       StateData#state.streamid}),
-               Conns = (?DICT):store({LFrom, LTo},
-                                     wait_for_verification,
-                                     StateData#state.connections),
-               change_shaper(StateData, LTo,
-                             jid:make(<<"">>, LFrom, <<"">>)),
-               {next_state, stream_established,
-                StateData#state{connections = Conns, timer = Timer}};
-           {_, false} ->
-               send_text(StateData, ?HOST_UNKNOWN_ERR),
-               {stop, normal, StateData};
-           {false, _} ->
-               send_text(StateData, ?INVALID_FROM_ERR),
-               {stop, normal, StateData}
-         end;
-      {verify, To, From, Id, Key} ->
-         ?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]),
-         LTo = jid:nameprep(To),
-         LFrom = jid:nameprep(From),
-         Type = case ejabberd_s2s:make_key({LTo, LFrom}, Id) of
-                    Key -> <<"valid">>;
-                    _ -> <<"invalid">>
-                end,
-         send_element(StateData,
-                      #xmlel{name = <<"db:verify">>,
-                             attrs =
-                                 [{<<"from">>, To}, {<<"to">>, From},
-                                  {<<"id">>, Id}, {<<"type">>, Type}],
-                             children = []}),
-         {next_state, stream_established,
-          StateData#state{timer = Timer}};
-      _ ->
-         NewEl = jlib:remove_attr(<<"xmlns">>, El),
-         #xmlel{name = Name, attrs = Attrs} = NewEl,
-         From_s = fxml:get_attr_s(<<"from">>, Attrs),
-         From = jid:from_string(From_s),
-         To_s = fxml:get_attr_s(<<"to">>, Attrs),
-         To = jid:from_string(To_s),
-         if (To /= error) and (From /= error) ->
-                LFrom = From#jid.lserver,
-                LTo = To#jid.lserver,
-                if StateData#state.authenticated ->
-                       case LFrom == StateData#state.auth_domain andalso
-                              lists:member(LTo,
-                                           ejabberd_router:dirty_get_all_domains())
-                           of
-                         true ->
-                             if (Name == <<"iq">>) or (Name == <<"message">>)
-                                  or (Name == <<"presence">>) ->
-                                    ejabberd_hooks:run(s2s_receive_packet, LTo,
-                                                       [From, To, NewEl]),
-                                    ejabberd_router:route(From, To, NewEl);
-                                true -> error
-                             end;
-                         false -> error
-                       end;
-                   true ->
-                       case (?DICT):find({LFrom, LTo},
-                                         StateData#state.connections)
-                           of
-                         {ok, established} ->
-                             if (Name == <<"iq">>) or (Name == <<"message">>)
-                                  or (Name == <<"presence">>) ->
-                                    ejabberd_hooks:run(s2s_receive_packet, LTo,
-                                                       [From, To, NewEl]),
-                                    ejabberd_router:route(From, To, NewEl);
-                                true -> error
-                             end;
-                         _ -> error
-                       end
-                end;
-            true -> error
-         end,
-         ejabberd_hooks:run(s2s_loop_debug,
-                            [{xmlstreamelement, El}]),
-         {next_state, stream_established,
-          StateData#state{timer = Timer}}
+    decode_element(El, stream_established, StateData#state{timer = Timer});
+stream_established(#db_result{to = To, from = From, key = Key},
+                  StateData) ->
+    ?DEBUG("GET KEY: ~p", [{To, From, Key}]),
+    LTo = To#jid.lserver,
+    LFrom = From#jid.lserver,
+    case {ejabberd_s2s:allow_host(LTo, LFrom),
+         lists:member(LTo, ejabberd_router:dirty_get_all_domains())} of
+       {true, true} ->
+           ejabberd_s2s_out:terminate_if_waiting_delay(LTo, LFrom),
+           ejabberd_s2s_out:start(LTo, LFrom,
+                                  {verify, self(), Key,
+                                   StateData#state.streamid}),
+           Conns = (?DICT):store({LFrom, LTo},
+                                 wait_for_verification,
+                                 StateData#state.connections),
+           change_shaper(StateData, LTo, jid:make(LFrom)),
+           {next_state, stream_established,
+            StateData#state{connections = Conns}};
+       {_, false} ->
+           send_element(StateData, xmpp:serr_host_unknown()),
+           {stop, normal, StateData};
+       {false, _} ->
+           send_element(StateData, xmpp:serr_invalid_from()),
+           {stop, normal, StateData}
     end;
+stream_established(#db_verify{to = To, from = From, id = Id, key = Key},
+                  StateData) ->
+    ?DEBUG("VERIFY KEY: ~p", [{To, From, Id, Key}]),
+    LTo = jid:nameprep(To),
+    LFrom = jid:nameprep(From),
+    Type = case ejabberd_s2s:make_key({LTo, LFrom}, Id) of
+              Key -> valid;
+              _ -> invalid
+          end,
+    send_element(StateData,
+                #db_verify{from = To, to = From, id = Id, type = Type}),
+    {next_state, stream_established, StateData};
+stream_established(Pkt, StateData) when ?is_stanza(Pkt) ->
+    From = xmpp:get_from(Pkt),
+    To = xmpp:get_to(Pkt),
+    if To /= undefined, From /= undefined ->
+           LFrom = From#jid.lserver,
+           LTo = To#jid.lserver,
+           if StateData#state.authenticated ->
+                   case LFrom == StateData#state.auth_domain andalso
+                       lists:member(LTo, ejabberd_router:dirty_get_all_domains()) of
+                       true ->
+                           ejabberd_hooks:run(s2s_receive_packet, LTo,
+                                              [From, To, Pkt]),
+                           ejabberd_router:route(From, To, Pkt);
+                       false ->
+                           send_error(StateData, Pkt, xmpp:err_not_authorized())
+                   end;
+              true ->
+                   case (?DICT):find({LFrom, LTo}, StateData#state.connections) of
+                       {ok, established} ->
+                           ejabberd_hooks:run(s2s_receive_packet, LTo,
+                                              [From, To, Pkt]),
+                           ejabberd_router:route(From, To, Pkt);
+                       _ ->
+                           send_error(StateData, Pkt, xmpp:err_not_authorized())
+                   end
+           end;
+       true ->
+           send_error(StateData, Pkt, xmpp:err_jid_malformed())
+    end,
+    ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]),
+    {next_state, stream_established, StateData};
 stream_established({valid, From, To}, StateData) ->
     send_element(StateData,
-                #xmlel{name = <<"db:result">>,
-                       attrs =
-                           [{<<"from">>, To}, {<<"to">>, From},
-                            {<<"type">>, <<"valid">>}],
-                       children = []}),
+                #db_result{from = To, to = From, type = valid}),
     ?INFO_MSG("Accepted s2s dialback authentication for ~s (TLS=~p)",
              [From, StateData#state.tls_enabled]),
     LFrom = jid:nameprep(From),
@@ -508,11 +420,7 @@ stream_established({valid, From, To}, StateData) ->
     {next_state, stream_established, NSD};
 stream_established({invalid, From, To}, StateData) ->
     send_element(StateData,
-                #xmlel{name = <<"db:result">>,
-                       attrs =
-                           [{<<"from">>, To}, {<<"to">>, From},
-                            {<<"type">>, <<"invalid">>}],
-                       children = []}),
+                #db_result{from = To, to = From, type = invalid}),
     LFrom = jid:nameprep(From),
     LTo = jid:nameprep(To),
     NSD = StateData#state{connections =
@@ -522,14 +430,16 @@ stream_established({invalid, From, To}, StateData) ->
 stream_established({xmlstreamend, _Name}, StateData) ->
     {stop, normal, StateData};
 stream_established({xmlstreamerror, _}, StateData) ->
-    send_text(StateData,
-             <<(?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
+    send_element(StateData, xmpp:serr_not_well_formed()),
     {stop, normal, StateData};
 stream_established(timeout, StateData) ->
+    send_element(StateData, xmpp:serr_connection_timeout()),
     {stop, normal, StateData};
 stream_established(closed, StateData) ->
-    {stop, normal, StateData}.
+    {stop, normal, StateData};
+stream_established(Pkt, StateData) ->
+    ejabberd_hooks:run(s2s_loop_debug, [{xmlstreamelement, Pkt}]),
+    {next_state, stream_established, StateData}.
 
 %%----------------------------------------------------------------------
 %% Func: StateName/3
@@ -589,8 +499,14 @@ code_change(_OldVsn, StateName, StateData, _Extra) ->
 handle_info({send_text, Text}, StateName, StateData) ->
     send_text(StateData, Text),
     {next_state, StateName, StateData};
-handle_info({timeout, Timer, _}, _StateName,
+handle_info({timeout, Timer, _}, StateName,
            #state{timer = Timer} = StateData) ->
+    if StateName == wait_for_stream ->
+           send_header(StateData, <<"">>);
+       true ->
+           ok
+    end,
+    send_element(StateData, xmpp:serr_connection_timeout()),
     {stop, normal, StateData};
 handle_info(_, StateName, StateData) ->
     {next_state, StateName, StateData}.
@@ -603,6 +519,7 @@ terminate(Reason, _StateName, StateData) ->
           || Host <- get_external_hosts(StateData)];
       _ -> ok
     end,
+    catch send_trailer(StateData),
     (StateData#state.sockmod):close(StateData#state.socket),
     ok.
 
@@ -621,39 +538,69 @@ print_state(State) -> State.
 %%% Internal functions
 %%%----------------------------------------------------------------------
 
+-spec send_text(state(), iodata()) -> ok.
 send_text(StateData, Text) ->
     (StateData#state.sockmod):send(StateData#state.socket,
                                   Text).
 
+-spec send_element(state(), xmpp_element()) -> ok.
 send_element(StateData, El) ->
-    send_text(StateData, fxml:element_to_binary(El)).
+    El1 = fix_ns(xmpp:encode(El)),
+    send_text(StateData, fxml:element_to_binary(El1)).
+
+-spec send_error(state(), xmlel() | stanza(), error()) -> ok.
+send_error(StateData, Stanza, Error) ->
+    Type = xmpp:get_type(Stanza),
+    if Type == error; Type == result;
+       Type == <<"error">>; Type == <<"result">> ->
+           ok;
+       true ->
+           send_element(StateData, xmpp:make_error(Stanza, Error))
+    end.
+
+-spec send_trailer(state()) -> ok.
+send_trailer(StateData) ->
+    send_text(StateData, <<"</stream:stream>">>).
 
+-spec send_header(state(), binary()) -> ok.
+send_header(StateData, Version) ->
+    send_text(StateData,
+             <<"<?xml version='1.0'?><stream:stream "
+               "xmlns:stream='http://etherx.jabber.org/stream"
+               "s' xmlns='jabber:server' xmlns:db='jabber:ser"
+               "ver:dialback' id='",
+               (StateData#state.streamid)/binary, "'", Version/binary,
+               ">">>).
+
+-spec change_shaper(state(), binary(), jid()) -> ok.
 change_shaper(StateData, Host, JID) ->
     Shaper = acl:match_rule(Host, StateData#state.shaper,
                            JID),
     (StateData#state.sockmod):change_shaper(StateData#state.socket,
                                            Shaper).
 
+-spec fix_ns(xmlel()) -> xmlel().
+fix_ns(#xmlel{name = Name} = El) when Name == <<"message">>;
+                                      Name == <<"iq">>;
+                                      Name == <<"presence">>;
+                                      Name == <<"db:verify">>,
+                                      Name == <<"db:result">> ->
+    Attrs = lists:filter(
+              fun({<<"xmlns">>, _}) -> false;
+                 (_) -> true
+              end, El#xmlel.attrs),
+    El#xmlel{attrs = Attrs};
+fix_ns(El) ->
+    El.
+
+-spec new_id() -> binary().
 new_id() -> randoms:get_string().
 
+-spec cancel_timer(reference()) -> ok.
 cancel_timer(Timer) ->
     erlang:cancel_timer(Timer),
     receive {timeout, Timer, _} -> ok after 0 -> ok end.
 
-is_key_packet(#xmlel{name = Name, attrs = Attrs,
-                    children = Els})
-    when Name == <<"db:result">> ->
-    {key, fxml:get_attr_s(<<"to">>, Attrs),
-     fxml:get_attr_s(<<"from">>, Attrs),
-     fxml:get_attr_s(<<"id">>, Attrs), fxml:get_cdata(Els)};
-is_key_packet(#xmlel{name = Name, attrs = Attrs,
-                    children = Els})
-    when Name == <<"db:verify">> ->
-    {verify, fxml:get_attr_s(<<"to">>, Attrs),
-     fxml:get_attr_s(<<"from">>, Attrs),
-     fxml:get_attr_s(<<"id">>, Attrs), fxml:get_cdata(Els)};
-is_key_packet(_) -> false.
-
 fsm_limit_opts(Opts) ->
     case lists:keysearch(max_fsm_queue, 1, Opts) of
       {value, {_, N}} when is_integer(N) -> [{max_queue, N}];
@@ -666,6 +613,22 @@ fsm_limit_opts(Opts) ->
          end
     end.
 
+-spec decode_element(xmlel(), state_name(), state()) -> fsm_transition().
+decode_element(#xmlel{} = El, StateName, StateData) ->
+    try xmpp:decode(El) of
+       Pkt -> ?MODULE:StateName(Pkt, StateData)
+    catch error:{xmpp_codec, Why} ->
+            case xmpp:is_stanza(El) of
+                true ->
+                   Lang = xmpp:get_lang(El),
+                    Txt = xmpp:format_error(Why),
+                   send_error(StateData, El, xmpp:err_bad_request(Txt, Lang));
+                false ->
+                    ok
+            end,
+            {next_state, StateName, StateData}
+    end.
+
 opt_type(domain_certfile) -> fun iolist_to_binary/1;
 opt_type(max_fsm_queue) ->
     fun (I) when is_integer(I), I > 0 -> I end;
index a30f2f438d0f3c7ece331cbc342a2bf888c0038c..024e51e7abb11a70f3df768b4f0f37d5334d793d 100644 (file)
@@ -50,8 +50,7 @@
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
-
--include("jlib.hrl").
+-include("xmpp.hrl").
 
 -record(state,
        {socket                           :: ejabberd_socket:socket_state(),
          bridge                           :: {atom(), atom()},
          timer = make_ref()               :: reference()}).
 
+-type state_name() :: open_socket | wait_for_stream |
+                     wait_for_validation | wait_for_features |
+                     wait_for_auth_result | wait_for_starttls_proceed |
+                     relay_to_bridge | reopen_socket | wait_before_retry |
+                     stream_established.
+-type state() :: #state{}.
+-type fsm_stop() :: {stop, normal, state()}.
+-type fsm_next() :: {next_state, state_name(), state(), non_neg_integer()} |
+                   {next_state, state_name(), state()}.
+-type fsm_transition() :: fsm_stop() | fsm_next().
+
 %%-define(DBGFSM, true).
 
 -ifdef(DBGFSM).
          "s' xmlns='jabber:server' xmlns:db='jabber:ser"
          "ver:dialback' from='~s' to='~s'~s>">>).
 
--define(STREAM_TRAILER, <<"</stream:stream>">>).
-
--define(INVALID_NAMESPACE_ERR,
-       fxml:element_to_binary(?SERR_INVALID_NAMESPACE)).
-
--define(HOST_UNKNOWN_ERR,
-       fxml:element_to_binary(?SERR_HOST_UNKNOWN)).
-
--define(INVALID_XML_ERR,
-       fxml:element_to_binary(?SERR_XML_NOT_WELL_FORMED)).
-
 -define(SOCKET_DEFAULT_RESULT, {error, badarg}).
 
 %%%----------------------------------------------------------------------
@@ -236,10 +235,7 @@ open_socket(init, StateData) ->
          NewStateData = StateData#state{socket = Socket,
                                         tls_enabled = false,
                                         streamid = new_id()},
-         send_text(NewStateData,
-                   io_lib:format(?STREAM_HEADER,
-                                 [StateData#state.myname,
-                                  StateData#state.server, Version])),
+         send_header(NewStateData, Version),
          {next_state, wait_for_stream, NewStateData,
           ?FSMTIMEOUT};
       {error, _Reason} ->
@@ -259,18 +255,8 @@ open_socket(init, StateData) ->
            _ -> wait_before_reconnect(StateData)
          end
     end;
-open_socket(closed, StateData) ->
-    ?INFO_MSG("s2s connection: ~s -> ~s (stopped in "
-             "open socket)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-open_socket(timeout, StateData) ->
-    ?INFO_MSG("s2s connection: ~s -> ~s (timeout in "
-             "open socket)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-open_socket(_, StateData) ->
-    {next_state, open_socket, StateData}.
+open_socket(Event, StateData) ->
+    handle_unexpected_event(Event, open_socket, StateData).
 
 open_socket1({_, _, _, _} = Addr, Port) ->
     open_socket2(inet, Addr, Port);
@@ -309,466 +295,215 @@ open_socket2(Type, Addr, Port) ->
 
 %%----------------------------------------------------------------------
 
-wait_for_stream({xmlstreamstart, _Name, Attrs},
-               StateData) ->
-    {CertCheckRes, CertCheckMsg, StateData0} =
-       if StateData#state.tls_certverify, StateData#state.tls_enabled ->
-              {Res, Msg} =
-                  ejabberd_s2s:check_peer_certificate(ejabberd_socket,
-                                                      StateData#state.socket,
-                                                      StateData#state.server),
-              ?DEBUG("Certificate verification result for ~s: ~s",
-                     [StateData#state.server, Msg]),
-              {Res, Msg, StateData#state{tls_certverify = false}};
+wait_for_stream({xmlstreamstart, Name, Attrs}, StateData0) ->
+    {CertCheckRes, CertCheckMsg, StateData} =
+       if StateData0#state.tls_certverify, StateData0#state.tls_enabled ->
+               {Res, Msg} =
+                   ejabberd_s2s:check_peer_certificate(ejabberd_socket,
+                                                       StateData0#state.socket,
+                                                       StateData0#state.server),
+               ?DEBUG("Certificate verification result for ~s: ~s",
+                      [StateData0#state.server, Msg]),
+               {Res, Msg, StateData0#state{tls_certverify = false}};
           true ->
-              {no_verify, <<"Not verified">>, StateData}
+               {no_verify, <<"Not verified">>, StateData0}
        end,
-    RemoteStreamID = fxml:get_attr_s(<<"id">>, Attrs),
-    NewStateData = StateData0#state{remote_streamid = RemoteStreamID},
-    case {fxml:get_attr_s(<<"xmlns">>, Attrs),
-         fxml:get_attr_s(<<"xmlns:db">>, Attrs),
-         fxml:get_attr_s(<<"version">>, Attrs) == <<"1.0">>}
-       of
-      _ when CertCheckRes == error ->
-         send_text(NewStateData,
-                   <<(fxml:element_to_binary(?SERRT_POLICY_VIOLATION(<<"en">>,
-                                                                    CertCheckMsg)))/binary,
-                     (?STREAM_TRAILER)/binary>>),
-         ?INFO_MSG("Closing s2s connection: ~s -> ~s (~s)",
-                   [NewStateData#state.myname,
-                    NewStateData#state.server,
-                    CertCheckMsg]),
-         {stop, normal, NewStateData};
-      {<<"jabber:server">>, <<"jabber:server:dialback">>,
-       false} ->
-         send_db_request(NewStateData);
-      {<<"jabber:server">>, <<"jabber:server:dialback">>,
-       true}
-         when NewStateData#state.use_v10 ->
-         {next_state, wait_for_features, NewStateData, ?FSMTIMEOUT};
-      %% Clause added to handle Tigase's workaround for an old ejabberd bug:
-      {<<"jabber:server">>, <<"jabber:server:dialback">>,
-       true}
-         when not NewStateData#state.use_v10 ->
-         send_db_request(NewStateData);
-      {<<"jabber:server">>, <<"">>, true}
-         when NewStateData#state.use_v10 ->
-         {next_state, wait_for_features,
-          NewStateData#state{db_enabled = false}, ?FSMTIMEOUT};
-      {NSProvided, DB, _} ->
-         send_text(NewStateData, ?INVALID_NAMESPACE_ERR),
-         ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid "
-                   "namespace).~nNamespace provided: ~p~nNamespac"
-                   "e expected: \"jabber:server\"~nxmlns:db "
-                   "provided: ~p~nAll attributes: ~p",
-                   [NewStateData#state.myname, NewStateData#state.server,
-                    NSProvided, DB, Attrs]),
-         {stop, normal, NewStateData}
+    try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
+       _ when CertCheckRes == error ->
+           send_element(StateData,
+                        xmpp:serr_policy_violation(CertCheckMsg, ?MYLANG)),
+           ?INFO_MSG("Closing s2s connection: ~s -> ~s (~s)",
+                     [StateData#state.myname, StateData#state.server,
+                      CertCheckMsg]),
+           {stop, normal, StateData};
+       #stream_start{xmlns = NS_SERVER, stream_xmlns = NS_STREAM}
+         when NS_SERVER /= ?NS_SERVER; NS_STREAM /= ?NS_STREAM ->
+           send_header(StateData, <<" version='1.0'">>),
+           send_element(StateData, xmpp:serr_invalid_namespace()),
+           {stop, normal, StateData};
+       #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID,
+                     version = V} when V /= <<"1.0">> ->
+           send_db_request(StateData#state{remote_streamid = ID});
+       #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID}
+         when StateData#state.use_v10 ->
+           {next_state, wait_for_features,
+            StateData#state{remote_streamid = ID}, ?FSMTIMEOUT};
+       #stream_start{db_xmlns = ?NS_SERVER_DIALBACK, id = ID}
+         when not StateData#state.use_v10 ->
+           %% Handle Tigase's workaround for an old ejabberd bug:
+           send_db_request(StateData#state{remote_streamid = ID});
+       #stream_start{id = ID} when StateData#state.use_v10 ->
+           {next_state, wait_for_features,
+            StateData#state{db_enabled = false, remote_streamid = ID},
+            ?FSMTIMEOUT};
+       #stream_start{} ->
+           send_header(StateData, <<"">>),
+           send_element(StateData, xmpp:serr_invalid_namespace()),
+           {stop, normal, StateData}
+    catch _:{xmpp_codec, Why} ->
+           Txt = xmpp:format_error(Why),
+           send_header(StateData, <<" version='1.0'">>),
+           send_element(StateData, xmpp:serr_not_well_formed(Txt, ?MYLANG)),
+           {stop, normal, StateData}
     end;
-wait_for_stream({xmlstreamerror, _}, StateData) ->
-    send_text(StateData,
-             <<(?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
-    ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid "
-             "xml)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-wait_for_stream({xmlstreamend, _Name}, StateData) ->
-    ?INFO_MSG("Closing s2s connection: ~s -> ~s (xmlstreamend)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-wait_for_stream(timeout, StateData) ->
-    ?INFO_MSG("Closing s2s connection: ~s -> ~s (timeout "
-             "in wait_for_stream)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-wait_for_stream(closed, StateData) ->
-    ?INFO_MSG("Closing s2s connection: ~s -> ~s (close "
-             "in wait_for_stream)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData}.
-
-wait_for_validation({xmlstreamelement, El},
+wait_for_stream(Event, StateData) ->
+    handle_unexpected_event(Event, wait_for_stream, StateData).
+
+wait_for_validation({xmlstreamelement, El}, StateData) ->
+    decode_element(El, wait_for_validation, StateData);
+wait_for_validation(#db_result{to = To, from = From, type = Type}, StateData) ->
+    ?DEBUG("recv result: ~p", [{From, To, Type}]),
+    case {Type, StateData#state.tls_enabled, StateData#state.tls_required} of
+       {valid, Enabled, Required} when (Enabled == true) or (Required == false) ->
+           send_queue(StateData, StateData#state.queue),
+           ?INFO_MSG("Connection established: ~s -> ~s with "
+                     "TLS=~p",
+                     [StateData#state.myname, StateData#state.server,
+                      StateData#state.tls_enabled]),
+           ejabberd_hooks:run(s2s_connect_hook,
+                              [StateData#state.myname,
+                               StateData#state.server]),
+           {next_state, stream_established, StateData#state{queue = queue:new()}};
+       {valid, Enabled, Required} when (Enabled == false) and (Required == true) ->
+           ?INFO_MSG("Closing s2s connection: ~s -> ~s (TLS "
+                     "is required but unavailable)",
+                     [StateData#state.myname, StateData#state.server]),
+           {stop, normal, StateData};
+       _ ->
+           ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid "
+                     "dialback key result)",
+                     [StateData#state.myname, StateData#state.server]),
+           {stop, normal, StateData}
+    end;
+wait_for_validation(#db_verify{to = To, from = From, id = Id, type = Type},
                    StateData) ->
-    case is_verify_res(El) of
-      {result, To, From, Id, Type} ->
-         ?DEBUG("recv result: ~p", [{From, To, Id, Type}]),
-         case {Type, StateData#state.tls_enabled,
-               StateData#state.tls_required}
-             of
-           {<<"valid">>, Enabled, Required}
-               when (Enabled == true) or (Required == false) ->
-               send_queue(StateData, StateData#state.queue),
-               ?INFO_MSG("Connection established: ~s -> ~s with "
-                         "TLS=~p",
-                         [StateData#state.myname, StateData#state.server,
-                          StateData#state.tls_enabled]),
-               ejabberd_hooks:run(s2s_connect_hook,
-                                  [StateData#state.myname,
-                                   StateData#state.server]),
-               {next_state, stream_established,
-                StateData#state{queue = queue:new()}};
-           {<<"valid">>, Enabled, Required}
-               when (Enabled == false) and (Required == true) ->
-               ?INFO_MSG("Closing s2s connection: ~s -> ~s (TLS "
-                         "is required but unavailable)",
-                         [StateData#state.myname, StateData#state.server]),
-               {stop, normal, StateData};
-           _ ->
-               ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid "
-                         "dialback key)",
-                         [StateData#state.myname, StateData#state.server]),
-               {stop, normal, StateData}
-         end;
-      {verify, To, From, Id, Type} ->
-         ?DEBUG("recv verify: ~p", [{From, To, Id, Type}]),
-         case StateData#state.verify of
-           false ->
-               NextState = wait_for_validation,
-               {next_state, NextState, StateData,
-                get_timeout_interval(NextState)};
-           {Pid, _Key, _SID} ->
-               case Type of
-                 <<"valid">> ->
-                     p1_fsm:send_event(Pid,
-                                       {valid, StateData#state.server,
-                                        StateData#state.myname});
-                 _ ->
-                     p1_fsm:send_event(Pid,
-                                       {invalid, StateData#state.server,
-                                        StateData#state.myname})
-               end,
-               if StateData#state.verify == false ->
-                      {stop, normal, StateData};
-                  true ->
-                      NextState = wait_for_validation,
-                      {next_state, NextState, StateData,
-                       get_timeout_interval(NextState)}
-               end
-         end;
-      _ ->
-         {next_state, wait_for_validation, StateData,
-          (?FSMTIMEOUT) * 3}
+    ?DEBUG("recv verify: ~p", [{From, To, Id, Type}]),
+    case StateData#state.verify of
+       false ->
+           NextState = wait_for_validation,
+           {next_state, NextState, StateData, get_timeout_interval(NextState)};
+       {Pid, _Key, _SID} ->
+           case Type of
+               valid ->
+                   p1_fsm:send_event(Pid,
+                                     {valid, StateData#state.server,
+                                      StateData#state.myname});
+               _ ->
+                   p1_fsm:send_event(Pid,
+                                     {invalid, StateData#state.server,
+                                      StateData#state.myname})
+           end,
+           if StateData#state.verify == false ->
+                   {stop, normal, StateData};
+              true ->
+                   NextState = wait_for_validation,
+                   {next_state, NextState, StateData, get_timeout_interval(NextState)}
+           end
     end;
-wait_for_validation({xmlstreamend, _Name}, StateData) ->
-    ?INFO_MSG("wait for validation: ~s -> ~s (xmlstreamend)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-wait_for_validation({xmlstreamerror, _}, StateData) ->
-    ?INFO_MSG("wait for validation: ~s -> ~s (xmlstreamerror)",
-             [StateData#state.myname, StateData#state.server]),
-    send_text(StateData,
-             <<(?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
-    {stop, normal, StateData};
 wait_for_validation(timeout,
                    #state{verify = {VPid, VKey, SID}} = StateData)
-    when is_pid(VPid) and is_binary(VKey) and
-          is_binary(SID) ->
-    ?DEBUG("wait_for_validation: ~s -> ~s (timeout "
-          "in verify connection)",
+  when is_pid(VPid) and is_binary(VKey) and is_binary(SID) ->
+    ?DEBUG("wait_for_validation: ~s -> ~s (timeout in verify connection)",
           [StateData#state.myname, StateData#state.server]),
     {stop, normal, StateData};
-wait_for_validation(timeout, StateData) ->
-    ?INFO_MSG("wait_for_validation: ~s -> ~s (connect "
-             "timeout)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-wait_for_validation(closed, StateData) ->
-    ?INFO_MSG("wait for validation: ~s -> ~s (closed)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData}.
+wait_for_validation(Event, StateData) ->
+    handle_unexpected_event(Event, wait_for_validation, StateData).
 
 wait_for_features({xmlstreamelement, El}, StateData) ->
-    case El of
-      #xmlel{name = <<"stream:features">>, children = Els} ->
-         {SASLEXT, StartTLS, StartTLSRequired} = lists:foldl(fun
-                                                               (#xmlel{name =
-                                                                           <<"mechanisms">>,
-                                                                       attrs =
-                                                                           Attrs1,
-                                                                       children
-                                                                           =
-                                                                           Els1} =
-                                                                    _El1,
-                                                                {_SEXT, STLS,
-                                                                 STLSReq} =
-                                                                    Acc) ->
-                                                                   case
-                                                                     fxml:get_attr_s(<<"xmlns">>,
-                                                                                    Attrs1)
-                                                                       of
-                                                                     ?NS_SASL ->
-                                                                         NewSEXT =
-                                                                             lists:any(fun
-                                                                                         (#xmlel{name
-                                                                                                     =
-                                                                                                     <<"mechanism">>,
-                                                                                                 children
-                                                                                                     =
-                                                                                                     Els2}) ->
-                                                                                             case
-                                                                                               fxml:get_cdata(Els2)
-                                                                                                 of
-                                                                                               <<"EXTERNAL">> ->
-                                                                                                   true;
-                                                                                               _ ->
-                                                                                                   false
-                                                                                             end;
-                                                                                         (_) ->
-                                                                                             false
-                                                                                       end,
-                                                                                       Els1),
-                                                                         {NewSEXT,
-                                                                          STLS,
-                                                                          STLSReq};
-                                                                     _ -> Acc
-                                                                   end;
-                                                               (#xmlel{name =
-                                                                           <<"starttls">>,
-                                                                       attrs =
-                                                                           Attrs1} =
-                                                                    El1,
-                                                                {SEXT, _STLS,
-                                                                 _STLSReq} =
-                                                                    Acc) ->
-                                                                   case
-                                                                     fxml:get_attr_s(<<"xmlns">>,
-                                                                                    Attrs1)
-                                                                       of
-                                                                     ?NS_TLS ->
-                                                                         Req =
-                                                                             case
-                                                                               fxml:get_subtag(El1,
-                                                                                              <<"required">>)
-                                                                                 of
-                                                                               #xmlel{} ->
-                                                                                   true;
-                                                                               false ->
-                                                                                   false
-                                                                             end,
-                                                                         {SEXT,
-                                                                          true,
-                                                                          Req};
-                                                                     _ -> Acc
-                                                                   end;
-                                                               (_, Acc) -> Acc
-                                                             end,
-                                                             {false, false,
-                                                              false},
-                                                             Els),
-         if not SASLEXT and not StartTLS and
-              StateData#state.authenticated ->
-                send_queue(StateData, StateData#state.queue),
-                ?INFO_MSG("Connection established: ~s -> ~s with "
-                          "SASL EXTERNAL and TLS=~p",
-                          [StateData#state.myname, StateData#state.server,
-                           StateData#state.tls_enabled]),
-                ejabberd_hooks:run(s2s_connect_hook,
-                                   [StateData#state.myname,
-                                    StateData#state.server]),
-                {next_state, stream_established,
-                 StateData#state{queue = queue:new()}};
-            SASLEXT and StateData#state.try_auth and
-              (StateData#state.new /= false) and
-                (StateData#state.tls_enabled or
-                  not StateData#state.tls_required) ->
-                send_element(StateData,
-                             #xmlel{name = <<"auth">>,
-                                    attrs =
-                                        [{<<"xmlns">>, ?NS_SASL},
-                                         {<<"mechanism">>, <<"EXTERNAL">>}],
-                                    children =
-                                        [{xmlcdata,
-                                          jlib:encode_base64(StateData#state.myname)}]}),
-                {next_state, wait_for_auth_result,
-                 StateData#state{try_auth = false}, ?FSMTIMEOUT};
-            StartTLS and StateData#state.tls and
-              not StateData#state.tls_enabled ->
-                send_element(StateData,
-                             #xmlel{name = <<"starttls">>,
-                                    attrs = [{<<"xmlns">>, ?NS_TLS}],
-                                    children = []}),
-                {next_state, wait_for_starttls_proceed, StateData,
-                 ?FSMTIMEOUT};
-            StartTLSRequired and not StateData#state.tls ->
-                ?DEBUG("restarted: ~p",
-                       [{StateData#state.myname, StateData#state.server}]),
-                ejabberd_socket:close(StateData#state.socket),
-                {next_state, reopen_socket,
-                 StateData#state{socket = undefined, use_v10 = false},
-                 ?FSMTIMEOUT};
-            StateData#state.db_enabled ->
-                send_db_request(StateData);
-            true ->
-                ?DEBUG("restarted: ~p",
-                       [{StateData#state.myname, StateData#state.server}]),
-                ejabberd_socket:close(StateData#state.socket),
-                {next_state, reopen_socket,
-                 StateData#state{socket = undefined, use_v10 = false},
-                 ?FSMTIMEOUT}
-         end;
-      _ ->
-         send_text(StateData,
-                   <<(fxml:element_to_binary(?SERR_BAD_FORMAT))/binary,
-                     (?STREAM_TRAILER)/binary>>),
-         ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad "
-                   "format)",
-                   [StateData#state.myname, StateData#state.server]),
-         {stop, normal, StateData}
+    decode_element(El, wait_for_features, StateData);
+wait_for_features(#stream_features{sub_els = Els}, StateData) ->
+    {SASLEXT, StartTLS, StartTLSRequired} =
+       lists:foldl(
+         fun(#sasl_mechanisms{list = Mechs}, {_, STLS, STLSReq}) ->
+                 {lists:member(<<"EXTERNAL">>, Mechs), STLS, STLSReq};
+            (#starttls{required = Required}, {SEXT, _, _}) ->
+                 {SEXT, true, Required};
+            (_, Acc) ->
+                 Acc
+         end, {false, false, false}, Els),
+    if not SASLEXT and not StartTLS and StateData#state.authenticated ->
+           send_queue(StateData, StateData#state.queue),
+           ?INFO_MSG("Connection established: ~s -> ~s with "
+                     "SASL EXTERNAL and TLS=~p",
+                     [StateData#state.myname, StateData#state.server,
+                      StateData#state.tls_enabled]),
+           ejabberd_hooks:run(s2s_connect_hook,
+                              [StateData#state.myname,
+                               StateData#state.server]),
+           {next_state, stream_established,
+            StateData#state{queue = queue:new()}};
+       SASLEXT and StateData#state.try_auth and
+       (StateData#state.new /= false) and
+       (StateData#state.tls_enabled or
+       not StateData#state.tls_required) ->
+           send_element(StateData,
+                        #sasl_auth{mechanism = <<"EXTERNAL">>,
+                                   text = StateData#state.myname}),
+           {next_state, wait_for_auth_result,
+            StateData#state{try_auth = false}, ?FSMTIMEOUT};
+       StartTLS and StateData#state.tls and
+       not StateData#state.tls_enabled ->
+           send_element(StateData, #starttls{}),
+           {next_state, wait_for_starttls_proceed, StateData, ?FSMTIMEOUT};
+       StartTLSRequired and not StateData#state.tls ->
+           ?DEBUG("restarted: ~p",
+                  [{StateData#state.myname, StateData#state.server}]),
+           ejabberd_socket:close(StateData#state.socket),
+           {next_state, reopen_socket,
+            StateData#state{socket = undefined, use_v10 = false},
+            ?FSMTIMEOUT};
+       StateData#state.db_enabled ->
+           send_db_request(StateData);
+       true ->
+           ?DEBUG("restarted: ~p",
+                  [{StateData#state.myname, StateData#state.server}]),
+           ejabberd_socket:close(StateData#state.socket),
+           {next_state, reopen_socket,
+            StateData#state{socket = undefined, use_v10 = false}, ?FSMTIMEOUT}
     end;
-wait_for_features({xmlstreamend, _Name}, StateData) ->
-    ?INFO_MSG("wait_for_features: xmlstreamend", []),
-    {stop, normal, StateData};
-wait_for_features({xmlstreamerror, _}, StateData) ->
-    send_text(StateData,
-             <<(?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
-    ?INFO_MSG("wait for features: xmlstreamerror", []),
-    {stop, normal, StateData};
-wait_for_features(timeout, StateData) ->
-    ?INFO_MSG("wait for features: timeout", []),
-    {stop, normal, StateData};
-wait_for_features(closed, StateData) ->
-    ?INFO_MSG("wait for features: closed", []),
-    {stop, normal, StateData}.
-
-wait_for_auth_result({xmlstreamelement, El},
-                    StateData) ->
-    case El of
-      #xmlel{name = <<"success">>, attrs = Attrs} ->
-         case fxml:get_attr_s(<<"xmlns">>, Attrs) of
-           ?NS_SASL ->
-               ?DEBUG("auth: ~p",
-                      [{StateData#state.myname, StateData#state.server}]),
-               ejabberd_socket:reset_stream(StateData#state.socket),
-               send_text(StateData,
-                         io_lib:format(?STREAM_HEADER,
-                                       [StateData#state.myname,
-                                        StateData#state.server,
-                                        <<" version='1.0'">>])),
-               {next_state, wait_for_stream,
-                StateData#state{streamid = new_id(),
-                                authenticated = true},
-                ?FSMTIMEOUT};
-           _ ->
-               send_text(StateData,
-                         <<(fxml:element_to_binary(?SERR_BAD_FORMAT))/binary,
-                           (?STREAM_TRAILER)/binary>>),
-               ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad "
-                         "format)",
-                         [StateData#state.myname, StateData#state.server]),
-               {stop, normal, StateData}
-         end;
-      #xmlel{name = <<"failure">>, attrs = Attrs} ->
-         case fxml:get_attr_s(<<"xmlns">>, Attrs) of
-           ?NS_SASL ->
-               ?DEBUG("restarted: ~p",
-                      [{StateData#state.myname, StateData#state.server}]),
-               ejabberd_socket:close(StateData#state.socket),
-               {next_state, reopen_socket,
-                StateData#state{socket = undefined}, ?FSMTIMEOUT};
-           _ ->
-               send_text(StateData,
-                         <<(fxml:element_to_binary(?SERR_BAD_FORMAT))/binary,
-                           (?STREAM_TRAILER)/binary>>),
-               ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad "
-                         "format)",
-                         [StateData#state.myname, StateData#state.server]),
-               {stop, normal, StateData}
-         end;
-      _ ->
-         send_text(StateData,
-                   <<(fxml:element_to_binary(?SERR_BAD_FORMAT))/binary,
-                     (?STREAM_TRAILER)/binary>>),
-         ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad "
-                   "format)",
-                   [StateData#state.myname, StateData#state.server]),
-         {stop, normal, StateData}
-    end;
-wait_for_auth_result({xmlstreamend, _Name},
-                    StateData) ->
-    ?INFO_MSG("wait for auth result: xmlstreamend", []),
-    {stop, normal, StateData};
-wait_for_auth_result({xmlstreamerror, _}, StateData) ->
-    send_text(StateData,
-             <<(?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
-    ?INFO_MSG("wait for auth result: xmlstreamerror", []),
-    {stop, normal, StateData};
-wait_for_auth_result(timeout, StateData) ->
-    ?INFO_MSG("wait for auth result: timeout", []),
-    {stop, normal, StateData};
-wait_for_auth_result(closed, StateData) ->
-    ?INFO_MSG("wait for auth result: closed", []),
-    {stop, normal, StateData}.
-
-wait_for_starttls_proceed({xmlstreamelement, El},
-                         StateData) ->
-    case El of
-      #xmlel{name = <<"proceed">>, attrs = Attrs} ->
-         case fxml:get_attr_s(<<"xmlns">>, Attrs) of
-           ?NS_TLS ->
-               ?DEBUG("starttls: ~p",
-                      [{StateData#state.myname, StateData#state.server}]),
-               Socket = StateData#state.socket,
-               TLSOpts = case
-                           ejabberd_config:get_option(
-                              {domain_certfile, StateData#state.myname},
-                              fun iolist_to_binary/1)
-                             of
-                           undefined -> StateData#state.tls_options;
-                           CertFile ->
-                               [{certfile, CertFile}
-                                | lists:keydelete(certfile, 1,
-                                                  StateData#state.tls_options)]
-                         end,
-               TLSSocket = ejabberd_socket:starttls(Socket, TLSOpts),
-               NewStateData = StateData#state{socket = TLSSocket,
-                                              streamid = new_id(),
-                                              tls_enabled = true,
-                                              tls_options = TLSOpts},
-               send_text(NewStateData,
-                         io_lib:format(?STREAM_HEADER,
-                                       [NewStateData#state.myname,
-                                        NewStateData#state.server,
-                                        <<" version='1.0'">>])),
-               {next_state, wait_for_stream, NewStateData,
-                ?FSMTIMEOUT};
-           _ ->
-               send_text(StateData,
-                         <<(fxml:element_to_binary(?SERR_BAD_FORMAT))/binary,
-                           (?STREAM_TRAILER)/binary>>),
-               ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad "
-                         "format)",
-                         [StateData#state.myname, StateData#state.server]),
-               {stop, normal, StateData}
-         end;
-      _ ->
-         ?INFO_MSG("Closing s2s connection: ~s -> ~s (bad "
-                   "format)",
-                   [StateData#state.myname, StateData#state.server]),
-         {stop, normal, StateData}
-    end;
-wait_for_starttls_proceed({xmlstreamend, _Name},
-                         StateData) ->
-    ?INFO_MSG("wait for starttls proceed: xmlstreamend",
-             []),
-    {stop, normal, StateData};
-wait_for_starttls_proceed({xmlstreamerror, _},
-                         StateData) ->
-    send_text(StateData,
-             <<(?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
-    ?INFO_MSG("wait for starttls proceed: xmlstreamerror",
-             []),
-    {stop, normal, StateData};
-wait_for_starttls_proceed(timeout, StateData) ->
-    ?INFO_MSG("wait for starttls proceed: timeout", []),
-    {stop, normal, StateData};
-wait_for_starttls_proceed(closed, StateData) ->
-    ?INFO_MSG("wait for starttls proceed: closed", []),
-    {stop, normal, StateData}.
+wait_for_features(Event, StateData) ->
+    handle_unexpected_event(Event, wait_for_features, StateData).
+
+wait_for_auth_result({xmlstreamelement, El}, StateData) ->
+    decode_element(El, wait_for_auth_result, StateData);
+wait_for_auth_result(#sasl_success{}, StateData) ->
+    ?DEBUG("auth: ~p", [{StateData#state.myname, StateData#state.server}]),
+    ejabberd_socket:reset_stream(StateData#state.socket),
+    send_header(StateData, <<" version='1.0'">>),
+    {next_state, wait_for_stream,
+     StateData#state{streamid = new_id(), authenticated = true},
+     ?FSMTIMEOUT};
+wait_for_auth_result(#sasl_failure{}, StateData) ->
+    ?DEBUG("restarted: ~p", [{StateData#state.myname, StateData#state.server}]),
+    ejabberd_socket:close(StateData#state.socket),
+    {next_state, reopen_socket,
+     StateData#state{socket = undefined}, ?FSMTIMEOUT};
+wait_for_auth_result(Event, StateData) ->
+    handle_unexpected_event(Event, wait_for_auth_result, StateData).
+
+wait_for_starttls_proceed({xmlstreamelement, El}, StateData) ->
+    decode_element(El, wait_for_starttls_proceed, StateData);
+wait_for_starttls_proceed(#starttls_proceed{}, StateData) ->
+    ?DEBUG("starttls: ~p", [{StateData#state.myname, StateData#state.server}]),
+    Socket = StateData#state.socket,
+    TLSOpts = case ejabberd_config:get_option(
+                    {domain_certfile, StateData#state.myname},
+                    fun iolist_to_binary/1) of
+                 undefined -> StateData#state.tls_options;
+                 CertFile ->
+                     [{certfile, CertFile}
+                      | lists:keydelete(certfile, 1,
+                                        StateData#state.tls_options)]
+             end,
+    TLSSocket = ejabberd_socket:starttls(Socket, TLSOpts),
+    NewStateData = StateData#state{socket = TLSSocket,
+                                  streamid = new_id(),
+                                  tls_enabled = true,
+                                  tls_options = TLSOpts},
+    send_header(NewStateData, <<" version='1.0'">>),
+    {next_state, wait_for_stream, NewStateData, ?FSMTIMEOUT};
+wait_for_starttls_proceed(Event, StateData) ->
+    handle_unexpected_event(Event, wait_for_starttls_proceed, StateData).
 
 reopen_socket({xmlstreamelement, _El}, StateData) ->
     {next_state, reopen_socket, StateData, ?FSMTIMEOUT};
@@ -797,47 +532,69 @@ relay_to_bridge(_Event, StateData) ->
     {next_state, relay_to_bridge, StateData}.
 
 stream_established({xmlstreamelement, El}, StateData) ->
-    ?DEBUG("s2S stream established", []),
-    case is_verify_res(El) of
-      {verify, VTo, VFrom, VId, VType} ->
-         ?DEBUG("recv verify: ~p", [{VFrom, VTo, VId, VType}]),
-         case StateData#state.verify of
-           {VPid, _VKey, _SID} ->
-               case VType of
-                 <<"valid">> ->
-                     p1_fsm:send_event(VPid,
-                                       {valid, StateData#state.server,
-                                        StateData#state.myname});
-                 _ ->
-                     p1_fsm:send_event(VPid,
-                                       {invalid, StateData#state.server,
-                                        StateData#state.myname})
-               end;
-           _ -> ok
-         end;
-      _ -> ok
+    decode_element(El, stream_established, StateData);
+stream_established(#db_verify{to = VTo, from = VFrom, id = VId, type = VType},
+                  StateData) ->
+    ?DEBUG("recv verify: ~p", [{VFrom, VTo, VId, VType}]),
+    case StateData#state.verify of
+       {VPid, _VKey, _SID} ->
+           case VType of
+               valid ->
+                   p1_fsm:send_event(VPid,
+                                     {valid, StateData#state.server,
+                                      StateData#state.myname});
+               _ ->
+                   p1_fsm:send_event(VPid,
+                                     {invalid, StateData#state.server,
+                                      StateData#state.myname})
+           end;
+       _ -> ok
     end,
     {next_state, stream_established, StateData};
-stream_established({xmlstreamend, _Name}, StateData) ->
-    ?INFO_MSG("Connection closed in stream established: "
-             "~s -> ~s (xmlstreamend)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-stream_established({xmlstreamerror, _}, StateData) ->
-    send_text(StateData,
-             <<(?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
-    ?INFO_MSG("stream established: ~s -> ~s (xmlstreamerror)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-stream_established(timeout, StateData) ->
-    ?INFO_MSG("stream established: ~s -> ~s (timeout)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData};
-stream_established(closed, StateData) ->
-    ?INFO_MSG("stream established: ~s -> ~s (closed)",
-             [StateData#state.myname, StateData#state.server]),
-    {stop, normal, StateData}.
+stream_established(Event, StateData) ->
+    handle_unexpected_event(Event, stream_established, StateData).
+
+-spec handle_unexpected_event(term(), state_name(), state()) -> fsm_transition().
+handle_unexpected_event(Event, StateName, StateData) ->
+    case Event of
+       {xmlstreamerror, _} ->
+           send_element(StateData, xmpp:serr_not_well_formed()),
+           ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
+                     "got invalid XML from peer",
+                     [StateData#state.myname, StateData#state.server,
+                      StateName]),
+           {stop, normal, StateData};
+       {xmlstreamend, _} ->
+           ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
+                     "XML stream closed by peer",
+                     [StateData#state.myname, StateData#state.server]),
+           {stop, normal, StateData};
+       timeout ->
+           send_element(StateData, xmpp:serr_connection_timeout()),
+           ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
+                     "timed out during establishing an XML stream",
+                     [StateData#state.myname, StateData#state.server,
+                      StateName]),
+           {stop, normal, StateData};
+       closed ->
+           ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
+                     "connection socket closed",
+                     [StateData#state.myname, StateData#state.server,
+                      StateName]),
+           {stop, normal, StateData};
+       Pkt when StateName == wait_for_stream;
+                StateName == wait_for_features;
+                StateName == wait_for_auth_result;
+                StateName == wait_for_starttls_proceed ->
+           send_element(StateData, xmpp:serr_bad_format()),
+           ?INFO_MSG("Closing s2s connection ~s -> ~s in state ~s: "
+                     "got unexpected event ~p",
+                     [StateData#state.myname, StateData#state.server,
+                      StateName, Pkt]),
+           {stop, normal, StateData};
+       _ ->
+           {next_state, StateName, StateData, get_timeout_interval(StateName)}
+    end.
 
 %%----------------------------------------------------------------------
 %% Func: StateName/3
@@ -917,7 +674,7 @@ handle_info({send_element, El}, StateName, StateData) ->
       %% In this state we bounce all message: We are waiting before
       %% trying to reconnect
       wait_before_retry ->
-         bounce_element(El, ?ERR_REMOTE_SERVER_NOT_FOUND),
+         bounce_element(El, xmpp:err_remote_server_not_found()),
          {next_state, StateName, StateData};
       relay_to_bridge ->
          {Mod, Fun} = StateData#state.bridge,
@@ -926,7 +683,7 @@ handle_info({send_element, El}, StateName, StateData) ->
            {'EXIT', Reason} ->
                ?ERROR_MSG("Error while relaying to bridge: ~p",
                           [Reason]),
-               bounce_element(El, ?ERR_INTERNAL_SERVER_ERROR),
+               bounce_element(El, xmpp:err_internal_server_error()),
                wait_before_reconnect(StateData);
            _ -> {next_state, StateName, StateData}
          end;
@@ -966,12 +723,13 @@ terminate(Reason, StateName, StateData) ->
                                          StateData#state.server},
                                         self())
     end,
-    bounce_queue(StateData#state.queue,
-                ?ERR_REMOTE_SERVER_NOT_FOUND),
-    bounce_messages(?ERR_REMOTE_SERVER_NOT_FOUND),
+    bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()),
+    bounce_messages(xmpp:err_remote_server_not_found()),
     case StateData#state.socket of
       undefined -> ok;
-      _Socket -> ejabberd_socket:close(StateData#state.socket)
+      _Socket ->
+           catch send_trailer(StateData),
+           ejabberd_socket:close(StateData#state.socket)
     end,
     ok.
 
@@ -981,12 +739,30 @@ print_state(State) -> State.
 %%% Internal functions
 %%%----------------------------------------------------------------------
 
+-spec send_text(state(), iodata()) -> ok.
 send_text(StateData, Text) ->
     ejabberd_socket:send(StateData#state.socket, Text).
 
+-spec send_element(state(), xmpp_element()) -> ok.
 send_element(StateData, El) ->
-    send_text(StateData, fxml:element_to_binary(El)).
-
+    El1 = fix_ns(xmpp:encode(El)),
+    send_text(StateData, fxml:element_to_binary(El1)).
+
+-spec send_header(state(), binary()) -> ok.
+send_header(StateData, Version) ->
+    Txt = io_lib:format(
+           "<?xml version='1.0'?><stream:stream "
+           "xmlns:stream='http://etherx.jabber.org/stream"
+           "s' xmlns='jabber:server' xmlns:db='jabber:ser"
+           "ver:dialback' from='~s' to='~s'~s>",
+           [StateData#state.myname, StateData#state.server, Version]),
+    send_text(StateData, Txt).
+
+-spec send_trailer(state()) -> ok.
+send_trailer(StateData) ->
+    send_text(StateData, <<"</stream:stream>">>).
+
+-spec send_queue(state(), queue:queue()) -> ok.
 send_queue(StateData, Q) ->
     case queue:out(Q) of
       {{value, El}, Q1} ->
@@ -994,21 +770,28 @@ send_queue(StateData, Q) ->
       {empty, _Q1} -> ok
     end.
 
+-spec fix_ns(xmlel()) -> xmlel().
+fix_ns(#xmlel{name = Name} = El) when Name == <<"message">>;
+                                     Name == <<"iq">>;
+                                     Name == <<"presence">>;
+                                     Name == <<"db:verify">>,
+                                     Name == <<"db:result">> ->
+    Attrs = lists:filter(
+             fun({<<"xmlns">>, _}) -> false;
+                (_) -> true
+             end, El#xmlel.attrs),
+    El#xmlel{attrs = Attrs};
+fix_ns(El) ->
+    El.
+
 %% Bounce a single message (xmlelement)
+-spec bounce_element(stanza(), error()) -> ok.
 bounce_element(El, Error) ->
-    #xmlel{attrs = Attrs} = El,
-    case fxml:get_attr_s(<<"type">>, Attrs) of
-      <<"error">> -> ok;
-      <<"result">> -> ok;
-      _ ->
-         Err = jlib:make_error_reply(El, Error),
-         From = jid:from_string(fxml:get_tag_attr_s(<<"from">>,
-                                                      El)),
-         To = jid:from_string(fxml:get_tag_attr_s(<<"to">>,
-                                                    El)),
-         ejabberd_router:route(To, From, Err)
-    end.
+    From = xmpp:get_from(El),
+    To = xmpp:get_to(El),
+    ejabberd_router:route_error(To, From, El, Error).
 
+-spec bounce_queue(queue:queue(), error()) -> ok.
 bounce_queue(Q, Error) ->
     case queue:out(Q) of
       {{value, El}, Q1} ->
@@ -1016,12 +799,15 @@ bounce_queue(Q, Error) ->
       {empty, _} -> ok
     end.
 
+-spec new_id() -> binary().
 new_id() -> randoms:get_string().
 
+-spec cancel_timer(reference()) -> ok.
 cancel_timer(Timer) ->
     erlang:cancel_timer(Timer),
     receive {timeout, Timer, _} -> ok after 0 -> ok end.
 
+-spec bounce_messages(error()) -> ok.
 bounce_messages(Error) ->
     receive
       {send_element, El} ->
@@ -1029,6 +815,7 @@ bounce_messages(Error) ->
       after 0 -> ok
     end.
 
+-spec send_db_request(state()) -> fsm_transition().
 send_db_request(StateData) ->
     Server = StateData#state.server,
     New = case StateData#state.new of
@@ -1045,22 +832,18 @@ send_db_request(StateData) ->
                       {StateData#state.myname, Server},
                       StateData#state.remote_streamid),
              send_element(StateData,
-                          #xmlel{name = <<"db:result">>,
-                                 attrs =
-                                     [{<<"from">>, StateData#state.myname},
-                                      {<<"to">>, Server}],
-                                 children = [{xmlcdata, Key1}]})
+                          #db_result{from = jid:make(StateData#state.myname),
+                                     to = jid:make(Server),
+                                     key = Key1})
        end,
        case StateData#state.verify of
          false -> ok;
          {_Pid, Key2, SID} ->
              send_element(StateData,
-                          #xmlel{name = <<"db:verify">>,
-                                 attrs =
-                                     [{<<"from">>, StateData#state.myname},
-                                      {<<"to">>, StateData#state.server},
-                                      {<<"id">>, SID}],
-                                 children = [{xmlcdata, Key2}]})
+                          #db_verify{from = jid:make(StateData#state.myname),
+                                     to = StateData#state.server,
+                                     id = SID,
+                                     key = Key2})
        end,
        {next_state, wait_for_validation, NewStateData,
         (?FSMTIMEOUT) * 6}
@@ -1068,20 +851,6 @@ send_db_request(StateData) ->
       _:_ -> {stop, normal, NewStateData}
     end.
 
-is_verify_res(#xmlel{name = Name, attrs = Attrs})
-    when Name == <<"db:result">> ->
-    {result, fxml:get_attr_s(<<"to">>, Attrs),
-     fxml:get_attr_s(<<"from">>, Attrs),
-     fxml:get_attr_s(<<"id">>, Attrs),
-     fxml:get_attr_s(<<"type">>, Attrs)};
-is_verify_res(#xmlel{name = Name, attrs = Attrs})
-    when Name == <<"db:verify">> ->
-    {verify, fxml:get_attr_s(<<"to">>, Attrs),
-     fxml:get_attr_s(<<"from">>, Attrs),
-     fxml:get_attr_s(<<"id">>, Attrs),
-     fxml:get_attr_s(<<"type">>, Attrs)};
-is_verify_res(_) -> false.
-
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 %% SRV support
 
@@ -1190,12 +959,14 @@ get_addrs(Host, Family) ->
          []
     end.
 
+-spec outgoing_s2s_port() -> pos_integer().
 outgoing_s2s_port() ->
     ejabberd_config:get_option(
       outgoing_s2s_port,
       fun(I) when is_integer(I), I > 0, I =< 65536 -> I end,
       5269).
 
+-spec outgoing_s2s_families() -> [ipv4 | ipv6].
 outgoing_s2s_families() ->
     ejabberd_config:get_option(
       outgoing_s2s_families,
@@ -1207,6 +978,7 @@ outgoing_s2s_families() ->
               Families
       end, [ipv4, ipv6]).
 
+-spec outgoing_s2s_timeout() -> pos_integer().
 outgoing_s2s_timeout() ->
     ejabberd_config:get_option(
       outgoing_s2s_timeout,
@@ -1256,21 +1028,24 @@ log_s2s_out(_, Myname, Server, Tls) ->
 
 %% Calculate timeout depending on which state we are in:
 %% Can return integer > 0 | infinity
+-spec get_timeout_interval(state_name()) -> pos_integer() | infinity.
 get_timeout_interval(StateName) ->
     case StateName of
       %% Validation implies dialback: Networking can take longer:
       wait_for_validation -> (?FSMTIMEOUT) * 6;
       %% When stream is established, we only rely on S2S Timeout timer:
       stream_established -> infinity;
+      relay_to_bridge -> infinity;
+      open_socket -> infinity;
       _ -> ?FSMTIMEOUT
     end.
 
 %% This function is intended to be called at the end of a state
 %% function that want to wait for a reconnect delay before stopping.
+-spec wait_before_reconnect(state()) -> fsm_next().
 wait_before_reconnect(StateData) ->
-    bounce_queue(StateData#state.queue,
-                ?ERR_REMOTE_SERVER_NOT_FOUND),
-    bounce_messages(?ERR_REMOTE_SERVER_NOT_FOUND),
+    bounce_queue(StateData#state.queue, xmpp:err_remote_server_not_found()),
+    bounce_messages(xmpp:err_remote_server_not_found()),
     cancel_timer(StateData#state.timer),
     Delay = case StateData#state.delay_to_retry of
              undefined_delay ->
@@ -1282,6 +1057,7 @@ wait_before_reconnect(StateData) ->
      StateData#state{timer = Timer, delay_to_retry = Delay,
                     queue = queue:new()}}.
 
+-spec get_max_retry_delay() -> pos_integer().
 get_max_retry_delay() ->
     case ejabberd_config:get_option(
            s2s_max_retry_delay,
@@ -1291,6 +1067,7 @@ get_max_retry_delay() ->
     end.
 
 %% Terminate s2s_out connections that are in state wait_before_retry
+-spec terminate_if_waiting_delay(ljid(), ljid()) -> ok.
 terminate_if_waiting_delay(From, To) ->
     FromTo = {From, To},
     Pids = ejabberd_s2s:get_connections_pids(FromTo),
@@ -1299,6 +1076,7 @@ terminate_if_waiting_delay(From, To) ->
                  end,
                  Pids).
 
+-spec fsm_limit_opts() -> [{max_queue, pos_integer()}].
 fsm_limit_opts() ->
     case ejabberd_config:get_option(
            max_fsm_queue,
@@ -1307,6 +1085,24 @@ fsm_limit_opts() ->
         N -> [{max_queue, N}]
     end.
 
+-spec decode_element(xmlel(), state_name(), state()) -> fsm_next().
+decode_element(#xmlel{} = El, StateName, StateData) ->
+    try xmpp:decode(El) of
+       Pkt -> ?MODULE:StateName(Pkt, StateData)
+    catch error:{xmpp_codec, Why} ->
+           Type = xmpp:get_type(El),
+           case xmpp:is_stanza(El) of
+               true when Type /= <<"result">>, Type /= <<"error">> ->
+                   Lang = xmpp:get_lang(El),
+                   Txt = xmpp:format_error(Why),
+                   Err = xmpp:make_error(El, xmpp:err_bad_request(Txt, Lang)),
+                   send_element(StateData, Err);
+               false ->
+                   ok
+           end,
+           {next_state, StateName, StateData, get_timeout_interval(StateName)}
+    end.
+
 opt_type(domain_certfile) -> fun iolist_to_binary/1;
 opt_type(max_fsm_queue) ->
     fun (I) when is_integer(I), I > 0 -> I end;
index 465fb587a3a680babe7bc0793213e10beb5f4d47..432253e094172dd59ed2a079aa8160b125878387 100644 (file)
@@ -46,8 +46,7 @@
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
-
--include("jlib.hrl").
+-include("xmpp.hrl").
 
 -record(state,
        {socket                    :: ejabberd_socket:socket_state(),
          access                    :: atom(),
         check_from = true         :: boolean()}).
 
-%-define(DBGFSM, true).
+-type state_name() :: wait_for_stream | wait_for_handshake | stream_established.
+-type state() :: #state{}.
+-type fsm_next() :: {next_state, state_name(), state()}.
+-type fsm_stop() :: {stop, normal, state()}.
+-type fsm_transition() :: fsm_stop() | fsm_next().
 
+%-define(DBGFSM, true).
 -ifdef(DBGFSM).
-
 -define(FSMOPTS, [{debug, [trace]}]).
-
 -else.
-
 -define(FSMOPTS, []).
-
 -endif.
 
--define(STREAM_HEADER,
-       <<"<?xml version='1.0'?><stream:stream "
-         "xmlns:stream='http://etherx.jabber.org/stream"
-         "s' xmlns='jabber:component:accept' id='~s' "
-         "from='~s'>">>).
-
--define(STREAM_TRAILER, <<"</stream:stream>">>).
-
--define(INVALID_HEADER_ERR,
-       <<"<stream:stream xmlns:stream='http://etherx.ja"
-         "bber.org/streams'><stream:error>Invalid "
-         "Stream Header</stream:error></stream:stream>">>).
-
--define(INVALID_HANDSHAKE_ERR,
-       <<"<stream:error><not-authorized xmlns='urn:ietf"
-         ":params:xml:ns:xmpp-streams'/><text "
-         "xmlns='urn:ietf:params:xml:ns:xmpp-streams' "
-         "xml:lang='en'>Invalid Handshake</text></strea"
-         "m:error></stream:stream>">>).
-
--define(INVALID_XML_ERR,
-       fxml:element_to_binary(?SERR_XML_NOT_WELL_FORMED)).
-
--define(INVALID_NS_ERR,
-       fxml:element_to_binary(?SERR_INVALID_NAMESPACE)).
-
 %%%----------------------------------------------------------------------
 %%% API
 %%%----------------------------------------------------------------------
@@ -112,14 +86,6 @@ socket_type() -> xml_stream.
 %%%----------------------------------------------------------------------
 %%% Callback functions from gen_fsm
 %%%----------------------------------------------------------------------
-
-%%----------------------------------------------------------------------
-%% Func: init/1
-%% Returns: {ok, StateName, StateData}          |
-%%          {ok, StateName, StateData, Timeout} |
-%%          ignore                              |
-%%          {stop, StopReason}
-%%----------------------------------------------------------------------
 init([{SockMod, Socket}, Opts]) ->
     ?INFO_MSG("(~w) External service connected", [Socket]),
     Access = case lists:keysearch(access, 1, Opts) of
@@ -157,177 +123,127 @@ init([{SockMod, Socket}, Opts]) ->
            streamid = new_id(), host_opts = HostOpts,
            access = Access, check_from = CheckFrom}}.
 
-%%----------------------------------------------------------------------
-%% Func: StateName/2
-%% Returns: {next_state, NextStateName, NextStateData}          |
-%%          {next_state, NextStateName, NextStateData, Timeout} |
-%%          {stop, Reason, NewStateData}
-%%----------------------------------------------------------------------
-
-wait_for_stream({xmlstreamstart, _Name, Attrs},
-               StateData) ->
-    case fxml:get_attr_s(<<"xmlns">>, Attrs) of
-      <<"jabber:component:accept">> ->
-         To = fxml:get_attr_s(<<"to">>, Attrs),
-         Host = jid:nameprep(To),
-         if Host == error ->
-                 Header = io_lib:format(?STREAM_HEADER,
-                                        [<<"none">>, ?MYNAME]),
-                 send_text(StateData,
-                           <<(list_to_binary(Header))/binary,
-                             (?INVALID_XML_ERR)/binary,
-                             (?STREAM_TRAILER)/binary>>),
-                 {stop, normal, StateData};
-            true ->
-                 Header = io_lib:format(?STREAM_HEADER,
-                                        [StateData#state.streamid, fxml:crypt(To)]),
-                 send_text(StateData, Header),
-                 HostOpts = case dict:is_key(Host, StateData#state.host_opts) of
-                                true ->
-                                    StateData#state.host_opts;
-                                false ->
-                                    case dict:find(global, StateData#state.host_opts) of
-                                        {ok, GlobalPass} ->
-                                            dict:from_list([{Host, GlobalPass}]);
-                                        error ->
-                                            StateData#state.host_opts
-                                    end
-                            end,
-                 {next_state, wait_for_handshake,
-                  StateData#state{host = Host, host_opts = HostOpts}}
-         end;
-      _ ->
-         send_text(StateData, ?INVALID_HEADER_ERR),
-         {stop, normal, StateData}
+wait_for_stream({xmlstreamstart, Name, Attrs}, StateData) ->
+    try xmpp:decode(#xmlel{name = Name, attrs = Attrs}) of
+       #stream_start{xmlns = ?NS_COMPONENT, to = To} when is_record(To, jid) ->
+           Host = To#jid.lserver,
+           send_header(StateData, To),
+           HostOpts = case dict:is_key(Host, StateData#state.host_opts) of
+                          true ->
+                              StateData#state.host_opts;
+                          false ->
+                              case dict:find(global, StateData#state.host_opts) of
+                                  {ok, GlobalPass} ->
+                                      dict:from_list([{Host, GlobalPass}]);
+                                  error ->
+                                      StateData#state.host_opts
+                              end
+                      end,
+           {next_state, wait_for_handshake,
+            StateData#state{host = Host, host_opts = HostOpts}};
+       #stream_start{xmlns = ?NS_COMPONENT} ->
+           send_header(StateData, ?MYNAME),
+           send_element(StateData, xmpp:serr_improper_addressing()),
+           {stop, normal, StateData};
+       #stream_start{} ->
+           send_header(StateData, ?MYNAME),
+           send_element(StateData, xmpp:serr_invalid_namespace()),
+           {stop, normal, StateData}
+    catch _:{xmpp_codec, Why} ->
+           Txt = xmpp:format_error(Why),
+           send_header(StateData, ?MYNAME),
+           send_element(StateData, xmpp:serr_invalid_xml(Txt, ?MYLANG)),
+           {stop, normal, StateData}
     end;
 wait_for_stream({xmlstreamerror, _}, StateData) ->
-    Header = io_lib:format(?STREAM_HEADER,
-                          [<<"none">>, ?MYNAME]),
-    send_text(StateData,
-             <<(list_to_binary(Header))/binary, (?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
+    send_header(StateData, ?MYNAME),
+    send_element(StateData, xmpp:serr_not_well_formed()),
     {stop, normal, StateData};
 wait_for_stream(closed, StateData) ->
     {stop, normal, StateData}.
 
 wait_for_handshake({xmlstreamelement, El}, StateData) ->
-    #xmlel{name = Name, children = Els} = El,
-    case {Name, fxml:get_cdata(Els)} of
-      {<<"handshake">>, Digest} ->
-         case dict:find(StateData#state.host, StateData#state.host_opts) of
-             {ok, Password} ->
-                 case p1_sha:sha(<<(StateData#state.streamid)/binary,
-                                   Password/binary>>) of
-                     Digest ->
-                         send_text(StateData, <<"<handshake/>">>),
-                         lists:foreach(
-                           fun (H) ->
-                                   ejabberd_router:register_route(H, ?MYNAME),
-                                   ?INFO_MSG("Route registered for service ~p~n",
-                                             [H])
-                           end, dict:fetch_keys(StateData#state.host_opts)),
-                         {next_state, stream_established, StateData};
-                     _ ->
-                         send_text(StateData, ?INVALID_HANDSHAKE_ERR),
-                         {stop, normal, StateData}
-                 end;
-             _ ->
-                 send_text(StateData, ?INVALID_HANDSHAKE_ERR),
-                 {stop, normal, StateData}
-         end;
-      _ -> {next_state, wait_for_handshake, StateData}
+    decode_element(El, wait_for_handshake, StateData);
+wait_for_handshake(#handshake{data = Digest}, StateData) ->
+    case dict:find(StateData#state.host, StateData#state.host_opts) of
+       {ok, Password} ->
+           case p1_sha:sha(<<(StateData#state.streamid)/binary,
+                             Password/binary>>) of
+               Digest ->
+                   send_element(StateData, #handshake{}),
+                   lists:foreach(
+                     fun (H) ->
+                             ejabberd_router:register_route(H, ?MYNAME),
+                             ?INFO_MSG("Route registered for service ~p~n",
+                                       [H])
+                     end, dict:fetch_keys(StateData#state.host_opts)),
+                   {next_state, stream_established, StateData};
+               _ ->
+                   send_element(StateData, xmpp:serr_not_authorized()),
+                   {stop, normal, StateData}
+           end;
+       _ ->
+           send_element(StateData, xmpp:serr_not_authorized()),
+           {stop, normal, StateData}
     end;
 wait_for_handshake({xmlstreamend, _Name}, StateData) ->
     {stop, normal, StateData};
 wait_for_handshake({xmlstreamerror, _}, StateData) ->
-    send_text(StateData,
-             <<(?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
+    send_element(StateData, xmpp:serr_not_well_formed()),
     {stop, normal, StateData};
 wait_for_handshake(closed, StateData) ->
-    {stop, normal, StateData}.
+    {stop, normal, StateData};
+wait_for_handshake(_Pkt, StateData) ->
+    {next_state, wait_for_handshake, StateData}.
 
 stream_established({xmlstreamelement, El}, StateData) ->
-    NewEl = jlib:remove_attr(<<"xmlns">>, El),
-    #xmlel{name = Name, attrs = Attrs} = NewEl,
-    From = fxml:get_attr_s(<<"from">>, Attrs),
-    FromJID = case StateData#state.check_from of
-               %% If the admin does not want to check the from field
-               %% when accept packets from any address.
-               %% In this case, the component can send packet of
-               %% behalf of the server users.
-               false -> jid:from_string(From);
-               %% The default is the standard behaviour in XEP-0114
-               _ ->
-                   FromJID1 = jid:from_string(From),
-                   case FromJID1 of
-                     #jid{lserver = Server} ->
-                         case dict:is_key(Server, StateData#state.host_opts) of
-                           true -> FromJID1;
-                           false -> error
-                         end;
-                     _ -> error
-                   end
-             end,
-    To = fxml:get_attr_s(<<"to">>, Attrs),
-    ToJID = case To of
-             <<"">> -> error;
-             _ -> jid:from_string(To)
-           end,
-    if ((Name == <<"iq">>) or (Name == <<"message">>) or
-         (Name == <<"presence">>))
-        and (ToJID /= error)
-        and (FromJID /= error) ->
-          ejabberd_router:route(FromJID, ToJID, NewEl);
+    decode_element(El, stream_established, StateData);
+stream_established(El, StateData) when ?is_stanza(El) ->
+    From = xmpp:get_from(El),
+    To = xmpp:get_to(El),
+    Lang = xmpp:get_lang(El),
+    if From == undefined orelse To == undefined ->
+           send_error(StateData, El, xmpp:err_jid_malformed());
        true ->
-          Lang = fxml:get_tag_attr_s(<<"xml:lang">>, El),
-          Txt = <<"Incorrect stanza name or from/to JID">>,
-          Err = jlib:make_error_reply(NewEl, ?ERRT_BAD_REQUEST(Lang, Txt)),
-          send_element(StateData, Err),
-          error
+           FromJID = case StateData#state.check_from of
+                         false ->
+                             %% If the admin does not want to check the from field
+                             %% when accept packets from any address.
+                             %% In this case, the component can send packet of
+                             %% behalf of the server users.
+                             From;
+                         _ ->
+                             %% The default is the standard behaviour in XEP-0114
+                             case From of
+                                 #jid{lserver = Server} ->
+                                     case dict:is_key(Server, StateData#state.host_opts) of
+                                         true -> From;
+                                         false -> error
+                                     end;
+                                 _ -> error
+                             end
+                     end,
+           if FromJID /= error ->
+                   ejabberd_router:route(FromJID, To, El);
+              true ->
+                   Txt = <<"Incorrect value of 'from' or 'to' attribute">>,
+                   send_error(StateData, El, xmpp:err_not_allowed(Txt, Lang))
+           end
     end,
     {next_state, stream_established, StateData};
 stream_established({xmlstreamend, _Name}, StateData) ->
     {stop, normal, StateData};
 stream_established({xmlstreamerror, _}, StateData) ->
-    send_text(StateData,
-             <<(?INVALID_XML_ERR)/binary,
-               (?STREAM_TRAILER)/binary>>),
+    send_element(StateData, xmpp:serr_not_well_formed()),
     {stop, normal, StateData};
 stream_established(closed, StateData) ->
-    {stop, normal, StateData}.
-
-%%----------------------------------------------------------------------
-%% Func: StateName/3
-%% Returns: {next_state, NextStateName, NextStateData}            |
-%%          {next_state, NextStateName, NextStateData, Timeout}   |
-%%          {reply, Reply, NextStateName, NextStateData}          |
-%%          {reply, Reply, NextStateName, NextStateData, Timeout} |
-%%          {stop, Reason, NewStateData}                          |
-%%          {stop, Reason, Reply, NewStateData}
-%%----------------------------------------------------------------------
-%state_name(Event, From, StateData) ->
-%    Reply = ok,
-%    {reply, Reply, state_name, StateData}.
+    {stop, normal, StateData};
+stream_established(_Event, StateData) ->
+    {next_state, stream_established, StateData}.
 
-%%----------------------------------------------------------------------
-%% Func: handle_event/3
-%% Returns: {next_state, NextStateName, NextStateData}          |
-%%          {next_state, NextStateName, NextStateData, Timeout} |
-%%          {stop, Reason, NewStateData}
-%%----------------------------------------------------------------------
 handle_event(_Event, StateName, StateData) ->
     {next_state, StateName, StateData}.
 
-%%----------------------------------------------------------------------
-%% Func: handle_sync_event/4
-%% Returns: {next_state, NextStateName, NextStateData}            |
-%%          {next_state, NextStateName, NextStateData, Timeout}   |
-%%          {reply, Reply, NextStateName, NextStateData}          |
-%%          {reply, Reply, NextStateName, NextStateData, Timeout} |
-%%          {stop, Reason, NewStateData}                          |
-%%          {stop, Reason, Reply, NewStateData}
-%%----------------------------------------------------------------------
 handle_sync_event(_Event, _From, StateName,
                  StateData) ->
     Reply = ok, {reply, Reply, StateName, StateData}.
@@ -335,12 +251,6 @@ handle_sync_event(_Event, _From, StateName,
 code_change(_OldVsn, StateName, StateData, _Extra) ->
     {ok, StateName, StateData}.
 
-%%----------------------------------------------------------------------
-%% Func: handle_info/3
-%% Returns: {next_state, NextStateName, NextStateData}          |
-%%          {next_state, NextStateName, NextStateData, Timeout} |
-%%          {stop, Reason, NewStateData}
-%%----------------------------------------------------------------------
 handle_info({send_text, Text}, StateName, StateData) ->
     send_text(StateData, Text),
     {next_state, StateName, StateData};
@@ -349,34 +259,20 @@ handle_info({send_element, El}, StateName, StateData) ->
     {next_state, StateName, StateData};
 handle_info({route, From, To, Packet}, StateName,
            StateData) ->
-    case acl:match_rule(global, StateData#state.access,
-                       From)
-       of
+    case acl:match_rule(global, StateData#state.access, From) of
       allow ->
-         #xmlel{name = Name, attrs = Attrs, children = Els} =
-             Packet,
-         Attrs2 =
-             jlib:replace_from_to_attrs(jid:to_string(From),
-                                        jid:to_string(To), Attrs),
-         Text = fxml:element_to_binary(#xmlel{name = Name,
-                                             attrs = Attrs2, children = Els}),
-         send_text(StateData, Text);
-      deny ->
-         Lang = fxml:get_tag_attr_s(<<"xml:lang">>, Packet),
-         Txt = <<"Denied by ACL">>,
-         Err = jlib:make_error_reply(Packet, ?ERRT_NOT_ALLOWED(Lang, Txt)),
-         ejabberd_router:route_error(To, From, Err, Packet)
+           Pkt = xmpp:set_from_to(Packet, From, To),
+           send_element(StateData, Pkt);
+       deny ->
+           Lang = xmpp:get_lang(Packet),
+           Err = xmpp:err_not_allowed(<<"Denied by ACL">>, Lang),
+           ejabberd_router:route_error(To, From, Packet, Err)
     end,
     {next_state, StateName, StateData};
 handle_info(Info, StateName, StateData) ->
     ?ERROR_MSG("Unexpected info: ~p", [Info]),
     {next_state, StateName, StateData}.
 
-%%----------------------------------------------------------------------
-%% Func: terminate/3
-%% Purpose: Shutdown the fsm
-%% Returns: any
-%%----------------------------------------------------------------------
 terminate(Reason, StateName, StateData) ->
     ?INFO_MSG("terminated: ~p", [Reason]),
     case StateName of
@@ -387,6 +283,7 @@ terminate(Reason, StateName, StateData) ->
                        dict:fetch_keys(StateData#state.host_opts));
       _ -> ok
     end,
+    catch send_trailer(StateData),
     (StateData#state.sockmod):close(StateData#state.socket),
     ok.
 
@@ -401,13 +298,69 @@ print_state(State) -> State.
 %%% Internal functions
 %%%----------------------------------------------------------------------
 
+-spec send_text(state(), iodata()) -> ok.
 send_text(StateData, Text) ->
     (StateData#state.sockmod):send(StateData#state.socket,
                                   Text).
 
+-spec send_element(state(), xmpp_element()) -> ok.
 send_element(StateData, El) ->
-    send_text(StateData, fxml:element_to_binary(El)).
+    El1 = fix_ns(xmpp:encode(El)),
+    send_text(StateData, fxml:element_to_binary(El1)).
+
+-spec send_error(state(), xmlel() | stanza(), error()) -> ok.
+send_error(StateData, Stanza, Error) ->
+    Type = xmpp:get_type(Stanza),
+    if Type == error; Type == result;
+       Type == <<"error">>; Type == <<"result">> ->
+           ok;
+       true ->
+           send_element(StateData, xmpp:make_error(Stanza, Error))
+    end.
+
+-spec send_header(state(), binary()) -> ok.
+send_header(StateData, Host) ->
+    send_text(StateData,
+             io_lib:format(
+               <<"<?xml version='1.0'?><stream:stream "
+                 "xmlns:stream='http://etherx.jabber.org/stream"
+                 "s' xmlns='jabber:component:accept' id='~s' "
+                 "from='~s'>">>,
+               [StateData#state.streamid, fxml:crypt(Host)])).
+
+-spec send_trailer(state()) -> ok.
+send_trailer(StateData) ->
+    send_text(StateData, <<"</stream:stream>">>).
+
+-spec fix_ns(xmlel()) -> xmlel().
+fix_ns(#xmlel{name = Name} = El) when Name == <<"message">>;
+                                      Name == <<"iq">>;
+                                      Name == <<"presence">> ->
+    Attrs = lists:filter(
+              fun({<<"xmlns">>, _}) -> false;
+                 (_) -> true
+              end, El#xmlel.attrs),
+    El#xmlel{attrs = Attrs};
+fix_ns(El) ->
+    El.
+
+-spec decode_element(xmlel(), state_name(), state()) -> fsm_transition().
+decode_element(#xmlel{} = El, StateName, StateData) ->
+    try xmpp:decode(El, [ignore_els]) of
+       Pkt -> ?MODULE:StateName(Pkt, StateData)
+    catch error:{xmpp_codec, Why} ->
+            case xmpp:is_stanza(El) of
+                true ->
+                    Lang = xmpp:get_lang(El),
+                    Txt = xmpp:format_error(Why),
+                    send_error(StateData, El, xmpp:err_bad_request(Txt, Lang));
+                false ->
+                    ok
+            end,
+            {next_state, StateName, StateData}
+    end.
 
+-spec new_id() -> binary().
 new_id() -> randoms:get_string().
 
 transform_listen_option({hosts, Hosts, O}, Opts) ->
index 369fb90c5daf902a908b6058a2f258c7c6518886..5b7e3d1ccb5461727414501fc0774b0e7246bf98 100644 (file)
          serr_unsupported_stanza_type/0, serr_unsupported_stanza_type/2,
          serr_unsupported_version/0, serr_unsupported_version/2]).
 
--ifndef(NS_CLIENT).
--define(NS_CLIENT, <<"jabber:client">>).
--endif.
-
 -include("xmpp.hrl").
 
 %%%===================================================================
@@ -246,9 +242,14 @@ get_name(Pkt) ->
 decode(El) ->
     decode(El, []).
 
--spec decode(xmlel() | xmpp_element(), [proplists:property()]) ->
+-spec decode(xmlel() | xmpp_element(),
+            [proplists:property()] |
+            fun((xmlel() | xmpp_element()) -> boolean())) ->
                    {ok, xmpp_element()} | {error, any()}.
-decode(#xmlel{} = El, Opts) ->
+decode(#xmlel{} = El, MatchFun) when is_function(MatchFun) ->
+    Pkt = xmpp_codec:decode(add_ns(El), [ignore_els]),
+    decode_els(Pkt, MatchFun);
+decode(#xmlel{} = El, Opts) when is_list(Opts) ->
     xmpp_codec:decode(add_ns(El), Opts);
 decode(Pkt, _Opts) ->
     Pkt.
index c8a4f002f17b8ea8e58536ca309e31606dd62c79..7eb06b11e66991c9e204c267da5be08e1ec498f1 100644 (file)
@@ -15,6 +15,21 @@ decode(_el) -> decode(_el, []).
 decode({xmlel, _name, _attrs, _} = _el, Opts) ->
     IgnoreEls = proplists:get_bool(ignore_els, Opts),
     case {_name, get_attr(<<"xmlns">>, _attrs)} of
+      {<<"stream:stream">>, <<"jabber:client">>} ->
+         decode_stream_start(<<"jabber:client">>, IgnoreEls,
+                             _el);
+      {<<"stream:stream">>, <<"jabber:server">>} ->
+         decode_stream_start(<<"jabber:server">>, IgnoreEls,
+                             _el);
+      {<<"stream:stream">>, <<"jabber:component:accept">>} ->
+         decode_stream_start(<<"jabber:component:accept">>,
+                             IgnoreEls, _el);
+      {<<"handshake">>, <<"jabber:client">>} ->
+         decode_handshake(<<"jabber:client">>, IgnoreEls, _el);
+      {<<"db:verify">>, <<"jabber:client">>} ->
+         decode_db_verify(<<"jabber:client">>, IgnoreEls, _el);
+      {<<"db:result">>, <<"jabber:client">>} ->
+         decode_db_result(<<"jabber:client">>, IgnoreEls, _el);
       {<<"command">>,
        <<"http://jabber.org/protocol/commands">>} ->
          decode_adhoc_command(<<"http://jabber.org/protocol/commands">>,
@@ -1278,6 +1293,13 @@ decode({xmlel, _name, _attrs, _} = _el, Opts) ->
 
 is_known_tag({xmlel, _name, _attrs, _} = _el) ->
     case {_name, get_attr(<<"xmlns">>, _attrs)} of
+      {<<"stream:stream">>, <<"jabber:client">>} -> true;
+      {<<"stream:stream">>, <<"jabber:server">>} -> true;
+      {<<"stream:stream">>, <<"jabber:component:accept">>} ->
+         true;
+      {<<"handshake">>, <<"jabber:client">>} -> true;
+      {<<"db:verify">>, <<"jabber:client">>} -> true;
+      {<<"db:result">>, <<"jabber:client">>} -> true;
       {<<"command">>,
        <<"http://jabber.org/protocol/commands">>} ->
          true;
@@ -2538,7 +2560,19 @@ encode({adhoc_command, _, _, _, _, _, _, _, _} =
           Command) ->
     encode_adhoc_command(Command,
                         [{<<"xmlns">>,
-                          <<"http://jabber.org/protocol/commands">>}]).
+                          <<"http://jabber.org/protocol/commands">>}]);
+encode({db_result, _, _, _, _, _} = Db_result) ->
+    encode_db_result(Db_result,
+                    [{<<"xmlns">>, <<"jabber:client">>}]);
+encode({db_verify, _, _, _, _, _, _} = Db_verify) ->
+    encode_db_verify(Db_verify,
+                    [{<<"xmlns">>, <<"jabber:client">>}]);
+encode({handshake, _} = Handshake) ->
+    encode_handshake(Handshake,
+                    [{<<"xmlns">>, <<"jabber:client">>}]);
+encode({stream_start, _, _, _, _, _, _, _, _} =
+          Stream_stream) ->
+    encode_stream_start(Stream_stream, []).
 
 get_name({last, _, _}) -> <<"query">>;
 get_name({version, _, _, _}) -> <<"query">>;
@@ -2720,7 +2754,13 @@ get_name({client_id, _}) -> <<"client-id">>;
 get_name({adhoc_actions, _, _, _, _}) -> <<"actions">>;
 get_name({adhoc_note, _, _}) -> <<"note">>;
 get_name({adhoc_command, _, _, _, _, _, _, _, _}) ->
-    <<"command">>.
+    <<"command">>;
+get_name({db_result, _, _, _, _, _}) -> <<"db:result">>;
+get_name({db_verify, _, _, _, _, _, _}) ->
+    <<"db:verify">>;
+get_name({handshake, _}) -> <<"handshake">>;
+get_name({stream_start, _, _, _, _, _, _, _, _}) ->
+    <<"stream:stream">>.
 
 get_ns({last, _, _}) -> <<"jabber:iq:last">>;
 get_ns({version, _, _, _}) -> <<"jabber:iq:version">>;
@@ -2974,7 +3014,14 @@ get_ns({adhoc_actions, _, _, _, _}) ->
 get_ns({adhoc_note, _, _}) ->
     <<"http://jabber.org/protocol/commands">>;
 get_ns({adhoc_command, _, _, _, _, _, _, _, _}) ->
-    <<"http://jabber.org/protocol/commands">>.
+    <<"http://jabber.org/protocol/commands">>;
+get_ns({db_result, _, _, _, _, _}) ->
+    <<"jabber:client">>;
+get_ns({db_verify, _, _, _, _, _, _}) ->
+    <<"jabber:client">>;
+get_ns({handshake, _}) -> <<"jabber:client">>;
+get_ns({stream_start, _, _, _, _, Xmlns, _, _, _}) ->
+    Xmlns.
 
 dec_int(Val) -> dec_int(Val, infinity, infinity).
 
@@ -3210,6 +3257,12 @@ pp(adhoc_note, 2) -> [type, data];
 pp(adhoc_command, 8) ->
     [node, action, sid, status, lang, actions, notes,
      xdata];
+pp(db_result, 5) -> [from, to, type, key, error];
+pp(db_verify, 6) -> [from, to, id, type, key, error];
+pp(handshake, 1) -> [data];
+pp(stream_start, 8) ->
+    [from, to, id, version, xmlns, stream_xmlns, db_xmlns,
+     lang];
 pp(_, _) -> no.
 
 join([], _Sep) -> <<>>;
@@ -3256,6 +3309,478 @@ dec_tzo(Val) ->
     M = jlib:binary_to_integer(M1),
     if H >= -12, H =< 12, M >= 0, M < 60 -> {H, M} end.
 
+decode_stream_start(__TopXMLNS, __IgnoreEls,
+                   {xmlel, <<"stream:stream">>, _attrs, _els}) ->
+    {From, To, Xmlns, Stream_xmlns, Db_xmlns, Lang, Version,
+     Id} =
+       decode_stream_start_attrs(__TopXMLNS, _attrs, undefined,
+                                 undefined, undefined, undefined, undefined,
+                                 undefined, undefined, undefined),
+    {stream_start, From, To, Id, Version, Xmlns,
+     Stream_xmlns, Db_xmlns, Lang}.
+
+decode_stream_start_attrs(__TopXMLNS,
+                         [{<<"from">>, _val} | _attrs], _From, To, Xmlns,
+                         Stream_xmlns, Db_xmlns, Lang, Version, Id) ->
+    decode_stream_start_attrs(__TopXMLNS, _attrs, _val, To,
+                             Xmlns, Stream_xmlns, Db_xmlns, Lang, Version, Id);
+decode_stream_start_attrs(__TopXMLNS,
+                         [{<<"to">>, _val} | _attrs], From, _To, Xmlns,
+                         Stream_xmlns, Db_xmlns, Lang, Version, Id) ->
+    decode_stream_start_attrs(__TopXMLNS, _attrs, From,
+                             _val, Xmlns, Stream_xmlns, Db_xmlns, Lang,
+                             Version, Id);
+decode_stream_start_attrs(__TopXMLNS,
+                         [{<<"xmlns">>, _val} | _attrs], From, To, _Xmlns,
+                         Stream_xmlns, Db_xmlns, Lang, Version, Id) ->
+    decode_stream_start_attrs(__TopXMLNS, _attrs, From, To,
+                             _val, Stream_xmlns, Db_xmlns, Lang, Version, Id);
+decode_stream_start_attrs(__TopXMLNS,
+                         [{<<"xmlns:stream">>, _val} | _attrs], From, To,
+                         Xmlns, _Stream_xmlns, Db_xmlns, Lang, Version, Id) ->
+    decode_stream_start_attrs(__TopXMLNS, _attrs, From, To,
+                             Xmlns, _val, Db_xmlns, Lang, Version, Id);
+decode_stream_start_attrs(__TopXMLNS,
+                         [{<<"xmlns:db">>, _val} | _attrs], From, To, Xmlns,
+                         Stream_xmlns, _Db_xmlns, Lang, Version, Id) ->
+    decode_stream_start_attrs(__TopXMLNS, _attrs, From, To,
+                             Xmlns, Stream_xmlns, _val, Lang, Version, Id);
+decode_stream_start_attrs(__TopXMLNS,
+                         [{<<"xml:lang">>, _val} | _attrs], From, To, Xmlns,
+                         Stream_xmlns, Db_xmlns, _Lang, Version, Id) ->
+    decode_stream_start_attrs(__TopXMLNS, _attrs, From, To,
+                             Xmlns, Stream_xmlns, Db_xmlns, _val, Version, Id);
+decode_stream_start_attrs(__TopXMLNS,
+                         [{<<"version">>, _val} | _attrs], From, To, Xmlns,
+                         Stream_xmlns, Db_xmlns, Lang, _Version, Id) ->
+    decode_stream_start_attrs(__TopXMLNS, _attrs, From, To,
+                             Xmlns, Stream_xmlns, Db_xmlns, Lang, _val, Id);
+decode_stream_start_attrs(__TopXMLNS,
+                         [{<<"id">>, _val} | _attrs], From, To, Xmlns,
+                         Stream_xmlns, Db_xmlns, Lang, Version, _Id) ->
+    decode_stream_start_attrs(__TopXMLNS, _attrs, From, To,
+                             Xmlns, Stream_xmlns, Db_xmlns, Lang, Version,
+                             _val);
+decode_stream_start_attrs(__TopXMLNS, [_ | _attrs],
+                         From, To, Xmlns, Stream_xmlns, Db_xmlns, Lang,
+                         Version, Id) ->
+    decode_stream_start_attrs(__TopXMLNS, _attrs, From, To,
+                             Xmlns, Stream_xmlns, Db_xmlns, Lang, Version, Id);
+decode_stream_start_attrs(__TopXMLNS, [], From, To,
+                         Xmlns, Stream_xmlns, Db_xmlns, Lang, Version, Id) ->
+    {decode_stream_start_attr_from(__TopXMLNS, From),
+     decode_stream_start_attr_to(__TopXMLNS, To),
+     decode_stream_start_attr_xmlns(__TopXMLNS, Xmlns),
+     'decode_stream_start_attr_xmlns:stream'(__TopXMLNS,
+                                            Stream_xmlns),
+     'decode_stream_start_attr_xmlns:db'(__TopXMLNS,
+                                        Db_xmlns),
+     'decode_stream_start_attr_xml:lang'(__TopXMLNS, Lang),
+     decode_stream_start_attr_version(__TopXMLNS, Version),
+     decode_stream_start_attr_id(__TopXMLNS, Id)}.
+
+encode_stream_start({stream_start, From, To, Id,
+                    Version, Xmlns, Stream_xmlns, Db_xmlns, Lang},
+                   _xmlns_attrs) ->
+    _els = [],
+    _attrs = encode_stream_start_attr_id(Id,
+                                        encode_stream_start_attr_version(Version,
+                                                                         'encode_stream_start_attr_xml:lang'(Lang,
+                                                                                                             'encode_stream_start_attr_xmlns:db'(Db_xmlns,
+                                                                                                                                                 'encode_stream_start_attr_xmlns:stream'(Stream_xmlns,
+                                                                                                                                                                                         encode_stream_start_attr_xmlns(Xmlns,
+                                                                                                                                                                                                                        encode_stream_start_attr_to(To,
+                                                                                                                                                                                                                                                    encode_stream_start_attr_from(From,
+                                                                                                                                                                                                                                                                                  _xmlns_attrs)))))))),
+    {xmlel, <<"stream:stream">>, _attrs, _els}.
+
+decode_stream_start_attr_from(__TopXMLNS, undefined) ->
+    undefined;
+decode_stream_start_attr_from(__TopXMLNS, _val) ->
+    case catch dec_jid(_val) of
+      {'EXIT', _} ->
+         erlang:error({xmpp_codec,
+                       {bad_attr_value, <<"from">>, <<"stream:stream">>,
+                        __TopXMLNS}});
+      _res -> _res
+    end.
+
+encode_stream_start_attr_from(undefined, _acc) -> _acc;
+encode_stream_start_attr_from(_val, _acc) ->
+    [{<<"from">>, enc_jid(_val)} | _acc].
+
+decode_stream_start_attr_to(__TopXMLNS, undefined) ->
+    undefined;
+decode_stream_start_attr_to(__TopXMLNS, _val) ->
+    case catch dec_jid(_val) of
+      {'EXIT', _} ->
+         erlang:error({xmpp_codec,
+                       {bad_attr_value, <<"to">>, <<"stream:stream">>,
+                        __TopXMLNS}});
+      _res -> _res
+    end.
+
+encode_stream_start_attr_to(undefined, _acc) -> _acc;
+encode_stream_start_attr_to(_val, _acc) ->
+    [{<<"to">>, enc_jid(_val)} | _acc].
+
+decode_stream_start_attr_xmlns(__TopXMLNS, undefined) ->
+    undefined;
+decode_stream_start_attr_xmlns(__TopXMLNS, _val) ->
+    _val.
+
+encode_stream_start_attr_xmlns(undefined, _acc) -> _acc;
+encode_stream_start_attr_xmlns(_val, _acc) ->
+    [{<<"xmlns">>, _val} | _acc].
+
+'decode_stream_start_attr_xmlns:stream'(__TopXMLNS,
+                                       undefined) ->
+    <<>>;
+'decode_stream_start_attr_xmlns:stream'(__TopXMLNS,
+                                       _val) ->
+    _val.
+
+'encode_stream_start_attr_xmlns:stream'(<<>>, _acc) ->
+    _acc;
+'encode_stream_start_attr_xmlns:stream'(_val, _acc) ->
+    [{<<"xmlns:stream">>, _val} | _acc].
+
+'decode_stream_start_attr_xmlns:db'(__TopXMLNS,
+                                   undefined) ->
+    <<>>;
+'decode_stream_start_attr_xmlns:db'(__TopXMLNS, _val) ->
+    _val.
+
+'encode_stream_start_attr_xmlns:db'(<<>>, _acc) -> _acc;
+'encode_stream_start_attr_xmlns:db'(_val, _acc) ->
+    [{<<"xmlns:db">>, _val} | _acc].
+
+'decode_stream_start_attr_xml:lang'(__TopXMLNS,
+                                   undefined) ->
+    <<>>;
+'decode_stream_start_attr_xml:lang'(__TopXMLNS, _val) ->
+    _val.
+
+'encode_stream_start_attr_xml:lang'(<<>>, _acc) -> _acc;
+'encode_stream_start_attr_xml:lang'(_val, _acc) ->
+    [{<<"xml:lang">>, _val} | _acc].
+
+decode_stream_start_attr_version(__TopXMLNS,
+                                undefined) ->
+    <<>>;
+decode_stream_start_attr_version(__TopXMLNS, _val) ->
+    _val.
+
+encode_stream_start_attr_version(<<>>, _acc) -> _acc;
+encode_stream_start_attr_version(_val, _acc) ->
+    [{<<"version">>, _val} | _acc].
+
+decode_stream_start_attr_id(__TopXMLNS, undefined) ->
+    <<>>;
+decode_stream_start_attr_id(__TopXMLNS, _val) -> _val.
+
+encode_stream_start_attr_id(<<>>, _acc) -> _acc;
+encode_stream_start_attr_id(_val, _acc) ->
+    [{<<"id">>, _val} | _acc].
+
+decode_handshake(__TopXMLNS, __IgnoreEls,
+                {xmlel, <<"handshake">>, _attrs, _els}) ->
+    Data = decode_handshake_els(__TopXMLNS, __IgnoreEls,
+                               _els, <<>>),
+    {handshake, Data}.
+
+decode_handshake_els(__TopXMLNS, __IgnoreEls, [],
+                    Data) ->
+    decode_handshake_cdata(__TopXMLNS, Data);
+decode_handshake_els(__TopXMLNS, __IgnoreEls,
+                    [{xmlcdata, _data} | _els], Data) ->
+    decode_handshake_els(__TopXMLNS, __IgnoreEls, _els,
+                        <<Data/binary, _data/binary>>);
+decode_handshake_els(__TopXMLNS, __IgnoreEls,
+                    [_ | _els], Data) ->
+    decode_handshake_els(__TopXMLNS, __IgnoreEls, _els,
+                        Data).
+
+encode_handshake({handshake, Data}, _xmlns_attrs) ->
+    _els = encode_handshake_cdata(Data, []),
+    _attrs = _xmlns_attrs,
+    {xmlel, <<"handshake">>, _attrs, _els}.
+
+decode_handshake_cdata(__TopXMLNS, <<>>) -> <<>>;
+decode_handshake_cdata(__TopXMLNS, _val) -> _val.
+
+encode_handshake_cdata(<<>>, _acc) -> _acc;
+encode_handshake_cdata(_val, _acc) ->
+    [{xmlcdata, _val} | _acc].
+
+decode_db_verify(__TopXMLNS, __IgnoreEls,
+                {xmlel, <<"db:verify">>, _attrs, _els}) ->
+    {Key, Error} = decode_db_verify_els(__TopXMLNS,
+                                       __IgnoreEls, _els, <<>>, undefined),
+    {From, To, Id, Type} =
+       decode_db_verify_attrs(__TopXMLNS, _attrs, undefined,
+                              undefined, undefined, undefined),
+    {db_verify, From, To, Id, Type, Key, Error}.
+
+decode_db_verify_els(__TopXMLNS, __IgnoreEls, [], Key,
+                    Error) ->
+    {decode_db_verify_cdata(__TopXMLNS, Key), Error};
+decode_db_verify_els(__TopXMLNS, __IgnoreEls,
+                    [{xmlcdata, _data} | _els], Key, Error) ->
+    decode_db_verify_els(__TopXMLNS, __IgnoreEls, _els,
+                        <<Key/binary, _data/binary>>, Error);
+decode_db_verify_els(__TopXMLNS, __IgnoreEls,
+                    [{xmlel, <<"error">>, _attrs, _} = _el | _els], Key,
+                    Error) ->
+    case get_attr(<<"xmlns">>, _attrs) of
+      <<"">> when __TopXMLNS == <<"jabber:client">> ->
+         decode_db_verify_els(__TopXMLNS, __IgnoreEls, _els, Key,
+                              decode_error(__TopXMLNS, __IgnoreEls, _el));
+      <<"jabber:client">> ->
+         decode_db_verify_els(__TopXMLNS, __IgnoreEls, _els, Key,
+                              decode_error(<<"jabber:client">>, __IgnoreEls,
+                                           _el));
+      _ ->
+         decode_db_verify_els(__TopXMLNS, __IgnoreEls, _els, Key,
+                              Error)
+    end;
+decode_db_verify_els(__TopXMLNS, __IgnoreEls,
+                    [_ | _els], Key, Error) ->
+    decode_db_verify_els(__TopXMLNS, __IgnoreEls, _els, Key,
+                        Error).
+
+decode_db_verify_attrs(__TopXMLNS,
+                      [{<<"from">>, _val} | _attrs], _From, To, Id, Type) ->
+    decode_db_verify_attrs(__TopXMLNS, _attrs, _val, To, Id,
+                          Type);
+decode_db_verify_attrs(__TopXMLNS,
+                      [{<<"to">>, _val} | _attrs], From, _To, Id, Type) ->
+    decode_db_verify_attrs(__TopXMLNS, _attrs, From, _val,
+                          Id, Type);
+decode_db_verify_attrs(__TopXMLNS,
+                      [{<<"id">>, _val} | _attrs], From, To, _Id, Type) ->
+    decode_db_verify_attrs(__TopXMLNS, _attrs, From, To,
+                          _val, Type);
+decode_db_verify_attrs(__TopXMLNS,
+                      [{<<"type">>, _val} | _attrs], From, To, Id, _Type) ->
+    decode_db_verify_attrs(__TopXMLNS, _attrs, From, To, Id,
+                          _val);
+decode_db_verify_attrs(__TopXMLNS, [_ | _attrs], From,
+                      To, Id, Type) ->
+    decode_db_verify_attrs(__TopXMLNS, _attrs, From, To, Id,
+                          Type);
+decode_db_verify_attrs(__TopXMLNS, [], From, To, Id,
+                      Type) ->
+    {decode_db_verify_attr_from(__TopXMLNS, From),
+     decode_db_verify_attr_to(__TopXMLNS, To),
+     decode_db_verify_attr_id(__TopXMLNS, Id),
+     decode_db_verify_attr_type(__TopXMLNS, Type)}.
+
+encode_db_verify({db_verify, From, To, Id, Type, Key,
+                 Error},
+                _xmlns_attrs) ->
+    _els = lists:reverse(encode_db_verify_cdata(Key,
+                                               'encode_db_verify_$error'(Error,
+                                                                         []))),
+    _attrs = encode_db_verify_attr_type(Type,
+                                       encode_db_verify_attr_id(Id,
+                                                                encode_db_verify_attr_to(To,
+                                                                                         encode_db_verify_attr_from(From,
+                                                                                                                    _xmlns_attrs)))),
+    {xmlel, <<"db:verify">>, _attrs, _els}.
+
+'encode_db_verify_$error'(undefined, _acc) -> _acc;
+'encode_db_verify_$error'(Error, _acc) ->
+    [encode_error(Error, []) | _acc].
+
+decode_db_verify_attr_from(__TopXMLNS, undefined) ->
+    erlang:error({xmpp_codec,
+                 {missing_attr, <<"from">>, <<"db:verify">>,
+                  __TopXMLNS}});
+decode_db_verify_attr_from(__TopXMLNS, _val) ->
+    case catch dec_jid(_val) of
+      {'EXIT', _} ->
+         erlang:error({xmpp_codec,
+                       {bad_attr_value, <<"from">>, <<"db:verify">>,
+                        __TopXMLNS}});
+      _res -> _res
+    end.
+
+encode_db_verify_attr_from(_val, _acc) ->
+    [{<<"from">>, enc_jid(_val)} | _acc].
+
+decode_db_verify_attr_to(__TopXMLNS, undefined) ->
+    erlang:error({xmpp_codec,
+                 {missing_attr, <<"to">>, <<"db:verify">>, __TopXMLNS}});
+decode_db_verify_attr_to(__TopXMLNS, _val) ->
+    case catch dec_jid(_val) of
+      {'EXIT', _} ->
+         erlang:error({xmpp_codec,
+                       {bad_attr_value, <<"to">>, <<"db:verify">>,
+                        __TopXMLNS}});
+      _res -> _res
+    end.
+
+encode_db_verify_attr_to(_val, _acc) ->
+    [{<<"to">>, enc_jid(_val)} | _acc].
+
+decode_db_verify_attr_id(__TopXMLNS, undefined) ->
+    erlang:error({xmpp_codec,
+                 {missing_attr, <<"id">>, <<"db:verify">>, __TopXMLNS}});
+decode_db_verify_attr_id(__TopXMLNS, _val) -> _val.
+
+encode_db_verify_attr_id(_val, _acc) ->
+    [{<<"id">>, _val} | _acc].
+
+decode_db_verify_attr_type(__TopXMLNS, undefined) ->
+    undefined;
+decode_db_verify_attr_type(__TopXMLNS, _val) ->
+    case catch dec_enum(_val, [valid, invalid, error]) of
+      {'EXIT', _} ->
+         erlang:error({xmpp_codec,
+                       {bad_attr_value, <<"type">>, <<"db:verify">>,
+                        __TopXMLNS}});
+      _res -> _res
+    end.
+
+encode_db_verify_attr_type(undefined, _acc) -> _acc;
+encode_db_verify_attr_type(_val, _acc) ->
+    [{<<"type">>, enc_enum(_val)} | _acc].
+
+decode_db_verify_cdata(__TopXMLNS, <<>>) -> <<>>;
+decode_db_verify_cdata(__TopXMLNS, _val) -> _val.
+
+encode_db_verify_cdata(<<>>, _acc) -> _acc;
+encode_db_verify_cdata(_val, _acc) ->
+    [{xmlcdata, _val} | _acc].
+
+decode_db_result(__TopXMLNS, __IgnoreEls,
+                {xmlel, <<"db:result">>, _attrs, _els}) ->
+    {Key, Error} = decode_db_result_els(__TopXMLNS,
+                                       __IgnoreEls, _els, <<>>, undefined),
+    {From, To, Type} = decode_db_result_attrs(__TopXMLNS,
+                                             _attrs, undefined, undefined,
+                                             undefined),
+    {db_result, From, To, Type, Key, Error}.
+
+decode_db_result_els(__TopXMLNS, __IgnoreEls, [], Key,
+                    Error) ->
+    {decode_db_result_cdata(__TopXMLNS, Key), Error};
+decode_db_result_els(__TopXMLNS, __IgnoreEls,
+                    [{xmlcdata, _data} | _els], Key, Error) ->
+    decode_db_result_els(__TopXMLNS, __IgnoreEls, _els,
+                        <<Key/binary, _data/binary>>, Error);
+decode_db_result_els(__TopXMLNS, __IgnoreEls,
+                    [{xmlel, <<"error">>, _attrs, _} = _el | _els], Key,
+                    Error) ->
+    case get_attr(<<"xmlns">>, _attrs) of
+      <<"">> when __TopXMLNS == <<"jabber:client">> ->
+         decode_db_result_els(__TopXMLNS, __IgnoreEls, _els, Key,
+                              decode_error(__TopXMLNS, __IgnoreEls, _el));
+      <<"jabber:client">> ->
+         decode_db_result_els(__TopXMLNS, __IgnoreEls, _els, Key,
+                              decode_error(<<"jabber:client">>, __IgnoreEls,
+                                           _el));
+      _ ->
+         decode_db_result_els(__TopXMLNS, __IgnoreEls, _els, Key,
+                              Error)
+    end;
+decode_db_result_els(__TopXMLNS, __IgnoreEls,
+                    [_ | _els], Key, Error) ->
+    decode_db_result_els(__TopXMLNS, __IgnoreEls, _els, Key,
+                        Error).
+
+decode_db_result_attrs(__TopXMLNS,
+                      [{<<"from">>, _val} | _attrs], _From, To, Type) ->
+    decode_db_result_attrs(__TopXMLNS, _attrs, _val, To,
+                          Type);
+decode_db_result_attrs(__TopXMLNS,
+                      [{<<"to">>, _val} | _attrs], From, _To, Type) ->
+    decode_db_result_attrs(__TopXMLNS, _attrs, From, _val,
+                          Type);
+decode_db_result_attrs(__TopXMLNS,
+                      [{<<"type">>, _val} | _attrs], From, To, _Type) ->
+    decode_db_result_attrs(__TopXMLNS, _attrs, From, To,
+                          _val);
+decode_db_result_attrs(__TopXMLNS, [_ | _attrs], From,
+                      To, Type) ->
+    decode_db_result_attrs(__TopXMLNS, _attrs, From, To,
+                          Type);
+decode_db_result_attrs(__TopXMLNS, [], From, To,
+                      Type) ->
+    {decode_db_result_attr_from(__TopXMLNS, From),
+     decode_db_result_attr_to(__TopXMLNS, To),
+     decode_db_result_attr_type(__TopXMLNS, Type)}.
+
+encode_db_result({db_result, From, To, Type, Key,
+                 Error},
+                _xmlns_attrs) ->
+    _els = lists:reverse(encode_db_result_cdata(Key,
+                                               'encode_db_result_$error'(Error,
+                                                                         []))),
+    _attrs = encode_db_result_attr_type(Type,
+                                       encode_db_result_attr_to(To,
+                                                                encode_db_result_attr_from(From,
+                                                                                           _xmlns_attrs))),
+    {xmlel, <<"db:result">>, _attrs, _els}.
+
+'encode_db_result_$error'(undefined, _acc) -> _acc;
+'encode_db_result_$error'(Error, _acc) ->
+    [encode_error(Error, []) | _acc].
+
+decode_db_result_attr_from(__TopXMLNS, undefined) ->
+    erlang:error({xmpp_codec,
+                 {missing_attr, <<"from">>, <<"db:result">>,
+                  __TopXMLNS}});
+decode_db_result_attr_from(__TopXMLNS, _val) ->
+    case catch dec_jid(_val) of
+      {'EXIT', _} ->
+         erlang:error({xmpp_codec,
+                       {bad_attr_value, <<"from">>, <<"db:result">>,
+                        __TopXMLNS}});
+      _res -> _res
+    end.
+
+encode_db_result_attr_from(_val, _acc) ->
+    [{<<"from">>, enc_jid(_val)} | _acc].
+
+decode_db_result_attr_to(__TopXMLNS, undefined) ->
+    erlang:error({xmpp_codec,
+                 {missing_attr, <<"to">>, <<"db:result">>, __TopXMLNS}});
+decode_db_result_attr_to(__TopXMLNS, _val) ->
+    case catch dec_jid(_val) of
+      {'EXIT', _} ->
+         erlang:error({xmpp_codec,
+                       {bad_attr_value, <<"to">>, <<"db:result">>,
+                        __TopXMLNS}});
+      _res -> _res
+    end.
+
+encode_db_result_attr_to(_val, _acc) ->
+    [{<<"to">>, enc_jid(_val)} | _acc].
+
+decode_db_result_attr_type(__TopXMLNS, undefined) ->
+    undefined;
+decode_db_result_attr_type(__TopXMLNS, _val) ->
+    case catch dec_enum(_val, [valid, invalid, error]) of
+      {'EXIT', _} ->
+         erlang:error({xmpp_codec,
+                       {bad_attr_value, <<"type">>, <<"db:result">>,
+                        __TopXMLNS}});
+      _res -> _res
+    end.
+
+encode_db_result_attr_type(undefined, _acc) -> _acc;
+encode_db_result_attr_type(_val, _acc) ->
+    [{<<"type">>, enc_enum(_val)} | _acc].
+
+decode_db_result_cdata(__TopXMLNS, <<>>) -> <<>>;
+decode_db_result_cdata(__TopXMLNS, _val) -> _val.
+
+encode_db_result_cdata(<<>>, _acc) -> _acc;
+encode_db_result_cdata(_val, _acc) ->
+    [{xmlcdata, _val} | _acc].
+
 decode_adhoc_command(__TopXMLNS, __IgnoreEls,
                     {xmlel, <<"command">>, _attrs, _els}) ->
     {Xdata, Notes, Actions} =
index 0e0145f729fc74fe4f08287157e80c73d4ea6bff..7503eab10a600d326fe11f29a490c2e47b8e3649 100644 (file)
                   #ref{name = xdata, min = 0, max = 1},
                   #ref{name = adhoc_command_notes, label = '$notes'}]}).
 
+-xml(db_result,
+     #elem{name = <<"db:result">>,
+          xmlns = <<"jabber:client">>,
+          result = {db_result, '$from', '$to', '$type', '$key', '$error'},
+          refs = [#ref{name = error, min = 0, max = 1}],
+          cdata = #cdata{default = <<"">>, label = '$key'},
+          attrs = [#attr{name = <<"from">>, required = true,
+                         dec = {dec_jid, []}, enc = {enc_jid, []}},
+                   #attr{name = <<"to">>, required = true,
+                         dec = {dec_jid, []}, enc = {enc_jid, []}},
+                   #attr{name = <<"type">>,
+                         dec = {dec_enum, [[valid, invalid, error]]},
+                         enc = {enc_enum, []}}]}).
+
+-xml(db_verify,
+     #elem{name = <<"db:verify">>,
+          xmlns = <<"jabber:client">>,
+          result = {db_verify, '$from', '$to', '$id', '$type', '$key', '$error'},
+          refs = [#ref{name = error, min = 0, max = 1}],
+          cdata = #cdata{default = <<"">>, label = '$key'},
+          attrs = [#attr{name = <<"from">>, required = true,
+                         dec = {dec_jid, []}, enc = {enc_jid, []}},
+                   #attr{name = <<"to">>, required = true,
+                         dec = {dec_jid, []}, enc = {enc_jid, []}},
+                   #attr{name = <<"id">>, required = true},
+                   #attr{name = <<"type">>,
+                         dec = {dec_enum, [[valid, invalid, error]]},
+                         enc = {enc_enum, []}}]}).
+
+-xml(handshake,
+     #elem{name = <<"handshake">>,
+          xmlns = <<"jabber:client">>,
+          result = {handshake, '$data'},
+          cdata = #cdata{default = <<"">>, label = '$data'}}).
+
+-xml(stream_start,
+     #elem{name = <<"stream:stream">>,
+          xmlns = [<<"jabber:client">>, <<"jabber:server">>,
+                   <<"jabber:component:accept">>],
+          result = {stream_start, '$from', '$to', '$id',
+                    '$version', '$xmlns', '$stream_xmlns',
+                    '$db_xmlns', '$lang'},
+          attrs = [#attr{name = <<"from">>,
+                         dec = {dec_jid, []},
+                         enc = {enc_jid, []}},
+                   #attr{name = <<"to">>,
+                         dec = {dec_jid, []},
+                         enc = {enc_jid, []}},
+                   #attr{name = <<"xmlns">>},
+                   #attr{name = <<"xmlns:stream">>,
+                         label = '$stream_xmlns',
+                         default = <<"">>},
+                   #attr{name = <<"xmlns:db">>,
+                         label = '$db_xmlns',
+                         default = <<"">>},
+                   #attr{name = <<"xml:lang">>, label = '$lang',
+                         default = <<"">>},
+                   #attr{name = <<"version">>, default = <<"">>},
+                   #attr{name = <<"id">>, default = <<"">>}]}).
+
 dec_tzo(Val) ->
     [H1, M1] = str:tokens(Val, <<":">>),
     H = jlib:binary_to_integer(H1),