-export([init/1,
wait_for_stream/2,
wait_for_auth/2,
- wait_for_sasl_auth/2,
+ wait_for_feature_request/2,
wait_for_bind/2,
wait_for_session/2,
wait_for_sasl_response/2,
{xmlelement, "mechanism", [],
[{xmlcdata, S}]}
end, cyrsasl:listmech()),
+ TLS = StateData#state.tls,
+ TLSEnabled = StateData#state.tls_enabled,
+ SockMod = StateData#state.sockmod,
+ TLSFeature =
+ case (TLS == true) andalso
+ (TLSEnabled == false) andalso
+ (SockMod == gen_tcp) of
+ true ->
+ [{xmlelement, "starttls",
+ [{"xmlns", ?NS_TLS}], []}];
+ false ->
+ []
+ end,
send_element(StateData,
{xmlelement, "stream:features", [],
+ TLSFeature ++
[{xmlelement, "mechanisms",
[{"xmlns", ?NS_SASL}],
Mechs}]}),
- {next_state, wait_for_sasl_auth,
+ {next_state, wait_for_feature_request,
StateData#state{sasl_state = SASLState,
lang = Lang}};
_ ->
{stop, normal, StateData}.
-wait_for_sasl_auth({xmlstreamelement, El}, StateData) ->
+wait_for_feature_request({xmlstreamelement, El}, StateData) ->
{xmlelement, Name, Attrs, Els} = El,
+ TLS = StateData#state.tls,
+ TLSEnabled = StateData#state.tls_enabled,
+ SockMod = StateData#state.sockmod,
case {xml:get_attr_s("xmlns", Attrs), Name} of
{?NS_SASL, "auth"} ->
Mech = xml:get_attr_s("mechanism", Attrs),
{xmlelement, "failure",
[{"xmlns", ?NS_SASL}],
[{xmlelement, Error, [], []}]}),
- {next_state, wait_for_sasl_auth, StateData}
+ {next_state, wait_for_feature_request, StateData}
end;
+ {?NS_TLS, "starttls"} when TLS == true,
+ TLSEnabled == false,
+ SockMod == gen_tcp ->
+ Socket = StateData#state.socket,
+ TLSOpts = StateData#state.tls_options,
+ {ok, TLSSocket} = tls:tcp_to_tls(Socket, TLSOpts),
+ ejabberd_receiver:starttls(StateData#state.receiver, TLSSocket),
+ send_element(StateData,
+ {xmlelement, "proceed", [{"xmlns", ?NS_TLS}], []}),
+ {next_state, wait_for_stream,
+ StateData#state{sockmod = tls,
+ socket = TLSSocket,
+ tls_enabled = true
+ }};
_ ->
case jlib:iq_query_info(El) of
#iq{xmlns = ?NS_REGISTER} = IQ ->
jlib:iq_to_xml(ResIQ)),
Res = jlib:remove_attr("to", Res1),
send_element(StateData, Res),
- {next_state, wait_for_sasl_auth, StateData};
+ {next_state, wait_for_feature_request, StateData};
_ ->
- {next_state, wait_for_sasl_auth, StateData}
+ {next_state, wait_for_feature_request, StateData}
end
end;
-wait_for_sasl_auth({xmlstreamend, _Name}, StateData) ->
+wait_for_feature_request({xmlstreamend, _Name}, StateData) ->
send_text(StateData, ?STREAM_TRAILER),
{stop, normal, StateData};
-wait_for_sasl_auth({xmlstreamerror, _}, StateData) ->
+wait_for_feature_request({xmlstreamerror, _}, StateData) ->
send_text(StateData, ?INVALID_XML_ERR ++ ?STREAM_TRAILER),
{stop, normal, StateData};
-wait_for_sasl_auth(closed, StateData) ->
+wait_for_feature_request(closed, StateData) ->
{stop, normal, StateData}.
{xmlelement, "failure",
[{"xmlns", ?NS_SASL}],
[{xmlelement, Error, [], []}]}),
- {next_state, wait_for_sasl_auth, StateData}
+ {next_state, wait_for_feature_request, StateData}
end;
_ ->
case jlib:iq_query_info(El) of
jlib:iq_to_xml(ResIQ)),
Res = jlib:remove_attr("to", Res1),
send_element(StateData, Res),
- {next_state, wait_for_sasl_auth, StateData};
+ {next_state, wait_for_feature_request, StateData};
_ ->
- {next_state, wait_for_sasl_auth, StateData}
+ {next_state, wait_for_feature_request, StateData}
end
end;
-export([start/3,
receiver/4,
change_shaper/2,
- reset_stream/1]).
+ reset_stream/1,
+ starttls/2]).
-include("ejabberd.hrl").
receiver(Socket, SockMod, ShaperState, C2SPid, XMLStreamPid, Timeout).
receiver(Socket, SockMod, ShaperState, C2SPid, XMLStreamPid, Timeout) ->
- case catch SockMod:recv(Socket, 0, Timeout) of
+ Res = (catch SockMod:recv(Socket, 0, Timeout)),
+ case Res of
+ {ok, Data} ->
+ receive
+ {starttls, TLSSocket} ->
+ exit(XMLStreamPid, closed),
+ XMLStreamPid1 = xml_stream:start(self(), C2SPid),
+ TLSRes = tls:recv_data(TLSSocket, Data),
+ receiver1(TLSSocket, tls,
+ ShaperState, C2SPid, XMLStreamPid1, Timeout,
+ TLSRes)
+ after 0 ->
+ receiver1(Socket, SockMod,
+ ShaperState, C2SPid, XMLStreamPid, Timeout,
+ Res)
+ end;
+ _ ->
+ receiver1(Socket, SockMod,
+ ShaperState, C2SPid, XMLStreamPid, Timeout, Res)
+ end.
+
+
+receiver1(Socket, SockMod, ShaperState, C2SPid, XMLStreamPid, Timeout, Res) ->
+ case Res of
{ok, Text} ->
ShaperSt1 = receive
{change_shaper, Shaper} ->
reset_stream(Pid) ->
Pid ! reset_stream.
+starttls(Pid, TLSSocket) ->
+ Pid ! {starttls, TLSSocket}.
+
-export([start/0, start_link/0,
tcp_to_tls/2, tls_to_tcp/1,
send/2,
- recv/2, recv/3,
+ recv/2, recv/3, recv_data/2,
close/1,
test/0]).
{value, {certfile, CertFile}} ->
ok = erl_ddll:load_driver(ejabberd:get_so_path(), tls_drv),
Port = open_port({spawn, tls_drv}, [binary]),
- io:format("open_port: ~p~n", [Port]),
case port_control(Port, ?SET_CERTIFICATE_FILE,
CertFile ++ [0]) of
<<0>> ->
recv(Socket, Length) ->
recv(Socket, Length, infinity).
-recv(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Length, Timeout) ->
+recv(#tlssock{tcpsock = TCPSocket, tlsport = Port} = TLSSock,
+ Length, Timeout) ->
case gen_tcp:recv(TCPSocket, Length, Timeout) of
{ok, Packet} ->
- case port_control(Port, ?SET_ENCRYPTED_INPUT, Packet) of
- <<0>> ->
- case port_control(Port, ?GET_DECRYPTED_INPUT, []) of
- <<0, In/binary>> ->
- case port_control(Port, ?GET_ENCRYPTED_OUTPUT, []) of
- <<0, Out/binary>> ->
- case gen_tcp:send(TCPSocket, Out) of
- ok ->
- {ok, In};
- Error ->
- Error
- end;
- <<1, Error/binary>> ->
- {error, binary_to_list(Error)}
+ recv_data(TLSSock, Packet);
+ {error, _Reason} = Error ->
+ Error
+ end.
+
+recv_data(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) ->
+ case port_control(Port, ?SET_ENCRYPTED_INPUT, Packet) of
+ <<0>> ->
+ case port_control(Port, ?GET_DECRYPTED_INPUT, []) of
+ <<0, In/binary>> ->
+ case port_control(Port, ?GET_ENCRYPTED_OUTPUT, []) of
+ <<0, Out/binary>> ->
+ case gen_tcp:send(TCPSocket, Out) of
+ ok ->
+ {ok, In};
+ Error ->
+ Error
end;
<<1, Error/binary>> ->
{error, binary_to_list(Error)}
<<1, Error/binary>> ->
{error, binary_to_list(Error)}
end;
- {error, _Reason} = Error ->
- Error
+ <<1, Error/binary>> ->
+ {error, binary_to_list(Error)}
end.
send(#tlssock{tcpsock = TCPSocket, tlsport = Port}, Packet) ->