]> granicus.if.org Git - ejabberd/commitdiff
Reflect cyrsasl API changes in remaining code
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Sat, 31 Dec 2016 10:48:55 +0000 (13:48 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Sat, 31 Dec 2016 10:48:55 +0000 (13:48 +0300)
src/ejabberd_c2s.erl
src/mod_s2s_dialback.erl
src/mod_sm.erl
src/xmpp_stream_in.erl
src/xmpp_stream_out.erl
src/xmpp_stream_pkix.erl

index f22960c50019ebfd8d72aa4c1d161dee942c7c0f..a10ee59a5f73bd2588ddc6478d5716e4c8a23306 100644 (file)
@@ -221,7 +221,7 @@ process_closed(State, Reason) ->
 process_terminated(#{socket := Socket, jid := JID} = State,
                   Reason) ->
     Status = format_reason(State, Reason),
-    ?INFO_MSG("(~s) Closing c2s connection for ~s: ~s",
+    ?INFO_MSG("(~s) Closing c2s session for ~s: ~s",
              [ejabberd_socket:pp(Socket), jid:to_string(JID), Status]),
     Pres = #presence{type = unavailable,
                     status = xmpp:mk_text(Status),
@@ -292,12 +292,12 @@ bind(R, #{user := U, server := S, access := Access, lang := Lang,
                    State1 = open_session(State#{resource => Resource}),
                    State2 = ejabberd_hooks:run_fold(
                               c2s_session_opened, LServer, State1, []),
-                   ?INFO_MSG("(~s) Opened session for ~s",
+                   ?INFO_MSG("(~s) Opened c2s session for ~s",
                              [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
                    {ok, State2};
                deny ->
                    ejabberd_hooks:run(forbidden_session_hook, LServer, [JID]),
-                   ?INFO_MSG("(~s) Forbidden session for ~s",
+                   ?INFO_MSG("(~s) Forbidden c2s session for ~s",
                              [ejabberd_socket:pp(Socket), jid:to_string(JID)]),
                    Txt = <<"Denied by ACL">>,
                    {error, xmpp:err_not_allowed(Txt, Lang), State}
index 4bdda2ca7a9dbec2b2ad2a5db69e9511cf06744a..d0d78a30c272b9cf07b5d7a07e4b003e58d026ce 100644 (file)
@@ -29,7 +29,7 @@
 -export([start/2, stop/1, depends/2, mod_opt_type/1]).
 %% Hooks
 -export([s2s_out_auth_result/2, s2s_out_downgraded/2,
-        s2s_in_packet/2, s2s_out_packet/2,
+        s2s_in_packet/2, s2s_out_packet/2, s2s_in_recv/3,
         s2s_in_features/2, s2s_out_init/2, s2s_out_closed/2]).
 
 -include("ejabberd.hrl").
@@ -52,6 +52,8 @@ start(Host, _Opts) ->
                               s2s_in_features, 50),
            ejabberd_hooks:add(s2s_in_post_auth_features, Host, ?MODULE,
                               s2s_in_features, 50),
+           ejabberd_hooks:add(s2s_in_handle_recv, Host, ?MODULE,
+                              s2s_in_recv, 50),
            ejabberd_hooks:add(s2s_in_unauthenticated_packet, Host, ?MODULE,
                               s2s_in_packet, 50),
            ejabberd_hooks:add(s2s_in_authenticated_packet, Host, ?MODULE,
@@ -71,6 +73,8 @@ stop(Host) ->
                          s2s_in_features, 50),
     ejabberd_hooks:delete(s2s_in_post_auth_features, Host, ?MODULE,
                          s2s_in_features, 50),
+    ejabberd_hooks:delete(s2s_in_handle_recv, Host, ?MODULE,
+                         s2s_in_recv, 50),
     ejabberd_hooks:delete(s2s_in_unauthenticated_packet, Host, ?MODULE,
                          s2s_in_packet, 50),
     ejabberd_hooks:delete(s2s_in_authenticated_packet, Host, ?MODULE,
@@ -191,6 +195,25 @@ s2s_in_packet(State, Pkt) when is_record(Pkt, db_result);
 s2s_in_packet(State, _) ->
     State.
 
+s2s_in_recv(State, El, {error, Why}) ->
+    case xmpp:get_name(El) of
+       Tag when Tag == <<"db:result">>;
+                Tag == <<"db:verify">> ->
+           case xmpp:get_type(El) of
+               T when T /= <<"valid">>,
+                      T /= <<"invalid">>,
+                      T /= <<"error">> ->
+                   Err = xmpp:make_error(El, mk_error({codec_error, Why})),
+                   {stop, ejabberd_s2s_in:send(State, Err)};
+               _ ->
+                   State
+           end;
+       _ ->
+           State
+    end;
+s2s_in_recv(State, _El, _Pkt) ->
+    State.
+
 s2s_out_packet(#{server := LServer,
                 remote_server := RServer,
                 db_verify := {StreamID, _Key, Pid}} = State,
@@ -286,6 +309,8 @@ mk_error(forbidden) ->
     xmpp:err_forbidden(<<"Denied by ACL">>, ?MYLANG);
 mk_error(host_unknown) ->
     xmpp:err_not_allowed(<<"Host unknown">>, ?MYLANG);
+mk_error({codec_error, Why}) ->
+    xmpp:err_bad_request(xmpp:io_format_error(Why), ?MYLANG);
 mk_error({_Class, _Reason} = Why) ->
     Txt = xmpp_stream_out:format_error(Why),
     xmpp:err_remote_server_not_found(Txt, ?MYLANG);
index 7032344196a113d63f16cdcdfb96dbf7e7d82337..7e64e6a000360b477aa21b23f581aa44abfc2b48 100644 (file)
@@ -179,16 +179,14 @@ c2s_handle_recv(#{lang := Lang} = State, El, {error, Why}) ->
 c2s_handle_recv(State, _, _) ->
     State.
 
-c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, Result)
+c2s_handle_send(#{mgmt_state := MgmtState} = State, Pkt, _Result)
   when MgmtState == pending; MgmtState == active ->
     State1 = mgmt_queue_add(State, Pkt),
-    case Result of
-       ok when ?is_stanza(Pkt) ->
+    case xmpp:is_stanza(Pkt) of
+       true ->
            send_rack(State1);
-       ok ->
-           State1;
-       {error, _} ->
-           transition_to_pending(State1)
+       false ->
+           State1
     end;
 c2s_handle_send(State, _Pkt, _Result) ->
     State.
@@ -210,8 +208,9 @@ c2s_handle_info(#{mgmt_ack_timer := TRef, jid := JID} = State,
                {timeout, TRef, ack_timeout}) ->
     ?DEBUG("Timed out waiting for stream management acknowledgement of ~s",
           [jid:to_string(JID)]),
-    State1 = ejabberd_c2s:close(State, _SendTrailer = false),
-    {stop, transition_to_pending(State1)};
+    State1 = State#{stop_reason => {socket, timeout}},
+    State2 = ejabberd_c2s:close(State1, _SendTrailer = false),
+    {stop, transition_to_pending(State2)};
 c2s_handle_info(#{mgmt_state := pending, jid := JID} = State,
                {timeout, _, pending_timeout}) ->
     ?DEBUG("Timed out waiting for resumption of stream for ~s",
@@ -222,8 +221,8 @@ c2s_handle_info(State, _) ->
 
 c2s_closed(State, {stream, _}) ->
     State;
-c2s_closed(#{mgmt_state := active} = State, Reason) ->
-    {stop, transition_to_pending(State#{stop_reason => Reason})};
+c2s_closed(#{mgmt_state := active} = State, _Reason) ->
+    {stop, transition_to_pending(State)};
 c2s_closed(State, _Reason) ->
     State.
 
@@ -368,10 +367,9 @@ transition_to_pending(#{mgmt_state := active, jid := JID,
                        lserver := LServer, mgmt_timeout := Timeout} = State) ->
     State1 = cancel_ack_timer(State),
     ?INFO_MSG("Waiting for resumption of stream for ~s", [jid:to_string(JID)]),
-    State2 = ejabberd_hooks:run_fold(c2s_session_pending, LServer, State1, []),
-    State3 = ejabberd_c2s:close(State2, _SendTrailer = false),
     erlang:start_timer(timer:seconds(Timeout), self(), pending_timeout),
-    State3#{mgmt_state => pending};
+    State2 = State1#{mgmt_state => pending},
+    ejabberd_hooks:run_fold(c2s_session_pending, LServer, State2, []);
 transition_to_pending(State) ->
     State.
 
@@ -405,8 +403,8 @@ update_num_stanzas_in(State, _El) ->
 send_rack(#{mgmt_ack_timer := _} = State) ->
     State;
 send_rack(#{mgmt_xmlns := Xmlns,
-          mgmt_stanzas_out := NumStanzasOut,
-          mgmt_ack_timeout := AckTimeout} = State) ->
+           mgmt_stanzas_out := NumStanzasOut,
+           mgmt_ack_timeout := AckTimeout} = State) ->
     State1 = send(State, #sm_r{xmlns = Xmlns}),
     TRef = erlang:start_timer(AckTimeout, self(), ack_timeout),
     State1#{mgmt_ack_timer => TRef, mgmt_stanzas_req => NumStanzasOut}.
@@ -425,16 +423,19 @@ resend_rack(State) ->
 
 -spec mgmt_queue_add(state(), xmpp_element()) -> state().
 mgmt_queue_add(#{mgmt_stanzas_out := NumStanzasOut,
-                mgmt_queue := Queue} = State, Stanza) when ?is_stanza(Stanza) ->
-    NewNum = case NumStanzasOut of
-              4294967295 -> 0;
-              Num -> Num + 1
-            end,
-    Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Stanza}, Queue),
-    State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum},
-    check_queue_length(State1);
-mgmt_queue_add(State, _Nonza) ->
-    State.
+                mgmt_queue := Queue} = State, Pkt) ->
+    case xmpp:is_stanza(Pkt) of
+       true ->
+           NewNum = case NumStanzasOut of
+                        4294967295 -> 0;
+                        Num -> Num + 1
+                    end,
+           Queue1 = queue_in({NewNum, p1_time_compat:timestamp(), Pkt}, Queue),
+           State1 = State#{mgmt_queue => Queue1, mgmt_stanzas_out => NewNum},
+           check_queue_length(State1);
+       false ->
+           State
+    end.
 
 -spec mgmt_queue_drop(state(), non_neg_integer()) -> state().
 mgmt_queue_drop(#{mgmt_queue := Queue} = State, NumHandled) ->
index a0387064376521f22d2cb2584559f8ec0e0b3569..1ad78d45b5b0501a58813f6078cf51c570f97e75 100644 (file)
@@ -42,7 +42,7 @@
 
 -include("xmpp.hrl").
 -type state() :: map().
--type stop_reason() :: {stream, reset | stream_error()} |
+-type stop_reason() :: {stream, reset | {in | out, stream_error()}} |
                       {tls, term()} |
                       {socket, inet:posix() | closed | timeout} |
                       internal_failure.
@@ -188,8 +188,10 @@ format_error({socket, Reason}) ->
     format("Connection failed: ~s", [format_inet_error(Reason)]);
 format_error({stream, reset}) ->
     <<"Stream reset by peer">>;
-format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
-    format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
+format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
+    format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]);
+format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
+    format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
 format_error({tls, Reason}) ->
     format("TLS failed: ~w", [Reason]);
 format_error(internal_failure) ->
@@ -304,7 +306,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
              send_element(State1, Err)
       end);
 handle_info({'$gen_event', {xmlstreamelement, El}},
-           #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
+           #{xmlns := NS, mod := Mod} = State) ->
     noreply(
       try xmpp:decode(El, NS, [ignore_els]) of
          Pkt ->
@@ -321,10 +323,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
                       end,
              case is_disconnected(State1) of
                  true -> State1;
-                 false ->
-                     Txt = xmpp:io_format_error(Why),
-                     Lang = select_lang(MyLang, xmpp:get_lang(El)),
-                     send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
+                 false -> process_invalid_xml(State1, El, Why)
              end
       end);
 handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
@@ -394,6 +393,33 @@ peername(SockMod, Socket) ->
        _ -> SockMod:peername(Socket)
     end.
 
+-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
+process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
+    case xmpp:is_stanza(El) of
+       true ->
+           Txt = xmpp:io_format_error(Reason),
+           Lang = select_lang(MyLang, xmpp:get_lang(El)),
+           send_error(State, El, xmpp:err_bad_request(Txt, Lang));
+       false ->
+           case {xmpp:get_name(El), xmpp:get_ns(El)} of
+               {Tag, ?NS_SASL} when Tag == <<"auth">>;
+                                    Tag == <<"response">>;
+                                    Tag == <<"abort">> ->
+                   Txt = xmpp:io_format_error(Reason),
+                   Err = #sasl_failure{reason = 'malformed-request',
+                                       text = xmpp:mk_text(Txt, MyLang)},
+                   send_element(State, Err);
+               {<<"starttls">>, ?NS_TLS} ->
+                   send_element(State, #starttls_failure{});
+               {<<"compress">>, ?NS_COMPRESS} ->
+                   Err = #compress_failure{reason = 'setup-failed'},
+                   send_element(State, Err);
+               _ ->
+                   %% Maybe add something more?
+                   State
+           end
+    end.
+
 -spec process_stream_end(stop_reason(), state()) -> state().
 process_stream_end(_, #{stream_state := disconnected} = State) ->
     State;
@@ -423,11 +449,6 @@ process_stream(#stream_start{lang = Lang},
 process_stream(#stream_start{to = undefined}, #{lang := Lang} = State) ->
     Txt = <<"Missing 'to' attribute">>,
     send_element(State, xmpp:serr_improper_addressing(Txt, Lang));
-process_stream(#stream_start{from = undefined, version = {1,0}},
-              #{lang := Lang, xmlns := ?NS_SERVER,
-                stream_encrypted := true} = State) ->
-    Txt = <<"Missing 'from' attribute">>,
-    send_element(State, xmpp:serr_invalid_from(Txt, Lang));
 process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
               #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
     Txt = <<"Improper 'to' attribute">>,
@@ -450,9 +471,10 @@ process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
                true ->
                     State
             end,
-    State2 = if NS == ?NS_SERVER andalso Encrypted ->
-                    State1#{remote_server => From#jid.lserver};
-               true ->
+    State2 = case From of
+                #jid{lserver = RemoteServer} when NS == ?NS_SERVER ->
+                    State1#{remote_server => RemoteServer};
+                _ ->
                     State1
             end,
     State3 = try Mod:handle_stream_start(StreamStart, State2)
@@ -517,7 +539,7 @@ process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
        #handshake{} ->
            State;
        #stream_error{} ->
-           process_stream_end({stream, Pkt}, State);
+           process_stream_end({stream, {in, Pkt}}, State);
        _ when StateName == wait_for_sasl_request;
               StateName == wait_for_handshake;
               StateName == wait_for_sasl_response ->
@@ -707,35 +729,34 @@ process_starttls_failure(Why, State) ->
 -spec process_sasl_request(sasl_auth(), state()) -> state().
 process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
                     #{mod := Mod, lserver := LServer} = State) ->
-    GetPW = try Mod:get_password_fun(State)
-           catch _:undef -> fun(_) -> false end
-           end,
-    CheckPW = try Mod:check_password_fun(State)
-             catch _:undef -> fun(_, _, _) -> false end
-             end,
-    CheckPWDigest = try Mod:check_password_digest_fun(State)
-                   catch _:undef -> fun(_, _, _, _, _) -> false end
-                   end,
-    SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
-                                  GetPW, CheckPW, CheckPWDigest),
-    State1 = State#{sasl_state => SASLState, sasl_mech => Mech},
+    State1 = State#{sasl_mech => Mech},
     Mechs = get_sasl_mechanisms(State1),
-    SASLResult = case lists:member(Mech, Mechs) of
-                    true when Mech == <<"EXTERNAL">> ->
-                        case xmpp_stream_pkix:authenticate(State1, ClientIn) of
-                            {ok, Peer} ->
-                                {ok, [{auth_module, pkix},
-                                      {username, Peer}]};
-                            {error, _Reason, Peer} ->
-                                %% TODO: return meaningful error
-                                {error, 'not-authorized', Peer}
-                        end;
-                    true ->
-                        cyrsasl:server_start(SASLState, Mech, ClientIn);
-                    false ->
-                        {error, 'invalid-mechanism'}
-                end,
-    process_sasl_result(SASLResult, State1).
+    case lists:member(Mech, Mechs) of
+       true when Mech == <<"EXTERNAL">> ->
+           Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of
+                     {ok, Peer} ->
+                         {ok, [{auth_module, pkix}, {username, Peer}]};
+                     {error, Reason, Peer} ->
+                         {error, Reason, Peer}
+                 end,
+           process_sasl_result(Res, State1);
+       true ->
+           GetPW = try Mod:get_password_fun(State1)
+                   catch _:undef -> fun(_) -> false end
+                   end,
+           CheckPW = try Mod:check_password_fun(State1)
+                     catch _:undef -> fun(_, _, _) -> false end
+                     end,
+           CheckPWDigest = try Mod:check_password_digest_fun(State1)
+                           catch _:undef -> fun(_, _, _, _, _) -> false end
+                           end,
+           SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
+                                          GetPW, CheckPW, CheckPWDigest),
+           Res = cyrsasl:server_start(SASLState, Mech, ClientIn),
+           process_sasl_result(Res, State1#{sasl_state => SASLState});
+       false ->
+           process_sasl_result({error, unsupported_mechanism, <<"">>}, State1)
+    end.
 
 -spec process_sasl_response(sasl_response(), state()) -> state().
 process_sasl_response(#sasl_response{text = ClientIn},
@@ -751,9 +772,7 @@ process_sasl_result({ok, Props, ServerOut}, State) ->
 process_sasl_result({continue, ServerOut, NewSASLState}, State) ->
     process_sasl_continue(ServerOut, NewSASLState, State);
 process_sasl_result({error, Reason, User}, State) ->
-    process_sasl_failure(Reason, User, State);
-process_sasl_result({error, Reason}, State) ->
-    process_sasl_failure(Reason, <<"">>, State).
+    process_sasl_failure(Reason, User, State).
 
 -spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
 process_sasl_success(Props, ServerOut,
@@ -790,18 +809,20 @@ process_sasl_continue(ServerOut, NewSASLState, State) ->
     send_element(State1, #sasl_challenge{text = ServerOut}).
 
 -spec process_sasl_failure(atom(), binary(), state()) -> state().
-process_sasl_failure(Reason, User,
-                    #{mod := Mod, sasl_mech := Mech} = State) ->
-    State1 = try Mod:handle_auth_failure(User, Mech, Reason, State)
+process_sasl_failure(Err, User,
+                    #{mod := Mod, sasl_mech := Mech, lang := Lang} = State) ->
+    {Reason, Text} = format_sasl_error(Mech, Err),
+    State1 = try Mod:handle_auth_failure(User, Mech, Text, State)
             catch _:undef -> State
             end,
     State2 = maps:remove(sasl_state, maps:remove(sasl_mech, State1)),
     State3 = State2#{stream_state => wait_for_sasl_request},
-    send_element(State3, #sasl_failure{reason = Reason}).
+    send_element(State3, #sasl_failure{reason = Reason,
+                                      text = xmpp:mk_text(Text, Lang)}).
 
 -spec process_sasl_abort(state()) -> state().
 process_sasl_abort(State) ->
-    process_sasl_failure('aborted', <<"">>, State).
+    process_sasl_failure(aborted, <<"">>, State).
 
 -spec send_features(state()) -> state().
 send_features(#{stream_version := {1,0},
@@ -985,13 +1006,17 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
     State1 = try Mod:handle_send(Pkt, Result, State)
             catch _:undef -> State
             end,
-    case Result of
-       _ when is_record(Pkt, stream_error) ->
-           process_stream_end({stream, Pkt}, State1);
-       ok ->
-           State1;
-       {error, Why} ->
-           process_stream_end({socket, Why}, State1)
+    case is_disconnected(State1) of
+       true -> State1;
+       false ->
+           case Result of
+               _ when is_record(Pkt, stream_error) ->
+                   process_stream_end({stream, {out, Pkt}}, State1);
+               ok ->
+                   State1;
+               {error, Why} ->
+                   process_stream_end({socket, Why}, State1)
+           end
     end.
 
 -spec send_error(state(), xmpp_element(), stanza_error()) -> state().
@@ -1025,6 +1050,8 @@ send_text(_, _) ->
     {error, closed}.
 
 -spec close_socket(state()) -> state().
+close_socket(#{stream_state := disconnected} = State) ->
+    State;
 close_socket(#{sockmod := SockMod, socket := Socket} = State) ->
     SockMod:close(Socket),
     State#{stream_timeout => infinity,
@@ -1052,6 +1079,7 @@ format_inet_error(Reason) ->
 -spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
 format_stream_error(Reason, Txt) ->
     Slogan = case Reason of
+                undefined -> "no reason";
                 #'see-other-host'{} -> "see-other-host";
                 _ -> atom_to_list(Reason)
             end,
@@ -1062,6 +1090,12 @@ format_stream_error(Reason, Txt) ->
            binary_to_list(Data) ++ " (" ++ Slogan ++ ")"
     end.
 
+-spec format_sasl_error(cyrsasl:mechanism(), atom()) -> {atom(), binary()}.
+format_sasl_error(<<"EXTERNAL">>, Err) ->
+    xmpp_stream_pkix:format_error(Err);
+format_sasl_error(Mech, Err) ->
+    cyrsasl:format_error(Mech, Err).
+
 -spec format(io:format(), list()) -> binary().
 format(Fmt, Args) ->
     iolist_to_binary(io_lib:format(Fmt, Args)).
index 08804e43282927e29cb06bf6dea763b871ead57b..290a92a49fafc8ec303654dd7102c4f47c1b190d 100644 (file)
@@ -1,10 +1,23 @@
 %%%-------------------------------------------------------------------
-%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
-%%% @copyright (C) 2016, Evgeny Khramtsov
-%%% @doc
-%%%
-%%% @end
 %%% Created : 14 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2016   ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
 %%%-------------------------------------------------------------------
 -module(xmpp_stream_out).
 -behaviour(gen_server).
@@ -39,7 +52,7 @@
 -type network_error() :: {error, inet:posix() | inet_res:res_error()}.
 -type stop_reason() :: {idna, bad_string} |
                       {dns, inet:posix() | inet_res:res_error()} |
-                      {stream, reset | stream_error()} |
+                      {stream, reset | {in | out, stream_error()}} |
                       {tls, term()} |
                       {pkix, binary()} |
                       {auth, atom() | binary() | string()} |
@@ -135,7 +148,7 @@ change_shaper(_, _) ->
 
 -spec format_error(stop_reason()) ->  binary().
 format_error({idna, _}) ->
-    <<"Not an IDN hostname">>;
+    <<"Remote domain is not an IDN hostname">>;
 format_error({dns, Reason}) ->
     format("DNS lookup failed: ~s", [format_inet_error(Reason)]);
 format_error({socket, Reason}) ->
@@ -144,8 +157,10 @@ format_error({pkix, Reason}) ->
     format("Peer certificate rejected: ~s", [Reason]);
 format_error({stream, reset}) ->
     <<"Stream reset by peer">>;
-format_error({stream, #stream_error{reason = Reason, text = Txt}}) ->
-    format("Stream failed: ~s", [format_stream_error(Reason, Txt)]);
+format_error({stream, {in, #stream_error{reason = Reason, text = Txt}}}) ->
+    format("Stream closed by peer: ~s", [format_stream_error(Reason, Txt)]);
+format_error({stream, {out, #stream_error{reason = Reason, text = Txt}}}) ->
+    format("Stream closed by us: ~s", [format_stream_error(Reason, Txt)]);
 format_error({tls, Reason}) ->
     format("TLS failed: ~w", [Reason]);
 format_error({auth, Reason}) ->
@@ -264,7 +279,7 @@ handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
              send_element(State1, Err)
       end);
 handle_info({'$gen_event', {xmlstreamelement, El}},
-           #{xmlns := NS, lang := MyLang, mod := Mod} = State) ->
+           #{xmlns := NS, mod := Mod} = State) ->
     noreply(
       try xmpp:decode(El, NS, [ignore_els]) of
          Pkt ->
@@ -281,10 +296,7 @@ handle_info({'$gen_event', {xmlstreamelement, El}},
                       end,
              case is_disconnected(State1) of
                  true -> State1;
-                 false ->
-                     Txt = xmpp:io_format_error(Why),
-                     Lang = select_lang(MyLang, xmpp:get_lang(El)),
-                     send_error(State1, El, xmpp:err_bad_request(Txt, Lang))
+                 false -> process_invalid_xml(State1, El, Why)
              end
       end);
 handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
@@ -347,6 +359,17 @@ new_id() ->
 is_disconnected(#{stream_state := StreamState}) ->
     StreamState == disconnected.
 
+-spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
+process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
+    case xmpp:is_stanza(El) of
+       true ->
+           Txt = xmpp:io_format_error(Reason),
+           Lang = select_lang(MyLang, xmpp:get_lang(El)),
+           send_error(State, El, xmpp:err_bad_request(Txt, Lang));
+       false ->
+           State
+    end.
+
 -spec process_stream_end(stop_reason(), state()) -> state().
 process_stream_end(_, #{stream_state := disconnected} = State) ->
     State;
@@ -394,7 +417,7 @@ process_element(Pkt, #{stream_state := StateName} = State) ->
        #sasl_failure{} when StateName == wait_for_sasl_response ->
            process_sasl_failure(Pkt, State);
        #stream_error{} ->
-           process_stream_end({stream, Pkt}, State);
+           process_stream_end({stream, {in, Pkt}}, State);
        _ when is_record(Pkt, stream_features);
               is_record(Pkt, starttls_proceed);
               is_record(Pkt, starttls);
@@ -612,7 +635,7 @@ send_element(#{xmlns := NS, mod := Mod} = State, Pkt) ->
        false ->
            case send_text(State1, Data) of
                _ when is_record(Pkt, stream_error) ->
-                   process_stream_end({stream, Pkt}, State1);
+                   process_stream_end({stream, {out, Pkt}}, State1);
                ok ->
                    State1;
                {error, Why} ->
@@ -650,6 +673,8 @@ send_trailer(State) ->
     close_socket(State).
 
 -spec close_socket(state()) -> state().
+close_socket(#{stream_state := disconnected} = State) ->
+    State;
 close_socket(State) ->
     case State of
        #{sockmod := SockMod, socket := Socket} ->
@@ -674,6 +699,7 @@ format_inet_error(Reason) ->
 -spec format_stream_error(atom() | 'see-other-host'(), undefined | text()) -> string().
 format_stream_error(Reason, Txt) ->
     Slogan = case Reason of
+                undefined -> "no reason";
                 #'see-other-host'{} -> "see-other-host";
                 _ -> atom_to_list(Reason)
             end,
index 59f5d820eaa6d730c05bdb540b3e6dee7b6fc859..5d64c5eb6eb6932faeaf0dc12acf89da9e6b1e24 100644 (file)
@@ -9,7 +9,7 @@
 -module(xmpp_stream_pkix).
 
 %% API
--export([authenticate/1, authenticate/2]).
+-export([authenticate/1, authenticate/2, format_error/1]).
 
 -include("xmpp.hrl").
 -include_lib("public_key/include/public_key.hrl").
 %%% API
 %%%===================================================================
 -spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state())
-      -> {ok, binary()} | {error, binary(), binary()}.
+      -> {ok, binary()} | {error, atom(), binary()}.
 authenticate(State) ->
     authenticate(State, <<"">>).
 
 -spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state(), binary())
-      -> {ok, binary()} | {error, binary(), binary()}.
-authenticate(#{xmlns := ?NS_SERVER, remote_server := Peer,
-              sockmod := SockMod, socket := Socket}, _Authzid) ->
+      -> {ok, binary()} | {error, atom(), binary()}.
+authenticate(#{xmlns := ?NS_SERVER, sockmod := SockMod,
+              socket := Socket} = State, Authzid) ->
+    Peer = try maps:get(remote_server, State)
+          catch _:{badkey, _} -> Authzid
+          end,
     case SockMod:get_peer_certificate(Socket) of
        {ok, Cert} ->
            case SockMod:get_verify_result(Socket) of
                0 ->
                    case ejabberd_idna:domain_utf8_to_ascii(Peer) of
                        false ->
-                           {error, <<"Cannot decode remote server name">>, Peer};
+                           {error, idna_failed, Peer};
                        AsciiPeer ->
                            case lists:any(
                                   fun(D) -> match_domain(AsciiPeer, D) end,
@@ -41,20 +44,34 @@ authenticate(#{xmlns := ?NS_SERVER, remote_server := Peer,
                                true ->
                                    {ok, Peer};
                                false ->
-                                   {error, <<"Certificate host name mismatch">>, Peer}
+                                   {error, hostname_mismatch, Peer}
                            end
                    end;
                VerifyRes ->
-                   {error, fast_tls:get_cert_verify_string(VerifyRes, Cert), Peer}
+                   %% TODO: return atomic errors
+                   %% This should be improved in fast_tls
+                   Reason = fast_tls:get_cert_verify_string(VerifyRes, Cert),
+                   {error, erlang:binary_to_atom(Reason, utf8), Peer}
            end;
        {error, _Reason} ->
-           {error, <<"Cannot get peer certificate">>, Peer};
+           {error, get_cert_failed, Peer};
        error ->
-           {error, <<"Cannot get peer certificate">>, Peer}
+           {error, get_cert_failed, Peer}
     end;
 authenticate(_State, _Authzid) ->
     %% TODO: client PKIX authentication
-    {error, <<"Client certificate verification not implemented">>, <<"">>}.
+    {error, client_not_supported, <<"">>}.
+
+format_error(idna_failed) ->
+    {'bad-protocol', <<"Remote domain is not an IDN hostname">>};
+format_error(hostname_mismatch) ->
+    {'not-authorized', <<"Certificate host name mismatch">>};
+format_error(get_cert_failed) ->
+    {'bad-protocol', <<"Failed to get peer certificate">>};
+format_error(client_not_supported) ->
+    {'invalid-mechanism', <<"Client certificate verification is not supported">>};
+format_error(Other) ->
+    {'not-authorized', erlang:atom_to_binary(Other, utf8)}.
 
 %%%===================================================================
 %%% Internal functions