]> granicus.if.org Git - ejabberd/commitdiff
Support SASL PLAIN by xmpp_stream_out
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Mon, 25 Jun 2018 16:16:33 +0000 (19:16 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Mon, 25 Jun 2018 16:16:33 +0000 (19:16 +0300)
Also, SASL mechanisms chaining is now supported:
if several mechanisms are supported and authentication
fails, next mechanism in the list is picked, until the
list is exhausted. In the case of a failure, the latest
SASL failure reason is returned within handle_auth_failure/3
callback.

src/xmpp_stream_out.erl

index 5856fa58acf5a1984ff86c2d787195ff49dac4e2..6bdd162135265b5b566cb348155e2b34b446b88f 100644 (file)
@@ -91,6 +91,7 @@
 -callback tls_verify(state()) -> boolean().
 -callback tls_enabled(state()) -> boolean().
 -callback resolve(string(), state()) -> [host_port()].
+-callback sasl_mechanisms(state()) -> [binary()].
 -callback dns_timeout(state()) -> timeout().
 -callback dns_retries(state()) -> non_neg_integer().
 -callback default_port(state()) -> inet:port_number().
                     tls_verify/1,
                     tls_enabled/1,
                     resolve/2,
+                    sasl_mechanisms/1,
                     dns_timeout/1,
                     dns_retries/1,
                     default_port/1,
@@ -260,6 +262,7 @@ init([Mod, From, To, Opts]) ->
              server => From,
              user => <<"">>,
              resource => <<"">>,
+             password => <<"">>,
              lang => <<"">>,
              remote_server => To,
              xmlns => ?NS_SERVER,
@@ -317,16 +320,8 @@ handle_cast(connect, #{remote_server := RemoteServer,
              end
       end);
 handle_cast(connect, #{stream_state := disconnected} = State) ->
-    State1 = State#{stream_id => new_id(),
-                   stream_encrypted => false,
-                   stream_verified => false,
-                   stream_authenticated => false,
-                   stream_restarted => false,
-                   stream_state => connecting},
-    State2 = maps:remove(ip, State1),
-    State3 = maps:remove(socket, State2),
-    State4 = maps:remove(socket_monitor, State3),
-    handle_cast(connect, State4);
+    State1 = reset_state(State),
+    handle_cast(connect, State1);
 handle_cast(connect, State) ->
     %% Ignoring connection attempts in other states
     noreply(State);
@@ -559,8 +554,7 @@ process_features(StreamFeatures,
     catch _:{?MODULE, undef} -> process_bind(StreamFeatures, State)
     end;
 process_features(StreamFeatures,
-                #{stream_encrypted := Encrypted,
-                  lang := Lang, xmlns := NS} = State) ->
+                #{stream_encrypted := Encrypted, lang := Lang} = State) ->
     State1 = try callback(handle_unauthenticated_features, StreamFeatures, State)
             catch _:{?MODULE, undef} -> State
             end,
@@ -573,39 +567,21 @@ process_features(StreamFeatures,
                false when TLSRequired and not Encrypted ->
                    Txt = <<"Use of STARTTLS required">>,
                    send_pkt(State1, xmpp:serr_policy_violation(Txt, Lang));
-               false when NS == ?NS_SERVER andalso not Encrypted ->
-                   process_sasl_failure(
-                     <<"Peer doesn't support STARTTLS">>, State1);
                #starttls{required = true} when not TLSAvailable and not Encrypted ->
                    Txt = <<"Use of STARTTLS forbidden">>,
                    send_pkt(State1, xmpp:serr_unsupported_feature(Txt, Lang));
                #starttls{} when TLSAvailable and not Encrypted ->
                    State2 = State1#{stream_state => wait_for_starttls_response},
                    send_pkt(State2, #starttls{});
-               #starttls{} when NS == ?NS_SERVER andalso not Encrypted ->
-                   process_sasl_failure(
-                     <<"STARTTLS is disabled in local configuration">>, State1);
                _ ->
                    State2 = process_cert_verification(State1),
                    case is_disconnected(State2) of
                        true -> State2;
-                       false ->
-                           try xmpp:try_subtag(StreamFeatures, #sasl_mechanisms{}) of
-                               #sasl_mechanisms{list = Mechs} ->
-                                   process_sasl_mechanisms(Mechs, State2);
-                               false ->
-                                   Txt = <<"Peer provided no SASL mechanisms; "
-                                           "most likely it doesn't accept "
-                                           "our certificate">>,
-                                   process_sasl_failure(Txt, State2)
-                           catch _:{xmpp_codec, Why} ->
-                                   Txt = xmpp:io_format_error(Why),
-                                   process_sasl_failure(Txt, State1)
-                           end
+                       false -> process_sasl_mechanisms(StreamFeatures, State2)
                    end
            catch _:{xmpp_codec, Why} ->
                    Txt = xmpp:io_format_error(Why),
-                   process_sasl_failure(Txt, State1)
+                   send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang))
            end
     end.
 
@@ -621,18 +597,56 @@ process_stream_established(State) ->
     catch _:{?MODULE, undef} -> State1
     end.
 
--spec process_sasl_mechanisms([binary()], state()) -> state().
-process_sasl_mechanisms(Mechs, #{user := User, server := Server} = State) ->
-    %% TODO: support other mechanisms
-    Mech = <<"EXTERNAL">>,
-    case lists:member(<<"EXTERNAL">>, Mechs) of
-       true ->
-           State1 = State#{stream_state => wait_for_sasl_response},
-           Authzid = jid:encode(jid:make(User, Server)),
-           send_pkt(State1, #sasl_auth{mechanism = Mech, text = Authzid});
+-spec process_sasl_mechanisms(stream_features(), state()) -> state().
+process_sasl_mechanisms(StreamFeatures, State) ->
+    AvailMechs = sasl_mechanisms(State),
+    State1 = State#{sasl_mechs_available => AvailMechs},
+    try xmpp:try_subtag(StreamFeatures, #sasl_mechanisms{}) of
+       #sasl_mechanisms{list = ProvidedMechs} ->
+           process_sasl_auth(State1#{sasl_mechs_provided => ProvidedMechs});
        false ->
-           process_sasl_failure(
-             <<"Peer doesn't support EXTERNAL authentication">>, State)
+           process_sasl_auth(State1#{sasl_mechs_provided => []})
+    catch _:{xmpp_codec, Why} ->
+           Txt = xmpp:io_format_error(Why),
+           Lang = maps:get(lang, State),
+           send_pkt(State, xmpp:serr_invalid_xml(Txt, Lang))
+    end.
+
+process_sasl_auth(#{stream_encrypted := false, xmlns := ?NS_SERVER} = State) ->
+    State1 = State#{sasl_mechs_available => []},
+    Txt = case is_starttls_available(State) of
+             true -> <<"Peer doesn't support STARTTLS">>;
+             false -> <<"STARTTLS is disabled in local configuration">>
+         end,
+    process_sasl_failure(Txt, State1);
+process_sasl_auth(#{sasl_mechs_provided := [],
+                   stream_encrypted := Encrypted} = State) ->
+    State1 = State#{sasl_mechs_available => []},
+    Hint = case Encrypted of
+              true -> <<"; most likely it doesn't accept our certificate">>;
+              false -> <<"">>
+          end,
+    Txt = <<"Peer provided no SASL mechanisms", Hint/binary>>,
+    process_sasl_failure(Txt, State1);
+process_sasl_auth(#{sasl_mechs_available := []} = State) ->
+    Err = maps:get(sasl_error, State,
+                  <<"No mutually supported SASL mechanisms found">>),
+    process_sasl_failure(Err, State);
+process_sasl_auth(#{sasl_mechs_available := [Mech|AvailMechs],
+                   sasl_mechs_provided := ProvidedMechs} = State) ->
+    State1 = State#{sasl_mechs_available => AvailMechs},
+    if Mech == <<"EXTERNAL">> orelse Mech == <<"PLAIN">> ->
+           case lists:member(Mech, ProvidedMechs) of
+               true ->
+                   Text = make_sasl_authzid(Mech, State1),
+                   State2 = State1#{sasl_mech => Mech,
+                                    stream_state => wait_for_sasl_response},
+                   send(State2, #sasl_auth{mechanism = Mech, text = Text});
+               false ->
+                   process_sasl_auth(State1)
+           end;
+       true ->
+           process_sasl_auth(State1)
     end.
 
 -spec process_starttls(state()) -> state().
@@ -685,30 +699,42 @@ process_cert_verification(State) ->
     State.
 
 -spec process_sasl_success(state()) -> state().
-process_sasl_success(#{socket := Socket} = State) ->
+process_sasl_success(#{socket := Socket, sasl_mech := Mech} = State) ->
     Socket1 = xmpp_socket:reset_stream(Socket),
-    State0 = State#{socket => Socket1},
-    State1 = State0#{stream_id => new_id(),
+    State1 = State#{socket => Socket1},
+    State2 = State1#{stream_id => new_id(),
                     stream_restarted => true,
                     stream_state => wait_for_stream,
                     stream_authenticated => true},
-    State2 = send_header(State1),
-    case is_disconnected(State2) of
-       true -> State2;
+    State3 = reset_sasl_state(State2),
+    State4 = send_header(State3),
+    case is_disconnected(State4) of
+       true -> State4;
        false ->
-           try callback(handle_auth_success, <<"EXTERNAL">>, State2)
-           catch _:{?MODULE, undef} -> State2
+           try callback(handle_auth_success, Mech, State4)
+           catch _:{?MODULE, undef} -> State4
            end
     end.
 
 -spec process_sasl_failure(sasl_failure() | binary(), state()) -> state().
+process_sasl_failure(Failure, #{sasl_mechs_available := [_|_]} = State) ->
+    process_sasl_auth(State#{sasl_failure => Failure});
 process_sasl_failure(#sasl_failure{} = Failure, State) ->
     Reason = format("Peer responded with error: ~s",
                    [xmpp:format_sasl_error(Failure)]),
     process_sasl_failure(Reason, State);
 process_sasl_failure(Reason, State) ->
-    try callback(handle_auth_failure, <<"EXTERNAL">>, {auth, Reason}, State)
-    catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State)
+    Mech = case maps:get(sasl_mech, State, undefined) of
+              undefined ->
+                  case sasl_mechanisms(State) of
+                      [] -> <<"EXTERNAL">>;
+                      [M|_] -> M
+                  end;
+              M -> M
+          end,
+    State1 = reset_sasl_state(State),
+    try callback(handle_auth_failure, Mech, {auth, Reason}, State1)
+    catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State1)
     end.
 
 -spec process_bind(stream_features(), state()) -> state().
@@ -736,7 +762,7 @@ process_bind(_, State) ->
 -spec process_bind_response(xmpp_element(), state()) -> state().
 process_bind_response(#iq{type = result, id = ID} = IQ,
                      #{lang := Lang, bind_id := ID} = State) ->
-    State1 = maps:remove(bind_id, State),
+    State1 = reset_bind_state(State),
     try xmpp:try_subtag(IQ, #bind{}) of
        #bind{jid = #jid{user = U, server = S, resource = R}} ->
            State2 = State1#{user => U, server => S, resource => R},
@@ -757,7 +783,7 @@ process_bind_response(#iq{type = result, id = ID} = IQ,
 process_bind_response(#iq{type = error, id = ID} = IQ,
                      #{bind_id := ID} = State) ->
     Err = xmpp:get_error(IQ),
-    State1 = maps:remove(bind_id, State),
+    State1 = reset_bind_state(State),
     try callback(handle_bind_failure, Err, State1)
     catch _:{?MODULE, undef} -> process_stream_end({bind, Err}, State1)
     end;
@@ -782,6 +808,17 @@ is_starttls_available(State) ->
     catch _:{?MODULE, undef} -> true
     end.
 
+-spec sasl_mechanisms(state()) -> [binary()].
+sasl_mechanisms(#{stream_encrypted := Encrypted} = State) ->
+    try callback(sasl_mechanisms, State) of
+       Ms when Encrypted -> Ms;
+       Ms -> lists:delete(<<"EXTERNAL">>, Ms)
+    catch _:{?MODULE, undef} ->
+           if Encrypted -> [<<"EXTERNAL">>];
+              true -> []
+           end
+    end.
+
 -spec send_header(state()) -> state().
 send_header(#{remote_server := RemoteServer,
              stream_encrypted := Encrypted,
@@ -912,6 +949,54 @@ format_tls_error(Reason) ->
 format(Fmt, Args) ->
     iolist_to_binary(io_lib:format(Fmt, Args)).
 
+-spec make_sasl_authzid(binary(), state()) -> binary().
+make_sasl_authzid(Mech, #{user := User, server := Server,
+                         password := Password}) ->
+    case Mech of
+       <<"EXTERNAL">> ->
+           jid:encode(jid:make(User, Server));
+       <<"PLAIN">> ->
+           JID = jid:encode(jid:make(User, Server)),
+           <<JID/binary, 0, User/binary, 0, Password/binary>>
+    end.
+
+%%%===================================================================
+%%% State resets
+%%%===================================================================
+-spec reset_sasl_state(state()) -> state().
+reset_sasl_state(State) ->
+    State1 = maps:remove(sasl_mech, State),
+    State2 = maps:remove(sasl_failure, State1),
+    State3 = maps:remove(sasl_mechs_provided, State2),
+    maps:remove(sasl_mechs_available, State3).
+
+-spec reset_connection_state(state()) -> state().
+reset_connection_state(State) ->
+    State1 = maps:remove(ip, State),
+    State2 = maps:remove(socket, State1),
+    maps:remove(socket_monitor, State2).
+
+-spec reset_stream_state(state()) -> state().
+reset_stream_state(State) ->
+    State1 = State#{stream_id => new_id(),
+                   stream_encrypted => false,
+                   stream_verified => false,
+                   stream_authenticated => false,
+                   stream_restarted => false,
+                   stream_state => connecting},
+    maps:remove(stream_remote_id, State1).
+
+-spec reset_bind_state(state()) -> state().
+reset_bind_state(State) ->
+    maps:remove(bind_id, State).
+
+-spec reset_state(state()) -> state().
+reset_state(State) ->
+    State1 = reset_bind_state(State),
+    State2 = reset_sasl_state(State1),
+    State3 = reset_connection_state(State2),
+    reset_stream_state(State3).
+
 %%%===================================================================
 %%% Connection stuff
 %%%===================================================================