Add SCRAM support to ejabberd_auth_odbc
authorAlexey Shchepin <alexey@process-one.net>
Tue, 17 Feb 2015 20:26:31 +0000 (23:26 +0300)
committerAlexey Shchepin <alexey@process-one.net>
Fri, 20 Mar 2015 10:45:24 +0000 (13:45 +0300)
sql/mysql.sql
sql/pg.sql
src/ejabberd_auth_odbc.erl
src/odbc_queries.erl

index e591092f75794d53afef4bca3ed4b40434d93dfb..c79d3062b2fa0bbd73817398eeede7a56fcb69e2 100644 (file)
@@ -22,6 +22,10 @@ CREATE TABLE users (
     created_at timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP
 ) ENGINE=InnoDB CHARACTER SET utf8;
 
+-- To support SCRAM auth:
+-- ALTER TABLE users ADD COLUMN serverkey text NOT NULL DEFAULT '';
+-- ALTER TABLE users ADD COLUMN salt text NOT NULL DEFAULT '';
+-- ALTER TABLE users ADD COLUMN iterationcount integer NOT NULL DEFAULT 0;
 
 CREATE TABLE last (
     username varchar(250) PRIMARY KEY,
index 9df8ffd9f5d865e946272820f5e7701b3e5c6c56..8412c3c6b513970aaa8776f97dca51a535b0f613 100644 (file)
@@ -22,6 +22,10 @@ CREATE TABLE users (
     created_at TIMESTAMP NOT NULL DEFAULT now()
 );
 
+-- To support SCRAM auth:
+-- ALTER TABLE users ADD COLUMN serverkey text NOT NULL DEFAULT '';
+-- ALTER TABLE users ADD COLUMN salt text NOT NULL DEFAULT '';
+-- ALTER TABLE users ADD COLUMN iterationcount integer NOT NULL DEFAULT 0;
 
 CREATE TABLE last (
     username text PRIMARY KEY,
index aea039c1bba1ad0eda5a0725a0865bd1806fb20e..23694ea875b0d9b2197dad23766919acc3c487f4 100644 (file)
 -include("ejabberd.hrl").
 -include("logger.hrl").
 
+-define(SALT_LENGTH, 16).
+
 %%%----------------------------------------------------------------------
 %%% API
 %%%----------------------------------------------------------------------
 start(_Host) -> ok.
 
-plain_password_required() -> false.
+plain_password_required() ->
+    case is_scrammed() of
+      false -> false;
+      true -> true
+    end.
 
-store_type() -> plain.
+store_type() ->
+    case is_scrammed() of
+      false -> plain; %% allows: PLAIN DIGEST-MD5 SCRAM
+      true -> scram %% allows: PLAIN SCRAM
+    end.
 
 %% @spec (User, Server, Password) -> true | false | {error, Error}
 check_password(User, Server, Password) ->
-    case jlib:nodeprep(User) of
-      error -> false;
-      LUser ->
-         Username = ejabberd_odbc:escape(LUser),
-         LServer = jlib:nameprep(Server),
-         try odbc_queries:get_password(LServer, Username) of
-           {selected, [<<"password">>], [[Password]]} ->
-               Password /= <<"">>;
-           {selected, [<<"password">>], [[_Password2]]} ->
-               false; %% Password is not correct
-           {selected, [<<"password">>], []} ->
-               false; %% Account does not exist
-           {error, _Error} ->
-               false %% Typical error is that table doesn't exist
-         catch
-           _:_ ->
-               false %% Typical error is database not accessible
-         end
+    LServer = jlib:nameprep(Server),
+    LUser = jlib:nodeprep(User),
+    if (LUser == error) or (LServer == error) ->
+            false;
+       (LUser == <<>>) or (LServer == <<>>) ->
+            false;
+       true ->
+            Username = ejabberd_odbc:escape(LUser),
+            case is_scrammed() of
+                true ->
+                    try odbc_queries:get_password_scram(LServer, Username) of
+                        {selected, [<<"password">>, <<"serverkey">>,
+                                    <<"salt">>, <<"iterationcount">>],
+                         [[StoredKey, ServerKey, Salt, IterationCount]]} ->
+                            Scram =
+                                #scram{storedkey = StoredKey,
+                                       serverkey = ServerKey,
+                                       salt = Salt,
+                                       iterationcount = binary_to_integer(
+                                                          IterationCount)},
+                            is_password_scram_valid(Password, Scram);
+                        {selected, [<<"password">>, <<"serverkey">>,
+                                    <<"salt">>, <<"iterationcount">>], []} ->
+                            false; %% Account does not exist
+                        {error, _Error} ->
+                            false %% Typical error is that table doesn't exist
+                    catch
+                        _:_ ->
+                            false %% Typical error is database not accessible
+                    end;
+                false ->
+                    try odbc_queries:get_password(LServer, Username) of
+                        {selected, [<<"password">>], [[Password]]} ->
+                            Password /= <<"">>;
+                        {selected, [<<"password">>], [[_Password2]]} ->
+                            false; %% Password is not correct
+                        {selected, [<<"password">>], []} ->
+                            false; %% Account does not exist
+                        {error, _Error} ->
+                            false %% Typical error is that table doesn't exist
+                    catch
+                        _:_ ->
+                            false %% Typical error is database not accessible
+                    end
+            end
     end.
 
 %% @spec (User, Server, Password, Digest, DigestGen) -> true | false | {error, Error}
 check_password(User, Server, Password, Digest,
               DigestGen) ->
-    case jlib:nodeprep(User) of
-      error -> false;
-      LUser ->
-         Username = ejabberd_odbc:escape(LUser),
-         LServer = jlib:nameprep(Server),
-         try odbc_queries:get_password(LServer, Username) of
-           %% Account exists, check if password is valid
-           {selected, [<<"password">>], [[Passwd]]} ->
-               DigRes = if Digest /= <<"">> ->
-                               Digest == DigestGen(Passwd);
-                           true -> false
-                        end,
-               if DigRes -> true;
-                  true -> (Passwd == Password) and (Password /= <<"">>)
-               end;
-           {selected, [<<"password">>], []} ->
-               false; %% Account does not exist
-           {error, _Error} ->
-               false %% Typical error is that table doesn't exist
-         catch
-           _:_ ->
-               false %% Typical error is database not accessible
-         end
+    LServer = jlib:nameprep(Server),
+    LUser = jlib:nodeprep(User),
+    if (LUser == error) or (LServer == error) ->
+            false;
+       (LUser == <<>>) or (LServer == <<>>) ->
+            false;
+       true ->
+            case is_scrammed() of
+                false ->
+                    Username = ejabberd_odbc:escape(LUser),
+                    try odbc_queries:get_password(LServer, Username) of
+                        %% Account exists, check if password is valid
+                        {selected, [<<"password">>], [[Passwd]]} ->
+                            DigRes = if Digest /= <<"">> ->
+                                             Digest == DigestGen(Passwd);
+                                        true -> false
+                                     end,
+                            if DigRes -> true;
+                               true -> (Passwd == Password) and (Password /= <<"">>)
+                            end;
+                        {selected, [<<"password">>], []} ->
+                            false; %% Account does not exist
+                        {error, _Error} ->
+                            false %% Typical error is that table doesn't exist
+                    catch
+                        _:_ ->
+                            false %% Typical error is database not accessible
+                    end;
+                true ->
+                    false
+            end
     end.
 
 %% @spec (User::string(), Server::string(), Password::string()) ->
 %%       ok | {error, invalid_jid}
 set_password(User, Server, Password) ->
-    case jlib:nodeprep(User) of
-      error -> {error, invalid_jid};
-      LUser ->
-         Username = ejabberd_odbc:escape(LUser),
-         Pass = ejabberd_odbc:escape(Password),
-         LServer = jlib:nameprep(Server),
-         case catch odbc_queries:set_password_t(LServer,
-                                                Username, Pass)
-             of
-           {atomic, ok} -> ok;
-           Other -> {error, Other}
-         end
+    LServer = jlib:nameprep(Server),
+    LUser = jlib:nodeprep(User),
+    if (LUser == error) or (LServer == error) ->
+            {error, invalid_jid};
+       (LUser == <<>>) or (LServer == <<>>) ->
+            {error, invalid_jid};
+       true ->
+            Username = ejabberd_odbc:escape(LUser),
+            case is_scrammed() of
+                true ->
+                    Scram = password_to_scram(Password),
+                    case catch odbc_queries:set_password_scram_t(
+                                 LServer,
+                                 Username,
+                                 ejabberd_odbc:escape(Scram#scram.storedkey),
+                                 ejabberd_odbc:escape(Scram#scram.serverkey),
+                                 ejabberd_odbc:escape(Scram#scram.salt),
+                                 integer_to_binary(Scram#scram.iterationcount)
+                                )
+                        of
+                        {atomic, ok} -> ok;
+                        Other -> {error, Other}
+                    end;
+                false ->
+                    Pass = ejabberd_odbc:escape(Password),
+                    case catch odbc_queries:set_password_t(LServer,
+                                                           Username, Pass)
+                        of
+                        {atomic, ok} -> ok;
+                        Other -> {error, Other}
+                    end
+            end
     end.
 
 %% @spec (User, Server, Password) -> {atomic, ok} | {atomic, exists} | {error, invalid_jid}
 try_register(User, Server, Password) ->
-    case jlib:nodeprep(User) of
-      error -> {error, invalid_jid};
-      LUser ->
+    LServer = jlib:nameprep(Server),
+    LUser = jlib:nodeprep(User),
+    if (LUser == error) or (LServer == error) ->
+            {error, invalid_jid};
+       (LUser == <<>>) or (LServer == <<>>) ->
+            {error, invalid_jid};
+       true ->
          Username = ejabberd_odbc:escape(LUser),
-         Pass = ejabberd_odbc:escape(Password),
-         LServer = jlib:nameprep(Server),
-         case catch odbc_queries:add_user(LServer, Username,
-                                          Pass)
-             of
-           {updated, 1} -> {atomic, ok};
-           _ -> {atomic, exists}
-         end
+            case is_scrammed() of
+                true ->
+                    Scram = password_to_scram(Password),
+                    case catch odbc_queries:add_user_scram(
+                                 LServer,
+                                 Username,
+                                 ejabberd_odbc:escape(Scram#scram.storedkey),
+                                 ejabberd_odbc:escape(Scram#scram.serverkey),
+                                 ejabberd_odbc:escape(Scram#scram.salt),
+                                 integer_to_binary(Scram#scram.iterationcount)
+                                ) of
+                        {updated, 1} -> {atomic, ok};
+                        _ -> {atomic, exists}
+                    end;
+                false ->
+                    Pass = ejabberd_odbc:escape(Password),
+                    case catch odbc_queries:add_user(LServer, Username,
+                                                     Pass)
+                        of
+                        {updated, 1} -> {atomic, ok};
+                        _ -> {atomic, exists}
+                    end
+            end
     end.
 
 dirty_get_registered_users() ->
@@ -175,29 +259,53 @@ get_vh_registered_users_number(Server, Opts) ->
     end.
 
 get_password(User, Server) ->
-    case jlib:nodeprep(User) of
-      error -> false;
-      LUser ->
-         Username = ejabberd_odbc:escape(LUser),
-         LServer = jlib:nameprep(Server),
-         case catch odbc_queries:get_password(LServer, Username)
-             of
-           {selected, [<<"password">>], [[Password]]} -> Password;
-           _ -> false
-         end
+    LServer = jlib:nameprep(Server),
+    LUser = jlib:nodeprep(User),
+    if (LUser == error) or (LServer == error) ->
+            false;
+       (LUser == <<>>) or (LServer == <<>>) ->
+            false;
+       true ->
+            Username = ejabberd_odbc:escape(LUser),
+            case is_scrammed() of
+                true ->
+                    case catch odbc_queries:get_password_scram(
+                                 LServer, Username) of
+                        {selected, [<<"password">>, <<"serverkey">>,
+                                    <<"salt">>, <<"iterationcount">>],
+                         [[StoredKey, ServerKey, Salt, IterationCount]]} ->
+                            {jlib:decode_base64(StoredKey),
+                             jlib:decode_base64(ServerKey),
+                             jlib:decode_base64(Salt),
+                             binary_to_integer(IterationCount)};
+                        _ -> false
+                    end;
+                false ->
+                    case catch odbc_queries:get_password(LServer, Username)
+                        of
+                        {selected, [<<"password">>], [[Password]]} -> Password;
+                        _ -> false
+                    end
+            end
     end.
 
 get_password_s(User, Server) ->
-    case jlib:nodeprep(User) of
-      error -> <<"">>;
-      LUser ->
-         Username = ejabberd_odbc:escape(LUser),
-         LServer = jlib:nameprep(Server),
-         case catch odbc_queries:get_password(LServer, Username)
-             of
-           {selected, [<<"password">>], [[Password]]} -> Password;
-           _ -> <<"">>
-         end
+    LServer = jlib:nameprep(Server),
+    LUser = jlib:nodeprep(User),
+    if (LUser == error) or (LServer == error) ->
+            <<"">>;
+       (LUser == <<>>) or (LServer == <<>>) ->
+            <<"">>;
+       true ->
+            case is_scrammed() of
+                false ->
+                    Username = ejabberd_odbc:escape(LUser),
+                    case catch odbc_queries:get_password(LServer, Username) of
+                        {selected, [<<"password">>], [[Password]]} -> Password;
+                        _ -> <<"">>
+                    end;
+                true -> <<"">>
+            end
     end.
 
 %% @spec (User, Server) -> true | false | {error, Error}
@@ -234,23 +342,72 @@ remove_user(User, Server) ->
 %% @spec (User, Server, Password) -> ok | error | not_exists | not_allowed
 %% @doc Remove user if the provided password is correct.
 remove_user(User, Server, Password) ->
-    case jlib:nodeprep(User) of
-      error -> error;
-      LUser ->
-         Username = ejabberd_odbc:escape(LUser),
-         Pass = ejabberd_odbc:escape(Password),
-         LServer = jlib:nameprep(Server),
-         F = fun () ->
-                     Result = odbc_queries:del_user_return_password(LServer,
-                                                                    Username,
-                                                                    Pass),
-                     case Result of
-                       {selected, [<<"password">>], [[Password]]} -> ok;
-                       {selected, [<<"password">>], []} -> not_exists;
-                       _ -> not_allowed
-                     end
-             end,
-         {atomic, Result} = odbc_queries:sql_transaction(LServer,
-                                                         F),
-         Result
+    LServer = jlib:nameprep(Server),
+    LUser = jlib:nodeprep(User),
+    if (LUser == error) or (LServer == error) ->
+            error;
+       (LUser == <<>>) or (LServer == <<>>) ->
+            error;
+       true ->
+            case is_scrammed() of
+                true ->
+                    case check_password(User, Server, Password) of
+                        true ->
+                            remove_user(User, Server),
+                            ok;
+                        false -> not_allowed
+                    end;
+                false ->
+                    Username = ejabberd_odbc:escape(LUser),
+                    Pass = ejabberd_odbc:escape(Password),
+                    F = fun () ->
+                                Result = odbc_queries:del_user_return_password(
+                                           LServer, Username, Pass),
+                                case Result of
+                                    {selected, [<<"password">>],
+                                     [[Password]]} -> ok;
+                                    {selected, [<<"password">>],
+                                     []} -> not_exists;
+                                    _ -> not_allowed
+                                end
+                        end,
+                    {atomic, Result} = odbc_queries:sql_transaction(
+                                         LServer, F),
+                    Result
+            end
     end.
+
+%%%
+%%% SCRAM
+%%%
+
+is_scrammed() ->
+    scram ==
+      ejabberd_config:get_option({auth_password_format, ?MYNAME},
+                                 fun(V) -> V end).
+
+password_to_scram(Password) ->
+    password_to_scram(Password,
+                     ?SCRAM_DEFAULT_ITERATION_COUNT).
+
+password_to_scram(Password, IterationCount) ->
+    Salt = crypto:rand_bytes(?SALT_LENGTH),
+    SaltedPassword = scram:salted_password(Password, Salt,
+                                          IterationCount),
+    StoredKey =
+       scram:stored_key(scram:client_key(SaltedPassword)),
+    ServerKey = scram:server_key(SaltedPassword),
+    #scram{storedkey = jlib:encode_base64(StoredKey),
+          serverkey = jlib:encode_base64(ServerKey),
+          salt = jlib:encode_base64(Salt),
+          iterationcount = IterationCount}.
+
+is_password_scram_valid(Password, Scram) ->
+    IterationCount = Scram#scram.iterationcount,
+    Salt = jlib:decode_base64(Scram#scram.salt),
+    SaltedPassword = scram:salted_password(Password, Salt,
+                                          IterationCount),
+    StoredKey =
+       scram:stored_key(scram:client_key(SaltedPassword)),
+    jlib:decode_base64(Scram#scram.storedkey) == StoredKey.
+
index f2771e52f31fb2e9f64e870a44f9e76465370a94..106e09940be49d8e1cf5627e28f71ce0f5c71a28 100644 (file)
 -author("mremond@process-one.net").
 
 -export([get_db_type/0, update/5, update_t/4, sql_transaction/2,
-        get_last/2, set_last_t/4, del_last/2, get_password/2,
-        set_password_t/3, add_user/3, del_user/2,
+        get_last/2, set_last_t/4, del_last/2,
+         get_password/2, get_password_scram/2,
+        set_password_t/3, set_password_scram_t/6,
+        add_user/3, add_user_scram/6, del_user/2,
         del_user_return_password/3, list_users/1, list_users/2,
         users_number/1, users_number/2, add_spool_sql/2,
         add_spool/2, get_and_del_spool_msg_t/2, del_spool_msg/2,
@@ -157,6 +159,12 @@ get_password(LServer, Username) ->
                            [<<"select password from users where username='">>,
                             Username, <<"';">>]).
 
+get_password_scram(LServer, Username) ->
+    ejabberd_odbc:sql_query(
+      LServer,
+      [<<"select password, serverkey, salt, iterationcount from users where "
+        "username='">>, Username, <<"';">>]).
+
 set_password_t(LServer, Username, Pass) ->
     ejabberd_odbc:sql_transaction(LServer,
                                  fun () ->
@@ -168,12 +176,39 @@ set_password_t(LServer, Username, Pass) ->
                                                    <<"'">>])
                                  end).
 
+set_password_scram_t(LServer, Username,
+                     StoredKey, ServerKey, Salt, IterationCount) ->
+    ejabberd_odbc:sql_transaction(LServer,
+                                 fun () ->
+                                         update_t(<<"users">>,
+                                                  [<<"username">>,
+                                                   <<"password">>,
+                                                   <<"serverkey">>,
+                                                   <<"salt">>,
+                                                   <<"iterationcount">>],
+                                                  [Username, StoredKey,
+                                                    ServerKey, Salt,
+                                                    IterationCount],
+                                                  [<<"username='">>, Username,
+                                                   <<"'">>])
+                                 end).
+
 add_user(LServer, Username, Pass) ->
     ejabberd_odbc:sql_query(LServer,
                            [<<"insert into users(username, password) "
                               "values ('">>,
                             Username, <<"', '">>, Pass, <<"');">>]).
 
+add_user_scram(LServer, Username,
+               StoredKey, ServerKey, Salt, IterationCount) ->
+    ejabberd_odbc:sql_query(LServer,
+                           [<<"insert into users(username, password, serverkey, salt, iterationcount) "
+                              "values ('">>,
+                            Username, <<"', '">>, StoredKey, <<"', '">>,
+                             ServerKey, <<"', '">>,
+                             Salt, <<"', '">>,
+                             IterationCount, <<"');">>]).
+
 del_user(LServer, Username) ->
     ejabberd_odbc:sql_query(LServer,
                            [<<"delete from users where username='">>, Username,