+++ /dev/null
-XmppAddr { iso(1) identified-organization(3)
- dod(6) internet(1) security(5) mechanisms(5) pkix(7)
- id-on(8) id-on-xmppAddr(5) }
-
-DEFINITIONS EXPLICIT TAGS ::=
-BEGIN
-
-id-on-xmppAddr OBJECT IDENTIFIER ::= { iso(1) identified-organization(3)
- dod(6) internet(1) security(5) mechanisms(5) pkix(7)
- id-on(8) 5 }
-
-XmppAddr ::= UTF8String
-
-END
+++ /dev/null
-%%%----------------------------------------------------------------------
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
--record(scram, {storedkey = <<"">> :: binary(),
- serverkey = <<"">> :: binary(),
- salt = <<"">> :: binary(),
- iterationcount = 0 :: integer()}).
-
--type scram() :: #scram{}.
-
--define(SCRAM_DEFAULT_ITERATION_COUNT, 4096).
{fast_tls, ".*", {git, "https://github.com/processone/fast_tls", {tag, "1.0.23"}}},
{stringprep, ".*", {git, "https://github.com/processone/stringprep", {tag, "1.0.12"}}},
{fast_xml, ".*", {git, "https://github.com/processone/fast_xml", {tag, "1.1.32"}}},
- {xmpp, ".*", {git, "https://github.com/processone/xmpp", "0e2ef5d"}},
+ {xmpp, ".*", {git, "https://github.com/processone/xmpp", "2a5193c"}},
{fast_yaml, ".*", {git, "https://github.com/processone/fast_yaml", {tag, "1.0.15"}}},
{jiffy, ".*", {git, "https://github.com/davisp/jiffy", {tag, "0.14.8"}}},
{p1_oauth2, ".*", {git, "https://github.com/processone/p1_oauth2", {tag, "0.6.3"}}},
{if_have_fun, {public_key, short_name_hash, 1}, {d, 'SHORT_NAME_HASH'}},
{if_var_true, new_sql_schema, {d, 'NEW_SQL_SCHEMA'}},
{if_var_true, hipe, native},
- {src_dirs, [asn1, src,
+ {src_dirs, [src,
{if_var_true, tools, tools},
{if_var_true, elixir, include}]}]}.
+++ /dev/null
-%%%----------------------------------------------------------------------
-%%% File : cyrsasl.erl
-%%% Author : Alexey Shchepin <alexey@process-one.net>
-%%% Purpose : Cyrus SASL-like library
-%%% Created : 8 Mar 2003 by Alexey Shchepin <alexey@process-one.net>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
--module(cyrsasl).
-
--author('alexey@process-one.net').
--behaviour(gen_server).
-
--export([start_link/0, register_mechanism/3, listmech/1,
- server_new/7, server_start/3, server_step/2,
- get_mech/1, format_error/2]).
-%% gen_server callbacks
--export([init/1, handle_call/3, handle_cast/2, handle_info/2,
- terminate/2, code_change/3]).
-
--include("logger.hrl").
-
--record(state, {}).
-
--record(sasl_mechanism,
- {mechanism = <<"">> :: mechanism() | '$1',
- module :: atom(),
- password_type = plain :: password_type() | '$2'}).
-
--type(mechanism() :: binary()).
--type(mechanisms() :: [mechanism(),...]).
--type(password_type() :: plain | digest | scram).
--type sasl_property() :: {username, binary()} |
- {authzid, binary()} |
- {mechanism, binary()} |
- {auth_module, atom()}.
--type sasl_return() :: {ok, [sasl_property()]} |
- {ok, [sasl_property()], binary()} |
- {continue, binary(), sasl_state()} |
- {error, atom(), binary()}.
-
--type(sasl_mechanism() :: #sasl_mechanism{}).
--type error_reason() :: cyrsasl_digest:error_reason() |
- cyrsasl_oauth:error_reason() |
- cyrsasl_plain:error_reason() |
- cyrsasl_scram:error_reason() |
- unsupported_mechanism | nodeprep_failed |
- empty_username | aborted.
--record(sasl_state,
-{
- service,
- myname,
- realm,
- get_password,
- check_password,
- check_password_digest,
- mech_name = <<"">>,
- mech_mod,
- mech_state
-}).
--type sasl_state() :: #sasl_state{}.
--export_type([mechanism/0, mechanisms/0, sasl_mechanism/0, error_reason/0,
- sasl_state/0, sasl_return/0, sasl_property/0]).
-
--callback start(list()) -> any().
--callback stop() -> any().
--callback mech_new(binary(), fun(), fun(), fun()) -> any().
--callback mech_step(any(), binary()) -> sasl_return().
-
-start_link() ->
- gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
-
-init([]) ->
- ets:new(sasl_mechanism,
- [named_table, public,
- {keypos, #sasl_mechanism.mechanism}]),
- cyrsasl_plain:start([]),
- cyrsasl_digest:start([]),
- cyrsasl_scram:start([]),
- cyrsasl_anonymous:start([]),
- cyrsasl_oauth:start([]),
- {ok, #state{}}.
-
-handle_call(_Request, _From, State) ->
- Reply = ok,
- {reply, Reply, State}.
-
-handle_cast(_Msg, State) ->
- {noreply, State}.
-
-handle_info(_Info, State) ->
- {noreply, State}.
-
-terminate(_Reason, _State) ->
- cyrsasl_plain:stop(),
- cyrsasl_digest:stop(),
- cyrsasl_scram:stop(),
- cyrsasl_anonymous:stop(),
- cyrsasl_oauth:stop().
-
-code_change(_OldVsn, State, _Extra) ->
- {ok, State}.
-
--spec format_error(mechanism() | sasl_state(), error_reason()) -> {atom(), binary()}.
-format_error(_, unsupported_mechanism) ->
- {'invalid-mechanism', <<"Unsupported mechanism">>};
-format_error(_, nodeprep_failed) ->
- {'bad-protocol', <<"Nodeprep failed">>};
-format_error(_, empty_username) ->
- {'bad-protocol', <<"Empty username">>};
-format_error(_, aborted) ->
- {'aborted', <<"Aborted">>};
-format_error(#sasl_state{mech_mod = Mod}, Reason) ->
- Mod:format_error(Reason);
-format_error(Mech, Reason) ->
- case ets:lookup(sasl_mechanism, Mech) of
- [#sasl_mechanism{module = Mod}] ->
- Mod:format_error(Reason);
- [] ->
- {'invalid-mechanism', <<"Unsupported mechanism">>}
- end.
-
--spec register_mechanism(Mechanim :: mechanism(), Module :: module(),
- PasswordType :: password_type()) -> any().
-
-register_mechanism(Mechanism, Module, PasswordType) ->
- ets:insert(sasl_mechanism,
- #sasl_mechanism{mechanism = Mechanism, module = Module,
- password_type = PasswordType}).
-
-check_credentials(_State, Props) ->
- User = proplists:get_value(authzid, Props, <<>>),
- case jid:nodeprep(User) of
- error -> {error, nodeprep_failed};
- <<"">> -> {error, empty_username};
- _LUser -> ok
- end.
-
--spec listmech(Host ::binary()) -> Mechanisms::mechanisms().
-
-listmech(Host) ->
- ets:select(sasl_mechanism,
- [{#sasl_mechanism{mechanism = '$1',
- password_type = '$2', _ = '_'},
- case catch ejabberd_auth:store_type(Host) of
- external -> [{'==', '$2', plain}];
- scram -> [{'/=', '$2', digest}];
- {'EXIT', {undef, [{Module, store_type, []} | _]}} ->
- ?WARNING_MSG("~p doesn't implement the function store_type/0",
- [Module]),
- [];
- _Else -> []
- end,
- ['$1']}]).
-
--spec server_new(binary(), binary(), binary(), term(),
- fun(), fun(), fun()) -> sasl_state().
-server_new(Service, ServerFQDN, UserRealm, _SecFlags,
- GetPassword, CheckPassword, CheckPasswordDigest) ->
- #sasl_state{service = Service, myname = ServerFQDN,
- realm = UserRealm, get_password = GetPassword,
- check_password = CheckPassword,
- check_password_digest = CheckPasswordDigest}.
-
--spec server_start(sasl_state(), mechanism(), binary()) -> sasl_return().
-server_start(State, Mech, ClientIn) ->
- case lists:member(Mech,
- listmech(State#sasl_state.myname))
- of
- true ->
- case ets:lookup(sasl_mechanism, Mech) of
- [#sasl_mechanism{module = Module}] ->
- {ok, MechState} =
- Module:mech_new(State#sasl_state.myname,
- State#sasl_state.get_password,
- State#sasl_state.check_password,
- State#sasl_state.check_password_digest),
- server_step(State#sasl_state{mech_mod = Module,
- mech_name = Mech,
- mech_state = MechState},
- ClientIn);
- _ -> {error, unsupported_mechanism, <<"">>}
- end;
- false -> {error, unsupported_mechanism, <<"">>}
- end.
-
--spec server_step(sasl_state(), binary()) -> sasl_return().
-server_step(State, ClientIn) ->
- Module = State#sasl_state.mech_mod,
- MechState = State#sasl_state.mech_state,
- case Module:mech_step(MechState, ClientIn) of
- {ok, Props} ->
- case check_credentials(State, Props) of
- ok -> {ok, Props};
- {error, Error} -> {error, Error, <<"">>}
- end;
- {ok, Props, ServerOut} ->
- case check_credentials(State, Props) of
- ok -> {ok, Props, ServerOut};
- {error, Error} -> {error, Error, <<"">>}
- end;
- {continue, ServerOut, NewMechState} ->
- {continue, ServerOut, State#sasl_state{mech_state = NewMechState}};
- {error, Error, Username} ->
- {error, Error, Username};
- {error, Error} ->
- {error, Error, <<"">>}
- end.
-
--spec get_mech(sasl_state()) -> binary().
-get_mech(#sasl_state{mech_name = Mech}) ->
- Mech.
+++ /dev/null
-%%%----------------------------------------------------------------------
-%%% File : cyrsasl_anonymous.erl
-%%% Author : Magnus Henoch <henoch@dtek.chalmers.se>
-%%% Purpose : ANONYMOUS SASL mechanism
-%%% See http://www.ietf.org/internet-drafts/draft-ietf-sasl-anon-05.txt
-%%% Created : 23 Aug 2005 by Magnus Henoch <henoch@dtek.chalmers.se>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
--module(cyrsasl_anonymous).
-
--protocol({xep, 175, '1.2'}).
-
--export([start/1, stop/0, mech_new/4, mech_step/2]).
-
--behaviour(cyrsasl).
-
--record(state, {server = <<"">> :: binary()}).
-
-start(_Opts) ->
- cyrsasl:register_mechanism(<<"ANONYMOUS">>, ?MODULE, plain).
-
-stop() -> ok.
-
-mech_new(Host, _GetPassword, _CheckPassword, _CheckPasswordDigest) ->
- {ok, #state{server = Host}}.
-
-mech_step(#state{}, _ClientIn) ->
- User = iolist_to_binary([p1_rand:get_string(),
- integer_to_binary(p1_time_compat:unique_integer([positive]))]),
- {ok, [{username, User},
- {authzid, User},
- {auth_module, ejabberd_auth_anonymous}]}.
+++ /dev/null
-%%%----------------------------------------------------------------------
-%%% File : cyrsasl_digest.erl
-%%% Author : Alexey Shchepin <alexey@sevcom.net>
-%%% Purpose : DIGEST-MD5 SASL mechanism
-%%% Created : 11 Mar 2003 by Alexey Shchepin <alexey@sevcom.net>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
--module(cyrsasl_digest).
-
--behaviour(ejabberd_config).
-
--author('alexey@sevcom.net').
-
--export([start/1, stop/0, mech_new/4, mech_step/2,
- parse/1, format_error/1, opt_type/1]).
-
--include("logger.hrl").
-
--behaviour(cyrsasl).
-
--type get_password_fun() :: fun((binary()) -> {false, any()} |
- {binary(), atom()}).
--type check_password_fun() :: fun((binary(), binary(), binary(), binary(),
- fun((binary()) -> binary())) ->
- {boolean(), any()} |
- false).
--type error_reason() :: parser_failed | invalid_digest_uri |
- not_authorized | unexpected_response.
--export_type([error_reason/0]).
-
--record(state, {step = 1 :: 1 | 3 | 5,
- nonce = <<"">> :: binary(),
- username = <<"">> :: binary(),
- authzid = <<"">> :: binary(),
- get_password :: get_password_fun(),
- check_password :: check_password_fun(),
- auth_module :: atom(),
- host = <<"">> :: binary(),
- hostfqdn = [] :: [binary()]}).
-
-start(_Opts) ->
- Fqdn = get_local_fqdn(),
- ?DEBUG("FQDN used to check DIGEST-MD5 SASL authentication: ~s",
- [Fqdn]),
- cyrsasl:register_mechanism(<<"DIGEST-MD5">>, ?MODULE,
- digest).
-
-stop() -> ok.
-
--spec format_error(error_reason()) -> {atom(), binary()}.
-format_error(parser_failed) ->
- {'bad-protocol', <<"Response decoding failed">>};
-format_error(invalid_digest_uri) ->
- {'bad-protocol', <<"Invalid digest URI">>};
-format_error(not_authorized) ->
- {'not-authorized', <<"Invalid username or password">>};
-format_error(unexpected_response) ->
- {'bad-protocol', <<"Unexpected response">>}.
-
-mech_new(Host, GetPassword, _CheckPassword,
- CheckPasswordDigest) ->
- {ok,
- #state{step = 1, nonce = p1_rand:get_string(),
- host = Host, hostfqdn = get_local_fqdn(),
- get_password = GetPassword,
- check_password = CheckPasswordDigest}}.
-
-mech_step(#state{step = 1, nonce = Nonce} = State, _) ->
- {continue,
- <<"nonce=\"", Nonce/binary,
- "\",qop=\"auth\",charset=utf-8,algorithm=md5-sess">>,
- State#state{step = 3}};
-mech_step(#state{step = 3, nonce = Nonce} = State,
- ClientIn) ->
- case parse(ClientIn) of
- bad -> {error, parser_failed};
- KeyVals ->
- DigestURI = proplists:get_value(<<"digest-uri">>, KeyVals, <<>>),
- UserName = proplists:get_value(<<"username">>, KeyVals, <<>>),
- case is_digesturi_valid(DigestURI, State#state.host,
- State#state.hostfqdn)
- of
- false ->
- ?DEBUG("User login not authorized because digest-uri "
- "seems invalid: ~p (checking for Host "
- "~p, FQDN ~p)",
- [DigestURI, State#state.host, State#state.hostfqdn]),
- {error, invalid_digest_uri, UserName};
- true ->
- AuthzId = proplists:get_value(<<"authzid">>, KeyVals, <<>>),
- case (State#state.get_password)(UserName) of
- {false, _} -> {error, not_authorized, UserName};
- {Passwd, AuthModule} ->
- case (State#state.check_password)(UserName, UserName, <<"">>,
- proplists:get_value(<<"response">>, KeyVals, <<>>),
- fun (PW) ->
- response(KeyVals,
- UserName,
- PW,
- Nonce,
- AuthzId,
- <<"AUTHENTICATE">>)
- end)
- of
- {true, _} ->
- RspAuth = response(KeyVals, UserName, Passwd, Nonce,
- AuthzId, <<"">>),
- {continue, <<"rspauth=", RspAuth/binary>>,
- State#state{step = 5, auth_module = AuthModule,
- username = UserName,
- authzid = AuthzId}};
- false -> {error, not_authorized, UserName};
- {false, _} -> {error, not_authorized, UserName}
- end
- end
- end
- end;
-mech_step(#state{step = 5, auth_module = AuthModule,
- username = UserName, authzid = AuthzId},
- <<"">>) ->
- {ok,
- [{username, UserName}, {authzid, case AuthzId of
- <<"">> -> UserName;
- _ -> AuthzId
- end
- },
- {auth_module, AuthModule}]};
-mech_step(A, B) ->
- ?DEBUG("SASL DIGEST: A ~p B ~p", [A, B]),
- {error, unexpected_response}.
-
-parse(S) -> parse1(binary_to_list(S), "", []).
-
-parse1([$= | Cs], S, Ts) ->
- parse2(Cs, lists:reverse(S), "", Ts);
-parse1([$, | Cs], [], Ts) -> parse1(Cs, [], Ts);
-parse1([$\s | Cs], [], Ts) -> parse1(Cs, [], Ts);
-parse1([C | Cs], S, Ts) -> parse1(Cs, [C | S], Ts);
-parse1([], [], T) -> lists:reverse(T);
-parse1([], _S, _T) -> bad.
-
-parse2([$" | Cs], Key, Val, Ts) ->
- parse3(Cs, Key, Val, Ts);
-parse2([C | Cs], Key, Val, Ts) ->
- parse4(Cs, Key, [C | Val], Ts);
-parse2([], _, _, _) -> bad.
-
-parse3([$" | Cs], Key, Val, Ts) ->
- parse4(Cs, Key, Val, Ts);
-parse3([$\\, C | Cs], Key, Val, Ts) ->
- parse3(Cs, Key, [C | Val], Ts);
-parse3([C | Cs], Key, Val, Ts) ->
- parse3(Cs, Key, [C | Val], Ts);
-parse3([], _, _, _) -> bad.
-
-parse4([$, | Cs], Key, Val, Ts) ->
- parse1(Cs, "", [{list_to_binary(Key), list_to_binary(lists:reverse(Val))} | Ts]);
-parse4([$\s | Cs], Key, Val, Ts) ->
- parse4(Cs, Key, Val, Ts);
-parse4([C | Cs], Key, Val, Ts) ->
- parse4(Cs, Key, [C | Val], Ts);
-parse4([], Key, Val, Ts) ->
-%% @doc Check if the digest-uri is valid.
-%% RFC-2831 allows to provide the IP address in Host,
-%% however ejabberd doesn't allow that.
-%% If the service (for example jabber.example.org)
-%% is provided by several hosts (being one of them server3.example.org),
-%% then acceptable digest-uris would be:
-%% xmpp/server3.example.org/jabber.example.org, xmpp/server3.example.org and
-%% xmpp/jabber.example.org
-%% The last version is not actually allowed by the RFC, but implemented by popular clients
- parse1([], "", [{list_to_binary(Key), list_to_binary(lists:reverse(Val))} | Ts]).
-
-is_digesturi_valid(DigestURICase, JabberDomain,
- JabberFQDN) ->
- DigestURI = stringprep:tolower(DigestURICase),
- case catch str:tokens(DigestURI, <<"/">>) of
- [<<"xmpp">>, Host] ->
- IsHostFqdn = is_host_fqdn(Host, JabberFQDN),
- (Host == JabberDomain) or IsHostFqdn;
- [<<"xmpp">>, Host, ServName] ->
- IsHostFqdn = is_host_fqdn(Host, JabberFQDN),
- (ServName == JabberDomain) and IsHostFqdn;
- _ ->
- false
- end.
-
-is_host_fqdn(_Host, []) ->
- false;
-is_host_fqdn(Host, [Fqdn | _FqdnTail]) when Host == Fqdn ->
- true;
-is_host_fqdn(Host, [Fqdn | FqdnTail]) when Host /= Fqdn ->
- is_host_fqdn(Host, FqdnTail).
-
-get_local_fqdn() ->
- case ejabberd_config:get_option(fqdn) of
- undefined ->
- {ok, Hostname} = inet:gethostname(),
- {ok, {hostent, Fqdn, _, _, _, _}} = inet:gethostbyname(Hostname),
- [list_to_binary(Fqdn)];
- Fqdn ->
- Fqdn
- end.
-
-hex(S) ->
- str:to_hexlist(S).
-
-proplists_get_bin_value(Key, Pairs, Default) ->
- case proplists:get_value(Key, Pairs, Default) of
- L when is_list(L) ->
- list_to_binary(L);
- L2 ->
- L2
- end.
-
-response(KeyVals, User, Passwd, Nonce, AuthzId,
- A2Prefix) ->
- Realm = proplists_get_bin_value(<<"realm">>, KeyVals, <<>>),
- CNonce = proplists_get_bin_value(<<"cnonce">>, KeyVals, <<>>),
- DigestURI = proplists_get_bin_value(<<"digest-uri">>, KeyVals, <<>>),
- NC = proplists_get_bin_value(<<"nc">>, KeyVals, <<>>),
- QOP = proplists_get_bin_value(<<"qop">>, KeyVals, <<>>),
- MD5Hash = erlang:md5(<<User/binary, ":", Realm/binary, ":",
- Passwd/binary>>),
- A1 = case AuthzId of
- <<"">> ->
- <<MD5Hash/binary, ":", Nonce/binary, ":", CNonce/binary>>;
- _ ->
- <<MD5Hash/binary, ":", Nonce/binary, ":", CNonce/binary, ":",
- AuthzId/binary>>
- end,
- A2 = case QOP of
- <<"auth">> ->
- <<A2Prefix/binary, ":", DigestURI/binary>>;
- _ ->
- <<A2Prefix/binary, ":", DigestURI/binary,
- ":00000000000000000000000000000000">>
- end,
- T = <<(hex((erlang:md5(A1))))/binary, ":", Nonce/binary,
- ":", NC/binary, ":", CNonce/binary, ":", QOP/binary,
- ":", (hex((erlang:md5(A2))))/binary>>,
- hex((erlang:md5(T))).
-
--spec opt_type(fqdn) -> fun((binary() | [binary()]) -> [binary()]);
- (atom()) -> [atom()].
-opt_type(fqdn) ->
- fun(FQDN) when is_binary(FQDN) ->
- [FQDN];
- (FQDNs) when is_list(FQDNs) ->
- [iolist_to_binary(FQDN) || FQDN <- FQDNs]
- end;
-opt_type(_) -> [fqdn].
+++ /dev/null
-%%%----------------------------------------------------------------------
-%%% File : cyrsasl_oauth.erl
-%%% Author : Alexey Shchepin <alexey@process-one.net>
-%%% Purpose : X-OAUTH2 SASL mechanism
-%%% Created : 17 Sep 2015 by Alexey Shchepin <alexey@process-one.net>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
--module(cyrsasl_oauth).
-
--author('alexey@process-one.net').
-
--export([start/1, stop/0, mech_new/4, mech_step/2, parse/1, format_error/1]).
-
--behaviour(cyrsasl).
-
--record(state, {host}).
--type error_reason() :: parser_failed | not_authorized.
--export_type([error_reason/0]).
-
-start(_Opts) ->
- cyrsasl:register_mechanism(<<"X-OAUTH2">>, ?MODULE, plain).
-
-stop() -> ok.
-
--spec format_error(error_reason()) -> {atom(), binary()}.
-format_error(parser_failed) ->
- {'bad-protocol', <<"Response decoding failed">>};
-format_error(not_authorized) ->
- {'not-authorized', <<"Invalid token">>}.
-
-mech_new(Host, _GetPassword, _CheckPassword, _CheckPasswordDigest) ->
- {ok, #state{host = Host}}.
-
-mech_step(State, ClientIn) ->
- case prepare(ClientIn) of
- [AuthzId, User, Token] ->
- case ejabberd_oauth:check_token(
- User, State#state.host, [<<"sasl_auth">>], Token) of
- true ->
- {ok,
- [{username, User}, {authzid, AuthzId},
- {auth_module, ejabberd_oauth}]};
- _ ->
- {error, not_authorized, User}
- end;
- _ -> {error, parser_failed}
- end.
-
-prepare(ClientIn) ->
- case parse(ClientIn) of
- [<<"">>, UserMaybeDomain, Token] ->
- case parse_domain(UserMaybeDomain) of
- %% <NUL>login@domain<NUL>pwd
- [User, _Domain] -> [User, User, Token];
- %% <NUL>login<NUL>pwd
- [User] -> [User, User, Token]
- end;
- %% login@domain<NUL>login<NUL>pwd
- [AuthzId, User, Token] ->
- case parse_domain(AuthzId) of
- %% login@domain<NUL>login<NUL>pwd
- [AuthzUser, _Domain] -> [AuthzUser, User, Token];
- %% login<NUL>login<NUL>pwd
- [AuthzUser] -> [AuthzUser, User, Token]
- end;
- _ -> error
- end.
-
-parse(S) -> parse1(binary_to_list(S), "", []).
-
-parse1([0 | Cs], S, T) ->
- parse1(Cs, "", [list_to_binary(lists:reverse(S)) | T]);
-parse1([C | Cs], S, T) -> parse1(Cs, [C | S], T);
-%parse1([], [], T) ->
-% lists:reverse(T);
-parse1([], S, T) ->
- lists:reverse([list_to_binary(lists:reverse(S)) | T]).
-
-parse_domain(S) -> parse_domain1(binary_to_list(S), "", []).
-
-parse_domain1([$@ | Cs], S, T) ->
- parse_domain1(Cs, "", [list_to_binary(lists:reverse(S)) | T]);
-parse_domain1([C | Cs], S, T) ->
- parse_domain1(Cs, [C | S], T);
-parse_domain1([], S, T) ->
- lists:reverse([list_to_binary(lists:reverse(S)) | T]).
+++ /dev/null
-%%%----------------------------------------------------------------------
-%%% File : cyrsasl_plain.erl
-%%% Author : Alexey Shchepin <alexey@process-one.net>
-%%% Purpose : PLAIN SASL mechanism
-%%% Created : 8 Mar 2003 by Alexey Shchepin <alexey@process-one.net>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
--module(cyrsasl_plain).
-
--author('alexey@process-one.net').
-
--export([start/1, stop/0, mech_new/4, mech_step/2, parse/1, format_error/1]).
-
--behaviour(cyrsasl).
-
--record(state, {check_password}).
--type error_reason() :: parser_failed | not_authorized.
--export_type([error_reason/0]).
-
-start(_Opts) ->
- cyrsasl:register_mechanism(<<"PLAIN">>, ?MODULE, plain).
-
-stop() -> ok.
-
--spec format_error(error_reason()) -> {atom(), binary()}.
-format_error(parser_failed) ->
- {'bad-protocol', <<"Response decoding failed">>};
-format_error(not_authorized) ->
- {'not-authorized', <<"Invalid username or password">>}.
-
-mech_new(_Host, _GetPassword, CheckPassword, _CheckPasswordDigest) ->
- {ok, #state{check_password = CheckPassword}}.
-
-mech_step(State, ClientIn) ->
- case prepare(ClientIn) of
- [AuthzId, User, Password] ->
- case (State#state.check_password)(User, AuthzId, Password) of
- {true, AuthModule} ->
- {ok,
- [{username, User}, {authzid, AuthzId},
- {auth_module, AuthModule}]};
- _ -> {error, not_authorized, User}
- end;
- _ -> {error, parser_failed}
- end.
-
-prepare(ClientIn) ->
- case parse(ClientIn) of
- [<<"">>, UserMaybeDomain, Password] ->
- case parse_domain(UserMaybeDomain) of
- %% <NUL>login@domain<NUL>pwd
- [User, _Domain] -> [User, User, Password];
- %% <NUL>login<NUL>pwd
- [User] -> [User, User, Password]
- end;
- [AuthzId, User, Password] ->
- case parse_domain(AuthzId) of
- %% login@domain<NUL>login<NUL>pwd
- [AuthzUser, _Domain] -> [AuthzUser, User, Password];
- %% login<NUL>login<NUL>pwd
- [AuthzUser] -> [AuthzUser, User, Password]
- end;
- _ -> error
- end.
-
-parse(S) ->
- binary:split(S, <<0>>, [global]).
-
-parse_domain(S) -> parse_domain1(binary_to_list(S), "", []).
-
-parse_domain1([$@ | Cs], S, T) ->
- parse_domain1(Cs, "", [list_to_binary(lists:reverse(S)) | T]);
-parse_domain1([C | Cs], S, T) ->
- parse_domain1(Cs, [C | S], T);
-parse_domain1([], S, T) ->
- lists:reverse([list_to_binary(lists:reverse(S)) | T]).
+++ /dev/null
-%%%----------------------------------------------------------------------
-%%% File : cyrsasl_scram.erl
-%%% Author : Stephen Röttger <stephen.roettger@googlemail.com>
-%%% Purpose : SASL SCRAM authentication
-%%% Created : 7 Aug 2011 by Stephen Röttger <stephen.roettger@googlemail.com>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
--module(cyrsasl_scram).
-
--author('stephen.roettger@googlemail.com').
-
--protocol({rfc, 5802}).
-
--export([start/1, stop/0, mech_new/4, mech_step/2, format_error/1]).
-
--include("scram.hrl").
--include("logger.hrl").
-
--behaviour(cyrsasl).
-
--record(state,
- {step = 2 :: 2 | 4,
- stored_key = <<"">> :: binary(),
- server_key = <<"">> :: binary(),
- username = <<"">> :: binary(),
- auth_module :: module(),
- get_password :: fun((binary()) ->
- {false | ejabberd_auth:password(), module()}),
- auth_message = <<"">> :: binary(),
- client_nonce = <<"">> :: binary(),
- server_nonce = <<"">> :: binary()}).
-
--define(SALT_LENGTH, 16).
--define(NONCE_LENGTH, 16).
-
--type error_reason() :: unsupported_extension | bad_username |
- not_authorized | saslprep_failed |
- parser_failed | bad_attribute |
- nonce_mismatch | bad_channel_binding.
-
--export_type([error_reason/0]).
-
-start(_Opts) ->
- cyrsasl:register_mechanism(<<"SCRAM-SHA-1">>, ?MODULE,
- scram).
-
-stop() -> ok.
-
--spec format_error(error_reason()) -> {atom(), binary()}.
-format_error(unsupported_extension) ->
- {'bad-protocol', <<"Unsupported extension">>};
-format_error(bad_username) ->
- {'invalid-authzid', <<"Malformed username">>};
-format_error(not_authorized) ->
- {'not-authorized', <<"Invalid username or password">>};
-format_error(saslprep_failed) ->
- {'not-authorized', <<"SASLprep failed">>};
-format_error(parser_failed) ->
- {'bad-protocol', <<"Response decoding failed">>};
-format_error(bad_attribute) ->
- {'bad-protocol', <<"Malformed or unexpected attribute">>};
-format_error(nonce_mismatch) ->
- {'bad-protocol', <<"Nonce mismatch">>};
-format_error(bad_channel_binding) ->
- {'bad-protocol', <<"Invalid channel binding">>}.
-
-mech_new(_Host, GetPassword, _CheckPassword,
- _CheckPasswordDigest) ->
- {ok, #state{step = 2, get_password = GetPassword}}.
-
-mech_step(#state{step = 2} = State, ClientIn) ->
- case re:split(ClientIn, <<",">>, [{return, binary}]) of
- [_CBind, _AuthorizationIdentity, _UserNameAttribute, _ClientNonceAttribute, ExtensionAttribute | _]
- when ExtensionAttribute /= <<"">> ->
- {error, unsupported_extension};
- [CBind, _AuthorizationIdentity, UserNameAttribute, ClientNonceAttribute | _]
- when (CBind == <<"y">>) or (CBind == <<"n">>) ->
- case parse_attribute(UserNameAttribute) of
- {error, Reason} -> {error, Reason};
- {_, EscapedUserName} ->
- case unescape_username(EscapedUserName) of
- error -> {error, bad_username};
- UserName ->
- case parse_attribute(ClientNonceAttribute) of
- {$r, ClientNonce} ->
- {Pass, AuthModule} = (State#state.get_password)(UserName),
- LPass = if is_binary(Pass) -> jid:resourceprep(Pass);
- true -> Pass
- end,
- if Pass == false ->
- {error, not_authorized, UserName};
- LPass == error ->
- {error, saslprep_failed, UserName};
- true ->
- {StoredKey, ServerKey, Salt, IterationCount} =
- if is_record(Pass, scram) ->
- {base64:decode(Pass#scram.storedkey),
- base64:decode(Pass#scram.serverkey),
- base64:decode(Pass#scram.salt),
- Pass#scram.iterationcount};
- true ->
- TempSalt =
- p1_rand:bytes(?SALT_LENGTH),
- SaltedPassword =
- scram:salted_password(Pass,
- TempSalt,
- ?SCRAM_DEFAULT_ITERATION_COUNT),
- {scram:stored_key(scram:client_key(SaltedPassword)),
- scram:server_key(SaltedPassword),
- TempSalt,
- ?SCRAM_DEFAULT_ITERATION_COUNT}
- end,
- ClientFirstMessageBare =
- str:substr(ClientIn,
- str:str(ClientIn, <<"n=">>)),
- ServerNonce =
- base64:encode(p1_rand:bytes(?NONCE_LENGTH)),
- ServerFirstMessage =
- iolist_to_binary(
- ["r=",
- ClientNonce,
- ServerNonce,
- ",", "s=",
- base64:encode(Salt),
- ",", "i=",
- integer_to_list(IterationCount)]),
- {continue, ServerFirstMessage,
- State#state{step = 4, stored_key = StoredKey,
- server_key = ServerKey,
- auth_module = AuthModule,
- auth_message =
- <<ClientFirstMessageBare/binary,
- ",", ServerFirstMessage/binary>>,
- client_nonce = ClientNonce,
- server_nonce = ServerNonce,
- username = UserName}}
- end;
- _ -> {error, bad_attribute}
- end
- end
- end;
- _Else -> {error, parser_failed}
- end;
-mech_step(#state{step = 4} = State, ClientIn) ->
- case str:tokens(ClientIn, <<",">>) of
- [GS2ChannelBindingAttribute, NonceAttribute,
- ClientProofAttribute] ->
- case parse_attribute(GS2ChannelBindingAttribute) of
- {$c, CVal} ->
- ChannelBindingSupport = try binary:first(base64:decode(CVal))
- catch _:badarg -> 0
- end,
- if (ChannelBindingSupport == $n)
- or (ChannelBindingSupport == $y) ->
- Nonce = <<(State#state.client_nonce)/binary,
- (State#state.server_nonce)/binary>>,
- case parse_attribute(NonceAttribute) of
- {$r, CompareNonce} when CompareNonce == Nonce ->
- case parse_attribute(ClientProofAttribute) of
- {$p, ClientProofB64} ->
- ClientProof = try base64:decode(ClientProofB64)
- catch _:badarg -> <<>>
- end,
- AuthMessage = iolist_to_binary(
- [State#state.auth_message,
- ",",
- str:substr(ClientIn, 1,
- str:str(ClientIn, <<",p=">>)
- - 1)]),
- ClientSignature =
- scram:client_signature(State#state.stored_key,
- AuthMessage),
- ClientKey = scram:client_key(ClientProof,
- ClientSignature),
- CompareStoredKey = scram:stored_key(ClientKey),
- if CompareStoredKey == State#state.stored_key ->
- ServerSignature =
- scram:server_signature(State#state.server_key,
- AuthMessage),
- {ok, [{username, State#state.username},
- {auth_module, State#state.auth_module},
- {authzid, State#state.username}],
- <<"v=",
- (base64:encode(ServerSignature))/binary>>};
- true -> {error, not_authorized, State#state.username}
- end;
- _ -> {error, bad_attribute}
- end;
- {$r, _} -> {error, nonce_mismatch};
- _ -> {error, bad_attribute}
- end;
- true -> {error, bad_channel_binding}
- end;
- _ -> {error, bad_attribute}
- end;
- _ -> {error, parser_failed}
- end.
-
-parse_attribute(<<Name, $=, Val/binary>>) when Val /= <<>> ->
- case is_alpha(Name) of
- true -> {Name, Val};
- false -> {error, bad_attribute}
- end;
-parse_attribute(_) ->
- {error, bad_attribute}.
-
-unescape_username(<<"">>) -> <<"">>;
-unescape_username(EscapedUsername) ->
- Pos = str:str(EscapedUsername, <<"=">>),
- if Pos == 0 -> EscapedUsername;
- true ->
- Start = str:substr(EscapedUsername, 1, Pos - 1),
- End = str:substr(EscapedUsername, Pos),
- EndLen = byte_size(End),
- if EndLen < 3 -> error;
- true ->
- case str:substr(End, 1, 3) of
- <<"=2C">> ->
- <<Start/binary, ",",
- (unescape_username(str:substr(End, 4)))/binary>>;
- <<"=3D">> ->
- <<Start/binary, "=",
- (unescape_username(str:substr(End, 4)))/binary>>;
- _Else -> error
- end
- end
- end.
-
-is_alpha(Char) when Char >= $a, Char =< $z -> true;
-is_alpha(Char) when Char >= $A, Char =< $Z -> true;
-is_alpha(_) -> false.
{next_state, StateName, State1};
handle_event({change_shaper, Shaper}, StateName,
State) ->
- NewShaperState = ejabberd_shaper:new(Shaper),
- {next_state, StateName,
- State#state{shaper_state = NewShaperState}};
+ {next_state, StateName, State#state{shaper_state = Shaper}};
handle_event(_Event, StateName, State) ->
?ERROR_MSG("unexpected event in '~s': ~p",
[StateName, _Event]),
%% xmpp_stream_in callbacks
-export([init/1, handle_call/3, handle_cast/2,
handle_info/2, terminate/2, code_change/3]).
--export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
+-export([tls_options/1, tls_required/1, tls_enabled/1,
compress_methods/1, bind/2, sasl_mechanisms/2,
- get_password_fun/1, check_password_fun/1, check_password_digest_fun/1,
+ get_password_fun/2, check_password_fun/2, check_password_digest_fun/2,
unauthenticated_stream_features/1, authenticated_stream_features/1,
handle_stream_start/2, handle_stream_end/2,
handle_unauthenticated_packet/2, handle_authenticated_packet/2,
tls_required(#{tls_required := TLSRequired}) ->
TLSRequired.
-tls_verify(#{tls_verify := TLSVerify}) ->
- TLSVerify.
-
tls_enabled(#{tls_enabled := TLSEnabled,
tls_required := TLSRequired,
tls_verify := TLSVerify}) ->
authenticated_stream_features(#{lserver := LServer}) ->
ejabberd_hooks:run_fold(c2s_post_auth_features, LServer, [], [LServer]).
-sasl_mechanisms(Mechs, #{lserver := LServer}) ->
+sasl_mechanisms(Mechs, #{lserver := LServer} = State) ->
+ Type = ejabberd_auth:store_type(LServer),
Mechs1 = ejabberd_config:get_option({disable_sasl_mechanisms, LServer}, []),
- Mechs2 = case ejabberd_auth_anonymous:is_sasl_anonymous_enabled(LServer) of
- true -> Mechs1;
- false -> [<<"ANONYMOUS">>|Mechs1]
- end,
- Mechs -- Mechs2.
-
-get_password_fun(#{lserver := LServer}) ->
+ %% I re-created it from cyrsasl ets magic, but I think it's wrong
+ %% TODO: need to check before 18.09 release
+ lists:filter(
+ fun(<<"ANONYMOUS">>) ->
+ ejabberd_auth_anonymous:is_sasl_anonymous_enabled(LServer);
+ (<<"DIGEST-MD5">>) -> Type == plain;
+ (<<"SCRAM-SHA-1">>) -> Type /= external;
+ (<<"PLAIN">>) -> true;
+ (<<"X-OAUTH2">>) -> true;
+ (<<"EXTERNAL">>) -> maps:get(tls_verify, State, false);
+ (_) -> false
+ end, Mechs -- Mechs1).
+
+get_password_fun(_Mech, #{lserver := LServer}) ->
fun(U) ->
ejabberd_auth:get_password_with_authmodule(U, LServer)
end.
-check_password_fun(#{lserver := LServer}) ->
+check_password_fun(<<"X-OAUTH2">>, #{lserver := LServer}) ->
+ fun(User, _AuthzId, Token) ->
+ case ejabberd_oauth:check_token(
+ User, LServer, [<<"sasl_auth">>], Token) of
+ true -> {true, ejabberd_oauth};
+ _ -> {false, ejabberd_oauth}
+ end
+ end;
+check_password_fun(_Mech, #{lserver := LServer}) ->
fun(U, AuthzId, P) ->
ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P)
end.
-check_password_digest_fun(#{lserver := LServer}) ->
+check_password_digest_fun(_Mech, #{lserver := LServer}) ->
fun(U, AuthzId, P, D, DG) ->
ejabberd_auth:check_password_with_authmodule(U, AuthzId, LServer, P, D, DG)
end.
Shaper = acl:access_matches(ShaperName,
#{usr => jid:split(JID), ip => IP},
LServer),
- xmpp_stream_in:change_shaper(State, Shaper).
+ xmpp_stream_in:change_shaper(State, ejabberd_shaper:new(Shaper)).
-spec format_reason(state(), term()) -> binary().
format_reason(#{stop_reason := Reason}, _) ->
fun(#local_config{key = Key, value = Val}) ->
{Key, Val}
end, Opts)),
+ set_fqdn(),
set_log_level().
+set_fqdn() ->
+ FQDNs = get_option(fqdn, []),
+ xmpp:set_config([{fqdn, FQDNs}]).
+
set_log_level() ->
Level = get_option(loglevel, 4),
ejabberd_logger:set(Level).
fun(I) when is_integer(I), I>=0 -> I end;
opt_type(validate_stream) ->
fun(B) when is_boolean(B) -> B end;
+opt_type(fqdn) ->
+ fun(Domain) when is_binary(Domain) ->
+ [Domain];
+ (Domains) ->
+ [iolist_to_binary(Domain) || Domain <- Domains]
+ end;
opt_type(_) ->
[hide_sensitive_log_data, hosts, language, max_fsm_queue,
default_db, default_ram_db, queue_type, queue_dir, loglevel,
- use_cache, cache_size, cache_missed, cache_life_time,
+ use_cache, cache_size, cache_missed, cache_life_time, fqdn,
shared_key, node_start, validate_stream, negotiation_timeout].
-spec may_hide_data(any()) -> any().
+++ /dev/null
-%%%----------------------------------------------------------------------
-%%% File : ejabberd_idna.erl
-%%% Author : Alexey Shchepin <alexey@process-one.net>
-%%% Purpose : Support for IDNA (RFC3490)
-%%% Created : 10 Apr 2004 by Alexey Shchepin <alexey@process-one.net>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
--module(ejabberd_idna).
-
--author('alexey@process-one.net').
-
--export([domain_utf8_to_ascii/1,
- domain_ucs2_to_ascii/1,
- utf8_to_ucs2/1]).
-
--ifdef(TEST).
--include_lib("eunit/include/eunit.hrl").
--endif.
-
--spec domain_utf8_to_ascii(binary()) -> false | binary().
-
-domain_utf8_to_ascii(Domain) ->
- domain_ucs2_to_ascii(utf8_to_ucs2(Domain)).
-
-utf8_to_ucs2(S) ->
- utf8_to_ucs2(binary_to_list(S), "").
-
-utf8_to_ucs2([], R) -> lists:reverse(R);
-utf8_to_ucs2([C | S], R) when C < 128 ->
- utf8_to_ucs2(S, [C | R]);
-utf8_to_ucs2([C1, C2 | S], R) when C1 < 224 ->
- utf8_to_ucs2(S, [C1 band 31 bsl 6 bor C2 band 63 | R]);
-utf8_to_ucs2([C1, C2, C3 | S], R) when C1 < 240 ->
- utf8_to_ucs2(S,
- [C1 band 15 bsl 12 bor (C2 band 63 bsl 6) bor C3 band 63
- | R]).
-
--spec domain_ucs2_to_ascii(list()) -> false | binary().
-
-domain_ucs2_to_ascii(Domain) ->
- case catch domain_ucs2_to_ascii1(Domain) of
- {'EXIT', _Reason} -> false;
- Res -> iolist_to_binary(Res)
- end.
-
-domain_ucs2_to_ascii1(Domain) ->
- Parts = string:tokens(Domain,
- [46, 12290, 65294, 65377]),
- ASCIIParts = lists:map(fun (P) -> to_ascii(P) end,
- Parts),
- string:strip(lists:flatmap(fun (P) -> [$. | P] end,
- ASCIIParts),
- left, $.).
-
-%% Domain names are already nameprep'ed in ejabberd, so we skiping this step
-to_ascii(Name) ->
- false = lists:any(fun (C)
- when (0 =< C) and (C =< 44) or
- (46 =< C) and (C =< 47)
- or (58 =< C) and (C =< 64)
- or (91 =< C) and (C =< 96)
- or (123 =< C) and (C =< 127) ->
- true;
- (_) -> false
- end,
- Name),
- case Name of
- [H | _] when H /= $- -> true = lists:last(Name) /= $-
- end,
- ASCIIName = case lists:any(fun (C) -> C > 127 end, Name)
- of
- true ->
- true = case Name of
- "xn--" ++ _ -> false;
- _ -> true
- end,
- "xn--" ++ punycode_encode(Name);
- false -> Name
- end,
- L = length(ASCIIName),
- true = (1 =< L) and (L =< 63),
- ASCIIName.
-
-%%% PUNYCODE (RFC3492)
-
--define(BASE, 36).
-
--define(TMIN, 1).
-
--define(TMAX, 26).
-
--define(SKEW, 38).
-
--define(DAMP, 700).
-
--define(INITIAL_BIAS, 72).
-
--define(INITIAL_N, 128).
-
-punycode_encode(Input) ->
- N = (?INITIAL_N),
- Delta = 0,
- Bias = (?INITIAL_BIAS),
- Basic = lists:filter(fun (C) -> C =< 127 end, Input),
- NonBasic = lists:filter(fun (C) -> C > 127 end, Input),
- L = length(Input),
- B = length(Basic),
- SNonBasic = lists:usort(NonBasic),
- Output1 = if B > 0 -> Basic ++ "-";
- true -> ""
- end,
- Output2 = punycode_encode1(Input, SNonBasic, B, B, L, N,
- Delta, Bias, ""),
- Output1 ++ Output2.
-
-punycode_encode1(Input, [M | SNonBasic], B, H, L, N,
- Delta, Bias, Out)
- when H < L ->
- Delta1 = Delta + (M - N) * (H + 1),
- % let n = m
- {NewDelta, NewBias, NewH, NewOut} = lists:foldl(fun (C,
- {ADelta, ABias, AH,
- AOut}) ->
- if C < M ->
- {ADelta + 1,
- ABias, AH,
- AOut};
- C == M ->
- NewOut =
- punycode_encode_delta(ADelta,
- ABias,
- AOut),
- NewBias =
- adapt(ADelta,
- H +
- 1,
- H
- ==
- B),
- {0, NewBias,
- AH + 1,
- NewOut};
- true ->
- {ADelta,
- ABias, AH,
- AOut}
- end
- end,
- {Delta1, Bias, H, Out},
- Input),
- punycode_encode1(Input, SNonBasic, B, NewH, L, M + 1,
- NewDelta + 1, NewBias, NewOut);
-punycode_encode1(_Input, _SNonBasic, _B, _H, _L, _N,
- _Delta, _Bias, Out) ->
- lists:reverse(Out).
-
-punycode_encode_delta(Delta, Bias, Out) ->
- punycode_encode_delta(Delta, Bias, Out, ?BASE).
-
-punycode_encode_delta(Delta, Bias, Out, K) ->
- T = if K =< Bias -> ?TMIN;
- K >= Bias + (?TMAX) -> ?TMAX;
- true -> K - Bias
- end,
- if Delta < T -> [codepoint(Delta) | Out];
- true ->
- C = T + (Delta - T) rem ((?BASE) - T),
- punycode_encode_delta((Delta - T) div ((?BASE) - T),
- Bias, [codepoint(C) | Out], K + (?BASE))
- end.
-
-adapt(Delta, NumPoints, FirstTime) ->
- Delta1 = if FirstTime -> Delta div (?DAMP);
- true -> Delta div 2
- end,
- Delta2 = Delta1 + Delta1 div NumPoints,
- adapt1(Delta2, 0).
-
-adapt1(Delta, K) ->
- if Delta > ((?BASE) - (?TMIN)) * (?TMAX) div 2 ->
- adapt1(Delta div ((?BASE) - (?TMIN)), K + (?BASE));
- true ->
- K +
- ((?BASE) - (?TMIN) + 1) * Delta div (Delta + (?SKEW))
- end.
-
-codepoint(C) ->
- if (0 =< C) and (C =< 25) -> C + 97;
- (26 =< C) and (C =< 35) -> C + 22
- end.
-
-%%%===================================================================
-%%% Unit tests
-%%%===================================================================
--ifdef(TEST).
-
-acsii_test() ->
- ?assertEqual(<<"test.org">>, domain_utf8_to_ascii(<<"test.org">>)).
-
-utf8_test() ->
- ?assertEqual(
- <<"xn--d1acufc.xn--p1ai">>,
- domain_utf8_to_ascii(
- <<208,180,208,190,208,188,208,181,208,189,46,209,128,209,132>>)).
-
--endif.
ok
end, gen_event:which_handlers(lager_event))
end,
+ case LogLevel of
+ 5 -> xmpp:set_config([{debug, true}]);
+ _ -> ok
+ end,
{module, lager};
set({_LogLevel, _}) ->
error_logger:error_msg("custom loglevels are not supported for 'lager'"),
-spec get_certfile_no_default(binary()) -> {ok, binary()} | error.
get_certfile_no_default(Domain) ->
- case ejabberd_idna:domain_utf8_to_ascii(Domain) of
+ case xmpp_idna:domain_utf8_to_ascii(Domain) of
false ->
error;
ASCIIDomain ->
%% xmpp_stream_in callbacks
-export([init/1, handle_call/3, handle_cast/2,
handle_info/2, terminate/2, code_change/3]).
--export([tls_options/1, tls_required/1, tls_verify/1, tls_enabled/1,
- compress_methods/1,
+-export([tls_options/1, tls_required/1, tls_enabled/1,
+ compress_methods/1, sasl_mechanisms/2,
unauthenticated_stream_features/1, authenticated_stream_features/1,
handle_stream_start/2, handle_stream_end/2,
handle_stream_established/1, handle_auth_success/4,
tls_required(#{server_host := LServer}) ->
ejabberd_s2s:tls_required(LServer).
-tls_verify(#{server_host := LServer}) ->
- ejabberd_s2s:tls_verify(LServer).
-
tls_enabled(#{server_host := LServer}) ->
ejabberd_s2s:tls_enabled(LServer).
+sasl_mechanisms(Mechs, #{server_host := LServer}) ->
+ lists:filter(
+ fun(<<"EXTERNAL">>) -> ejabberd_s2s:tls_verify(LServer);
+ (_) -> false
+ end, Mechs).
+
compress_methods(#{server_host := LServer}) ->
case ejabberd_s2s:zlib_enabled(LServer) of
true -> [<<"zlib">>];
change_shaper(#{shaper := ShaperName, server_host := ServerHost} = State,
RServer) ->
Shaper = acl:match_rule(ServerHost, ShaperName, jid:make(RServer)),
- xmpp_stream_in:change_shaper(State, Shaper).
+ xmpp_stream_in:change_shaper(State, ejabberd_shaper:new(Shaper)).
-spec listen_opt_type(shaper) -> fun((any()) -> any());
(certfile) -> fun((binary()) -> binary());
end,
GlobalRoutes = proplists:get_value(global_routes, Opts, true),
Timeout = ejabberd_config:negotiation_timeout(),
- State1 = xmpp_stream_in:change_shaper(State, Shaper),
+ State1 = xmpp_stream_in:change_shaper(State, ejabberd_shaper:new(Shaper)),
State2 = xmpp_stream_in:set_timeout(State1, Timeout),
State3 = State2#{access => Access,
xmlns => ?NS_COMPONENT,
{ok, {{one_for_one, 10, 1},
[worker(ejabberd_hooks),
worker(ejabberd_cluster),
- worker(cyrsasl),
worker(translate),
worker(ejabberd_access_permissions),
worker(ejabberd_ctl),
+++ /dev/null
-%%%----------------------------------------------------------------------
-%%% File : scram.erl
-%%% Author : Stephen Röttger <stephen.roettger@googlemail.com>
-%%% Purpose : SCRAM (RFC 5802)
-%%% Created : 7 Aug 2011 by Stephen Röttger <stephen.roettger@googlemail.com>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
--module(scram).
-
--author('stephen.roettger@googlemail.com').
-
-%% External exports
-%% ejabberd doesn't implement SASLPREP, so we use the similar RESOURCEPREP instead
--export([salted_password/3, stored_key/1, server_key/1,
- server_signature/2, client_signature/2, client_key/1,
- client_key/2]).
-
--spec salted_password(binary(), binary(), non_neg_integer()) -> binary().
-
-salted_password(Password, Salt, IterationCount) ->
- hi(jid:resourceprep(Password), Salt, IterationCount).
-
--spec client_key(binary()) -> binary().
-
-client_key(SaltedPassword) ->
- sha_mac(SaltedPassword, <<"Client Key">>).
-
--spec stored_key(binary()) -> binary().
-
-stored_key(ClientKey) -> crypto:hash(sha, ClientKey).
-
--spec server_key(binary()) -> binary().
-
-server_key(SaltedPassword) ->
- sha_mac(SaltedPassword, <<"Server Key">>).
-
--spec client_signature(binary(), binary()) -> binary().
-
-client_signature(StoredKey, AuthMessage) ->
- sha_mac(StoredKey, AuthMessage).
-
--spec client_key(binary(), binary()) -> binary().
-
-client_key(ClientProof, ClientSignature) ->
- crypto:exor(ClientProof, ClientSignature).
-
--spec server_signature(binary(), binary()) -> binary().
-
-server_signature(ServerKey, AuthMessage) ->
- sha_mac(ServerKey, AuthMessage).
-
-hi(Password, Salt, IterationCount) ->
- U1 = sha_mac(Password, <<Salt/binary, 0, 0, 0, 1>>),
- crypto:exor(U1, hi_round(Password, U1, IterationCount - 1)).
-
-hi_round(Password, UPrev, 1) ->
- sha_mac(Password, UPrev);
-hi_round(Password, UPrev, IterationCount) ->
- U = sha_mac(Password, UPrev),
- crypto:exor(U, hi_round(Password, U, IterationCount - 1)).
-
-sha_mac(Key, Data) ->
- crypto:hmac(sha, Key, Data).
+++ /dev/null
-%%%----------------------------------------------------------------------
-%%% File : xmpp_socket.erl
-%%% Author : Alexey Shchepin <alexey@process-one.net>
-%%% Purpose : Socket with zlib and TLS support library
-%%% Created : 23 Aug 2006 by Alexey Shchepin <alexey@process-one.net>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%----------------------------------------------------------------------
-
--module(xmpp_socket).
-
--author('alexey@process-one.net').
-
-%% API
--export([start/4,
- connect/3,
- connect/4,
- connect/5,
- starttls/2,
- compress/1,
- compress/2,
- reset_stream/1,
- send_element/2,
- send_header/2,
- send_trailer/1,
- send/2,
- send_xml/2,
- recv/2,
- activate/1,
- change_shaper/2,
- monitor/1,
- get_sockmod/1,
- get_transport/1,
- get_peer_certificate/2,
- get_verify_result/1,
- close/1,
- pp/1,
- sockname/1, peername/1]).
-
--include("xmpp.hrl").
--include("logger.hrl").
-
--type sockmod() :: ejabberd_bosh |
- ejabberd_http_ws |
- gen_tcp | fast_tls | ezlib.
--type receiver() :: atom().
--type socket() :: pid() | inet:socket() |
- fast_tls:tls_socket() |
- ezlib:zlib_socket() |
- ejabberd_bosh:bosh_socket() |
- ejabberd_http_ws:ws_socket().
-
--record(socket_state, {sockmod = gen_tcp :: sockmod(),
- socket :: socket(),
- max_stanza_size = infinity :: timeout(),
- xml_stream :: undefined | fxml_stream:xml_stream_state(),
- shaper = none :: none | ejabberd_shaper:shaper(),
- receiver :: receiver()}).
-
--type socket_state() :: #socket_state{}.
-
--export_type([socket/0, socket_state/0, sockmod/0]).
-
--callback start({module(), socket_state()},
- [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore.
--callback start_link({module(), socket_state()},
- [proplists:property()]) -> {ok, pid()} | {error, term()} | ignore.
--callback socket_type() -> xml_stream | independent | raw.
-
--define(is_http_socket(S),
- (S#socket_state.sockmod == ejabberd_bosh orelse
- S#socket_state.sockmod == ejabberd_http_ws)).
-
-%%====================================================================
-%% API
-%%====================================================================
--spec start(atom(), sockmod(), socket(), [proplists:property()])
- -> {ok, pid() | independent} | {error, inet:posix() | any()} | ignore.
-start(Module, SockMod, Socket, Opts) ->
- try
- case Module:socket_type() of
- independent ->
- {ok, independent};
- xml_stream ->
- MaxStanzaSize = proplists:get_value(max_stanza_size, Opts, infinity),
- Receiver = proplists:get_value(receiver, Opts),
- SocketData = #socket_state{sockmod = SockMod,
- socket = Socket,
- receiver = Receiver,
- max_stanza_size = MaxStanzaSize},
- {ok, Pid} = Module:start({?MODULE, SocketData}, Opts),
- Receiver1 = if is_pid(Receiver) -> Receiver;
- true -> Pid
- end,
- ok = controlling_process(SocketData, Receiver1),
- ok = become_controller(SocketData, Pid),
- {ok, Receiver1};
- raw ->
- {ok, Pid} = Module:start({SockMod, Socket}, Opts),
- ok = SockMod:controlling_process(Socket, Pid),
- {ok, Pid}
- end
- catch
- _:{badmatch, {error, _} = Err} ->
- SockMod:close(Socket),
- Err;
- _:{badmatch, ignore} ->
- SockMod:close(Socket),
- ignore
- end.
-
-connect(Addr, Port, Opts) ->
- connect(Addr, Port, Opts, infinity, self()).
-
-connect(Addr, Port, Opts, Timeout) ->
- connect(Addr, Port, Opts, Timeout, self()).
-
-connect(Addr, Port, Opts, Timeout, Owner) ->
- case gen_tcp:connect(Addr, Port, Opts, Timeout) of
- {ok, Socket} ->
- SocketData = #socket_state{sockmod = gen_tcp, socket = Socket},
- case controlling_process(SocketData, Owner) of
- ok ->
- activate_after(Socket, Owner, 0),
- {ok, SocketData};
- {error, _Reason} = Error ->
- gen_tcp:close(Socket),
- Error
- end;
- {error, _Reason} = Error ->
- Error
- end.
-
-starttls(#socket_state{socket = Socket,
- receiver = undefined} = SocketData, TLSOpts) ->
- case fast_tls:tcp_to_tls(Socket, TLSOpts) of
- {ok, TLSSocket} ->
- SocketData1 = SocketData#socket_state{socket = TLSSocket,
- sockmod = fast_tls},
- SocketData2 = reset_stream(SocketData1),
- case fast_tls:recv_data(TLSSocket, <<>>) of
- {ok, TLSData} ->
- parse(SocketData2, TLSData);
- {error, _} = Err ->
- Err
- end;
- {error, _} = Err ->
- Err
- end.
-
-compress(SocketData) -> compress(SocketData, undefined).
-
-compress(#socket_state{receiver = undefined,
- sockmod = SockMod,
- socket = Socket} = SocketData, Data) ->
- {ok, ZlibSocket} = ezlib:enable_zlib(SockMod, Socket),
- case Data of
- undefined -> ok;
- _ -> send(SocketData, Data)
- end,
- SocketData1 = SocketData#socket_state{socket = ZlibSocket,
- sockmod = ezlib},
- SocketData2 = reset_stream(SocketData1),
- case ezlib:recv_data(ZlibSocket, <<"">>) of
- {ok, ZlibData} ->
- parse(SocketData2, ZlibData);
- {error, _} = Err ->
- Err
- end.
-
-reset_stream(#socket_state{xml_stream = XMLStream,
- receiver = Receiver,
- sockmod = SockMod, socket = Socket,
- max_stanza_size = MaxStanzaSize} = SocketData) ->
- XMLStream1 = try fxml_stream:reset(XMLStream)
- catch error:_ ->
- close_stream(XMLStream),
- fxml_stream:new(self(), MaxStanzaSize)
- end,
- case Receiver of
- undefined ->
- SocketData#socket_state{xml_stream = XMLStream1};
- _ ->
- Socket1 = SockMod:reset_stream(Socket),
- SocketData#socket_state{xml_stream = XMLStream1, socket = Socket1}
- end.
-
--spec send_element(socket_state(), fxml:xmlel()) -> ok | {error, inet:posix()}.
-send_element(SocketData, El) when ?is_http_socket(SocketData) ->
- send_xml(SocketData, {xmlstreamelement, El});
-send_element(SocketData, El) ->
- send(SocketData, fxml:element_to_binary(El)).
-
--spec send_header(socket_state(), fxml:xmlel()) -> ok | {error, inet:posix()}.
-send_header(SocketData, El) when ?is_http_socket(SocketData) ->
- send_xml(SocketData, {xmlstreamstart, El#xmlel.name, El#xmlel.attrs});
-send_header(SocketData, El) ->
- send(SocketData, fxml:element_to_header(El)).
-
--spec send_trailer(socket_state()) -> ok | {error, inet:posix()}.
-send_trailer(SocketData) when ?is_http_socket(SocketData) ->
- send_xml(SocketData, {xmlstreamend, <<"stream:stream">>});
-send_trailer(SocketData) ->
- send(SocketData, <<"</stream:stream>">>).
-
--spec send(socket_state(), iodata()) -> ok | {error, closed | inet:posix()}.
-send(#socket_state{sockmod = SockMod, socket = Socket} = SocketData, Data) ->
- ?DEBUG("(~s) Send XML on stream = ~p", [pp(SocketData), Data]),
- try SockMod:send(Socket, Data) of
- {error, einval} -> {error, closed};
- Result -> Result
- catch _:badarg ->
- %% Some modules throw badarg exceptions on closed sockets
- %% TODO: their code should be improved
- {error, closed}
- end.
-
--spec send_xml(socket_state(),
- {xmlstreamelement, fxml:xmlel()} |
- {xmlstreamstart, binary(), [{binary(), binary()}]} |
- {xmlstreamend, binary()} |
- {xmlstreamraw, iodata()}) -> term().
-send_xml(SocketData, El) ->
- (SocketData#socket_state.sockmod):send_xml(SocketData#socket_state.socket, El).
-
-recv(#socket_state{xml_stream = undefined} = SocketData, Data) ->
- XMLStream = fxml_stream:new(self(), SocketData#socket_state.max_stanza_size),
- recv(SocketData#socket_state{xml_stream = XMLStream}, Data);
-recv(#socket_state{sockmod = SockMod, socket = Socket} = SocketData, Data) ->
- case SockMod of
- fast_tls ->
- case fast_tls:recv_data(Socket, Data) of
- {ok, TLSData} ->
- parse(SocketData, TLSData);
- {error, _} = Err ->
- Err
- end;
- ezlib ->
- case ezlib:recv_data(Socket, Data) of
- {ok, ZlibData} ->
- parse(SocketData, ZlibData);
- {error, _} = Err ->
- Err
- end;
- _ ->
- parse(SocketData, Data)
- end.
-
-change_shaper(#socket_state{receiver = undefined} = SocketData, Shaper) ->
- ShaperState = ejabberd_shaper:new(Shaper),
- SocketData#socket_state{shaper = ShaperState};
-change_shaper(#socket_state{sockmod = SockMod,
- socket = Socket} = SocketData, Shaper) ->
- SockMod:change_shaper(Socket, Shaper),
- SocketData.
-
-monitor(#socket_state{receiver = undefined}) ->
- make_ref();
-monitor(#socket_state{sockmod = SockMod, socket = Socket}) ->
- SockMod:monitor(Socket).
-
-controlling_process(#socket_state{sockmod = SockMod,
- socket = Socket}, Pid) ->
- SockMod:controlling_process(Socket, Pid).
-
-become_controller(#socket_state{receiver = Receiver,
- sockmod = SockMod,
- socket = Socket}, Pid) ->
- if is_pid(Receiver) ->
- SockMod:become_controller(Receiver, Pid);
- true ->
- activate_after(Socket, Pid, 0)
- end.
-
-get_sockmod(SocketData) ->
- SocketData#socket_state.sockmod.
-
-get_transport(#socket_state{sockmod = SockMod,
- socket = Socket}) ->
- case SockMod of
- gen_tcp -> tcp;
- fast_tls -> tls;
- ezlib ->
- case ezlib:get_sockmod(Socket) of
- gen_tcp -> tcp_zlib;
- fast_tls -> tls_zlib
- end;
- ejabberd_bosh -> http_bind;
- ejabberd_http_ws -> websocket
- end.
-
-get_peer_certificate(SocketData, Type) ->
- fast_tls:get_peer_certificate(SocketData#socket_state.socket, Type).
-
-get_verify_result(SocketData) ->
- fast_tls:get_verify_result(SocketData#socket_state.socket).
-
-close(#socket_state{sockmod = SockMod, socket = Socket}) ->
- SockMod:close(Socket).
-
-sockname(#socket_state{sockmod = SockMod,
- socket = Socket}) ->
- case SockMod of
- gen_tcp -> inet:sockname(Socket);
- _ -> SockMod:sockname(Socket)
- end.
-
-peername(#socket_state{sockmod = SockMod,
- socket = Socket}) ->
- case SockMod of
- gen_tcp -> inet:peername(Socket);
- _ -> SockMod:peername(Socket)
- end.
-
-activate(#socket_state{sockmod = SockMod, socket = Socket}) ->
- case SockMod of
- gen_tcp -> inet:setopts(Socket, [{active, once}]);
- _ -> SockMod:setopts(Socket, [{active, once}])
- end.
-
-activate_after(Socket, Pid, Pause) ->
- if Pause > 0 ->
- erlang:send_after(Pause, Pid, {tcp, Socket, <<>>});
- true ->
- Pid ! {tcp, Socket, <<>>}
- end,
- ok.
-
-pp(#socket_state{receiver = Receiver} = State) ->
- Transport = get_transport(State),
- Receiver1 = case Receiver of
- undefined -> self();
- _ -> Receiver
- end,
- io_lib:format("~s|~w", [Transport, Receiver1]).
-
-parse(SocketData, Data) when Data == <<>>; Data == [] ->
- case activate(SocketData) of
- ok ->
- {ok, SocketData};
- {error, _} = Err ->
- Err
- end;
-parse(SocketData, [El | Els]) when is_record(El, xmlel) ->
- self() ! {'$gen_event', {xmlstreamelement, El}},
- parse(SocketData, Els);
-parse(SocketData, [El | Els]) when
- element(1, El) == xmlstreamstart;
- element(1, El) == xmlstreamelement;
- element(1, El) == xmlstreamend;
- element(1, El) == xmlstreamerror ->
- self() ! {'$gen_event', El},
- parse(SocketData, Els);
-parse(#socket_state{xml_stream = XMLStream,
- socket = Socket,
- shaper = ShaperState} = SocketData, Data)
- when is_binary(Data) ->
- ?DEBUG("(~s) Received XML on stream = ~p", [pp(SocketData), Data]),
- XMLStream1 = fxml_stream:parse(XMLStream, Data),
- {ShaperState1, Pause} = ejabberd_shaper:update(ShaperState, byte_size(Data)),
- Ret = if Pause > 0 ->
- activate_after(Socket, self(), Pause);
- true ->
- activate(SocketData)
- end,
- case Ret of
- ok ->
- {ok, SocketData#socket_state{xml_stream = XMLStream1,
- shaper = ShaperState1}};
- {error, _} = Err ->
- Err
- end.
-
-close_stream(undefined) ->
- ok;
-close_stream(XMLStream) ->
- fxml_stream:close(XMLStream).
+++ /dev/null
-%%%-------------------------------------------------------------------
-%%% Created : 26 Nov 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%-------------------------------------------------------------------
--module(xmpp_stream_in).
--define(GEN_SERVER, p1_server).
--behaviour(?GEN_SERVER).
-
--protocol({rfc, 6120}).
--protocol({xep, 114, '1.6'}).
-
-%% API
--export([start/3, start_link/3, call/3, cast/2, reply/2, stop/1,
- send/2, close/1, close/2, send_error/3, establish/1,
- get_transport/1, change_shaper/2, set_timeout/2, format_error/1]).
-
-%% gen_server callbacks
--export([init/1, handle_cast/2, handle_call/3, handle_info/2,
- terminate/2, code_change/3]).
-
-%%-define(DBGFSM, true).
--ifdef(DBGFSM).
--define(FSMOPTS, [{debug, [trace]}]).
--else.
--define(FSMOPTS, []).
--endif.
-
--include("xmpp.hrl").
--type state() :: map().
--type stop_reason() :: {stream, reset | {in | out, stream_error()}} |
- {tls, inet:posix() | atom() | binary()} |
- {socket, inet:posix() | atom()} |
- internal_failure.
--export_type([state/0, stop_reason/0]).
--callback init(list()) -> {ok, state()} | {error, term()} | ignore.
--callback handle_cast(term(), state()) -> state().
--callback handle_call(term(), term(), state()) -> state().
--callback handle_info(term(), state()) -> state().
--callback terminate(term(), state()) -> any().
--callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
--callback handle_stream_start(stream_start(), state()) -> state().
--callback handle_stream_established(state()) -> state().
--callback handle_stream_end(stop_reason(), state()) -> state().
--callback handle_cdata(binary(), state()) -> state().
--callback handle_unauthenticated_packet(xmpp_element(), state()) -> state().
--callback handle_authenticated_packet(xmpp_element(), state()) -> state().
--callback handle_unbinded_packet(xmpp_element(), state()) -> state().
--callback handle_auth_success(binary(), binary(), module(), state()) -> state().
--callback handle_auth_failure(binary(), binary(), binary(), state()) -> state().
--callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
--callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
--callback handle_timeout(state()) -> state().
--callback get_password_fun(state()) -> fun().
--callback check_password_fun(state()) -> fun().
--callback check_password_digest_fun(state()) -> fun().
--callback bind(binary(), state()) -> {ok, state()} | {error, stanza_error(), state()}.
--callback compress_methods(state()) -> [binary()].
--callback tls_options(state()) -> [proplists:property()].
--callback tls_required(state()) -> boolean().
--callback tls_verify(state()) -> boolean().
--callback tls_enabled(state()) -> boolean().
--callback sasl_mechanisms([cyrsasl:mechanism()], state()) -> [cyrsasl:mechanism()].
--callback unauthenticated_stream_features(state()) -> [xmpp_element()].
--callback authenticated_stream_features(state()) -> [xmpp_element()].
-
-%% All callbacks are optional
--optional_callbacks([init/1,
- handle_cast/2,
- handle_call/3,
- handle_info/2,
- terminate/2,
- code_change/3,
- handle_stream_start/2,
- handle_stream_established/1,
- handle_stream_end/2,
- handle_cdata/2,
- handle_authenticated_packet/2,
- handle_unauthenticated_packet/2,
- handle_unbinded_packet/2,
- handle_auth_success/4,
- handle_auth_failure/4,
- handle_send/3,
- handle_recv/3,
- handle_timeout/1,
- get_password_fun/1,
- check_password_fun/1,
- check_password_digest_fun/1,
- bind/2,
- compress_methods/1,
- tls_options/1,
- tls_required/1,
- tls_verify/1,
- tls_enabled/1,
- sasl_mechanisms/2,
- unauthenticated_stream_features/1,
- authenticated_stream_features/1]).
-
-%%%===================================================================
-%%% API
-%%%===================================================================
-start(Mod, Args, Opts) ->
- ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
-
-start_link(Mod, Args, Opts) ->
- ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
-
-call(Ref, Msg, Timeout) ->
- ?GEN_SERVER:call(Ref, Msg, Timeout).
-
-cast(Ref, Msg) ->
- ?GEN_SERVER:cast(Ref, Msg).
-
-reply(Ref, Reply) ->
- ?GEN_SERVER:reply(Ref, Reply).
-
--spec stop(pid()) -> ok;
- (state()) -> no_return().
-stop(Pid) when is_pid(Pid) ->
- cast(Pid, stop);
-stop(#{owner := Owner} = State) when Owner == self() ->
- terminate(normal, State),
- exit(normal);
-stop(_) ->
- erlang:error(badarg).
-
--spec send(pid(), xmpp_element()) -> ok;
- (state(), xmpp_element()) -> state().
-send(Pid, Pkt) when is_pid(Pid) ->
- cast(Pid, {send, Pkt});
-send(#{owner := Owner} = State, Pkt) when Owner == self() ->
- send_pkt(State, Pkt);
-send(_, _) ->
- erlang:error(badarg).
-
--spec close(pid()) -> ok;
- (state()) -> state().
-close(Pid) when is_pid(Pid) ->
- close(Pid, closed);
-close(#{owner := Owner} = State) when Owner == self() ->
- close_socket(State);
-close(_) ->
- erlang:error(badarg).
-
--spec close(pid(), atom()) -> ok.
-close(Pid, Reason) ->
- cast(Pid, {close, Reason}).
-
--spec establish(state()) -> state().
-establish(State) ->
- process_stream_established(State).
-
--spec set_timeout(state(), non_neg_integer() | infinity) -> state().
-set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
- case Timeout of
- infinity -> State#{stream_timeout => infinity};
- _ ->
- Time = p1_time_compat:monotonic_time(milli_seconds),
- State#{stream_timeout => {Timeout, Time}}
- end;
-set_timeout(_, _) ->
- erlang:error(badarg).
-
-get_transport(#{socket := Socket, owner := Owner})
- when Owner == self() ->
- xmpp_socket:get_transport(Socket);
-get_transport(_) ->
- erlang:error(badarg).
-
--spec change_shaper(state(), ejabberd_shaper:shaper()) -> state().
-change_shaper(#{socket := Socket, owner := Owner} = State, Shaper)
- when Owner == self() ->
- Socket1 = xmpp_socket:change_shaper(Socket, Shaper),
- State#{socket => Socket1};
-change_shaper(_, _) ->
- erlang:error(badarg).
-
--spec format_error(stop_reason()) -> binary().
-format_error({socket, Reason}) ->
- format("Connection failed: ~s", [format_inet_error(Reason)]);
-format_error({stream, reset}) ->
- <<"Stream reset by peer">>;
-format_error({stream, {in, #stream_error{} = Err}}) ->
- format("Stream closed by peer: ~s", [xmpp:format_stream_error(Err)]);
-format_error({stream, {out, #stream_error{} = Err}}) ->
- format("Stream closed by us: ~s", [xmpp:format_stream_error(Err)]);
-format_error({tls, Reason}) ->
- format("TLS failed: ~s", [format_tls_error(Reason)]);
-format_error(internal_failure) ->
- <<"Internal server error">>;
-format_error(Err) ->
- format("Unrecognized error: ~w", [Err]).
-
-%%%===================================================================
-%%% gen_server callbacks
-%%%===================================================================
-init([Mod, {_SockMod, Socket}, Opts]) ->
- Encrypted = proplists:get_bool(tls, Opts),
- SocketMonitor = xmpp_socket:monitor(Socket),
- case xmpp_socket:peername(Socket) of
- {ok, IP} ->
- Time = p1_time_compat:monotonic_time(milli_seconds),
- State = #{owner => self(),
- mod => Mod,
- socket => Socket,
- socket_monitor => SocketMonitor,
- stream_timeout => {timer:seconds(30), Time},
- stream_direction => in,
- stream_id => new_id(),
- stream_state => wait_for_stream,
- stream_header_sent => false,
- stream_restarted => false,
- stream_compressed => false,
- stream_encrypted => Encrypted,
- stream_version => {1,0},
- stream_authenticated => false,
- codec_options => [ignore_els],
- xmlns => ?NS_CLIENT,
- lang => <<"">>,
- user => <<"">>,
- server => <<"">>,
- resource => <<"">>,
- lserver => <<"">>,
- ip => IP},
- case try Mod:init([State, Opts])
- catch _:undef -> {ok, State}
- end of
- {ok, State1} when not Encrypted ->
- {_, State2, Timeout} = noreply(State1),
- {ok, State2, Timeout};
- {ok, State1} when Encrypted ->
- TLSOpts = try callback(tls_options, State1)
- catch _:{?MODULE, undef} -> []
- end,
- case xmpp_socket:starttls(Socket, TLSOpts) of
- {ok, TLSSocket} ->
- State2 = State1#{socket => TLSSocket},
- {_, State3, Timeout} = noreply(State2),
- {ok, State3, Timeout};
- {error, Reason} ->
- {stop, Reason}
- end;
- {error, Reason} ->
- {stop, Reason};
- ignore ->
- ignore
- end;
- {error, _Reason} ->
- ignore
- end.
-
-handle_cast({send, Pkt}, State) ->
- noreply(send_pkt(State, Pkt));
-handle_cast(stop, State) ->
- {stop, normal, State};
-handle_cast({close, Reason}, State) ->
- State1 = close_socket(State),
- noreply(
- case is_disconnected(State) of
- true -> State1;
- false -> process_stream_end({socket, Reason}, State)
- end);
-handle_cast(Cast, State) ->
- noreply(try callback(handle_cast, Cast, State)
- catch _:{?MODULE, undef} -> State
- end).
-
-handle_call(Call, From, State) ->
- noreply(try callback(handle_call, Call, From, State)
- catch _:{?MODULE, undef} -> State
- end).
-
-handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
- #{stream_state := wait_for_stream,
- xmlns := XMLNS, lang := MyLang} = State) ->
- El = #xmlel{name = Name, attrs = Attrs},
- noreply(
- try xmpp:decode(El, XMLNS, []) of
- #stream_start{} = Pkt ->
- State1 = send_header(State, Pkt),
- case is_disconnected(State1) of
- true -> State1;
- false -> process_stream(Pkt, State1)
- end;
- _ ->
- State1 = send_header(State),
- case is_disconnected(State1) of
- true -> State1;
- false -> send_pkt(State1, xmpp:serr_invalid_xml())
- end
- catch _:{xmpp_codec, Why} ->
- State1 = send_header(State),
- case is_disconnected(State1) of
- true -> State1;
- false ->
- Txt = xmpp:io_format_error(Why),
- Lang = select_lang(MyLang, xmpp:get_lang(El)),
- Err = xmpp:serr_invalid_xml(Txt, Lang),
- send_pkt(State1, Err)
- end
- end);
-handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
- noreply(process_stream_end({stream, reset}, State));
-handle_info({'$gen_event', closed}, State) ->
- noreply(process_stream_end({socket, closed}, State));
-handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
- State1 = send_header(State),
- noreply(
- case is_disconnected(State1) of
- true -> State1;
- false ->
- Err = case Reason of
- <<"XML stanza is too big">> ->
- xmpp:serr_policy_violation(Reason, Lang);
- {_, Txt} ->
- xmpp:serr_not_well_formed(Txt, Lang)
- end,
- send_pkt(State1, Err)
- end);
-handle_info({'$gen_event', El}, #{stream_state := wait_for_stream} = State) ->
- error_logger:warning_msg("unexpected event from XML driver: ~p; "
- "xmlstreamstart was expected", [El]),
- State1 = send_header(State),
- noreply(
- case is_disconnected(State1) of
- true -> State1;
- false -> send_pkt(State1, xmpp:serr_invalid_xml())
- end);
-handle_info({'$gen_event', {xmlstreamelement, El}},
- #{xmlns := NS, codec_options := Opts} = State) ->
- noreply(
- try xmpp:decode(El, NS, Opts) of
- Pkt ->
- State1 = try callback(handle_recv, El, Pkt, State)
- catch _:{?MODULE, undef} -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false -> process_element(Pkt, State1)
- end
- catch _:{xmpp_codec, Why} ->
- State1 = try callback(handle_recv, El, {error, Why}, State)
- catch _:{?MODULE, undef} -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false -> process_invalid_xml(State1, El, Why)
- end
- end);
-handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}},
- State) ->
- noreply(try callback(handle_cdata, Data, State)
- catch _:{?MODULE, undef} -> State
- end);
-handle_info(timeout, #{lang := Lang} = State) ->
- Disconnected = is_disconnected(State),
- noreply(try callback(handle_timeout, State)
- catch _:{?MODULE, undef} when not Disconnected ->
- Txt = <<"Idle connection">>,
- send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
- _:{?MODULE, undef} ->
- stop(State)
- end);
-handle_info({'DOWN', MRef, _Type, _Object, _Info},
- #{socket_monitor := MRef} = State) ->
- noreply(process_stream_end({socket, closed}, State));
-handle_info({tcp, _, Data}, #{socket := Socket} = State) ->
- noreply(
- case xmpp_socket:recv(Socket, Data) of
- {ok, NewSocket} ->
- State#{socket => NewSocket};
- {error, Reason} when is_atom(Reason) ->
- process_stream_end({socket, Reason}, State);
- {error, Reason} ->
- %% TODO: make fast_tls return atoms
- process_stream_end({tls, Reason}, State)
- end);
-handle_info({tcp_closed, _}, State) ->
- handle_info({'$gen_event', closed}, State);
-handle_info({tcp_error, _, Reason}, State) ->
- noreply(process_stream_end({socket, Reason}, State));
-handle_info(Info, State) ->
- noreply(try callback(handle_info, Info, State)
- catch _:{?MODULE, undef} -> State
- end).
-
-terminate(Reason, State) ->
- case get(already_terminated) of
- true ->
- State;
- _ ->
- put(already_terminated, true),
- try callback(terminate, Reason, State)
- catch _:{?MODULE, undef} -> ok
- end,
- send_trailer(State)
- end.
-
-code_change(OldVsn, State, Extra) ->
- callback(code_change, OldVsn, State, Extra).
-
-%%%===================================================================
-%%% Internal functions
-%%%===================================================================
--spec noreply(state()) -> {noreply, state(), non_neg_integer() | infinity}.
-noreply(#{stream_timeout := infinity} = State) ->
- {noreply, State, infinity};
-noreply(#{stream_timeout := {MSecs, StartTime}} = State) ->
- CurrentTime = p1_time_compat:monotonic_time(milli_seconds),
- Timeout = max(0, MSecs - CurrentTime + StartTime),
- {noreply, State, Timeout}.
-
--spec new_id() -> binary().
-new_id() ->
- p1_rand:get_string().
-
--spec is_disconnected(state()) -> boolean().
-is_disconnected(#{stream_state := StreamState}) ->
- StreamState == disconnected.
-
--spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
-process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
- case xmpp:is_stanza(El) of
- true ->
- Txt = xmpp:io_format_error(Reason),
- Lang = select_lang(MyLang, xmpp:get_lang(El)),
- send_error(State, El, xmpp:err_bad_request(Txt, Lang));
- false ->
- case {xmpp:get_name(El), xmpp:get_ns(El)} of
- {Tag, ?NS_SASL} when Tag == <<"auth">>;
- Tag == <<"response">>;
- Tag == <<"abort">> ->
- Txt = xmpp:io_format_error(Reason),
- Err = #sasl_failure{reason = 'malformed-request',
- text = xmpp:mk_text(Txt, MyLang)},
- send_pkt(State, Err);
- {<<"starttls">>, ?NS_TLS} ->
- send_pkt(State, #starttls_failure{});
- {<<"compress">>, ?NS_COMPRESS} ->
- Err = #compress_failure{reason = 'setup-failed'},
- send_pkt(State, Err);
- _ ->
- %% Maybe add something more?
- State
- end
- end.
-
--spec process_stream_end(stop_reason(), state()) -> state().
-process_stream_end(_, #{stream_state := disconnected} = State) ->
- State;
-process_stream_end(Reason, State) ->
- State1 = State#{stream_timeout => infinity,
- stream_state => disconnected},
- try callback(handle_stream_end, Reason, State1)
- catch _:{?MODULE, undef} -> stop(State1)
- end.
-
--spec process_stream(stream_start(), state()) -> state().
-process_stream(#stream_start{xmlns = XML_NS,
- stream_xmlns = STREAM_NS},
- #{xmlns := NS} = State)
- when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
- send_pkt(State, xmpp:serr_invalid_namespace());
-process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
- send_pkt(State, xmpp:serr_unsupported_version());
-process_stream(#stream_start{lang = Lang},
- #{xmlns := ?NS_CLIENT, lang := DefaultLang} = State)
- when size(Lang) > 35 ->
- %% As stated in BCP47, 4.4.1:
- %% Protocols or specifications that specify limited buffer sizes for
- %% language tags MUST allow for language tags of at least 35 characters.
- %% Do not store long language tag to avoid possible DoS/flood attacks
- Txt = <<"Too long value of 'xml:lang' attribute">>,
- send_pkt(State, xmpp:serr_policy_violation(Txt, DefaultLang));
-process_stream(#stream_start{to = undefined, version = Version} = StreamStart,
- #{lang := Lang, server := Server, xmlns := NS} = State) ->
- if Version < {1,0} andalso NS /= ?NS_COMPONENT ->
- %% Work-around for gmail servers
- To = jid:make(Server),
- process_stream(StreamStart#stream_start{to = To}, State);
- true ->
- Txt = <<"Missing 'to' attribute">>,
- send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang))
- end;
-process_stream(#stream_start{to = #jid{luser = U, lresource = R}},
- #{lang := Lang} = State) when U /= <<"">>; R /= <<"">> ->
- Txt = <<"Improper 'to' attribute">>,
- send_pkt(State, xmpp:serr_improper_addressing(Txt, Lang));
-process_stream(#stream_start{to = #jid{lserver = RemoteServer}} = StreamStart,
- #{xmlns := ?NS_COMPONENT} = State) ->
- State1 = State#{remote_server => RemoteServer,
- stream_state => wait_for_handshake},
- try callback(handle_stream_start, StreamStart, State1)
- catch _:{?MODULE, undef} -> State1
- end;
-process_stream(#stream_start{to = #jid{server = Server, lserver = LServer},
- from = From} = StreamStart,
- #{stream_authenticated := Authenticated,
- stream_restarted := StreamWasRestarted,
- xmlns := NS, resource := Resource,
- stream_encrypted := Encrypted} = State) ->
- State1 = if not StreamWasRestarted ->
- State#{server => Server, lserver => LServer};
- true ->
- State
- end,
- State2 = case From of
- #jid{lserver = RemoteServer} when NS == ?NS_SERVER ->
- State1#{remote_server => RemoteServer};
- _ ->
- State1
- end,
- State3 = try callback(handle_stream_start, StreamStart, State2)
- catch _:{?MODULE, undef} -> State2
- end,
- case is_disconnected(State3) of
- true -> State3;
- false ->
- State4 = send_features(State3),
- case is_disconnected(State4) of
- true -> State4;
- false ->
- TLSRequired = is_starttls_required(State4),
- if not Authenticated and (TLSRequired and not Encrypted) ->
- State4#{stream_state => wait_for_starttls};
- not Authenticated ->
- State4#{stream_state => wait_for_sasl_request};
- (NS == ?NS_CLIENT) and (Resource == <<"">>) ->
- State4#{stream_state => wait_for_bind};
- true ->
- process_stream_established(State4)
- end
- end
- end.
-
--spec process_element(xmpp_element(), state()) -> state().
-process_element(Pkt, #{stream_state := StateName, lang := Lang} = State) ->
- case Pkt of
- #starttls{} when StateName == wait_for_starttls;
- StateName == wait_for_sasl_request ->
- process_starttls(State);
- #starttls{} ->
- process_starttls_failure(unexpected_starttls_request, State);
- #sasl_auth{} when StateName == wait_for_starttls ->
- send_pkt(State, #sasl_failure{reason = 'encryption-required'});
- #sasl_auth{} when StateName == wait_for_sasl_request ->
- process_sasl_request(Pkt, State);
- #sasl_auth{} when StateName == wait_for_sasl_response ->
- process_sasl_request(Pkt, maps:remove(sasl_state, State));
- #sasl_auth{} ->
- Txt = <<"SASL negotiation is not allowed in this state">>,
- send_pkt(State, #sasl_failure{reason = 'not-authorized',
- text = xmpp:mk_text(Txt, Lang)});
- #sasl_response{} when StateName == wait_for_starttls ->
- send_pkt(State, #sasl_failure{reason = 'encryption-required'});
- #sasl_response{} when StateName == wait_for_sasl_response ->
- process_sasl_response(Pkt, State);
- #sasl_response{} ->
- Txt = <<"SASL negotiation is not allowed in this state">>,
- send_pkt(State, #sasl_failure{reason = 'not-authorized',
- text = xmpp:mk_text(Txt, Lang)});
- #sasl_abort{} when StateName == wait_for_sasl_response ->
- process_sasl_abort(State);
- #sasl_abort{} ->
- send_pkt(State, #sasl_failure{reason = 'aborted'});
- #sasl_success{} ->
- State;
- #compress{} ->
- process_compress(Pkt, State);
- #handshake{} when StateName == wait_for_handshake ->
- process_handshake(Pkt, State);
- #handshake{} ->
- State;
- #stream_error{} ->
- process_stream_end({stream, {in, Pkt}}, State);
- _ when StateName == wait_for_sasl_request;
- StateName == wait_for_handshake;
- StateName == wait_for_sasl_response ->
- process_unauthenticated_packet(Pkt, State);
- _ when StateName == wait_for_starttls ->
- Txt = <<"Use of STARTTLS required">>,
- Err = xmpp:serr_policy_violation(Txt, Lang),
- send_pkt(State, Err);
- _ when StateName == wait_for_bind ->
- process_bind(Pkt, State);
- _ when StateName == established ->
- process_authenticated_packet(Pkt, State)
- end.
-
--spec process_unauthenticated_packet(xmpp_element(), state()) -> state().
-process_unauthenticated_packet(Pkt, State) ->
- NewPkt = set_lang(Pkt, State),
- try callback(handle_unauthenticated_packet, NewPkt, State)
- catch _:{?MODULE, undef} ->
- Err = xmpp:serr_not_authorized(),
- send(State, Err)
- end.
-
--spec process_authenticated_packet(xmpp_element(), state()) -> state().
-process_authenticated_packet(Pkt, State) ->
- Pkt1 = set_lang(Pkt, State),
- case set_from_to(Pkt1, State) of
- {ok, Pkt2} ->
- try callback(handle_authenticated_packet, Pkt2, State)
- catch _:{?MODULE, undef} ->
- Err = xmpp:err_service_unavailable(),
- send_error(State, Pkt, Err)
- end;
- {error, Err} ->
- send_pkt(State, Err)
- end.
-
--spec process_bind(xmpp_element(), state()) -> state().
-process_bind(#iq{type = set, sub_els = [_]} = Pkt,
- #{xmlns := ?NS_CLIENT, lang := MyLang} = State) ->
- try xmpp:try_subtag(Pkt, #bind{}) of
- #bind{resource = R} ->
- case callback(bind, R, State) of
- {ok, #{user := U, server := S, resource := NewR} = State1}
- when NewR /= <<"">> ->
- Reply = #bind{jid = jid:make(U, S, NewR)},
- State2 = send_pkt(State1, xmpp:make_iq_result(Pkt, Reply)),
- process_stream_established(State2);
- {error, #stanza_error{} = Err, State1} ->
- send_error(State1, Pkt, Err)
- end;
- _ ->
- try callback(handle_unbinded_packet, Pkt, State)
- catch _:{?MODULE, undef} ->
- Err = xmpp:err_not_authorized(),
- send_error(State, Pkt, Err)
- end
- catch _:{xmpp_codec, Why} ->
- Txt = xmpp:io_format_error(Why),
- Lang = select_lang(MyLang, xmpp:get_lang(Pkt)),
- Err = xmpp:err_bad_request(Txt, Lang),
- send_error(State, Pkt, Err)
- end;
-process_bind(Pkt, State) ->
- try callback(handle_unbinded_packet, Pkt, State)
- catch _:{?MODULE, undef} ->
- Err = xmpp:err_not_authorized(),
- send_error(State, Pkt, Err)
- end.
-
--spec process_handshake(handshake(), state()) -> state().
-process_handshake(#handshake{data = Digest},
- #{stream_id := StreamID,
- remote_server := RemoteServer} = State) ->
- GetPW = try callback(get_password_fun, State)
- catch _:{?MODULE, undef} -> fun(_) -> {false, undefined} end
- end,
- AuthRes = case GetPW(<<"">>) of
- {false, _} ->
- false;
- {Password, _} ->
- str:sha(<<StreamID/binary, Password/binary>>) == Digest
- end,
- case AuthRes of
- true ->
- State1 = try callback(handle_auth_success,
- RemoteServer, <<"handshake">>, undefined, State)
- catch _:{?MODULE, undef} -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false ->
- State2 = send_pkt(State1, #handshake{}),
- process_stream_established(State2)
- end;
- false ->
- State1 = try callback(handle_auth_failure,
- RemoteServer, <<"handshake">>, <<"not authorized">>, State)
- catch _:{?MODULE, undef} -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false -> send_pkt(State1, xmpp:serr_not_authorized())
- end
- end.
-
--spec process_stream_established(state()) -> state().
-process_stream_established(#{stream_state := StateName} = State)
- when StateName == disconnected; StateName == established ->
- State;
-process_stream_established(State) ->
- State1 = State#{stream_authenticated => true,
- stream_state => established,
- stream_timeout => infinity},
- try callback(handle_stream_established, State1)
- catch _:{?MODULE, undef} -> State1
- end.
-
--spec process_compress(compress(), state()) -> state().
-process_compress(#compress{},
- #{stream_compressed := Compressed,
- stream_authenticated := Authenticated} = State)
- when Compressed or not Authenticated ->
- send_pkt(State, #compress_failure{reason = 'setup-failed'});
-process_compress(#compress{methods = HisMethods},
- #{socket := Socket} = State) ->
- MyMethods = try callback(compress_methods, State)
- catch _:{?MODULE, undef} -> []
- end,
- CommonMethods = lists_intersection(MyMethods, HisMethods),
- case lists:member(<<"zlib">>, CommonMethods) of
- true ->
- case xmpp_socket:compress(Socket) of
- {ok, ZlibSocket} ->
- State1 = send_pkt(State, #compressed{}),
- case is_disconnected(State1) of
- true -> State1;
- false ->
- State1#{socket => ZlibSocket,
- stream_id => new_id(),
- stream_header_sent => false,
- stream_restarted => true,
- stream_state => wait_for_stream,
- stream_compressed => true}
- end;
- {error, _} ->
- Err = #compress_failure{reason = 'setup-failed'},
- send_pkt(State, Err)
- end;
- false ->
- send_pkt(State, #compress_failure{reason = 'unsupported-method'})
- end.
-
--spec process_starttls(state()) -> state().
-process_starttls(#{stream_encrypted := true} = State) ->
- process_starttls_failure(already_encrypted, State);
-process_starttls(#{socket := Socket} = State) ->
- case is_starttls_available(State) of
- true ->
- TLSOpts = try callback(tls_options, State)
- catch _:{?MODULE, undef} -> []
- end,
- case xmpp_socket:starttls(Socket, TLSOpts) of
- {ok, TLSSocket} ->
- State1 = send_pkt(State, #starttls_proceed{}),
- case is_disconnected(State1) of
- true -> State1;
- false ->
- State1#{socket => TLSSocket,
- stream_id => new_id(),
- stream_header_sent => false,
- stream_restarted => true,
- stream_state => wait_for_stream,
- stream_encrypted => true}
- end;
- {error, Reason} ->
- process_starttls_failure(Reason, State)
- end;
- false ->
- process_starttls_failure(starttls_unsupported, State)
- end.
-
--spec process_starttls_failure(term(), state()) -> state().
-process_starttls_failure(Why, State) ->
- State1 = send_pkt(State, #starttls_failure{}),
- case is_disconnected(State1) of
- true -> State1;
- false -> process_stream_end({tls, Why}, State1)
- end.
-
--spec process_sasl_request(sasl_auth(), state()) -> state().
-process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
- #{lserver := LServer} = State) ->
- State1 = State#{sasl_mech => Mech},
- Mechs = get_sasl_mechanisms(State1),
- case lists:member(Mech, Mechs) of
- true when Mech == <<"EXTERNAL">> ->
- Res = case xmpp_stream_pkix:authenticate(State1, ClientIn) of
- {ok, Peer} ->
- {ok, [{auth_module, pkix}, {username, Peer}]};
- {error, Reason, Peer} ->
- {error, Reason, Peer}
- end,
- process_sasl_result(Res, State1);
- true ->
- GetPW = try callback(get_password_fun, State1)
- catch _:{?MODULE, undef} -> fun(_) -> false end
- end,
- CheckPW = try callback(check_password_fun, State1)
- catch _:{?MODULE, undef} -> fun(_, _, _) -> false end
- end,
- CheckPWDigest = try callback(check_password_digest_fun, State1)
- catch _:{?MODULE, undef} -> fun(_, _, _, _, _) -> false end
- end,
- SASLState = cyrsasl:server_new(<<"jabber">>, LServer, <<"">>, [],
- GetPW, CheckPW, CheckPWDigest),
- Res = cyrsasl:server_start(SASLState, Mech, ClientIn),
- process_sasl_result(Res, State1#{sasl_state => SASLState});
- false ->
- process_sasl_result({error, unsupported_mechanism, <<"">>}, State1)
- end.
-
--spec process_sasl_response(sasl_response(), state()) -> state().
-process_sasl_response(#sasl_response{text = ClientIn},
- #{sasl_state := SASLState} = State) ->
- SASLResult = cyrsasl:server_step(SASLState, ClientIn),
- process_sasl_result(SASLResult, State).
-
--spec process_sasl_result(cyrsasl:sasl_return(), state()) -> state().
-process_sasl_result({ok, Props}, State) ->
- process_sasl_success(Props, <<"">>, State);
-process_sasl_result({ok, Props, ServerOut}, State) ->
- process_sasl_success(Props, ServerOut, State);
-process_sasl_result({continue, ServerOut, NewSASLState}, State) ->
- process_sasl_continue(ServerOut, NewSASLState, State);
-process_sasl_result({error, Reason, User}, State) ->
- process_sasl_failure(Reason, User, State).
-
--spec process_sasl_success([cyrsasl:sasl_property()], binary(), state()) -> state().
-process_sasl_success(Props, ServerOut,
- #{socket := Socket,
- sasl_mech := Mech} = State) ->
- User = identity(Props),
- AuthModule = proplists:get_value(auth_module, Props),
- Socket1 = xmpp_socket:reset_stream(Socket),
- State0 = State#{socket => Socket1},
- State1 = try callback(handle_auth_success, User, Mech, AuthModule, State0)
- catch _:{?MODULE, undef} -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false ->
- State2 = send_pkt(State1, #sasl_success{text = ServerOut}),
- case is_disconnected(State2) of
- true -> State2;
- false ->
- State3 = maps:remove(sasl_state,
- maps:remove(sasl_mech, State2)),
- State3#{stream_id => new_id(),
- stream_authenticated => true,
- stream_header_sent => false,
- stream_restarted => true,
- stream_state => wait_for_stream,
- user => User}
- end
- end.
-
--spec process_sasl_continue(binary(), cyrsasl:sasl_state(), state()) -> state().
-process_sasl_continue(ServerOut, NewSASLState, State) ->
- State1 = State#{sasl_state => NewSASLState,
- stream_state => wait_for_sasl_response},
- send_pkt(State1, #sasl_challenge{text = ServerOut}).
-
--spec process_sasl_failure(atom(), binary(), state()) -> state().
-process_sasl_failure(Err, User,
- #{sasl_mech := Mech, lang := Lang} = State) ->
- {Reason, Text} = format_sasl_error(Mech, Err),
- State1 = try callback(handle_auth_failure, User, Mech, Text, State)
- catch _:{?MODULE, undef} -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false ->
- State2 = send_pkt(State1,
- #sasl_failure{reason = Reason,
- text = xmpp:mk_text(Text, Lang)}),
- case is_disconnected(State2) of
- true -> State2;
- false ->
- State3 = maps:remove(sasl_state,
- maps:remove(sasl_mech, State2)),
- State3#{stream_state => wait_for_sasl_request}
- end
- end.
-
--spec process_sasl_abort(state()) -> state().
-process_sasl_abort(State) ->
- process_sasl_failure(aborted, <<"">>, State).
-
--spec send_features(state()) -> state().
-send_features(#{stream_version := {1,0},
- stream_encrypted := Encrypted} = State) ->
- TLSRequired = is_starttls_required(State),
- Features = if TLSRequired and not Encrypted ->
- get_tls_feature(State);
- true ->
- get_sasl_feature(State) ++ get_compress_feature(State)
- ++ get_tls_feature(State) ++ get_bind_feature(State)
- ++ get_session_feature(State) ++ get_other_features(State)
- end,
- send_pkt(State, #stream_features{sub_els = Features});
-send_features(State) ->
- %% clients and servers from stone age
- State.
-
--spec get_sasl_mechanisms(state()) -> [cyrsasl:mechanism()].
-get_sasl_mechanisms(#{stream_encrypted := Encrypted,
- xmlns := NS, lserver := LServer} = State) ->
- Mechs = if NS == ?NS_CLIENT -> cyrsasl:listmech(LServer);
- true -> []
- end,
- TLSVerify = try callback(tls_verify, State)
- catch _:{?MODULE, undef} -> false
- end,
- Mechs1 = if Encrypted andalso (TLSVerify orelse NS == ?NS_SERVER) ->
- [<<"EXTERNAL">>|Mechs];
- true ->
- Mechs
- end,
- try callback(sasl_mechanisms, Mechs1, State)
- catch _:{?MODULE, undef} -> Mechs1
- end.
-
--spec get_sasl_feature(state()) -> [sasl_mechanisms()].
-get_sasl_feature(#{stream_authenticated := false,
- stream_encrypted := Encrypted} = State) ->
- TLSRequired = is_starttls_required(State),
- if Encrypted or not TLSRequired ->
- Mechs = get_sasl_mechanisms(State),
- [#sasl_mechanisms{list = Mechs}];
- true ->
- []
- end;
-get_sasl_feature(_) ->
- [].
-
--spec get_compress_feature(state()) -> [compression()].
-get_compress_feature(#{stream_compressed := false,
- stream_authenticated := true} = State) ->
- try callback(compress_methods, State) of
- [] -> [];
- Ms -> [#compression{methods = Ms}]
- catch _:{?MODULE, undef} ->
- []
- end;
-get_compress_feature(_) ->
- [].
-
--spec get_tls_feature(state()) -> [starttls()].
-get_tls_feature(#{stream_authenticated := false,
- stream_encrypted := false} = State) ->
- case is_starttls_available(State) of
- true ->
- TLSRequired = is_starttls_required(State),
- [#starttls{required = TLSRequired}];
- false ->
- []
- end;
-get_tls_feature(_) ->
- [].
-
--spec get_bind_feature(state()) -> [bind()].
-get_bind_feature(#{xmlns := ?NS_CLIENT,
- stream_authenticated := true,
- resource := <<"">>}) ->
- [#bind{}];
-get_bind_feature(_) ->
- [].
-
--spec get_session_feature(state()) -> [xmpp_session()].
-get_session_feature(#{xmlns := ?NS_CLIENT,
- stream_authenticated := true,
- resource := <<"">>}) ->
- [#xmpp_session{optional = true}];
-get_session_feature(_) ->
- [].
-
--spec get_other_features(state()) -> [xmpp_element()].
-get_other_features(#{stream_authenticated := Auth} = State) ->
- try
- if Auth -> callback(authenticated_stream_features, State);
- true -> callback(unauthenticated_stream_features, State)
- end
- catch _:{?MODULE, undef} ->
- []
- end.
-
--spec is_starttls_available(state()) -> boolean().
-is_starttls_available(State) ->
- try callback(tls_enabled, State)
- catch _:{?MODULE, undef} -> true
- end.
-
--spec is_starttls_required(state()) -> boolean().
-is_starttls_required(State) ->
- try callback(tls_required, State)
- catch _:{?MODULE, undef} -> false
- end.
-
--spec set_from_to(xmpp_element(), state()) -> {ok, xmpp_element()} |
- {error, stream_error()}.
-set_from_to(Pkt, _State) when not ?is_stanza(Pkt) ->
- {ok, Pkt};
-set_from_to(Pkt, #{user := U, server := S, resource := R,
- lang := Lang, xmlns := ?NS_CLIENT}) ->
- JID = jid:make(U, S, R),
- From = case xmpp:get_from(Pkt) of
- undefined -> JID;
- F -> F
- end,
- if JID#jid.luser == From#jid.luser andalso
- JID#jid.lserver == From#jid.lserver andalso
- (JID#jid.lresource == From#jid.lresource
- orelse From#jid.lresource == <<"">>) ->
- To = case xmpp:get_to(Pkt) of
- undefined -> jid:make(U, S);
- T -> T
- end,
- {ok, xmpp:set_from_to(Pkt, JID, To)};
- true ->
- Txt = <<"Improper 'from' attribute">>,
- {error, xmpp:serr_invalid_from(Txt, Lang)}
- end;
-set_from_to(Pkt, #{lang := Lang}) ->
- From = xmpp:get_from(Pkt),
- To = xmpp:get_to(Pkt),
- if From == undefined ->
- Txt = <<"Missing 'from' attribute">>,
- {error, xmpp:serr_improper_addressing(Txt, Lang)};
- To == undefined ->
- Txt = <<"Missing 'to' attribute">>,
- {error, xmpp:serr_improper_addressing(Txt, Lang)};
- true ->
- {ok, Pkt}
- end.
-
--spec send_header(state()) -> state().
-send_header(#{stream_version := Version} = State) ->
- send_header(State, #stream_start{version = Version}).
-
--spec send_header(state(), stream_start()) -> state().
-send_header(#{stream_id := StreamID,
- stream_version := MyVersion,
- stream_header_sent := false,
- lang := MyLang,
- xmlns := NS} = State,
- #stream_start{to = HisTo, from = HisFrom,
- lang = HisLang, version = HisVersion}) ->
- Lang = select_lang(MyLang, HisLang),
- NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
- true -> <<"">>
- end,
- Version = case HisVersion of
- undefined -> undefined;
- {0,_} -> HisVersion;
- _ -> MyVersion
- end,
- StreamStart = #stream_start{version = Version,
- lang = Lang,
- xmlns = NS,
- stream_xmlns = ?NS_STREAM,
- db_xmlns = NS_DB,
- id = StreamID,
- to = HisFrom,
- from = HisTo},
- State1 = State#{lang => Lang,
- stream_version => Version,
- stream_header_sent => true},
- case socket_send(State1, StreamStart) of
- ok -> State1;
- {error, Why} -> process_stream_end({socket, Why}, State1)
- end;
-send_header(State, _) ->
- State.
-
--spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
-send_pkt(State, Pkt) ->
- Result = socket_send(State, Pkt),
- State1 = try callback(handle_send, Pkt, Result, State)
- catch _:{?MODULE, undef} -> State
- end,
- case Result of
- _ when is_record(Pkt, stream_error) ->
- process_stream_end({stream, {out, Pkt}}, State1);
- ok ->
- State1;
- {error, Why} ->
- process_stream_end({socket, Why}, State1)
- end.
-
--spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state().
-send_error(State, Pkt, Err) ->
- case xmpp:is_stanza(Pkt) of
- true ->
- case xmpp:get_type(Pkt) of
- result -> State;
- error -> State;
- <<"result">> -> State;
- <<"error">> -> State;
- _ ->
- ErrPkt = xmpp:make_error(Pkt, Err),
- send_pkt(State, ErrPkt)
- end;
- false ->
- State
- end.
-
--spec send_trailer(state()) -> state().
-send_trailer(State) ->
- socket_send(State, trailer),
- close_socket(State).
-
--spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}.
-socket_send(#{socket := Sock,
- stream_state := StateName,
- xmlns := NS,
- stream_header_sent := true}, Pkt) ->
- case Pkt of
- trailer ->
- xmpp_socket:send_trailer(Sock);
- #stream_start{} when StateName /= disconnected ->
- xmpp_socket:send_header(Sock, xmpp:encode(Pkt));
- _ when StateName /= disconnected ->
- xmpp_socket:send_element(Sock, xmpp:encode(Pkt, NS));
- _ ->
- {error, closed}
- end;
-socket_send(_, _) ->
- {error, closed}.
-
--spec close_socket(state()) -> state().
-close_socket(#{socket := Socket} = State) ->
- xmpp_socket:close(Socket),
- State#{stream_timeout => infinity,
- stream_state => disconnected}.
-
--spec select_lang(binary(), binary()) -> binary().
-select_lang(Lang, <<"">>) -> Lang;
-select_lang(_, Lang) -> Lang.
-
--spec set_lang(xmpp_element(), state()) -> xmpp_element().
-set_lang(Pkt, #{lang := MyLang, xmlns := ?NS_CLIENT}) when ?is_stanza(Pkt) ->
- HisLang = xmpp:get_lang(Pkt),
- Lang = select_lang(MyLang, HisLang),
- xmpp:set_lang(Pkt, Lang);
-set_lang(Pkt, _) ->
- Pkt.
-
--spec format_inet_error(atom()) -> string().
-format_inet_error(closed) ->
- "connection closed";
-format_inet_error(Reason) ->
- case inet:format_error(Reason) of
- "unknown POSIX error" -> atom_to_list(Reason);
- Txt -> Txt
- end.
-
--spec format_sasl_error(cyrsasl:mechanism(), atom()) -> {atom(), binary()}.
-format_sasl_error(<<"EXTERNAL">>, Err) ->
- xmpp_stream_pkix:format_error(Err);
-format_sasl_error(Mech, Err) ->
- cyrsasl:format_error(Mech, Err).
-
--spec format_tls_error(atom() | binary()) -> list().
-format_tls_error(Reason) when is_atom(Reason) ->
- format_inet_error(Reason);
-format_tls_error(Reason) ->
- Reason.
-
--spec format(io:format(), list()) -> binary().
-format(Fmt, Args) ->
- iolist_to_binary(io_lib:format(Fmt, Args)).
-
--spec lists_intersection(list(), list()) -> list().
-lists_intersection(L1, L2) ->
- lists:filter(
- fun(E) ->
- lists:member(E, L2)
- end, L1).
-
--spec identity([cyrsasl:sasl_property()]) -> binary().
-identity(Props) ->
- case proplists:get_value(authzid, Props, <<>>) of
- <<>> -> proplists:get_value(username, Props, <<>>);
- AuthzId -> AuthzId
- end.
-
-%%%===================================================================
-%%% Callbacks
-%%%===================================================================
-callback(F, #{mod := Mod} = State) ->
- case erlang:function_exported(Mod, F, 1) of
- true -> Mod:F(State);
- false -> erlang:error({?MODULE, undef})
- end.
-
-callback(F, Arg1, #{mod := Mod} = State) ->
- case erlang:function_exported(Mod, F, 2) of
- true -> Mod:F(Arg1, State);
- false -> erlang:error({?MODULE, undef})
- end.
-
-callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
- %% code_change/3 callback is a special snowflake
- case erlang:function_exported(Mod, code_change, 3) of
- true -> Mod:code_change(OldVsn, State, Extra);
- false -> {ok, State}
- end;
-callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
- case erlang:function_exported(Mod, F, 3) of
- true -> Mod:F(Arg1, Arg2, State);
- false -> erlang:error({?MODULE, undef})
- end.
-
-callback(F, Arg1, Arg2, Arg3, #{mod := Mod} = State) ->
- case erlang:function_exported(Mod, F, 4) of
- true -> Mod:F(Arg1, Arg2, Arg3, State);
- false -> erlang:error({?MODULE, undef})
- end.
+++ /dev/null
-%%%-------------------------------------------------------------------
-%%% Created : 14 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%-------------------------------------------------------------------
--module(xmpp_stream_out).
--define(GEN_SERVER, p1_server).
--behaviour(?GEN_SERVER).
-
--protocol({rfc, 6120}).
--protocol({xep, 114, '1.6'}).
--protocol({xep, 368, '1.0.0'}).
-
-%% API
--export([start/3, start_link/3, call/3, cast/2, reply/2, connect/1,
- stop/1, send/2, close/1, close/2, bind/2, establish/1, format_error/1,
- set_timeout/2, get_transport/1, change_shaper/2]).
-%% gen_server callbacks
--export([init/1, handle_call/3, handle_cast/2, handle_info/2,
- terminate/2, code_change/3]).
-
-%%-define(DBGFSM, true).
--ifdef(DBGFSM).
--define(FSMOPTS, [{debug, [trace]}]).
--else.
--define(FSMOPTS, []).
--endif.
-
--define(TCP_SEND_TIMEOUT, 15000).
-
--include("xmpp.hrl").
--include_lib("kernel/include/inet.hrl").
-
--type state() :: map().
--type noreply() :: {noreply, state(), timeout()}.
--type host_port() :: {inet:hostname(), inet:port_number(), boolean()} | ip_port().
--type ip_port() :: {inet:ip_address(), inet:port_number(), boolean()}.
--type h_addr_list() :: {{integer(), integer(), inet:port_number(), string()}, boolean()}.
--type network_error() :: {error, inet:posix() | inet_res:res_error()}.
--type tls_error_reason() :: inet:posix() | atom() | binary().
--type socket_error_reason() :: inet:posix() | atom().
--type stop_reason() :: {idna, bad_string} |
- {dns, inet:posix() | inet_res:res_error()} |
- {stream, reset | {in | out, stream_error()}} |
- {tls, tls_error_reason()} |
- {pkix, binary()} |
- {auth, atom() | binary() | string()} |
- {bind, stanza_error()} |
- {socket, socket_error_reason()} |
- internal_failure.
--export_type([state/0, stop_reason/0]).
--callback init(list()) -> {ok, state()} | {error, term()} | ignore.
--callback handle_cast(term(), state()) -> state().
--callback handle_call(term(), term(), state()) -> state().
--callback handle_info(term(), state()) -> state().
--callback terminate(term(), state()) -> any().
--callback code_change(term(), state(), term()) -> {ok, state()} | {error, term()}.
--callback handle_stream_start(stream_start(), state()) -> state().
--callback handle_stream_established(state()) -> state().
--callback handle_stream_downgraded(stream_start(), state()) -> state().
--callback handle_stream_end(stop_reason(), state()) -> state().
--callback handle_cdata(binary(), state()) -> state().
--callback handle_send(xmpp_element(), ok | {error, inet:posix()}, state()) -> state().
--callback handle_recv(fxml:xmlel(), xmpp_element() | {error, term()}, state()) -> state().
--callback handle_timeout(state()) -> state().
--callback handle_authenticated_features(stream_features(), state()) -> state().
--callback handle_unauthenticated_features(stream_features(), state()) -> state().
--callback handle_auth_success(cyrsasl:mechanism(), state()) -> state().
--callback handle_auth_failure(cyrsasl:mechanism(), binary(), state()) -> state().
--callback handle_bind_success(state()) -> state().
--callback handle_bind_failure(stanza_error(), state()) -> state().
--callback handle_packet(xmpp_element(), state()) -> state().
--callback tls_options(state()) -> [proplists:property()].
--callback tls_required(state()) -> boolean().
--callback tls_verify(state()) -> boolean().
--callback tls_enabled(state()) -> boolean().
--callback resolve(string(), state()) -> [host_port()].
--callback sasl_mechanisms(state()) -> [binary()].
--callback dns_timeout(state()) -> timeout().
--callback dns_retries(state()) -> non_neg_integer().
--callback default_port(state()) -> inet:port_number().
--callback connect_options(inet:ip_address(), list(), state()) -> list().
--callback address_families(state()) -> [inet:address_family()].
--callback connect_timeout(state()) -> timeout().
-
--optional_callbacks([init/1,
- handle_cast/2,
- handle_call/3,
- handle_info/2,
- terminate/2,
- code_change/3,
- handle_stream_start/2,
- handle_stream_established/1,
- handle_stream_downgraded/2,
- handle_stream_end/2,
- handle_cdata/2,
- handle_send/3,
- handle_recv/3,
- handle_timeout/1,
- handle_authenticated_features/2,
- handle_unauthenticated_features/2,
- handle_auth_success/2,
- handle_auth_failure/3,
- handle_bind_success/1,
- handle_bind_failure/2,
- handle_packet/2,
- tls_options/1,
- tls_required/1,
- tls_verify/1,
- tls_enabled/1,
- resolve/2,
- sasl_mechanisms/1,
- dns_timeout/1,
- dns_retries/1,
- default_port/1,
- connect_options/3,
- address_families/1,
- connect_timeout/1]).
-
-%%%===================================================================
-%%% API
-%%%===================================================================
-start({local, Mod}, Args, Opts) ->
- ?GEN_SERVER:start({local, Mod}, ?MODULE, [Mod|Args], Opts ++ ?FSMOPTS);
-start(Mod, Args, Opts) ->
- ?GEN_SERVER:start(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
-
-start_link({local, Mod}, Args, Opts) ->
- ?GEN_SERVER:start_link({local, Mod}, ?MODULE, [Mod|Args], Opts ++ ?FSMOPTS);
-start_link(Mod, Args, Opts) ->
- ?GEN_SERVER:start_link(?MODULE, [Mod|Args], Opts ++ ?FSMOPTS).
-
-call(Ref, Msg, Timeout) ->
- ?GEN_SERVER:call(Ref, Msg, Timeout).
-
-cast(Ref, Msg) ->
- ?GEN_SERVER:cast(Ref, Msg).
-
-reply(Ref, Reply) ->
- ?GEN_SERVER:reply(Ref, Reply).
-
--spec connect(pid()) -> ok.
-connect(Ref) ->
- cast(Ref, connect).
-
--spec stop(pid()) -> ok;
- (state()) -> no_return().
-stop(Pid) when is_pid(Pid) ->
- cast(Pid, stop);
-stop(#{owner := Owner} = State) when Owner == self() ->
- terminate(normal, State),
- exit(normal);
-stop(_) ->
- erlang:error(badarg).
-
--spec send(pid(), xmpp_element()) -> ok;
- (state(), xmpp_element()) -> state().
-send(Pid, Pkt) when is_pid(Pid) ->
- cast(Pid, {send, Pkt});
-send(#{owner := Owner} = State, Pkt) when Owner == self() ->
- send_pkt(State, Pkt);
-send(_, _) ->
- erlang:error(badarg).
-
--spec close(pid()) -> ok;
- (state()) -> state().
-close(Pid) when is_pid(Pid) ->
- close(Pid, closed);
-close(#{owner := Owner} = State) when Owner == self() ->
- close_socket(State);
-close(_) ->
- erlang:error(badarg).
-
--spec close(pid(), atom()) -> ok.
-close(Pid, Reason) ->
- cast(Pid, {close, Reason}).
-
--spec bind(state(), stream_features()) -> state().
-bind(#{stream_authenticated := true} = State, StreamFeatures) ->
- process_bind(StreamFeatures, State).
-
--spec establish(state()) -> state().
-establish(State) ->
- process_stream_established(State).
-
--spec set_timeout(state(), timeout()) -> state().
-set_timeout(#{owner := Owner} = State, Timeout) when Owner == self() ->
- case Timeout of
- infinity -> State#{stream_timeout => infinity};
- _ ->
- Time = p1_time_compat:monotonic_time(milli_seconds),
- State#{stream_timeout => {Timeout, Time}}
- end;
-set_timeout(_, _) ->
- erlang:error(badarg).
-
-get_transport(#{socket := Socket, owner := Owner})
- when Owner == self() ->
- xmpp_socket:get_transport(Socket);
-get_transport(_) ->
- erlang:error(badarg).
-
--spec change_shaper(state(), ejabberd_shaper:shaper()) -> state().
-change_shaper(#{socket := Socket, owner := Owner} = State, Shaper)
- when Owner == self() ->
- Socket1 = xmpp_socket:change_shaper(Socket, Shaper),
- State#{socket => Socket1};
-change_shaper(_, _) ->
- erlang:error(badarg).
-
--spec format_error(stop_reason()) -> binary().
-format_error({idna, _}) ->
- <<"Remote domain is not an IDN hostname">>;
-format_error({dns, Reason}) ->
- format("DNS lookup failed: ~s", [format_inet_error(Reason)]);
-format_error({socket, Reason}) ->
- format("Connection failed: ~s", [format_inet_error(Reason)]);
-format_error({pkix, Reason}) ->
- {_, ErrTxt} = xmpp_stream_pkix:format_error(Reason),
- format("Peer certificate rejected: ~s", [ErrTxt]);
-format_error({stream, reset}) ->
- <<"Stream reset by peer">>;
-format_error({stream, {in, #stream_error{} = Err}}) ->
- format("Stream closed by peer: ~s", [xmpp:format_stream_error(Err)]);
-format_error({stream, {out, #stream_error{} = Err}}) ->
- format("Stream closed by us: ~s", [xmpp:format_stream_error(Err)]);
-format_error({bind, #stanza_error{} = Err}) ->
- format("Resource binding failure: ~s", [xmpp:format_stanza_error(Err)]);
-format_error({tls, Reason}) ->
- format("TLS failed: ~s", [format_tls_error(Reason)]);
-format_error({auth, Reason}) ->
- format("Authentication failed: ~s", [Reason]);
-format_error(internal_failure) ->
- <<"Internal server error">>;
-format_error(Err) ->
- format("Unrecognized error: ~w", [Err]).
-
-%%%===================================================================
-%%% gen_server callbacks
-%%%===================================================================
--spec init(list()) -> {ok, state(), timeout()} | {stop, term()} | ignore.
-init([Mod, From, To, Opts]) ->
- Time = p1_time_compat:monotonic_time(milli_seconds),
- State = #{owner => self(),
- mod => Mod,
- server => From,
- user => <<"">>,
- resource => <<"">>,
- password => <<"">>,
- lang => <<"">>,
- remote_server => To,
- xmlns => ?NS_SERVER,
- codec_options => [ignore_els],
- stream_direction => out,
- stream_timeout => {timer:seconds(30), Time},
- stream_id => new_id(),
- stream_encrypted => false,
- stream_verified => false,
- stream_authenticated => false,
- stream_restarted => false,
- stream_state => connecting},
- case try Mod:init([State, Opts])
- catch _:undef -> {ok, State}
- end of
- {ok, State1} ->
- {_, State2, Timeout} = noreply(State1),
- {ok, State2, Timeout};
- {error, Reason} ->
- {stop, Reason};
- ignore ->
- ignore
- end.
-
--spec handle_call(term(), term(), state()) -> noreply().
-handle_call(Call, From, State) ->
- noreply(try callback(handle_call, Call, From, State)
- catch _:{?MODULE, undef} -> State
- end).
-
--spec handle_cast(term(), state()) -> noreply().
-handle_cast(connect, #{remote_server := RemoteServer,
- stream_state := connecting} = State) ->
- noreply(
- case idna_to_ascii(RemoteServer) of
- false ->
- process_stream_end({idna, bad_string}, State);
- ASCIIName ->
- case resolve(binary_to_list(ASCIIName), State) of
- {ok, AddrPorts} ->
- case connect(AddrPorts, State) of
- {ok, Socket, {Addr, Port, Encrypted}} ->
- SocketMonitor = xmpp_socket:monitor(Socket),
- State1 = State#{ip => {Addr, Port},
- socket => Socket,
- stream_encrypted => Encrypted,
- socket_monitor => SocketMonitor},
- State2 = State1#{stream_state => wait_for_stream},
- send_header(State2);
- {error, {Class, Why}} ->
- process_stream_end({Class, Why}, State)
- end;
- {error, Why} ->
- process_stream_end({dns, Why}, State)
- end
- end);
-handle_cast(connect, #{stream_state := disconnected} = State) ->
- State1 = reset_state(State),
- handle_cast(connect, State1);
-handle_cast(connect, State) ->
- %% Ignoring connection attempts in other states
- noreply(State);
-handle_cast({send, Pkt}, State) ->
- noreply(send_pkt(State, Pkt));
-handle_cast(stop, State) ->
- {stop, normal, State};
-handle_cast({close, Reason}, State) ->
- State1 = close_socket(State),
- noreply(
- case is_disconnected(State) of
- true -> State1;
- false -> process_stream_end({socket, Reason}, State)
- end);
-handle_cast(Cast, State) ->
- noreply(try callback(handle_cast, Cast, State)
- catch _:{?MODULE, undef} -> State
- end).
-
--spec handle_info(term(), state()) -> noreply().
-handle_info({'$gen_event', {xmlstreamstart, Name, Attrs}},
- #{stream_state := wait_for_stream,
- xmlns := XMLNS, lang := MyLang} = State) ->
- El = #xmlel{name = Name, attrs = Attrs},
- noreply(
- try xmpp:decode(El, XMLNS, []) of
- #stream_start{} = Pkt ->
- process_stream(Pkt, State);
- _ ->
- send_pkt(State, xmpp:serr_invalid_xml())
- catch _:{xmpp_codec, Why} ->
- Txt = xmpp:io_format_error(Why),
- Lang = select_lang(MyLang, xmpp:get_lang(El)),
- Err = xmpp:serr_invalid_xml(Txt, Lang),
- send_pkt(State, Err)
- end);
-handle_info({'$gen_event', {xmlstreamerror, Reason}}, #{lang := Lang}= State) ->
- State1 = send_header(State),
- noreply(
- case is_disconnected(State1) of
- true -> State1;
- false ->
- Err = case Reason of
- <<"XML stanza is too big">> ->
- xmpp:serr_policy_violation(Reason, Lang);
- {_, Txt} ->
- xmpp:serr_not_well_formed(Txt, Lang)
- end,
- send_pkt(State1, Err)
- end);
-handle_info({'$gen_event', {xmlstreamelement, El}},
- #{xmlns := NS, codec_options := Opts} = State) ->
- noreply(
- try xmpp:decode(El, NS, Opts) of
- Pkt ->
- State1 = try callback(handle_recv, El, Pkt, State)
- catch _:{?MODULE, undef} -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false -> process_element(Pkt, State1)
- end
- catch _:{xmpp_codec, Why} ->
- State1 = try callback(handle_recv, El, {error, Why}, State)
- catch _:{?MODULE, undef} -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false -> process_invalid_xml(State1, El, Why)
- end
- end);
-handle_info({'$gen_all_state_event', {xmlstreamcdata, Data}}, State) ->
- noreply(try callback(handle_cdata, Data, State)
- catch _:{?MODULE, undef} -> State
- end);
-handle_info({'$gen_event', {xmlstreamend, _}}, State) ->
- noreply(process_stream_end({stream, reset}, State));
-handle_info({'$gen_event', closed}, State) ->
- noreply(process_stream_end({socket, closed}, State));
-handle_info(timeout, #{lang := Lang} = State) ->
- Disconnected = is_disconnected(State),
- noreply(try callback(handle_timeout, State)
- catch _:{?MODULE, undef} when not Disconnected ->
- Txt = <<"Idle connection">>,
- send_pkt(State, xmpp:serr_connection_timeout(Txt, Lang));
- _:{?MODULE, undef} ->
- stop(State)
- end);
-handle_info({'DOWN', MRef, _Type, _Object, _Info},
- #{socket_monitor := MRef} = State) ->
- noreply(process_stream_end({socket, closed}, State));
-handle_info({tcp, _, Data}, #{socket := Socket} = State) ->
- noreply(
- case xmpp_socket:recv(Socket, Data) of
- {ok, NewSocket} ->
- State#{socket => NewSocket};
- {error, Reason} when is_atom(Reason) ->
- process_stream_end({socket, Reason}, State);
- {error, Reason} ->
- %% TODO: make fast_tls return atoms
- process_stream_end({tls, Reason}, State)
- end);
-handle_info({tcp_closed, _}, State) ->
- handle_info({'$gen_event', closed}, State);
-handle_info({tcp_error, _, Reason}, State) ->
- noreply(process_stream_end({socket, Reason}, State));
-handle_info({'EXIT', _, Reason}, State) ->
- {stop, Reason, State};
-handle_info(Info, State) ->
- noreply(try callback(handle_info, Info, State)
- catch _:{?MODULE, undef} -> State
- end).
-
--spec terminate(term(), state()) -> any().
-terminate(Reason, State) ->
- case get(already_terminated) of
- true ->
- State;
- _ ->
- put(already_terminated, true),
- try callback(terminate, Reason, State)
- catch _:{?MODULE, undef} -> ok
- end,
- send_trailer(State)
- end.
-
-code_change(OldVsn, State, Extra) ->
- callback(code_change, OldVsn, State, Extra).
-
-%%%===================================================================
-%%% Internal functions
-%%%===================================================================
--spec noreply(state()) -> noreply().
-noreply(#{stream_timeout := infinity} = State) ->
- {noreply, State, infinity};
-noreply(#{stream_timeout := {MSecs, OldTime}} = State) ->
- NewTime = p1_time_compat:monotonic_time(milli_seconds),
- Timeout = max(0, MSecs - NewTime + OldTime),
- {noreply, State, Timeout}.
-
--spec new_id() -> binary().
-new_id() ->
- p1_rand:get_string().
-
--spec is_disconnected(state()) -> boolean().
-is_disconnected(#{stream_state := StreamState}) ->
- StreamState == disconnected.
-
--spec process_invalid_xml(state(), fxml:xmlel(), term()) -> state().
-process_invalid_xml(#{lang := MyLang} = State, El, Reason) ->
- case xmpp:is_stanza(El) of
- true ->
- Txt = xmpp:io_format_error(Reason),
- Lang = select_lang(MyLang, xmpp:get_lang(El)),
- send_error(State, El, xmpp:err_bad_request(Txt, Lang));
- false ->
- State
- end.
-
--spec process_stream_end(stop_reason(), state()) -> state().
-process_stream_end(_, #{stream_state := disconnected} = State) ->
- State;
-process_stream_end(Reason, State) ->
- State1 = send_trailer(State),
- try callback(handle_stream_end, Reason, State1)
- catch _:{?MODULE, undef} -> stop(State1)
- end.
-
--spec process_stream(stream_start(), state()) -> state().
-process_stream(#stream_start{xmlns = XML_NS,
- stream_xmlns = STREAM_NS},
- #{xmlns := NS} = State)
- when XML_NS /= NS; STREAM_NS /= ?NS_STREAM ->
- send_pkt(State, xmpp:serr_invalid_namespace());
-process_stream(#stream_start{version = {N, _}}, State) when N > 1 ->
- send_pkt(State, xmpp:serr_unsupported_version());
-process_stream(#stream_start{lang = Lang, id = ID,
- version = Version} = StreamStart,
- State) ->
- State1 = State#{stream_remote_id => ID, lang => Lang},
- State2 = try callback(handle_stream_start, StreamStart, State1)
- catch _:{?MODULE, undef} -> State1
- end,
- case is_disconnected(State2) of
- true -> State2;
- false ->
- case Version of
- {1, _} ->
- State2#{stream_state => wait_for_features};
- _ ->
- process_stream_downgrade(StreamStart, State2)
- end
- end.
-
--spec process_element(xmpp_element(), state()) -> state().
-process_element(Pkt, #{stream_state := StateName} = State) ->
- case Pkt of
- #stream_features{} when StateName == wait_for_features ->
- process_features(Pkt, State);
- #starttls_proceed{} when StateName == wait_for_starttls_response ->
- process_starttls(State);
- #sasl_success{} when StateName == wait_for_sasl_response ->
- process_sasl_success(State);
- #sasl_failure{} when StateName == wait_for_sasl_response ->
- process_sasl_failure(Pkt, State);
- #stream_error{} ->
- process_stream_end({stream, {in, Pkt}}, State);
- _ when is_record(Pkt, stream_features);
- is_record(Pkt, starttls_proceed);
- is_record(Pkt, starttls);
- is_record(Pkt, sasl_auth);
- is_record(Pkt, sasl_success);
- is_record(Pkt, sasl_failure);
- is_record(Pkt, sasl_response);
- is_record(Pkt, sasl_abort);
- is_record(Pkt, compress);
- is_record(Pkt, handshake) ->
- %% Do not pass this crap upstream
- State;
- _ when StateName == wait_for_bind_response ->
- process_bind_response(Pkt, State);
- _ ->
- process_packet(Pkt, State)
- end.
-
--spec process_features(stream_features(), state()) -> state().
-process_features(StreamFeatures,
- #{stream_authenticated := true} = State) ->
- try callback(handle_authenticated_features, StreamFeatures, State)
- catch _:{?MODULE, undef} -> process_bind(StreamFeatures, State)
- end;
-process_features(StreamFeatures,
- #{stream_encrypted := Encrypted, lang := Lang} = State) ->
- State1 = try callback(handle_unauthenticated_features, StreamFeatures, State)
- catch _:{?MODULE, undef} -> State
- end,
- case is_disconnected(State1) of
- true -> State1;
- false ->
- TLSRequired = is_starttls_required(State1),
- TLSAvailable = is_starttls_available(State1),
- try xmpp:try_subtag(StreamFeatures, #starttls{}) of
- false when TLSRequired and not Encrypted ->
- Txt = <<"Use of STARTTLS required">>,
- send_pkt(State1, xmpp:serr_policy_violation(Txt, Lang));
- #starttls{required = true} when not TLSAvailable and not Encrypted ->
- Txt = <<"Use of STARTTLS forbidden">>,
- send_pkt(State1, xmpp:serr_unsupported_feature(Txt, Lang));
- #starttls{} when TLSAvailable and not Encrypted ->
- State2 = State1#{stream_state => wait_for_starttls_response},
- send_pkt(State2, #starttls{});
- _ ->
- State2 = process_cert_verification(State1),
- case is_disconnected(State2) of
- true -> State2;
- false -> process_sasl_mechanisms(StreamFeatures, State2)
- end
- catch _:{xmpp_codec, Why} ->
- Txt = xmpp:io_format_error(Why),
- send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang))
- end
- end.
-
--spec process_stream_established(state()) -> state().
-process_stream_established(#{stream_state := StateName} = State)
- when StateName == disconnected; StateName == established ->
- State;
-process_stream_established(State) ->
- State1 = State#{stream_authenticated := true,
- stream_state => established,
- stream_timeout => infinity},
- try callback(handle_stream_established, State1)
- catch _:{?MODULE, undef} -> State1
- end.
-
--spec process_sasl_mechanisms(stream_features(), state()) -> state().
-process_sasl_mechanisms(StreamFeatures, State) ->
- AvailMechs = sasl_mechanisms(State),
- State1 = State#{sasl_mechs_available => AvailMechs},
- try xmpp:try_subtag(StreamFeatures, #sasl_mechanisms{}) of
- #sasl_mechanisms{list = ProvidedMechs} ->
- process_sasl_auth(State1#{sasl_mechs_provided => ProvidedMechs});
- false ->
- process_sasl_auth(State1#{sasl_mechs_provided => []})
- catch _:{xmpp_codec, Why} ->
- Txt = xmpp:io_format_error(Why),
- Lang = maps:get(lang, State),
- send_pkt(State, xmpp:serr_invalid_xml(Txt, Lang))
- end.
-
-process_sasl_auth(#{stream_encrypted := false, xmlns := ?NS_SERVER} = State) ->
- State1 = State#{sasl_mechs_available => []},
- Txt = case is_starttls_available(State) of
- true -> <<"Peer doesn't support STARTTLS">>;
- false -> <<"STARTTLS is disabled in local configuration">>
- end,
- process_sasl_failure(Txt, State1);
-process_sasl_auth(#{sasl_mechs_provided := [],
- stream_encrypted := Encrypted} = State) ->
- State1 = State#{sasl_mechs_available => []},
- Hint = case Encrypted of
- true -> <<"; most likely it doesn't accept our certificate">>;
- false -> <<"">>
- end,
- Txt = <<"Peer provided no SASL mechanisms", Hint/binary>>,
- process_sasl_failure(Txt, State1);
-process_sasl_auth(#{sasl_mechs_available := []} = State) ->
- Err = maps:get(sasl_error, State,
- <<"No mutually supported SASL mechanisms found">>),
- process_sasl_failure(Err, State);
-process_sasl_auth(#{sasl_mechs_available := [Mech|AvailMechs],
- sasl_mechs_provided := ProvidedMechs} = State) ->
- State1 = State#{sasl_mechs_available => AvailMechs},
- if Mech == <<"EXTERNAL">> orelse Mech == <<"PLAIN">> ->
- case lists:member(Mech, ProvidedMechs) of
- true ->
- Text = make_sasl_authzid(Mech, State1),
- State2 = State1#{sasl_mech => Mech,
- stream_state => wait_for_sasl_response},
- send(State2, #sasl_auth{mechanism = Mech, text = Text});
- false ->
- process_sasl_auth(State1)
- end;
- true ->
- process_sasl_auth(State1)
- end.
-
--spec process_starttls(state()) -> state().
-process_starttls(#{socket := Socket} = State) ->
- case starttls(Socket, State) of
- {ok, TLSSocket} ->
- State1 = State#{socket => TLSSocket,
- stream_id => new_id(),
- stream_restarted => true,
- stream_state => wait_for_stream,
- stream_encrypted => true},
- send_header(State1);
- {error, Why} ->
- process_stream_end({tls, Why}, State)
- end.
-
--spec process_stream_downgrade(stream_start(), state()) -> state().
-process_stream_downgrade(StreamStart,
- #{lang := Lang,
- stream_encrypted := Encrypted} = State) ->
- TLSRequired = is_starttls_required(State),
- if not Encrypted and TLSRequired ->
- Txt = <<"Use of STARTTLS required">>,
- send_pkt(State, xmpp:serr_policy_violation(Txt, Lang));
- true ->
- State1 = State#{stream_state => downgraded},
- try callback(handle_stream_downgraded, StreamStart, State1)
- catch _:{?MODULE, undef} ->
- send_pkt(State1, xmpp:serr_unsupported_version())
- end
- end.
-
--spec process_cert_verification(state()) -> state().
-process_cert_verification(#{stream_encrypted := true,
- stream_verified := false} = State) ->
- case try callback(tls_verify, State)
- catch _:{?MODULE, undef} -> true
- end of
- true ->
- case xmpp_stream_pkix:authenticate(State) of
- {ok, _} ->
- State#{stream_verified => true};
- {error, Why, _Peer} ->
- process_stream_end({pkix, Why}, State)
- end;
- false ->
- State#{stream_verified => true}
- end;
-process_cert_verification(State) ->
- State.
-
--spec process_sasl_success(state()) -> state().
-process_sasl_success(#{socket := Socket, sasl_mech := Mech} = State) ->
- Socket1 = xmpp_socket:reset_stream(Socket),
- State1 = State#{socket => Socket1},
- State2 = State1#{stream_id => new_id(),
- stream_restarted => true,
- stream_state => wait_for_stream,
- stream_authenticated => true},
- State3 = reset_sasl_state(State2),
- State4 = send_header(State3),
- case is_disconnected(State4) of
- true -> State4;
- false ->
- try callback(handle_auth_success, Mech, State4)
- catch _:{?MODULE, undef} -> State4
- end
- end.
-
--spec process_sasl_failure(sasl_failure() | binary(), state()) -> state().
-process_sasl_failure(Failure, #{sasl_mechs_available := [_|_]} = State) ->
- process_sasl_auth(State#{sasl_failure => Failure});
-process_sasl_failure(#sasl_failure{} = Failure, State) ->
- Reason = format("Peer responded with error: ~s",
- [xmpp:format_sasl_error(Failure)]),
- process_sasl_failure(Reason, State);
-process_sasl_failure(Reason, State) ->
- Mech = case maps:get(sasl_mech, State, undefined) of
- undefined ->
- case sasl_mechanisms(State) of
- [] -> <<"EXTERNAL">>;
- [M|_] -> M
- end;
- M -> M
- end,
- State1 = reset_sasl_state(State),
- try callback(handle_auth_failure, Mech, {auth, Reason}, State1)
- catch _:{?MODULE, undef} -> process_stream_end({auth, Reason}, State1)
- end.
-
--spec process_bind(stream_features(), state()) -> state().
-process_bind(StreamFeatures, #{lang := Lang, xmlns := ?NS_CLIENT,
- resource := R,
- stream_state := StateName} = State)
- when StateName /= established, StateName /= disconnected ->
- case xmpp:has_subtag(StreamFeatures, #bind{}) of
- true ->
- ID = new_id(),
- Pkt = #iq{id = ID, type = set,
- sub_els = [#bind{resource = R}]},
- State1 = State#{stream_state => wait_for_bind_response,
- bind_id => ID},
- send_pkt(State1, Pkt);
- false ->
- Txt = <<"Missing resource binding feature">>,
- send_pkt(State, xmpp:serr_invalid_xml(Txt, Lang))
- end;
-process_bind(_, State) ->
- process_stream_established(State).
-
--spec process_bind_response(xmpp_element(), state()) -> state().
-process_bind_response(#iq{type = result, id = ID} = IQ,
- #{lang := Lang, bind_id := ID} = State) ->
- State1 = reset_bind_state(State),
- try xmpp:try_subtag(IQ, #bind{}) of
- #bind{jid = #jid{user = U, server = S, resource = R}} ->
- State2 = State1#{user => U, server => S, resource => R},
- State3 = try callback(handle_bind_success, State2)
- catch _:{?MODULE, undef} -> State2
- end,
- process_stream_established(State3);
- #bind{} ->
- Txt = <<"Missing <jid/> element in resource binding response">>,
- send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang));
- false ->
- Txt = <<"Missing <bind/> element in resource binding response">>,
- send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang))
- catch _:{xmpp_codec, Why} ->
- Txt = xmpp:io_format_error(Why),
- send_pkt(State1, xmpp:serr_invalid_xml(Txt, Lang))
- end;
-process_bind_response(#iq{type = error, id = ID} = IQ,
- #{bind_id := ID} = State) ->
- Err = xmpp:get_error(IQ),
- State1 = reset_bind_state(State),
- try callback(handle_bind_failure, Err, State1)
- catch _:{?MODULE, undef} -> process_stream_end({bind, Err}, State1)
- end;
-process_bind_response(Pkt, State) ->
- process_packet(Pkt, State).
-
--spec process_packet(xmpp_element(), state()) -> state().
-process_packet(Pkt, State) ->
- Pkt1 = fix_from(Pkt, State),
- try callback(handle_packet, Pkt1, State)
- catch _:{?MODULE, undef} -> State
- end.
-
--spec is_starttls_required(state()) -> boolean().
-is_starttls_required(State) ->
- try callback(tls_required, State)
- catch _:{?MODULE, undef} -> false
- end.
-
--spec is_starttls_available(state()) -> boolean().
-is_starttls_available(State) ->
- try callback(tls_enabled, State)
- catch _:{?MODULE, undef} -> true
- end.
-
--spec sasl_mechanisms(state()) -> [binary()].
-sasl_mechanisms(#{stream_encrypted := Encrypted} = State) ->
- try callback(sasl_mechanisms, State) of
- Ms when Encrypted -> Ms;
- Ms -> lists:delete(<<"EXTERNAL">>, Ms)
- catch _:{?MODULE, undef} ->
- if Encrypted -> [<<"EXTERNAL">>];
- true -> []
- end
- end.
-
--spec send_header(state()) -> state().
-send_header(#{remote_server := RemoteServer,
- stream_encrypted := Encrypted,
- lang := Lang,
- xmlns := NS,
- user := User,
- resource := Resource,
- server := Server} = State) ->
- NS_DB = if NS == ?NS_SERVER -> ?NS_SERVER_DIALBACK;
- true -> <<"">>
- end,
- From = if Encrypted ->
- jid:make(User, Server, Resource);
- NS == ?NS_SERVER ->
- jid:make(Server);
- true ->
- undefined
- end,
- StreamStart = #stream_start{xmlns = NS,
- lang = Lang,
- stream_xmlns = ?NS_STREAM,
- db_xmlns = NS_DB,
- from = From,
- to = jid:make(RemoteServer),
- version = {1,0}},
- case socket_send(State, StreamStart) of
- ok -> State;
- {error, Why} -> process_stream_end({socket, Why}, State)
- end.
-
--spec send_pkt(state(), xmpp_element() | xmlel()) -> state().
-send_pkt(State, Pkt) ->
- Result = socket_send(State, Pkt),
- State1 = try callback(handle_send, Pkt, Result, State)
- catch _:{?MODULE, undef} -> State
- end,
- case Result of
- _ when is_record(Pkt, stream_error) ->
- process_stream_end({stream, {out, Pkt}}, State1);
- ok ->
- State1;
- {error, Why} ->
- process_stream_end({socket, Why}, State1)
- end.
-
--spec send_error(state(), xmpp_element() | xmlel(), stanza_error()) -> state().
-send_error(State, Pkt, Err) ->
- case xmpp:is_stanza(Pkt) of
- true ->
- case xmpp:get_type(Pkt) of
- result -> State;
- error -> State;
- <<"result">> -> State;
- <<"error">> -> State;
- _ ->
- ErrPkt = xmpp:make_error(Pkt, Err),
- send_pkt(State, ErrPkt)
- end;
- false ->
- State
- end.
-
--spec socket_send(state(), xmpp_element() | xmlel() | trailer) -> ok | {error, inet:posix()}.
-socket_send(#{socket := Socket, xmlns := NS,
- stream_state := StateName}, Pkt) ->
- case Pkt of
- trailer ->
- xmpp_socket:send_trailer(Socket);
- #stream_start{} when StateName /= disconnected ->
- xmpp_socket:send_header(Socket, xmpp:encode(Pkt));
- _ when StateName /= disconnected ->
- xmpp_socket:send_element(Socket, xmpp:encode(Pkt, NS));
- _ ->
- {error, closed}
- end;
-socket_send(_, _) ->
- {error, closed}.
-
--spec send_trailer(state()) -> state().
-send_trailer(State) ->
- socket_send(State, trailer),
- close_socket(State).
-
--spec close_socket(state()) -> state().
-close_socket(State) ->
- case State of
- #{socket := Socket} ->
- xmpp_socket:close(Socket);
- _ ->
- ok
- end,
- State#{stream_timeout => infinity,
- stream_state => disconnected}.
-
--spec starttls(term(), state()) -> {ok, term()} | {error, tls_error_reason()}.
-starttls(Socket, #{xmlns := NS,
- remote_server := RemoteServer} = State) ->
- TLSOpts = try callback(tls_options, State)
- catch _:{?MODULE, undef} -> []
- end,
- SNI = idna_to_ascii(RemoteServer),
- ALPN = case NS of
- ?NS_SERVER -> <<"xmpp-server">>;
- ?NS_CLIENT -> <<"xmpp-client">>
- end,
- xmpp_socket:starttls(Socket, [connect, {sni, SNI}, {alpn, [ALPN]}|TLSOpts]).
-
--spec select_lang(binary(), binary()) -> binary().
-select_lang(Lang, <<"">>) -> Lang;
-select_lang(_, Lang) -> Lang.
-
--spec format_inet_error(atom()) -> string().
-format_inet_error(closed) ->
- "connection closed";
-format_inet_error(Reason) ->
- case inet:format_error(Reason) of
- "unknown POSIX error" -> atom_to_list(Reason);
- Txt -> Txt
- end.
-
--spec format_tls_error(atom() | binary()) -> list().
-format_tls_error(Reason) when is_atom(Reason) ->
- format_inet_error(Reason);
-format_tls_error(Reason) ->
- binary_to_list(Reason).
-
--spec format(io:format(), list()) -> binary().
-format(Fmt, Args) ->
- iolist_to_binary(io_lib:format(Fmt, Args)).
-
--spec make_sasl_authzid(binary(), state()) -> binary().
-make_sasl_authzid(Mech, #{user := User, server := Server,
- password := Password}) ->
- case Mech of
- <<"EXTERNAL">> ->
- jid:encode(jid:make(User, Server));
- <<"PLAIN">> ->
- JID = jid:encode(jid:make(User, Server)),
- <<JID/binary, 0, User/binary, 0, Password/binary>>
- end.
--spec fix_from(xmpp_element(), state()) -> xmpp_element().
-fix_from(Pkt, #{xmlns := ?NS_CLIENT} = State) ->
- case xmpp:is_stanza(Pkt) of
- true ->
- case xmpp:get_from(Pkt) of
- undefined ->
- #{user := U, server := S, resource := R} = State,
- From = jid:make(U, S, R),
- xmpp:set_from(Pkt, From);
- _ ->
- Pkt
- end;
- false ->
- Pkt
- end;
-fix_from(Pkt, _State) ->
- Pkt.
-
-%%%===================================================================
-%%% State resets
-%%%===================================================================
--spec reset_sasl_state(state()) -> state().
-reset_sasl_state(State) ->
- State1 = maps:remove(sasl_mech, State),
- State2 = maps:remove(sasl_failure, State1),
- State3 = maps:remove(sasl_mechs_provided, State2),
- maps:remove(sasl_mechs_available, State3).
-
--spec reset_connection_state(state()) -> state().
-reset_connection_state(State) ->
- State1 = maps:remove(ip, State),
- State2 = maps:remove(socket, State1),
- maps:remove(socket_monitor, State2).
-
--spec reset_stream_state(state()) -> state().
-reset_stream_state(State) ->
- State1 = State#{stream_id => new_id(),
- stream_encrypted => false,
- stream_verified => false,
- stream_authenticated => false,
- stream_restarted => false,
- stream_state => connecting},
- maps:remove(stream_remote_id, State1).
-
--spec reset_bind_state(state()) -> state().
-reset_bind_state(State) ->
- maps:remove(bind_id, State).
-
--spec reset_state(state()) -> state().
-reset_state(State) ->
- State1 = reset_bind_state(State),
- State2 = reset_sasl_state(State1),
- State3 = reset_connection_state(State2),
- reset_stream_state(State3).
-
-%%%===================================================================
-%%% Connection stuff
-%%%===================================================================
--spec idna_to_ascii(binary()) -> binary() | false.
-idna_to_ascii(<<$[, _/binary>> = Host) ->
- %% This is an IPv6 address in 'IP-literal' format (as per RFC7622)
- %% We remove brackets here
- case binary:last(Host) of
- $] ->
- IPv6 = binary:part(Host, {1, size(Host)-2}),
- case inet:parse_ipv6strict_address(binary_to_list(IPv6)) of
- {ok, _} -> IPv6;
- {error, _} -> false
- end;
- _ ->
- false
- end;
-idna_to_ascii(Host) ->
- case inet:parse_address(binary_to_list(Host)) of
- {ok, _} -> Host;
- {error, _} -> ejabberd_idna:domain_utf8_to_ascii(Host)
- end.
-
--spec resolve(string(), state()) -> {ok, [ip_port()]} | network_error().
-resolve(Host, State) ->
- try callback(resolve, Host, State) of
- [] ->
- do_resolve(Host, State);
- HostPorts ->
- a_lookup(HostPorts, State)
- catch _:{?MODULE, undef} ->
- do_resolve(Host, State)
- end.
-
--spec do_resolve(string(), state()) -> {ok, [ip_port()]} | network_error().
-do_resolve(Host, State) ->
- case srv_lookup(Host, State) of
- {error, _Reason} ->
- DefaultPort = get_default_port(State),
- a_lookup([{Host, DefaultPort, false}], State);
- {ok, HostPorts} ->
- a_lookup(HostPorts, State)
- end.
-
--spec srv_lookup(string(), state()) -> {ok, [host_port()]} | network_error().
-srv_lookup(_Host, #{xmlns := ?NS_COMPONENT}) ->
- %% Do not attempt to lookup SRV for component connections
- {error, nxdomain};
-srv_lookup(Host, State) ->
- %% Only perform SRV lookups for FQDN names
- case string:chr(Host, $.) of
- 0 ->
- {error, nxdomain};
- _ ->
- case inet:parse_address(Host) of
- {ok, _} ->
- {error, nxdomain};
- {error, _} ->
- Timeout = get_dns_timeout(State),
- Retries = get_dns_retries(State),
- case srv_lookup(Host, State, Timeout, Retries) of
- {ok, AddrList} ->
- h_addr_list_to_host_ports(AddrList);
- {error, _} = Err ->
- Err
- end
- end
- end.
-
-srv_lookup(Host, #{xmlns := NS} = State, Timeout, Retries) ->
- SRVType = case NS of
- ?NS_SERVER -> "-server._tcp.";
- ?NS_CLIENT -> "-client._tcp."
- end,
- TLSAddrs = case is_starttls_available(State) of
- true ->
- case srv_lookup("_xmpps" ++ SRVType ++ Host,
- Timeout, Retries) of
- {ok, HostEnt} ->
- [{A, true} || A <- HostEnt#hostent.h_addr_list];
- {error, _} ->
- []
- end;
- false ->
- []
- end,
- case srv_lookup("_xmpp" ++ SRVType ++ Host, Timeout, Retries) of
- {ok, HostEntry} ->
- Addrs = [{A, false} || A <- HostEntry#hostent.h_addr_list],
- {ok, TLSAddrs ++ Addrs};
- {error, _} when TLSAddrs /= [] ->
- {ok, TLSAddrs};
- {error, _} = Err ->
- Err
- end.
-
--spec srv_lookup(string(), timeout(), integer()) ->
- {ok, inet:hostent()} | network_error().
-srv_lookup(_SRVName, _Timeout, Retries) when Retries < 1 ->
- {error, timeout};
-srv_lookup(SRVName, Timeout, Retries) ->
- case inet_res:getbyname(SRVName, srv, Timeout) of
- {ok, HostEntry} ->
- {ok, HostEntry};
- {error, timeout} ->
- srv_lookup(SRVName, Timeout, Retries - 1);
- {error, _} = Err ->
- Err
- end.
-
--spec a_lookup([host_port()], state()) ->
- {ok, [ip_port()]} | network_error().
-a_lookup(HostPorts, State) ->
- HostPortFamilies = [{Host, Port, TLS, Family}
- || {Host, Port, TLS} <- HostPorts,
- Family <- get_address_families(State)],
- a_lookup(HostPortFamilies, State, [], {error, nxdomain}).
-
--spec a_lookup([{inet:hostname() | inet:ip_address(), inet:port_number(),
- boolean(), inet:address_family()}],
- state(), [ip_port()], network_error()) -> {ok, [ip_port()]} | network_error().
-a_lookup([{Addr, Port, TLS, Family}|HostPortFamilies], State, Acc, Err)
- when is_tuple(Addr) ->
- Acc1 = if tuple_size(Addr) == 4 andalso Family == inet ->
- [{Addr, Port, TLS}|Acc];
- tuple_size(Addr) == 8 andalso Family == inet6 ->
- [{Addr, Port, TLS}|Acc];
- true ->
- Acc
- end,
- a_lookup(HostPortFamilies, State, Acc1, Err);
-a_lookup([{Host, Port, TLS, Family}|HostPortFamilies], State, Acc, Err) ->
- Timeout = get_dns_timeout(State),
- Retries = get_dns_retries(State),
- case a_lookup(Host, Port, TLS, Family, Timeout, Retries) of
- {error, Reason} ->
- a_lookup(HostPortFamilies, State, Acc, {error, Reason});
- {ok, AddrPorts} ->
- a_lookup(HostPortFamilies, State, Acc ++ AddrPorts, Err)
- end;
-a_lookup([], _State, [], Err) ->
- Err;
-a_lookup([], _State, Acc, _) ->
- {ok, Acc}.
-
--spec a_lookup(inet:hostname(), inet:port_number(), boolean(), inet:address_family(),
- timeout(), integer()) -> {ok, [ip_port()]} | network_error().
-a_lookup(_Host, _Port, _TLS, _Family, _Timeout, Retries) when Retries < 1 ->
- {error, timeout};
-a_lookup(Host, Port, TLS, Family, Timeout, Retries) ->
- Start = p1_time_compat:monotonic_time(milli_seconds),
- case inet:gethostbyname(Host, Family, Timeout) of
- {error, nxdomain} = Err ->
- %% inet:gethostbyname/3 doesn't return {error, timeout},
- %% so we should check if 'nxdomain' is in fact a result
- %% of a timeout.
- %% We also cannot use inet_res:gethostbyname/3 because
- %% it ignores DNS configuration settings (/etc/hosts, etc)
- End = p1_time_compat:monotonic_time(milli_seconds),
- if (End - Start) >= Timeout ->
- a_lookup(Host, Port, TLS, Family, Timeout, Retries - 1);
- true ->
- Err
- end;
- {error, _} = Err ->
- Err;
- {ok, HostEntry} ->
- host_entry_to_addr_ports(HostEntry, Port, TLS)
- end.
-
--spec h_addr_list_to_host_ports(h_addr_list()) -> {ok, [host_port()]} |
- {error, nxdomain}.
-h_addr_list_to_host_ports(AddrList) ->
- PrioHostPorts = lists:flatmap(
- fun({{Priority, Weight, Port, Host}, TLS}) ->
- N = case Weight of
- 0 -> 0;
- _ -> (Weight + 1) * p1_rand:uniform()
- end,
- [{Priority * 65536 - N, Host, Port, TLS}];
- (_) ->
- []
- end, AddrList),
- HostPorts = [{Host, Port, TLS}
- || {_Priority, Host, Port, TLS} <- lists:usort(PrioHostPorts)],
- case HostPorts of
- [] -> {error, nxdomain};
- _ -> {ok, HostPorts}
- end.
-
--spec host_entry_to_addr_ports(inet:hostent(), inet:port_number(), boolean()) ->
- {ok, [ip_port()]} | {error, nxdomain}.
-host_entry_to_addr_ports(#hostent{h_addr_list = AddrList}, Port, TLS) ->
- AddrPorts = lists:flatmap(
- fun(Addr) ->
- try get_addr_type(Addr) of
- _ -> [{Addr, Port, TLS}]
- catch _:_ ->
- []
- end
- end, AddrList),
- case AddrPorts of
- [] -> {error, nxdomain};
- _ -> {ok, AddrPorts}
- end.
-
--spec connect([ip_port()], state()) -> {ok, term(), ip_port()} |
- {error, {socket, socket_error_reason()}} |
- {error, {tls, tls_error_reason()}}.
-connect(AddrPorts, State) ->
- Timeout = get_connect_timeout(State),
- case connect(AddrPorts, Timeout, State, {error, nxdomain}) of
- {ok, Socket, {Addr, Port, TLS = true}} ->
- case starttls(Socket, State) of
- {ok, TLSSocket} -> {ok, TLSSocket, {Addr, Port, TLS}};
- {error, Why} -> {error, {tls, Why}}
- end;
- {ok, Socket, {Addr, Port, TLS = false}} ->
- {ok, Socket, {Addr, Port, TLS}};
- {error, Why} ->
- {error, {socket, Why}}
- end.
-
--spec connect([ip_port()], timeout(), state(), network_error()) ->
- {ok, term(), ip_port()} | network_error().
-connect([{Addr, Port, TLS}|AddrPorts], Timeout, State, _) ->
- Type = get_addr_type(Addr),
- Opts = [binary, {packet, 0},
- {send_timeout, ?TCP_SEND_TIMEOUT},
- {send_timeout_close, true},
- {active, false}, Type],
- Opts1 = try callback(connect_options, Addr, Opts, State)
- catch _:{?MODULE, undef} -> Opts
- end,
- try xmpp_socket:connect(Addr, Port, Opts1, Timeout) of
- {ok, Socket} ->
- {ok, Socket, {Addr, Port, TLS}};
- Err ->
- connect(AddrPorts, Timeout, State, Err)
- catch _:badarg ->
- connect(AddrPorts, Timeout, State, {error, einval})
- end;
-connect([], _Timeout, _State, Err) ->
- Err.
-
--spec get_addr_type(inet:ip_address()) -> inet:address_family().
-get_addr_type({_, _, _, _}) -> inet;
-get_addr_type({_, _, _, _, _, _, _, _}) -> inet6.
-
--spec get_dns_timeout(state()) -> timeout().
-get_dns_timeout(State) ->
- try callback(dns_timeout, State)
- catch _:{?MODULE, undef} -> timer:seconds(10)
- end.
-
--spec get_dns_retries(state()) -> non_neg_integer().
-get_dns_retries(State) ->
- try callback(dns_retries, State)
- catch _:{?MODULE, undef} -> 2
- end.
-
--spec get_default_port(state()) -> inet:port_number().
-get_default_port(#{xmlns := NS} = State) ->
- try callback(default_port, State)
- catch _:{?MODULE, undef} when NS == ?NS_SERVER -> 5269;
- _:{?MODULE, undef} when NS == ?NS_CLIENT -> 5222
- end.
-
--spec get_address_families(state()) -> [inet:address_family()].
-get_address_families(State) ->
- try callback(address_families, State)
- catch _:{?MODULE, undef} -> [inet, inet6]
- end.
-
--spec get_connect_timeout(state()) -> timeout().
-get_connect_timeout(State) ->
- try callback(connect_timeout, State)
- catch _:{?MODULE, undef} -> timer:seconds(10)
- end.
-
-%%%===================================================================
-%%% Callbacks
-%%%===================================================================
-callback(F, #{mod := Mod} = State) ->
- case erlang:function_exported(Mod, F, 1) of
- true -> Mod:F(State);
- false -> erlang:error({?MODULE, undef})
- end.
-
-callback(F, Arg1, #{mod := Mod} = State) ->
- case erlang:function_exported(Mod, F, 2) of
- true -> Mod:F(Arg1, State);
- false -> erlang:error({?MODULE, undef})
- end.
-
-callback(code_change, OldVsn, #{mod := Mod} = State, Extra) ->
- %% code_change/3 callback is a special snowflake
- case erlang:function_exported(Mod, code_change, 3) of
- true -> Mod:code_change(OldVsn, State, Extra);
- false -> {ok, State}
- end;
-callback(F, Arg1, Arg2, #{mod := Mod} = State) ->
- case erlang:function_exported(Mod, F, 3) of
- true -> Mod:F(Arg1, Arg2, State);
- false -> erlang:error({?MODULE, undef})
- end.
+++ /dev/null
-%%%-------------------------------------------------------------------
-%%% Created : 13 Dec 2016 by Evgeny Khramtsov <ekhramtsov@process-one.net>
-%%%
-%%%
-%%% ejabberd, Copyright (C) 2002-2018 ProcessOne
-%%%
-%%% This program is free software; you can redistribute it and/or
-%%% modify it under the terms of the GNU General Public License as
-%%% published by the Free Software Foundation; either version 2 of the
-%%% License, or (at your option) any later version.
-%%%
-%%% This program is distributed in the hope that it will be useful,
-%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
-%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
-%%% General Public License for more details.
-%%%
-%%% You should have received a copy of the GNU General Public License along
-%%% with this program; if not, write to the Free Software Foundation, Inc.,
-%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
-%%%
-%%%-------------------------------------------------------------------
--module(xmpp_stream_pkix).
-
-%% API
--export([authenticate/1, authenticate/2, get_cert_domains/1, format_error/1]).
-
--include("xmpp.hrl").
--include_lib("public_key/include/public_key.hrl").
--include("XmppAddr.hrl").
-
--type cert() :: #'OTPCertificate'{}.
-
-%%%===================================================================
-%%% API
-%%%===================================================================
--spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state())
- -> {ok, binary()} | {error, atom(), binary()}.
-authenticate(State) ->
- authenticate(State, <<"">>).
-
--spec authenticate(xmpp_stream_in:state() | xmpp_stream_out:state(), binary())
- -> {ok, binary()} | {error, atom(), binary()}.
-authenticate(#{xmlns := ?NS_SERVER,
- socket := Socket} = State, Authzid) ->
- Peer = maps:get(remote_server, State, Authzid),
- case verify_cert(Socket) of
- {ok, Cert} ->
- case ejabberd_idna:domain_utf8_to_ascii(Peer) of
- false ->
- {error, idna_failed, Peer};
- AsciiPeer ->
- case lists:any(
- fun(D) -> match_domain(AsciiPeer, D) end,
- get_cert_domains(Cert)) of
- true ->
- {ok, Peer};
- false ->
- {error, hostname_mismatch, Peer}
- end
- end;
- {error, Reason} ->
- {error, Reason, Peer}
- end;
-authenticate(#{xmlns := ?NS_CLIENT,
- socket := Socket, lserver := LServer}, Authzid) ->
- JID = try jid:decode(Authzid)
- catch _:{bad_jid, <<>>} -> jid:make(LServer);
- _:{bad_jid, _} -> {error, invalid_authzid, Authzid}
- end,
- case JID of
- #jid{user = User} ->
- case verify_cert(Socket) of
- {ok, Cert} ->
- JIDs = get_xmpp_addrs(Cert),
- get_username(JID, JIDs, LServer);
- {error, Reason} ->
- {error, Reason, User}
- end;
- Err ->
- Err
- end.
-
-format_error(idna_failed) ->
- {'bad-protocol', <<"Remote domain is not an IDN hostname">>};
-format_error(hostname_mismatch) ->
- {'not-authorized', <<"Certificate host name mismatch">>};
-format_error(jid_mismatch) ->
- {'not-authorized', <<"Certificate JID mismatch">>};
-format_error(get_cert_failed) ->
- {'bad-protocol', <<"Failed to get peer certificate">>};
-format_error(invalid_authzid) ->
- {'invalid-authzid', <<"Malformed JID">>};
-format_error(Other) ->
- {'not-authorized', erlang:atom_to_binary(Other, utf8)}.
-
--spec get_cert_domains(cert()) -> [binary()].
-get_cert_domains(Cert) ->
- TBSCert = Cert#'OTPCertificate'.tbsCertificate,
- {rdnSequence, Subject} = TBSCert#'OTPTBSCertificate'.subject,
- Extensions = TBSCert#'OTPTBSCertificate'.extensions,
- get_domain_from_subject(lists:flatten(Subject)) ++
- get_domains_from_san(Extensions).
-
-%%%===================================================================
-%%% Internal functions
-%%%===================================================================
--spec verify_cert(xmpp_socket:socket()) -> {ok, cert()} | {error, atom()}.
-verify_cert(Socket) ->
- case xmpp_socket:get_peer_certificate(Socket, otp) of
- {ok, Cert} ->
- case xmpp_socket:get_verify_result(Socket) of
- 0 ->
- {ok, Cert};
- VerifyRes ->
- %% TODO: return atomic errors
- %% This should be improved in fast_tls
- Reason = fast_tls:get_cert_verify_string(VerifyRes, Cert),
- {error, erlang:binary_to_atom(Reason, utf8)}
- end;
- {error, _Reason} ->
- {error, get_cert_failed};
- error ->
- {error, get_cert_failed}
- end.
-
--spec get_domain_from_subject([#'AttributeTypeAndValue'{}]) -> [binary()].
-get_domain_from_subject(AttrVals) ->
- case lists:keyfind(?'id-at-commonName',
- #'AttributeTypeAndValue'.type,
- AttrVals) of
- #'AttributeTypeAndValue'{value = {_, S}} ->
- try jid:decode(iolist_to_binary(S)) of
- #jid{luser = <<"">>, lresource = <<"">>, lserver = Domain} ->
- [Domain];
- _ ->
- []
- catch _:{bad_jid, _} ->
- []
- end;
- _ ->
- []
- end.
-
--spec get_domains_from_san([#'Extension'{}] | asn1_NOVALUE) -> [binary()].
-get_domains_from_san(Extensions) when is_list(Extensions) ->
- case lists:keyfind(?'id-ce-subjectAltName',
- #'Extension'.extnID,
- Extensions) of
- #'Extension'{extnValue = Vals} ->
- lists:flatmap(
- fun({dNSName, S}) ->
- [iolist_to_binary(S)];
- ({otherName, AnotherName}) ->
- case decode_xmpp_addr(AnotherName) of
- {ok, #jid{luser = <<"">>,
- lresource = <<"">>,
- lserver = Domain}} ->
- case ejabberd_idna:domain_utf8_to_ascii(Domain) of
- false ->
- [];
- ASCIIDomain ->
- [ASCIIDomain]
- end;
- _ ->
- []
- end;
- (_) ->
- []
- end, Vals);
- _ ->
- []
- end;
-get_domains_from_san(_) ->
- [].
-
--spec decode_xmpp_addr(#'AnotherName'{}) -> {ok, jid()} | error.
-decode_xmpp_addr(#'AnotherName'{'type-id' = ?'id-on-xmppAddr',
- value = XmppAddr}) ->
- try 'XmppAddr':decode('XmppAddr', XmppAddr) of
- {ok, JIDStr} ->
- try {ok, jid:decode(iolist_to_binary(JIDStr))}
- catch _:{bad_jid, _} -> error
- end;
- _ ->
- error
- catch _:_ ->
- error
- end;
-decode_xmpp_addr(_) ->
- error.
-
--spec get_xmpp_addrs(cert()) -> [jid()].
-get_xmpp_addrs(Cert) ->
- TBSCert = Cert#'OTPCertificate'.tbsCertificate,
- case TBSCert#'OTPTBSCertificate'.extensions of
- Extensions when is_list(Extensions) ->
- case lists:keyfind(?'id-ce-subjectAltName',
- #'Extension'.extnID,
- Extensions) of
- #'Extension'{extnValue = Vals} ->
- lists:flatmap(
- fun({otherName, AnotherName}) ->
- case decode_xmpp_addr(AnotherName) of
- {ok, JID} -> [JID];
- _ -> []
- end;
- (_) ->
- []
- end, Vals);
- _ ->
- []
- end;
- _ ->
- []
- end.
-
-match_domain(Domain, Domain) -> true;
-match_domain(Domain, Pattern) ->
- DLabels = str:tokens(Domain, <<".">>),
- PLabels = str:tokens(Pattern, <<".">>),
- match_labels(DLabels, PLabels).
-
-match_labels([], []) -> true;
-match_labels([], [_ | _]) -> false;
-match_labels([_ | _], []) -> false;
-match_labels([DL | DLabels], [PL | PLabels]) ->
- case lists:all(fun (C) ->
- $a =< C andalso C =< $z orelse
- $0 =< C andalso C =< $9 orelse
- C == $- orelse C == $*
- end,
- binary_to_list(PL))
- of
- true ->
- Regexp = ejabberd_regexp:sh_to_awk(PL),
- case ejabberd_regexp:run(DL, Regexp) of
- match -> match_labels(DLabels, PLabels);
- nomatch -> false
- end;
- false -> false
- end.
-
--spec get_username(jid(), [jid()], binary()) ->
- {ok, binary()} | {error, jid_mismatch, binary()}.
-get_username(#jid{user = User, lserver = LS}, _, LServer) when LS /= LServer ->
- %% The user provided JID from different domain
- {error, jid_mismatch, User};
-get_username(#jid{user = <<>>}, [#jid{user = U, lserver = LS}], LServer)
- when U /= <<>> andalso LS == LServer ->
- %% The user didn't provide JID or username, and there is only
- %% one 'non-global' JID matching current domain
- {ok, U};
-get_username(#jid{user = User, luser = LUser}, JIDs, LServer) when User /= <<>> ->
- %% The user provided username
- lists:foldl(
- fun(_, {ok, _} = OK) ->
- OK;
- (#jid{user = <<>>, lserver = LS}, _) when LS == LServer ->
- %% Found "global" JID in the certficate
- %% (i.e. in the form of 'domain.com')
- %% within current domain, so we force matching
- {ok, User};
- (#jid{luser = LU, lserver = LS}, _) when LU == LUser, LS == LServer ->
- %% Found exact JID matching
- {ok, User};
- (_, Err) ->
- Err
- end, {error, jid_mismatch, User}, JIDs);
-get_username(#jid{user = User}, _, _) ->
- %% Nothing from above is true
- {error, jid_mismatch, User}.
start_module(:jid)
:ejabberd_hooks.start_link
:ok = :ejabberd_config.start(["domain1"], [])
- {:ok, _} = :cyrsasl.start_link
- cyrstate = :cyrsasl.server_new("domain1", "domain1", "domain1", :ok, &get_password/1,
+ {:ok, _} = :xmpp_sasl.start_link
+ cyrstate = :xmpp_sasl.server_new("domain1", "domain1", "domain1", :ok, &get_password/1,
&check_password/3, &check_password_digest/5)
setup_anonymous_mocks()
{:ok, cyrstate: cyrstate}
end
test "Plain text (correct user and pass)", context do
- step1 = :cyrsasl.server_start(context[:cyrstate], "PLAIN", <<0,"user1",0,"pass">>)
+ step1 = :xmpp_sasl.server_start(context[:cyrstate], "PLAIN", <<0,"user1",0,"pass">>)
assert {:ok, _} = step1
{:ok, kv} = step1
assert kv[:authzid] == "user1", "got correct user"
end
test "Plain text (correct user wrong pass)", context do
- step1 = :cyrsasl.server_start(context[:cyrstate], "PLAIN", <<0,"user1",0,"badpass">>)
+ step1 = :xmpp_sasl.server_start(context[:cyrstate], "PLAIN", <<0,"user1",0,"badpass">>)
assert step1 == {:error, :not_authorized, "user1"}
end
test "Plain text (wrong user wrong pass)", context do
- step1 = :cyrsasl.server_start(context[:cyrstate], "PLAIN", <<0,"nouser1",0,"badpass">>)
+ step1 = :xmpp_sasl.server_start(context[:cyrstate], "PLAIN", <<0,"nouser1",0,"badpass">>)
assert step1 == {:error, :not_authorized, "nouser1"}
end
test "Anonymous", context do
- step1 = :cyrsasl.server_start(context[:cyrstate], "ANONYMOUS", "domain1")
+ step1 = :xmpp_sasl.server_start(context[:cyrstate], "ANONYMOUS", "domain1")
assert {:ok, _} = step1
end
end
defp process_digest_md5(cyrstate, user, domain, pass) do
- assert {:continue, init_str, state1} = :cyrsasl.server_start(cyrstate, "DIGEST-MD5", "")
+ assert {:continue, init_str, state1} = :xmpp_sasl.server_start(cyrstate, "DIGEST-MD5", "")
assert [_, nonce] = Regex.run(~r/nonce="(.*?)"/, init_str)
digest_uri = "xmpp/#{domain}"
cnonce = "abcd"
response = "username=\"#{user}\",realm=\"#{domain}\",nonce=\"#{nonce}\",cnonce=\"#{cnonce}\"," <>
"nc=\"#{nc}\",qop=auth,digest-uri=\"#{digest_uri}\",response=\"#{response_hash}\"," <>
"charset=utf-8,algorithm=md5-sess"
- case :cyrsasl.server_step(state1, response) do
- {:continue, _calc_str, state2} -> :cyrsasl.server_step(state2, "")
+ case :xmpp_sasl.server_step(state1, response) do
+ {:continue, _calc_str, state2} -> :xmpp_sasl.server_step(state2, "")
other -> other
end
end
sasl_new(<<"DIGEST-MD5">>, {User, Server, Password}) ->
{<<"">>,
fun (ServerIn) ->
- case cyrsasl_digest:parse(ServerIn) of
+ case xmpp_sasl_digest:parse(ServerIn) of
bad -> {error, <<"Invalid SASL challenge">>};
KeyVals ->
Nonce = fxml:get_attr_s(<<"nonce">>, KeyVals),
MyResponse/binary, "\"">>,
{Resp,
fun (ServerIn2) ->
- case cyrsasl_digest:parse(ServerIn2) of
+ case xmpp_sasl_digest:parse(ServerIn2) of
bad -> {error, <<"Invalid SASL challenge">>};
_KeyVals2 ->
{<<"">>,