]> granicus.if.org Git - ejabberd/commitdiff
Improve return values in cyrsasl API
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Sat, 31 Dec 2016 10:47:35 +0000 (13:47 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Sat, 31 Dec 2016 10:47:35 +0000 (13:47 +0300)
src/cyrsasl.erl
src/cyrsasl_anonymous.erl
src/cyrsasl_digest.erl
src/cyrsasl_oauth.erl
src/cyrsasl_plain.erl
src/cyrsasl_scram.erl

index 874a417a17a98edee1eb6abe119e9841e98803e9..1edf44678fc90fc1ffd448f36eb04ffe9e1cd670 100644 (file)
@@ -31,7 +31,7 @@
 
 -export([start/0, register_mechanism/3, listmech/1,
         server_new/7, server_start/3, server_step/2,
-        get_mech/1, opt_type/1]).
+        get_mech/1, format_error/2, opt_type/1]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
                         {auth_module, atom()}.
 -type sasl_return() :: {ok, [sasl_property()]} |
                       {ok, [sasl_property()], binary()} |
-                      {continue, binary(), any()} |
-                      {error, atom()} |
+                      {continue, binary(), sasl_state()} |
                       {error, atom(), binary()}.
 
 -type(sasl_mechanism() :: #sasl_mechanism{}).
-
+-type error_reason() :: cyrsasl_digest:error_reason() |
+                       cyrsasl_oauth:error_reason() |
+                       cyrsasl_plain:error_reason() |
+                       cyrsasl_scram:error_reason() |
+                       unsupported_mechanism | nodeprep_failed |
+                       empty_username | aborted.
 -record(sasl_state,
 {
     service,
@@ -69,7 +73,7 @@
     mech_state
 }).
 -type sasl_state() :: #sasl_state{}.
--export_type([mechanism/0, mechanisms/0, sasl_mechanism/0,
+-export_type([mechanism/0, mechanisms/0, sasl_mechanism/0, error_reason/0,
              sasl_state/0, sasl_return/0, sasl_property/0]).
 
 -callback mech_new(binary(), fun(), fun(), fun()) -> any().
@@ -86,7 +90,25 @@ start() ->
     cyrsasl_oauth:start([]),
     ok.
 
-%%
+-spec format_error(mechanism() | sasl_state(), error_reason()) -> {atom(), binary()}.
+format_error(_, unsupported_mechanism) ->
+    {'invalid-mechanism', <<"Unsupported mechanism">>};
+format_error(_, nodeprep_failed) ->
+    {'bad-protocol', <<"Nodeprep failed">>};
+format_error(_, empty_username) ->
+    {'bad-protocol', <<"Empty username">>};
+format_error(_, aborted) ->
+    {'aborted', <<"Aborted">>};
+format_error(#sasl_state{mech_mod = Mod}, Reason) ->
+    Mod:format_error(Reason);
+format_error(Mech, Reason) ->
+    case ets:lookup(sasl_mechanism, Mech) of
+       [#sasl_mechanism{module = Mod}] ->
+           Mod:format_error(Reason);
+       [] ->
+           {'invalid-mechanism', <<"Unsupported mechanism">>}
+    end.
+
 -spec register_mechanism(Mechanim :: mechanism(), Module :: module(),
                         PasswordType :: password_type()) -> any().
 
@@ -104,8 +126,8 @@ register_mechanism(Mechanism, Module, PasswordType) ->
 check_credentials(_State, Props) ->
     User = proplists:get_value(authzid, Props, <<>>),
     case jid:nodeprep(User) of
-      error -> {error, 'not-authorized'};
-      <<"">> -> {error, 'not-authorized'};
+      error -> {error, nodeprep_failed};
+      <<"">> -> {error, empty_username};
       _LUser -> ok
     end.
 
@@ -127,6 +149,8 @@ listmech(Host) ->
                         ['$1']}]),
     filter_anonymous(Host, Mechs).
 
+-spec server_new(binary(), binary(), binary(), term(),
+                fun(), fun(), fun()) -> sasl_state().
 server_new(Service, ServerFQDN, UserRealm, _SecFlags,
           GetPassword, CheckPassword, CheckPasswordDigest) ->
     #sasl_state{service = Service, myname = ServerFQDN,
@@ -134,8 +158,7 @@ server_new(Service, ServerFQDN, UserRealm, _SecFlags,
                check_password = CheckPassword,
                check_password_digest = CheckPasswordDigest}.
 
-server_start(State, Mech, undefined) ->
-    server_start(State, Mech, <<"">>);
+-spec server_start(sasl_state(), mechanism(), binary()) -> sasl_return().
 server_start(State, Mech, ClientIn) ->
     case lists:member(Mech,
                      listmech(State#sasl_state.myname))
@@ -152,13 +175,12 @@ server_start(State, Mech, ClientIn) ->
                                             mech_name = Mech,
                                             mech_state = MechState},
                            ClientIn);
-           _ -> {error, 'no-mechanism'}
+           _ -> {error, unsupported_mechanism, <<"">>}
          end;
-      false -> {error, 'no-mechanism'}
+      false -> {error, unsupported_mechanism, <<"">>}
     end.
 
-server_step(State, undefined) ->
-    server_step(State, <<"">>);
+-spec server_step(sasl_state(), binary()) -> sasl_return().
 server_step(State, ClientIn) ->
     Module = State#sasl_state.mech_mod,
     MechState = State#sasl_state.mech_state,
@@ -166,19 +188,19 @@ server_step(State, ClientIn) ->
         {ok, Props} ->
             case check_credentials(State, Props) of
                 ok             -> {ok, Props};
-                {error, Error} -> {error, Error}
+                {error, Error} -> {error, Error, <<"">>}
             end;
         {ok, Props, ServerOut} ->
             case check_credentials(State, Props) of
                 ok             -> {ok, Props, ServerOut};
-                {error, Error} -> {error, Error}
+                {error, Error} -> {error, Error, <<"">>}
             end;
         {continue, ServerOut, NewMechState} ->
             {continue, ServerOut, State#sasl_state{mech_state = NewMechState}};
         {error, Error, Username} ->
             {error, Error, Username};
         {error, Error} ->
-            {error, Error}
+            {error, Error, <<"">>}
     end.
 
 -spec get_mech(sasl_state()) -> binary().
index 15980afc53a53aa73341915606a2b55f87de5f63..cad9cdf9399c73ffe38cb2056ccacd415e2c19c3 100644 (file)
@@ -43,10 +43,9 @@ stop() -> ok.
 mech_new(Host, _GetPassword, _CheckPassword, _CheckPasswordDigest) ->
     {ok, #state{server = Host}}.
 
-mech_step(#state{server = Server} = S, ClientIn) ->
+mech_step(#state{}, _ClientIn) ->
     User = iolist_to_binary([randoms:get_string(),
                              integer_to_binary(p1_time_compat:unique_integer([positive]))]),
-    case ejabberd_auth:is_user_exists(User, Server) of
-        true  -> mech_step(S, ClientIn);
-        false -> {ok, [{username, User}, {authzid, User}, {auth_module, ejabberd_auth_anonymous}]}
-    end.
+    {ok, [{username, User},
+         {authzid, User},
+         {auth_module, ejabberd_auth_anonymous}]}.
index 150aa854cea12be8449a1f5d0b0f64aa18a25e9d..9b4faca204ef87e2c216a28a73ae2e568d450d9e 100644 (file)
@@ -30,7 +30,7 @@
 -author('alexey@sevcom.net').
 
 -export([start/1, stop/0, mech_new/4, mech_step/2,
-        parse/1, opt_type/1]).
+        parse/1, format_error/1, opt_type/1]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
 
 -type get_password_fun() :: fun((binary()) -> {false, any()} |
                                               {binary(), atom()}).
-
 -type check_password_fun() :: fun((binary(), binary(), binary(),
                                    fun((binary()) -> binary())) ->
                                            {boolean(), any()} |
                                            false).
+-type error_reason() :: parser_failed | invalid_digest_uri |
+                       not_authorized | unexpected_response.
+-export_type([error_reason/0]).
 
 -record(state, {step = 1 :: 1 | 3 | 5,
                 nonce = <<"">> :: binary(),
@@ -64,6 +66,16 @@ start(_Opts) ->
 
 stop() -> ok.
 
+-spec format_error(error_reason()) -> {atom(), binary()}.
+format_error(parser_failed) ->
+    {'bad-protocol', <<"Response decoding failed">>};
+format_error(invalid_digest_uri) ->
+    {'bad-protocol', <<"Invalid digest URI">>};
+format_error(not_authorized) ->
+    {'not-authorized', <<"Invalid username or password">>};
+format_error(unexpected_response) ->
+    {'bad-protocol', <<"Unexpected response">>}.
+
 mech_new(Host, GetPassword, _CheckPassword,
         CheckPasswordDigest) ->
     {ok,
@@ -80,8 +92,8 @@ mech_step(#state{step = 1, nonce = Nonce} = State, _) ->
 mech_step(#state{step = 3, nonce = Nonce} = State,
          ClientIn) ->
     case parse(ClientIn) of
-      bad -> {error, 'bad-protocol'};
-      KeyVals ->
+       bad -> {error, parser_failed};
+       KeyVals ->
          DigestURI = proplists:get_value(<<"digest-uri">>, KeyVals, <<>>),
          UserName = proplists:get_value(<<"username">>, KeyVals, <<>>),
          case is_digesturi_valid(DigestURI, State#state.host,
@@ -92,11 +104,11 @@ mech_step(#state{step = 3, nonce = Nonce} = State,
                       "seems invalid: ~p (checking for Host "
                       "~p, FQDN ~p)",
                       [DigestURI, State#state.host, State#state.hostfqdn]),
-               {error, 'not-authorized', UserName};
+               {error, invalid_digest_uri, UserName};
            true ->
                AuthzId = proplists:get_value(<<"authzid">>, KeyVals, <<>>),
                case (State#state.get_password)(UserName) of
-                 {false, _} -> {error, 'not-authorized', UserName};
+                 {false, _} -> {error, not_authorized, UserName};
                  {Passwd, AuthModule} ->
                      case (State#state.check_password)(UserName, UserName, <<"">>,
                                    proplists:get_value(<<"response">>, KeyVals, <<>>),
@@ -116,8 +128,8 @@ mech_step(#state{step = 3, nonce = Nonce} = State,
                             State#state{step = 5, auth_module = AuthModule,
                                         username = UserName,
                                         authzid = AuthzId}};
-                       false -> {error, 'not-authorized', UserName};
-                       {false, _} -> {error, 'not-authorized', UserName}
+                       false -> {error, not_authorized, UserName};
+                       {false, _} -> {error, not_authorized, UserName}
                      end
                end
          end
@@ -134,7 +146,7 @@ mech_step(#state{step = 5, auth_module = AuthModule,
       {auth_module, AuthModule}]};
 mech_step(A, B) ->
     ?DEBUG("SASL DIGEST: A ~p B ~p", [A, B]),
-    {error, 'bad-protocol'}.
+    {error, unexpected_response}.
 
 parse(S) -> parse1(binary_to_list(S), "", []).
 
index 21dedc6db1e040e198ff5bb1eab032e146f1869b..be7e9a68d67d69c9dfec71f24a217f2ec0891020 100644 (file)
 
 -author('alexey@process-one.net').
 
--export([start/1, stop/0, mech_new/4, mech_step/2, parse/1]).
+-export([start/1, stop/0, mech_new/4, mech_step/2, parse/1, format_error/1]).
 
 -behaviour(cyrsasl).
 
 -record(state, {host}).
+-type error_reason() :: parser_failed | not_authorized.
+-export_type([error_reason/0]).
 
 start(_Opts) ->
     cyrsasl:register_mechanism(<<"X-OAUTH2">>, ?MODULE, plain),
@@ -39,6 +41,12 @@ start(_Opts) ->
 
 stop() -> ok.
 
+-spec format_error(error_reason()) -> {atom(), binary()}.
+format_error(parser_failed) ->
+    {'bad-protocol', <<"Response decoding failed">>};
+format_error(not_authorized) ->
+    {'not-authorized', <<"Invalid token">>}.
+
 mech_new(Host, _GetPassword, _CheckPassword, _CheckPasswordDigest) ->
     {ok, #state{host = Host}}.
 
@@ -52,9 +60,9 @@ mech_step(State, ClientIn) ->
                      [{username, User}, {authzid, AuthzId},
                       {auth_module, ejabberd_oauth}]};
                 _ ->
-                    {error, 'not-authorized', User}
+                    {error, not_authorized, User}
             end;
-        _ -> {error, 'bad-protocol'}
+        _ -> {error, parser_failed}
     end.
 
 prepare(ClientIn) ->
index 8e9b32b9931c7a0116e972b84c382d0c0d7a4eb5..bbac8deff64b1d6a02f8515c11ebb38a9c68e543 100644 (file)
 
 -author('alexey@process-one.net').
 
--export([start/1, stop/0, mech_new/4, mech_step/2, parse/1]).
+-export([start/1, stop/0, mech_new/4, mech_step/2, parse/1, format_error/1]).
 
 -behaviour(cyrsasl).
 
 -record(state, {check_password}).
+-type error_reason() :: parser_failed | not_authorized.
+-export_type([error_reason/0]).
 
 start(_Opts) ->
     cyrsasl:register_mechanism(<<"PLAIN">>, ?MODULE, plain),
@@ -39,6 +41,12 @@ start(_Opts) ->
 
 stop() -> ok.
 
+-spec format_error(error_reason()) -> {atom(), binary()}.
+format_error(parser_failed) ->
+    {'bad-protocol', <<"Response decoding failed">>};
+format_error(not_authorized) ->
+    {'not-authorized', <<"Invalid username or password">>}.
+
 mech_new(_Host, _GetPassword, CheckPassword, _CheckPasswordDigest) ->
     {ok, #state{check_password = CheckPassword}}.
 
@@ -50,9 +58,9 @@ mech_step(State, ClientIn) ->
                {ok,
                 [{username, User}, {authzid, AuthzId},
                  {auth_module, AuthModule}]};
-           _ -> {error, 'not-authorized', User}
+           _ -> {error, not_authorized, User}
          end;
-      _ -> {error, 'bad-protocol'}
+      _ -> {error, parser_failed}
     end.
 
 prepare(ClientIn) ->
index 1e2a5c68198d84e249891652d696034a90d91eff..55e06fd250ce440ee3df22831a0490e131bca624 100644 (file)
@@ -29,7 +29,7 @@
 
 -protocol({rfc, 5802}).
 
--export([start/1, stop/0, mech_new/4, mech_step/2]).
+-export([start/1, stop/0, mech_new/4, mech_step/2, format_error/1]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
@@ -41,6 +41,7 @@
          stored_key = <<"">>   :: binary(),
          server_key = <<"">>   :: binary(),
          username = <<"">>     :: binary(),
+        auth_module           :: module(),
          get_password          :: fun(),
         check_password        :: fun(),
          auth_message = <<"">> :: binary(),
         server_nonce = <<"">> :: binary()}).
 
 -define(SALT_LENGTH, 16).
-
 -define(NONCE_LENGTH, 16).
 
+-type error_reason() :: unsupported_extension | bad_username |
+                       not_authorized | saslprep_failed |
+                       parser_failed | bad_attribute |
+                       nonce_mismatch | bad_channel_binding.
+
+-export_type([error_reason/0]).
+
 start(_Opts) ->
     cyrsasl:register_mechanism(<<"SCRAM-SHA-1">>, ?MODULE,
                               scram).
 
 stop() -> ok.
 
+-spec format_error(error_reason()) -> {atom(), binary()}.
+format_error(unsupported_extension) ->
+    {'bad-protocol', <<"Unsupported extension">>};
+format_error(bad_username) ->
+    {'invalid-authzid', <<"Malformed username">>};
+format_error(not_authorized) ->
+    {'not-authorized', <<"Invalid username or password">>};
+format_error(saslprep_failed) ->
+    {'not-authorized', <<"SASLprep failed">>};
+format_error(parser_failed) ->
+    {'bad-protocol', <<"Response decoding failed">>};
+format_error(bad_attribute) ->
+    {'bad-protocol', <<"Malformed or unexpected attribute">>};
+format_error(nonce_mismatch) ->
+    {'bad-protocol', <<"Nonce mismatch">>};
+format_error(bad_channel_binding) ->
+    {'bad-protocol', <<"Invalid channel binding">>}.
+
 mech_new(_Host, GetPassword, _CheckPassword,
         _CheckPasswordDigest) ->
     {ok, #state{step = 2, get_password = GetPassword}}.
@@ -64,22 +89,22 @@ mech_new(_Host, GetPassword, _CheckPassword,
 mech_step(#state{step = 2} = State, ClientIn) ->
     case re:split(ClientIn, <<",">>, [{return, binary}]) of
       [_CBind, _AuthorizationIdentity, _UserNameAttribute, _ClientNonceAttribute, ExtensionAttribute | _]
-         when ExtensionAttribute /= [] ->
-         {error, 'protocol-error-extension-not-supported'};
+         when ExtensionAttribute /= <<"">> ->
+         {error, unsupported_extension};
       [CBind, _AuthorizationIdentity, UserNameAttribute, ClientNonceAttribute | _]
          when (CBind == <<"y">>) or (CBind == <<"n">>) ->
          case parse_attribute(UserNameAttribute) of
            {error, Reason} -> {error, Reason};
            {_, EscapedUserName} ->
                case unescape_username(EscapedUserName) of
-                 error -> {error, 'protocol-error-bad-username'};
+                 error -> {error, bad_username};
                  UserName ->
                      case parse_attribute(ClientNonceAttribute) of
                        {$r, ClientNonce} ->
-                           {Ret, _AuthModule} = (State#state.get_password)(UserName),
+                           {Ret, AuthModule} = (State#state.get_password)(UserName),
                            case {Ret, jid:resourceprep(Ret)} of
-                             {false, _} -> {error, 'not-authorized', UserName};
-                             {_, error} when is_binary(Ret) -> ?WARNING_MSG("invalid plain password", []), {error, 'not-authorized', UserName};
+                             {false, _} -> {error, not_authorized, UserName};
+                             {_, error} when is_binary(Ret) -> {error, saslprep_failed, UserName};
                              {Ret, _} ->
                                  {StoredKey, ServerKey, Salt, IterationCount} =
                                      if is_tuple(Ret) -> Ret;
@@ -112,6 +137,7 @@ mech_step(#state{step = 2} = State, ClientIn) ->
                                  {continue, ServerFirstMessage,
                                   State#state{step = 4, stored_key = StoredKey,
                                               server_key = ServerKey,
+                                              auth_module = AuthModule,
                                               auth_message =
                                                   <<ClientFirstMessageBare/binary,
                                                     ",", ServerFirstMessage/binary>>,
@@ -119,11 +145,11 @@ mech_step(#state{step = 2} = State, ClientIn) ->
                                               server_nonce = ServerNonce,
                                               username = UserName}}
                            end;
-                       _Else -> {error, 'not-supported'}
+                         _ -> {error, bad_attribute}
                      end
                end
          end;
-      _Else -> {error, 'bad-protocol'}
+      _Else -> {error, parser_failed}
     end;
 mech_step(#state{step = 4} = State, ClientIn) ->
     case str:tokens(ClientIn, <<",">>) of
@@ -158,39 +184,31 @@ mech_step(#state{step = 4} = State, ClientIn) ->
                                             scram:server_signature(State#state.server_key,
                                                                    AuthMessage),
                                         {ok, [{username, State#state.username},
+                                              {auth_module, State#state.auth_module},
                                               {authzid, State#state.username}],
                                          <<"v=",
                                            (jlib:encode_base64(ServerSignature))/binary>>};
-                                    true -> {error, 'bad-auth', State#state.username}
+                                    true -> {error, not_authorized, State#state.username}
                                  end;
-                           _Else -> {error, 'bad-protocol'}
+                           _ -> {error, bad_attribute}
                            end;
-                       {$r, _} -> {error, 'bad-nonce'};
-                       _Else -> {error, 'bad-protocol'}
+                       {$r, _} -> {error, nonce_mismatch};
+                       _ -> {error, bad_attribute}
                    end;
-                 true -> {error, 'bad-channel-binding'}
+                 true -> {error, bad_channel_binding}
                end;
-           _Else -> {error, 'bad-protocol'}
+           _ -> {error, bad_attribute}
          end;
-      _Else -> {error, 'bad-protocol'}
+      _ -> {error, parser_failed}
     end.
 
-parse_attribute(Attribute) ->
-    AttributeLen = byte_size(Attribute),
-    if AttributeLen >= 3 ->
-           AttributeS = binary_to_list(Attribute),
-          SecondChar = lists:nth(2, AttributeS),
-          case is_alpha(lists:nth(1, AttributeS)) of
-            true ->
-                if SecondChar == $= ->
-                       String = str:substr(Attribute, 3),
-                       {lists:nth(1, AttributeS), String};
-                   true -> {error, 'bad-format-second-char-not-equal-sign'}
-                end;
-            _Else -> {error, 'bad-format-first-char-not-a-letter'}
-          end;
-       true -> {error, 'bad-format-attribute-too-short'}
-    end.
+parse_attribute(<<Name, $=, Val/binary>>) when Val /= <<>> ->
+    case is_alpha(Name) of
+       true -> {Name, Val};
+       false -> {error, bad_attribute}
+    end;
+parse_attribute(_) ->
+    {error, bad_attribute}.
 
 unescape_username(<<"">>) -> <<"">>;
 unescape_username(EscapedUsername) ->