%%%-------------------------------------------------------------------
-module(mod_mqtt_session).
-behaviour(p1_server).
--define(VSN, 1).
+-define(VSN, 2).
-vsn(?VSN).
%% API
id = 0 :: non_neg_integer(),
in_flight :: undefined | publish() | pubrel(),
codec :: mqtt_codec:state(),
- queue :: undefined | p1_queue:queue()}).
+ queue :: undefined | p1_queue:queue(),
+ tls :: boolean()}).
-type error_reason() :: {auth, reason_code()} |
{code, reason_code()} |
session_expiry_non_zero | unknown_topic_alias.
-type state() :: #state{}.
--type sockmod() :: gen_tcp | fast_tls | mod_mqtt_ws.
--type socket() :: {sockmod(), inet:socket() | fast_tls:tls_socket() | mod_mqtt_ws:socket()}.
+-type socket() :: {gen_tcp, inet:socket()} |
+ {fast_tls, fast_tls:tls_socket()} |
+ {mod_mqtt_ws, mod_mqtt_ws:socket()}.
-type peername() :: {inet:ip_address(), inet:port_number()}.
-type seconds() :: non_neg_integer().
-type milli_seconds() :: non_neg_integer().
%%%===================================================================
init([SockMod, Socket, ListenOpts]) ->
MaxSize = proplists:get_value(max_payload_size, ListenOpts, infinity),
- SockMod1 = case {SockMod, proplists:get_bool(tls, ListenOpts)} of
- {gen_tcp, true} -> fast_tls;
- {gen_tcp, false} -> gen_tcp;
- {_, _} -> SockMod
- end,
- State1 = #state{socket = {SockMod1, Socket},
+ State1 = #state{socket = {SockMod, Socket},
id = p1_rand:uniform(65535),
+ tls = proplists:get_bool(tls, ListenOpts),
codec = mqtt_codec:new(MaxSize)},
Timeout = timer:seconds(30),
State2 = set_timeout(State1, Timeout),
?WARNING_MSG("Got unexpected call from ~p: ~p", [From, Request]),
noreply(State).
-handle_cast(accept, #state{socket = {_, Sock} = Socket} = State) ->
+handle_cast(accept, #state{socket = {_, Sock}} = State) ->
case peername(State) of
{ok, IPPort} ->
State1 = State#state{peername = IPPort},
- case starttls(Socket) of
+ case starttls(State) of
{ok, Socket1} ->
State2 = State1#state{socket = Socket1},
handle_info({tcp, Sock, <<>>}, State2);
noreply(State4)
end.
--spec upgrade_state(term()) -> state().
+%% Here is the code upgrading state between different
+%% code versions. This is needed when doing session resumption from
+%% remote node running the version of the code with incompatible #state{}
+%% record fields. Also used by code_change/3 callback.
+-spec upgrade_state(tuple()) -> state().
upgrade_state(State) ->
- %% Here will be the code upgrading state between different
- %% code versions. This is needed when doing session resumption from
- %% remote node running the version of the code with incompatible #state{}
- %% record fields. Also used by code_change/3 callback.
- %% Use element(2, State) for vsn comparison.
+ case element(2, State) of
+ ?VSN ->
+ State;
+ VSN when VSN > ?VSN ->
+ erlang:error({downgrade_not_supported, State});
+ VSN ->
+ State1 = upgrade_state(State, VSN),
+ upgrade_state(setelement(2, State1, VSN+1))
+ end.
+
+-spec upgrade_state(tuple(), 1..?VSN) -> tuple().
+upgrade_state(OldState, 1) ->
+ %% Appending 'tls' field
+ erlang:append_element(OldState, false);
+upgrade_state(State, _VSN) ->
State.
%%%===================================================================
self() ! {tcp_closed, Sock},
?DEBUG("MQTT socket error: ~p", [format_inet_error(Why)]).
--spec starttls(socket()) -> {ok, socket()} | {error, error_reason()}.
-starttls({fast_tls, Socket}) ->
+-spec starttls(state()) -> {ok, socket()} | {error, error_reason()}.
+starttls(#state{socket = {gen_tcp, Socket}, tls = true}) ->
case ejabberd_pkix:get_certfile() of
{ok, Cert} ->
case fast_tls:tcp_to_tls(Socket, [{certfile, Cert}]) of
error ->
{error, {tls, no_certfile}}
end;
-starttls(Socket) ->
+starttls(#state{socket = Socket}) ->
{ok, Socket}.
-spec recv_data(socket(), binary()) -> {ok, binary()} | {error, error_reason()}.
end.
-spec format_tls_error(atom() | binary()) -> string() | binary().
-format_tls_error(no_cerfile) ->
- "certificate not found";
+format_tls_error(no_certfile) ->
+ "certificate not configured";
format_tls_error(Reason) when is_atom(Reason) ->
format_inet_error(Reason);
format_tls_error(Reason) ->