]> granicus.if.org Git - ejabberd/commitdiff
Fix usage of TLS in mod_mqtt
authorEvgeny Khramtsov <ekhramtsov@process-one.net>
Fri, 21 Jun 2019 11:39:46 +0000 (14:39 +0300)
committerEvgeny Khramtsov <ekhramtsov@process-one.net>
Fri, 21 Jun 2019 11:39:46 +0000 (14:39 +0300)
Fixes #2919

src/mod_mqtt_session.erl

index c0777d266d145347711bcd246afb8393a45da488..db9a632cae78affd8aecc37a2dc4292d7d1d32c5 100644 (file)
@@ -17,7 +17,7 @@
 %%%-------------------------------------------------------------------
 -module(mod_mqtt_session).
 -behaviour(p1_server).
--define(VSN, 1).
+-define(VSN, 2).
 -vsn(?VSN).
 
 %% API
@@ -46,7 +46,8 @@
                id = 0                :: non_neg_integer(),
                in_flight             :: undefined | publish() | pubrel(),
                codec                 :: mqtt_codec:state(),
-               queue                 :: undefined | p1_queue:queue()}).
+               queue                 :: undefined | p1_queue:queue(),
+               tls                   :: boolean()}).
 
 -type error_reason() :: {auth, reason_code()} |
                         {code, reason_code()} |
@@ -64,8 +65,9 @@
                         session_expiry_non_zero | unknown_topic_alias.
 
 -type state() :: #state{}.
--type sockmod() :: gen_tcp | fast_tls | mod_mqtt_ws.
--type socket() :: {sockmod(), inet:socket() | fast_tls:tls_socket() | mod_mqtt_ws:socket()}.
+-type socket() :: {gen_tcp, inet:socket()} |
+                 {fast_tls, fast_tls:tls_socket()} |
+                 {mod_mqtt_ws, mod_mqtt_ws:socket()}.
 -type peername() :: {inet:ip_address(), inet:port_number()}.
 -type seconds() :: non_neg_integer().
 -type milli_seconds() :: non_neg_integer().
@@ -153,13 +155,9 @@ format_error(Reason) ->
 %%%===================================================================
 init([SockMod, Socket, ListenOpts]) ->
     MaxSize = proplists:get_value(max_payload_size, ListenOpts, infinity),
-    SockMod1 = case {SockMod, proplists:get_bool(tls, ListenOpts)} of
-                  {gen_tcp, true} -> fast_tls;
-                  {gen_tcp, false} -> gen_tcp;
-                  {_, _} -> SockMod
-              end,
-    State1 = #state{socket = {SockMod1, Socket},
+    State1 = #state{socket = {SockMod, Socket},
                    id = p1_rand:uniform(65535),
+                   tls = proplists:get_bool(tls, ListenOpts),
                    codec = mqtt_codec:new(MaxSize)},
     Timeout = timer:seconds(30),
     State2 = set_timeout(State1, Timeout),
@@ -191,11 +189,11 @@ handle_call(Request, From, State) ->
     ?WARNING_MSG("Got unexpected call from ~p: ~p", [From, Request]),
     noreply(State).
 
-handle_cast(accept, #state{socket = {_, Sock} = Socket} = State) ->
+handle_cast(accept, #state{socket = {_, Sock}} = State) ->
     case peername(State) of
        {ok, IPPort} ->
            State1 = State#state{peername = IPPort},
-           case starttls(Socket) of
+           case starttls(State) of
                {ok, Socket1} ->
                    State2 = State1#state{socket = Socket1},
                    handle_info({tcp, Sock, <<>>}, State2);
@@ -416,13 +414,27 @@ stop(#state{session_expiry = SessExp} = State, Reason) ->
            noreply(State4)
     end.
 
--spec upgrade_state(term()) -> state().
+%% Here is the code upgrading state between different
+%% code versions. This is needed when doing session resumption from
+%% remote node running the version of the code with incompatible #state{}
+%% record fields. Also used by code_change/3 callback.
+-spec upgrade_state(tuple()) -> state().
 upgrade_state(State) ->
-    %% Here will be the code upgrading state between different
-    %% code versions. This is needed when doing session resumption from
-    %% remote node running the version of the code with incompatible #state{}
-    %% record fields. Also used by code_change/3 callback.
-    %% Use element(2, State) for vsn comparison.
+    case element(2, State) of
+       ?VSN ->
+           State;
+       VSN when VSN > ?VSN ->
+           erlang:error({downgrade_not_supported, State});
+       VSN ->
+           State1 = upgrade_state(State, VSN),
+           upgrade_state(setelement(2, State1, VSN+1))
+    end.
+
+-spec upgrade_state(tuple(), 1..?VSN) -> tuple().
+upgrade_state(OldState, 1) ->
+    %% Appending 'tls' field
+    erlang:append_element(OldState, false);
+upgrade_state(State, _VSN) ->
     State.
 
 %%%===================================================================
@@ -915,8 +927,8 @@ check_sock_result({_, Sock}, {error, Why}) ->
     self() ! {tcp_closed, Sock},
     ?DEBUG("MQTT socket error: ~p", [format_inet_error(Why)]).
 
--spec starttls(socket()) -> {ok, socket()} | {error, error_reason()}.
-starttls({fast_tls, Socket}) ->
+-spec starttls(state()) -> {ok, socket()} | {error, error_reason()}.
+starttls(#state{socket = {gen_tcp, Socket}, tls = true}) ->
     case ejabberd_pkix:get_certfile() of
        {ok, Cert} ->
            case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of
@@ -928,7 +940,7 @@ starttls({fast_tls, Socket}) ->
        error ->
            {error, {tls, no_certfile}}
     end;
-starttls(Socket) ->
+starttls(#state{socket = Socket}) ->
     {ok, Socket}.
 
 -spec recv_data(socket(), binary()) -> {ok, binary()} | {error, error_reason()}.
@@ -961,8 +973,8 @@ format_inet_error(Reason) ->
     end.
 
 -spec format_tls_error(atom() | binary()) -> string() | binary().
-format_tls_error(no_cerfile) ->
-    "certificate not found";
+format_tls_error(no_certfile) ->
+    "certificate not configured";
 format_tls_error(Reason) when is_atom(Reason) ->
     format_inet_error(Reason);
 format_tls_error(Reason) ->