]> granicus.if.org Git - ejabberd/commitdiff
Support certificate verification for outgoing s2s
authorHolger Weiss <holger@zedat.fu-berlin.de>
Sun, 27 Apr 2014 23:42:02 +0000 (01:42 +0200)
committerHolger Weiss <holger@zedat.fu-berlin.de>
Sun, 27 Apr 2014 23:42:02 +0000 (01:42 +0200)
Handle "s2s_use_starttls: required_trusted" the same way for outgoing
s2s connections as for incoming connections.  That is, check the remote
server's certificate (including the host name) and abort the connection
if verification fails.

src/ejabberd_s2s.erl
src/ejabberd_s2s_in.erl
src/ejabberd_s2s_out.erl

index 057c60a98751b9386347ce17ae8a1242fa7b3cdf..eb989435024cb3f8be4ccea103a2be4af63c0cad 100644 (file)
@@ -37,7 +37,8 @@
         incoming_s2s_number/0, outgoing_s2s_number/0,
         clean_temporarily_blocked_table/0,
         list_temporarily_blocked_hosts/0,
-        external_host_overloaded/1, is_temporarly_blocked/1]).
+        external_host_overloaded/1, is_temporarly_blocked/1,
+        check_peer_certificate/3]).
 
 %% gen_server callbacks
 -export([init/1, handle_call/3, handle_cast/2,
 
 -include("ejabberd_commands.hrl").
 
+-include_lib("public_key/include/public_key.hrl").
+
+-define(PKIXEXPLICIT, 'OTP-PUB-KEY').
+
+-define(PKIXIMPLICIT, 'OTP-PUB-KEY').
+
+-include("XmppAddr.hrl").
+
 -define(DEFAULT_MAX_S2S_CONNECTIONS_NUMBER, 1).
 
 -define(DEFAULT_MAX_S2S_CONNECTIONS_NUMBER_PER_NODE, 1).
@@ -207,6 +216,31 @@ try_register(FromTo) ->
 dirty_get_connections() ->
     mnesia:dirty_all_keys(s2s).
 
+check_peer_certificate(SockMod, Sock, Peer) ->
+    case SockMod:get_peer_certificate(Sock) of
+      {ok, Cert} ->
+         case SockMod:get_verify_result(Sock) of
+           0 ->
+               case idna:domain_utf8_to_ascii(Peer) of
+                 false ->
+                     {error, <<"Cannot decode remote server name">>};
+                 AsciiPeer ->
+                     case
+                       lists:any(fun(D) -> match_domain(AsciiPeer, D) end,
+                                 get_cert_domains(Cert)) of
+                       true ->
+                           {ok, <<"Verification successful">>};
+                       false ->
+                           {error, <<"Certificate host name mismatch">>}
+                     end
+               end;
+           VerifyRes ->
+               {error, p1_tls:get_cert_verify_string(VerifyRes, Cert)}
+         end;
+      error ->
+         {error, <<"Cannot get peer certificate">>}
+    end.
+
 %%====================================================================
 %% gen_server callbacks
 %%====================================================================
@@ -619,3 +653,121 @@ get_s2s_state(S2sPid) ->
              {badrpc, _} -> [{status, error}]
            end,
     [{s2s_pid, S2sPid} | Infos].
+
+get_cert_domains(Cert) ->
+    {rdnSequence, Subject} =
+       (Cert#'Certificate'.tbsCertificate)#'TBSCertificate'.subject,
+    Extensions =
+       (Cert#'Certificate'.tbsCertificate)#'TBSCertificate'.extensions,
+    lists:flatmap(fun (#'AttributeTypeAndValue'{type =
+                                                   ?'id-at-commonName',
+                                               value = Val}) ->
+                         case 'OTP-PUB-KEY':decode('X520CommonName', Val) of
+                           {ok, {_, D1}} ->
+                               D = if is_binary(D1) -> D1;
+                                      is_list(D1) -> list_to_binary(D1);
+                                      true -> error
+                                   end,
+                               if D /= error ->
+                                      case jlib:string_to_jid(D) of
+                                        #jid{luser = <<"">>, lserver = LD,
+                                             lresource = <<"">>} ->
+                                            [LD];
+                                        _ -> []
+                                      end;
+                                  true -> []
+                               end;
+                           _ -> []
+                         end;
+                     (_) -> []
+                 end,
+                 lists:flatten(Subject))
+      ++
+      lists:flatmap(fun (#'Extension'{extnID =
+                                         ?'id-ce-subjectAltName',
+                                     extnValue = Val}) ->
+                           BVal = if is_list(Val) -> list_to_binary(Val);
+                                     true -> Val
+                                  end,
+                           case 'OTP-PUB-KEY':decode('SubjectAltName', BVal)
+                               of
+                             {ok, SANs} ->
+                                 lists:flatmap(fun ({otherName,
+                                                     #'AnotherName'{'type-id' =
+                                                                        ?'id-on-xmppAddr',
+                                                                    value =
+                                                                        XmppAddr}}) ->
+                                                       case
+                                                         'XmppAddr':decode('XmppAddr',
+                                                                           XmppAddr)
+                                                           of
+                                                         {ok, D}
+                                                             when
+                                                               is_binary(D) ->
+                                                             case
+                                                               jlib:string_to_jid((D))
+                                                                 of
+                                                               #jid{luser =
+                                                                        <<"">>,
+                                                                    lserver =
+                                                                        LD,
+                                                                    lresource =
+                                                                        <<"">>} ->
+                                                                   case
+                                                                     idna:domain_utf8_to_ascii(LD)
+                                                                       of
+                                                                     false ->
+                                                                         [];
+                                                                     PCLD ->
+                                                                         [PCLD]
+                                                                   end;
+                                                               _ -> []
+                                                             end;
+                                                         _ -> []
+                                                       end;
+                                                   ({dNSName, D})
+                                                       when is_list(D) ->
+                                                       case
+                                                         jlib:string_to_jid(list_to_binary(D))
+                                                           of
+                                                         #jid{luser = <<"">>,
+                                                              lserver = LD,
+                                                              lresource =
+                                                                  <<"">>} ->
+                                                             [LD];
+                                                         _ -> []
+                                                       end;
+                                                   (_) -> []
+                                               end,
+                                               SANs);
+                             _ -> []
+                           end;
+                       (_) -> []
+                   end,
+                   Extensions).
+
+match_domain(Domain, Domain) -> true;
+match_domain(Domain, Pattern) ->
+    DLabels = str:tokens(Domain, <<".">>),
+    PLabels = str:tokens(Pattern, <<".">>),
+    match_labels(DLabels, PLabels).
+
+match_labels([], []) -> true;
+match_labels([], [_ | _]) -> false;
+match_labels([_ | _], []) -> false;
+match_labels([DL | DLabels], [PL | PLabels]) ->
+    case lists:all(fun (C) ->
+                          $a =< C andalso C =< $z orelse
+                            $0 =< C andalso C =< $9 orelse
+                              C == $- orelse C == $*
+                  end,
+                  binary_to_list(PL))
+       of
+      true ->
+         Regexp = ejabberd_regexp:sh_to_awk(PL),
+         case ejabberd_regexp:run(DL, Regexp) of
+           match -> match_labels(DLabels, PLabels);
+           nomatch -> false
+         end;
+      false -> false
+    end.
index 3eb0b71ccd1ae0c9d5d6badc604463feebe2ee07..c490704d8673ee7ad06f57d97c8ed614a1079157 100644 (file)
@@ -30,8 +30,7 @@
 -behaviour(p1_fsm).
 
 %% External exports
--export([start/2, start_link/2, match_domain/2,
-        socket_type/0]).
+-export([start/2, start_link/2, socket_type/0]).
 
 %% gen_fsm callbacks
 -export([init/1, wait_for_stream/2,
 
 -include("jlib.hrl").
 
--include_lib("public_key/include/public_key.hrl").
-
--define(PKIXEXPLICIT, 'OTP-PUB-KEY').
-
--define(PKIXIMPLICIT, 'OTP-PUB-KEY').
-
--include("XmppAddr.hrl").
-
 -define(DICT, dict).
 
 -record(state,
@@ -227,45 +218,11 @@ wait_for_stream({xmlstreamstart, _Name, Attrs},
          Auth = if StateData#state.tls_enabled ->
                        case jlib:nameprep(xml:get_attr_s(<<"from">>, Attrs)) of
                          From when From /= <<"">>, From /= error ->
-                             case
-                               (StateData#state.sockmod):get_peer_certificate(StateData#state.socket)
-                                 of
-                               {ok, Cert} ->
-                                   case
-                                     (StateData#state.sockmod):get_verify_result(StateData#state.socket)
-                                       of
-                                     0 ->
-                                         case
-                                           idna:domain_utf8_to_ascii(From)
-                                             of
-                                           false ->
-                                               {error, From,
-                                                <<"Cannot decode 'from' attribute">>};
-                                           PCAuthDomain ->
-                                               case
-                                                 lists:any(fun (D) ->
-                                                                   match_domain(PCAuthDomain,
-                                                                                D)
-                                                           end,
-                                                           get_cert_domains(Cert))
-                                                   of
-                                                 true ->
-                                                     {ok, From,
-                                                      <<"Success">>};
-                                                 false ->
-                                                     {error, From,
-                                                      <<"Certificate host name mismatch">>}
-                                               end
-                                         end;
-                                     CertVerifyRes ->
-                                         {error, From,
-                                          p1_tls:get_cert_verify_string(CertVerifyRes,
-                                                                        Cert)}
-                                   end;
-                               error ->
-                                   {error, From,
-                                    <<"Cannot get peer certificate">>}
-                             end;
+                             {Result, Message} =
+                                 ejabberd_s2s:check_peer_certificate(StateData#state.sockmod,
+                                                                     StateData#state.socket,
+                                                                     From),
+                             {Result, From, Message};
                          _ ->
                              {error, <<"(unknown)">>,
                               <<"Got no valid 'from' attribute">>}
@@ -746,124 +703,6 @@ is_key_packet(#xmlel{name = Name, attrs = Attrs,
      xml:get_attr_s(<<"id">>, Attrs), xml:get_cdata(Els)};
 is_key_packet(_) -> false.
 
-get_cert_domains(Cert) ->
-    {rdnSequence, Subject} =
-       (Cert#'Certificate'.tbsCertificate)#'TBSCertificate'.subject,
-    Extensions =
-       (Cert#'Certificate'.tbsCertificate)#'TBSCertificate'.extensions,
-    lists:flatmap(fun (#'AttributeTypeAndValue'{type =
-                                                   ?'id-at-commonName',
-                                               value = Val}) ->
-                         case 'OTP-PUB-KEY':decode('X520CommonName', Val) of
-                           {ok, {_, D1}} ->
-                               D = if is_binary(D1) -> D1;
-                                      is_list(D1) -> list_to_binary(D1);
-                                      true -> error
-                                   end,
-                               if D /= error ->
-                                      case jlib:string_to_jid(D) of
-                                        #jid{luser = <<"">>, lserver = LD,
-                                             lresource = <<"">>} ->
-                                            [LD];
-                                        _ -> []
-                                      end;
-                                  true -> []
-                               end;
-                           _ -> []
-                         end;
-                     (_) -> []
-                 end,
-                 lists:flatten(Subject))
-      ++
-      lists:flatmap(fun (#'Extension'{extnID =
-                                         ?'id-ce-subjectAltName',
-                                     extnValue = Val}) ->
-                           BVal = if is_list(Val) -> list_to_binary(Val);
-                                     true -> Val
-                                  end,
-                           case 'OTP-PUB-KEY':decode('SubjectAltName', BVal)
-                               of
-                             {ok, SANs} ->
-                                 lists:flatmap(fun ({otherName,
-                                                     #'AnotherName'{'type-id' =
-                                                                        ?'id-on-xmppAddr',
-                                                                    value =
-                                                                        XmppAddr}}) ->
-                                                       case
-                                                         'XmppAddr':decode('XmppAddr',
-                                                                           XmppAddr)
-                                                           of
-                                                         {ok, D}
-                                                             when
-                                                               is_binary(D) ->
-                                                             case
-                                                               jlib:string_to_jid((D))
-                                                                 of
-                                                               #jid{luser =
-                                                                        <<"">>,
-                                                                    lserver =
-                                                                        LD,
-                                                                    lresource =
-                                                                        <<"">>} ->
-                                                                   case
-                                                                     idna:domain_utf8_to_ascii(LD)
-                                                                       of
-                                                                     false ->
-                                                                         [];
-                                                                     PCLD ->
-                                                                         [PCLD]
-                                                                   end;
-                                                               _ -> []
-                                                             end;
-                                                         _ -> []
-                                                       end;
-                                                   ({dNSName, D})
-                                                       when is_list(D) ->
-                                                       case
-                                                         jlib:string_to_jid(list_to_binary(D))
-                                                           of
-                                                         #jid{luser = <<"">>,
-                                                              lserver = LD,
-                                                              lresource =
-                                                                  <<"">>} ->
-                                                             [LD];
-                                                         _ -> []
-                                                       end;
-                                                   (_) -> []
-                                               end,
-                                               SANs);
-                             _ -> []
-                           end;
-                       (_) -> []
-                   end,
-                   Extensions).
-
-match_domain(Domain, Domain) -> true;
-match_domain(Domain, Pattern) ->
-    DLabels = str:tokens(Domain, <<".">>),
-    PLabels = str:tokens(Pattern, <<".">>),
-    match_labels(DLabels, PLabels).
-
-match_labels([], []) -> true;
-match_labels([], [_ | _]) -> false;
-match_labels([_ | _], []) -> false;
-match_labels([DL | DLabels], [PL | PLabels]) ->
-    case lists:all(fun (C) ->
-                          $a =< C andalso C =< $z orelse
-                            $0 =< C andalso C =< $9 orelse
-                              C == $- orelse C == $*
-                  end,
-                  binary_to_list(PL))
-       of
-      true ->
-         Regexp = ejabberd_regexp:sh_to_awk(PL),
-         case ejabberd_regexp:run(DL, Regexp) of
-           match -> match_labels(DLabels, PLabels);
-           nomatch -> false
-         end;
-      false -> false
-    end.
-
 fsm_limit_opts(Opts) ->
     case lists:keysearch(max_fsm_queue, 1, Opts) of
       {value, {_, N}} when is_integer(N) -> [{max_queue, N}];
index a0a83631d1674bdf5531977cacddd73cfb562627..9977fcd7ebcf28826ebc33cdeaa10a90e8add555 100644 (file)
@@ -69,6 +69,7 @@
          use_v10 = true                   :: boolean(),
          tls = false                      :: boolean(),
         tls_required = false             :: boolean(),
+        tls_certverify = false           :: boolean(),
          tls_enabled = false              :: boolean(),
         tls_options = [connect]          :: list(),
          authenticated = false            :: boolean(),
@@ -160,28 +161,27 @@ stop_connection(Pid) -> p1_fsm:send_event(Pid, closed).
 init([From, Server, Type]) ->
     process_flag(trap_exit, true),
     ?DEBUG("started: ~p", [{From, Server, Type}]),
-    {TLS, TLSRequired} = case
-                          ejabberd_config:get_option(
-                             s2s_use_starttls,
-                             fun(true) -> true;
-                                (false) -> false;
-                                (optional) -> optional;
-                                (required) -> required;
-                                (required_trusted) -> required_trusted
-                             end)
-                            of
-                          UseTls
-                              when (UseTls == undefined) or
-                                     (UseTls == false) ->
-                              {false, false};
-                          UseTls
-                              when (UseTls == true) or (UseTls == optional) ->
-                              {true, false};
-                          UseTls
-                              when (UseTls == required) or
-                                     (UseTls == required_trusted) ->
-                              {true, true}
-                        end,
+    {TLS, TLSRequired, TLSCertverify} =
+       case ejabberd_config:get_option(
+              s2s_use_starttls,
+              fun(true) -> true;
+                 (false) -> false;
+                 (optional) -> optional;
+                 (required) -> required;
+                 (required_trusted) -> required_trusted
+              end)
+           of
+         UseTls
+             when (UseTls == undefined) or (UseTls == false) ->
+             {false, false, false};
+         UseTls
+             when (UseTls == true) or (UseTls == optional) ->
+             {true, false, false};
+         required ->
+             {true, true, false};
+         required_trusted ->
+             {true, true, true}
+       end,
     UseV10 = TLS,
     TLSOpts1 = case
                ejabberd_config:get_option(
@@ -223,9 +223,9 @@ init([From, Server, Type]) ->
     Timer = erlang:start_timer(?S2STIMEOUT, self(), []),
     {ok, open_socket,
      #state{use_v10 = UseV10, tls = TLS,
-           tls_required = TLSRequired, tls_options = TLSOpts,
-           queue = queue:new(), myname = From, server = Server,
-           new = New, verify = Verify, timer = Timer}}.
+           tls_required = TLSRequired, tls_certverify = TLSCertverify,
+           tls_options = TLSOpts, queue = queue:new(), myname = From,
+           server = Server, new = New, verify = Verify, timer = Timer}}.
 
 %%----------------------------------------------------------------------
 %% Func: StateName/2
@@ -345,35 +345,57 @@ open_socket2(Type, Addr, Port) ->
 
 wait_for_stream({xmlstreamstart, _Name, Attrs},
                StateData) ->
+    {CertCheckRes, CertCheckMsg, NewStateData} =
+       if StateData#state.tls_certverify, StateData#state.tls_enabled ->
+              {Res, Msg} =
+                  ejabberd_s2s:check_peer_certificate(ejabberd_socket,
+                                                      StateData#state.socket,
+                                                      StateData#state.server),
+              ?DEBUG("Certificate verification result for ~s: ~s",
+                     [StateData#state.server, Msg]),
+              {Res, Msg, StateData#state{tls_certverify = false}};
+          true ->
+              {no_verify, <<"Not verified">>, StateData}
+       end,
     case {xml:get_attr_s(<<"xmlns">>, Attrs),
          xml:get_attr_s(<<"xmlns:db">>, Attrs),
          xml:get_attr_s(<<"version">>, Attrs) == <<"1.0">>}
        of
+      _ when CertCheckRes == error ->
+         send_text(NewStateData,
+                   <<(xml:element_to_binary(?SERRT_POLICY_VIOLATION(<<"en">>,
+                                                                    CertCheckMsg)))/binary,
+                     (?STREAM_TRAILER)/binary>>),
+         ?INFO_MSG("Closing s2s connection: ~s -> ~s (~s)",
+                   [NewStateData#state.myname,
+                    NewStateData#state.server,
+                    CertCheckMsg]),
+         {stop, normal, NewStateData};
       {<<"jabber:server">>, <<"jabber:server:dialback">>,
        false} ->
-         send_db_request(StateData);
+         send_db_request(NewStateData);
       {<<"jabber:server">>, <<"jabber:server:dialback">>,
        true}
-         when StateData#state.use_v10 ->
-         {next_state, wait_for_features, StateData, ?FSMTIMEOUT};
+         when NewStateData#state.use_v10 ->
+         {next_state, wait_for_features, NewStateData, ?FSMTIMEOUT};
       %% Clause added to handle Tigase's workaround for an old ejabberd bug:
       {<<"jabber:server">>, <<"jabber:server:dialback">>,
        true}
-         when not StateData#state.use_v10 ->
-         send_db_request(StateData);
+         when not NewStateData#state.use_v10 ->
+         send_db_request(NewStateData);
       {<<"jabber:server">>, <<"">>, true}
-         when StateData#state.use_v10 ->
+         when NewStateData#state.use_v10 ->
          {next_state, wait_for_features,
-          StateData#state{db_enabled = false}, ?FSMTIMEOUT};
+          NewStateData#state{db_enabled = false}, ?FSMTIMEOUT};
       {NSProvided, DB, _} ->
-         send_text(StateData, ?INVALID_NAMESPACE_ERR),
+         send_text(NewStateData, ?INVALID_NAMESPACE_ERR),
          ?INFO_MSG("Closing s2s connection: ~s -> ~s (invalid "
                    "namespace).~nNamespace provided: ~p~nNamespac"
                    "e expected: \"jabber:server\"~nxmlns:db "
                    "provided: ~p~nAll attributes: ~p",
-                   [StateData#state.myname, StateData#state.server,
+                   [NewStateData#state.myname, NewStateData#state.server,
                     NSProvided, DB, Attrs]),
-         {stop, normal, StateData}
+         {stop, normal, NewStateData}
     end;
 wait_for_stream({xmlstreamerror, _}, StateData) ->
     send_text(StateData,
@@ -736,8 +758,8 @@ wait_for_starttls_proceed({xmlstreamelement, El},
                                               tls_options = TLSOpts},
                send_text(NewStateData,
                          io_lib:format(?STREAM_HEADER,
-                                       [StateData#state.myname,
-                                        StateData#state.server,
+                                       [NewStateData#state.myname,
+                                        NewStateData#state.server,
                                         <<" version='1.0'">>])),
                {next_state, wait_for_stream, NewStateData,
                 ?FSMTIMEOUT};