]> granicus.if.org Git - ejabberd/commitdiff
Mix _xmpp-server and _xmpps-server SRV records
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Wed, 25 Oct 2017 08:39:14 +0000 (11:39 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Wed, 25 Oct 2017 08:39:20 +0000 (11:39 +0300)
XEP-0368 describes this procedure as following:
> Both 'xmpp-' and 'xmpps-' records SHOULD be treated as
> the same record with regard to connection order as specified
> by RFC 2782 [3], in that all priorities and weights are mixed.
> This enables the server operator to decide if they would
> rather clients connect with STARTTLS or direct TLS.

src/xmpp_stream_out.erl

index 718db1eac45ef8d82b4c70fa3efc7da153076a73..dfcebf96d6b87f31fa5a41cfea34a2dbbe153c96 100644 (file)
@@ -49,8 +49,9 @@
 
 -type state() :: map().
 -type noreply() :: {noreply, state(), timeout()}.
--type host_port() :: {inet:hostname(), inet:port_number()}.
--type ip_port() :: {inet:ip_address(), inet:port_number()}.
+-type host_port() :: {inet:hostname(), inet:port_number(), boolean()}.
+-type ip_port() :: {inet:ip_address(), inet:port_number(), boolean()}.
+-type h_addr_list() :: {{integer(), integer(), inet:port_number(), string()}, boolean()}.
 -type network_error() :: {error, inet:posix() | inet_res:res_error()}.
 -type tls_error_reason() :: inet:posix() | atom() | binary().
 -type socket_error_reason() :: inet:posix() | atom().
@@ -279,11 +280,11 @@ handle_cast(connect, #{remote_server := RemoteServer,
              process_stream_end({idna, bad_string}, State);
          ASCIIName ->
              case resolve(binary_to_list(ASCIIName), State) of
-                 {{ok, AddrPorts}, Encrypted} ->
-                     case connect(AddrPorts, State, Encrypted) of
-                         {ok, Socket, AddrPort} ->
+                 {ok, AddrPorts} ->
+                     case connect(AddrPorts, State) of
+                         {ok, Socket, {Addr, Port, Encrypted}} ->
                              SocketMonitor = SockMod:monitor(Socket),
-                             State1 = State#{ip => AddrPort,
+                             State1 = State#{ip => {Addr, Port},
                                              socket => Socket,
                                              stream_encrypted => Encrypted,
                                              socket_monitor => SocketMonitor},
@@ -292,7 +293,7 @@ handle_cast(connect, #{remote_server := RemoteServer,
                          {error, {Class, Why}} ->
                              process_stream_end({Class, Why}, State)
                      end;
-                 {{error, Why}, _} ->
+                 {error, Why} ->
                      process_stream_end({dns, Why}, State)
              end
       end);
@@ -855,17 +856,17 @@ idna_to_ascii(Host) ->
        {error, _} -> ejabberd_idna:domain_utf8_to_ascii(Host)
     end.
 
--spec resolve(string(), state()) -> {{ok, [ip_port()]} | network_error(), boolean()}.
+-spec resolve(string(), state()) -> {ok, [ip_port()]} | network_error().
 resolve(Host, State) ->
     case srv_lookup(Host, State) of
-       {{error, _Reason}, _} ->
+       {error, _Reason} ->
            DefaultPort = get_default_port(State),
-           {a_lookup([{Host, DefaultPort}], State), false};
-       {{ok, HostPorts}, TLS} ->
-           {a_lookup(HostPorts, State), TLS}
+           a_lookup([{Host, DefaultPort, false}], State);
+       {ok, HostPorts} ->
+           a_lookup(HostPorts, State)
     end.
 
--spec srv_lookup(string(), state()) -> {{ok, [host_port()]} | network_error(), boolean()}.
+-spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error().
 srv_lookup(_Host, #{xmlns := ?NS_COMPONENT}) ->
     %% Do not attempt to lookup SRV for component connections
     {{error, nxdomain}, false};
@@ -873,61 +874,74 @@ srv_lookup(Host, State) ->
     %% Only perform SRV lookups for FQDN names
     case string:chr(Host, $.) of
        0 ->
-           {{error, nxdomain}, false};
+           {error, nxdomain};
        _ ->
            case inet:parse_address(Host) of
                {ok, _} ->
-                   {{error, nxdomain}, false};
+                   {error, nxdomain};
                {error, _} ->
                    Timeout = get_dns_timeout(State),
                    Retries = get_dns_retries(State),
-                   case is_starttls_available(State) of
-                       true ->
-                           case srv_lookup("_xmpps-server._tcp." ++ Host,
-                                           Timeout, Retries) of
-                               {error, _} ->
-                                   {srv_lookup("_xmpp-server._tcp." ++ Host,
-                                               Timeout, Retries),
-                                    false};
-                               {ok, Res} ->
-                                   {{ok, Res}, true}
-                           end;
-                       false ->
-                           {srv_lookup("_xmpp-server._tcp." ++ Host,
-                                       Timeout, Retries),
-                            false}
+                   case srv_lookup(Host, State, Timeout, Retries) of
+                       {ok, AddrList} ->
+                           h_addr_list_to_host_ports(AddrList);
+                       {error, _} = Err ->
+                           {Err, false}
                    end
            end
     end.
 
+srv_lookup(Host, State, Timeout, Retries) ->
+    TLSAddrs = case is_starttls_available(State) of
+                  true ->
+                      case srv_lookup("_xmpps-server._tcp." ++ Host,
+                                      Timeout, Retries) of
+                          {ok, HostEnt} ->
+                              [{A, true} || A <- HostEnt#hostent.h_addr_list];
+                          {error, _} ->
+                              []
+                      end;
+                  false ->
+                      []
+              end,
+    case srv_lookup("_xmpp-server._tcp." ++ Host, Timeout, Retries) of
+       {ok, HostEntry} ->
+           Addrs = [{A, false} || A <- HostEntry#hostent.h_addr_list],
+           {ok, TLSAddrs ++ Addrs};
+       {error, _} when TLSAddrs /= [] ->
+           {ok, TLSAddrs};
+       {error, _} = Err ->
+           Err
+    end.
+
 -spec srv_lookup(string(), timeout(), integer()) ->
-                       {ok, [host_port()]} | network_error().
+                       {ok, inet:hostent()} | network_error().
 srv_lookup(_SRVName, _Timeout, Retries) when Retries < 1 ->
     {error, timeout};
 srv_lookup(SRVName, Timeout, Retries) ->
     case inet_res:getbyname(SRVName, srv, Timeout) of
        {ok, HostEntry} ->
-           host_entry_to_host_ports(HostEntry);
+           {ok, HostEntry};
        {error, timeout} ->
            srv_lookup(SRVName, Timeout, Retries - 1);
        {error, _} = Err ->
            Err
     end.
 
--spec a_lookup([{inet:hostname(), inet:port_number()}], state()) ->
+-spec a_lookup([host_port()], state()) ->
                      {ok, [ip_port()]} | network_error().
 a_lookup(HostPorts, State) ->
-    HostPortFamilies = [{Host, Port, Family}
-                       || {Host, Port} <- HostPorts,
+    HostPortFamilies = [{Host, Port, TLS, Family}
+                       || {Host, Port, TLS} <- HostPorts,
                           Family <- get_address_families(State)],
     a_lookup(HostPortFamilies, State, [], {error, nxdomain}).
 
--spec a_lookup([{inet:hostname(), inet:port_number(), inet:address_family()}],
+-spec a_lookup([{inet:hostname(), inet:port_number(), boolean(), inet:address_family()}],
               state(), [ip_port()], network_error()) -> {ok, [ip_port()]} | network_error().
-a_lookup([{Host, Port, Family}|HostPortFamilies], State, Acc, Err) ->
+a_lookup([{Host, Port, TLS, Family}|HostPortFamilies], State, Acc, Err) ->
     Timeout = get_dns_timeout(State),
     Retries = get_dns_retries(State),
-    case a_lookup(Host, Port, Family, Timeout, Retries) of
+    case a_lookup(Host, Port, TLS, Family, Timeout, Retries) of
        {error, Reason} ->
            a_lookup(HostPortFamilies, State, Acc, {error, Reason});
        {ok, AddrPorts} ->
@@ -938,11 +952,11 @@ a_lookup([], _State, [], Err) ->
 a_lookup([], _State, Acc, _) ->
     {ok, Acc}.
 
--spec a_lookup(inet:hostname(), inet:port_number(), inet:address_family(),
+-spec a_lookup(inet:hostname(), inet:port_number(), boolean(), inet:address_family(),
               timeout(), integer()) -> {ok, [ip_port()]} | network_error().
-a_lookup(_Host, _Port, _Family, _Timeout, Retries) when Retries < 1 ->
+a_lookup(_Host, _Port, _TLS, _Family, _Timeout, Retries) when Retries < 1 ->
     {error, timeout};
-a_lookup(Host, Port, Family, Timeout, Retries) ->
+a_lookup(Host, Port, TLS, Family, Timeout, Retries) ->
     Start = p1_time_compat:monotonic_time(milli_seconds),
     case inet:gethostbyname(Host, Family, Timeout) of
        {error, nxdomain} = Err ->
@@ -953,43 +967,43 @@ a_lookup(Host, Port, Family, Timeout, Retries) ->
            %% it ignores DNS configuration settings (/etc/hosts, etc)
            End = p1_time_compat:monotonic_time(milli_seconds),
            if (End - Start) >= Timeout ->
-                   a_lookup(Host, Port, Family, Timeout, Retries - 1);
+                   a_lookup(Host, Port, TLS, Family, Timeout, Retries - 1);
               true ->
                    Err
            end;
        {error, _} = Err ->
            Err;
        {ok, HostEntry} ->
-           host_entry_to_addr_ports(HostEntry, Port)
+           host_entry_to_addr_ports(HostEntry, Port, TLS)
     end.
 
--spec host_entry_to_host_ports(inet:hostent()) -> {ok, [host_port()]} |
+-spec h_addr_list_to_host_ports(h_addr_list()) -> {ok, [host_port()]} |
                                                  {error, nxdomain}.
-host_entry_to_host_ports(#hostent{h_addr_list = AddrList}) ->
+h_addr_list_to_host_ports(AddrList) ->
     PrioHostPorts = lists:flatmap(
-                     fun({Priority, Weight, Port, Host}) ->
+                     fun({{Priority, Weight, Port, Host}, TLS}) ->
                              N = case Weight of
                                      0 -> 0;
                                      _ -> (Weight + 1) * randoms:uniform()
                                  end,
-                             [{Priority * 65536 - N, Host, Port}];
+                             [{Priority * 65536 - N, Host, Port, TLS}];
                         (_) ->
                              []
                      end, AddrList),
-    HostPorts = [{Host, Port}
-                || {_Priority, Host, Port} <- lists:usort(PrioHostPorts)],
+    HostPorts = [{Host, Port, TLS}
+                || {_Priority, Host, Port, TLS} <- lists:usort(PrioHostPorts)],
     case HostPorts of
        [] -> {error, nxdomain};
        _ -> {ok, HostPorts}
     end.
 
--spec host_entry_to_addr_ports(inet:hostent(), inet:port_number()) ->
+-spec host_entry_to_addr_ports(inet:hostent(), inet:port_number(), boolean()) ->
                                      {ok, [ip_port()]} | {error, nxdomain}.
-host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port) ->
+host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port, TLS) ->
     AddrPorts = lists:flatmap(
                  fun(Addr) ->
                          try get_addr_type(Addr) of
-                             _ -> [{Addr, Port}]
+                             _ -> [{Addr, Port, TLS}]
                          catch _:_ ->
                                  []
                          end
@@ -999,26 +1013,26 @@ host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port) ->
        _ -> {ok, AddrPorts}
     end.
 
--spec connect([ip_port()], state(), boolean()) -> {ok, term(), ip_port()} |
-                                                 {error, {socket, socket_error_reason()}} |
-                                                 {error, {tls, tls_error_reason()}}.
-connect(AddrPorts, #{sockmod := SockMod} = State, TLS) ->
+-spec connect([ip_port()], state()) -> {ok, term(), ip_port()} |
+                                      {error, {socket, socket_error_reason()}} |
+                                      {error, {tls, tls_error_reason()}}.
+connect(AddrPorts, #{sockmod := SockMod} = State) ->
     Timeout = get_connect_timeout(State),
     case connect(AddrPorts, SockMod, Timeout, {error, nxdomain}) of
-       {ok, Socket, AddrPort} when TLS ->
+       {ok, Socket, {Addr, Port, TLS = true}} ->
            case starttls(Socket, State) of
-               {ok, TLSSocket} -> {ok, TLSSocket, AddrPort};
+               {ok, TLSSocket} -> {ok, TLSSocket, {Addr, Port, TLS}};
                {error, Why} -> {error, {tls, Why}}
            end;
-       {ok, _Socket, _AddrPort} = OK ->
-           OK;
+       {ok, Socket, {Addr, Port, TLS = false}} ->
+           {ok, Socket, {Addr, Port, TLS}};
        {error, Why} ->
            {error, {socket, Why}}
     end.
 
 -spec connect([ip_port()], module(), timeout(), network_error()) ->
                     {ok, term(), ip_port()} | network_error().
-connect([{Addr, Port}|AddrPorts], SockMod, Timeout, _) ->
+connect([{Addr, Port, TLS}|AddrPorts], SockMod, Timeout, _) ->
     Type = get_addr_type(Addr),
     try SockMod:connect(Addr, Port,
                        [binary, {packet, 0},
@@ -1027,7 +1041,7 @@ connect([{Addr, Port}|AddrPorts], SockMod, Timeout, _) ->
                         {active, false}, Type],
                        Timeout) of
        {ok, Socket} ->
-           {ok, Socket, {Addr, Port}};
+           {ok, Socket, {Addr, Port, TLS}};
        Err ->
            connect(AddrPorts, SockMod, Timeout, Err)
     catch _:badarg ->