]> granicus.if.org Git - ejabberd/commitdiff
Add DB backend support for ejabberd_oauth
authorAlexey Shchepin <alexey@process-one.net>
Wed, 20 Jul 2016 13:55:45 +0000 (16:55 +0300)
committerAlexey Shchepin <alexey@process-one.net>
Mon, 25 Jul 2016 17:08:30 +0000 (20:08 +0300)
include/ejabberd_oauth.hrl [new file with mode: 0644]
src/ejabberd_oauth.erl
src/ejabberd_oauth_mnesia.erl [new file with mode: 0644]

diff --git a/include/ejabberd_oauth.hrl b/include/ejabberd_oauth.hrl
new file mode 100644 (file)
index 0000000..6b5a9bc
--- /dev/null
@@ -0,0 +1,26 @@
+%%%----------------------------------------------------------------------
+%%%
+%%% ejabberd, Copyright (C) 2002-2016   ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
+%%%----------------------------------------------------------------------
+
+-record(oauth_token, {
+          token = <<"">>           :: binary() | '_',
+          us = {<<"">>, <<"">>}    :: {binary(), binary()} | '_',
+          scope = []               :: [binary()] | '_',
+          expire                   :: integer() | '$1'
+         }).
index 81b5f4156bdfdfd68e2e4968c7908d577ef5cbd6..d4b1ff87ebf23f1655ac2f2dba088cb28cc1686d 100644 (file)
@@ -56,6 +56,7 @@
 
 -include("ejabberd_http.hrl").
 -include("ejabberd_web_admin.hrl").
+-include("ejabberd_oauth.hrl").
 
 -include("ejabberd_commands.hrl").
 
 %%   * Using the web form/api results in the token being generated in behalf of the user providing the user/pass
 %%   * Using the command line and oauth_issue_token command, the token is generated in behalf of ejabberd' sysadmin
 %%    (as it has access to ejabberd command line).
--record(oauth_token, {
-          token = {<<"">>, <<"">>} :: {binary(), binary()},
-          us = {<<"">>, <<"">>}    :: {binary(), binary()},
-          scope = []               :: [binary()],
-          expire                   :: integer()
-         }).
 
 -define(EXPIRE, 3600).
 
 start() ->
-    init_db(mnesia, ?MYNAME),
+    DBMod = get_db_backend(),
+    DBMod:init(),
     Expire = expire(),
     application:set_env(oauth2, backend, ejabberd_oauth),
     application:set_env(oauth2, expiry_time, Expire),
@@ -172,15 +168,8 @@ handle_cast(_Msg, State) -> {noreply, State}.
 handle_info(clean, State) ->
     {MegaSecs, Secs, MiniSecs} = os:timestamp(),
     TS = 1000000 * MegaSecs + Secs,
-    F = fun() ->
-               Ts = mnesia:select(
-                      oauth_token,
-                      [{#oauth_token{expire = '$1', _ = '_'},
-                        [{'<', '$1', TS}],
-                        ['$_']}]),
-               lists:foreach(fun mnesia:delete_object/1, Ts)
-        end,
-    mnesia:async_dirty(F),
+    DBMod = get_db_backend(),
+    DBMod:clean(TS),
     erlang:send_after(trunc(expire() * 1000 * (1 + MiniSecs / 1000000)),
                       self(), clean),
     {noreply, State};
@@ -191,16 +180,6 @@ terminate(_Reason, _State) -> ok.
 code_change(_OldVsn, State, _Extra) -> {ok, State}.
 
 
-init_db(mnesia, _Host) ->
-    mnesia:create_table(oauth_token,
-                        [{disc_copies, [node()]},
-                         {attributes,
-                          record_info(fields, oauth_token)}]),
-    mnesia:add_table_copy(oauth_token, node(), disc_copies);
-init_db(_, _) ->
-    ok.
-
-
 get_client_identity(Client, Ctx) -> {ok, {Ctx, {client, Client}}}.
 
 verify_redirection_uri(_, _, Ctx) -> {ok, Ctx}.
@@ -305,7 +284,8 @@ associate_access_token(AccessToken, Context, AppContext) ->
       scope = Scope,
       expire = Expire
      },
-    mnesia:dirty_write(R),
+    DBMod = get_db_backend(),
+    DBMod:store(R),
     {ok, AppContext}.
 
 associate_refresh_token(_RefreshToken, _Context, AppContext) ->
@@ -315,10 +295,11 @@ associate_refresh_token(_RefreshToken, _Context, AppContext) ->
 check_token(User, Server, ScopeList, Token) ->
     LUser = jid:nodeprep(User),
     LServer = jid:nameprep(Server),
-    case catch mnesia:dirty_read(oauth_token, Token) of
-        [#oauth_token{us = {LUser, LServer},
-                      scope = TokenScope,
-                      expire = Expire}] ->
+    DBMod = get_db_backend(),
+    case DBMod:lookup(Token) of
+        #oauth_token{us = {LUser, LServer},
+                     scope = TokenScope,
+                     expire = Expire} ->
             {MegaSecs, Secs, _} = os:timestamp(),
             TS = 1000000 * MegaSecs + Secs,
             TokenScopeSet = oauth2_priv_set:new(TokenScope),
@@ -330,10 +311,11 @@ check_token(User, Server, ScopeList, Token) ->
     end.
 
 check_token(ScopeList, Token) ->
-    case catch mnesia:dirty_read(oauth_token, Token) of
-        [#oauth_token{us = US,
-                      scope = TokenScope,
-                      expire = Expire}] ->
+    DBMod = get_db_backend(),
+    case DBMod:lookup(Token) of
+        #oauth_token{us = US,
+                     scope = TokenScope,
+                     expire = Expire} ->
             {MegaSecs, Secs, _} = os:timestamp(),
             TS = 1000000 * MegaSecs + Secs,
             TokenScopeSet = oauth2_priv_set:new(TokenScope),
@@ -548,6 +530,15 @@ process(_Handlers,
 process(_Handlers, _Request) ->
     ejabberd_web:error(not_found).
 
+-spec get_db_backend() -> module().
+
+get_db_backend() ->
+    DBType = ejabberd_config:get_option(
+               oauth_db_type,
+               fun(T) -> ejabberd_config:v_db(?MODULE, T) end,
+               mnesia),
+    list_to_atom("ejabberd_oauth_" ++ atom_to_list(DBType)).
+
 
 %% Headers as per RFC 6749 
 json_response(Code, Body) ->
@@ -688,4 +679,6 @@ opt_type(oauth_expire) ->
     fun(I) when is_integer(I), I >= 0 -> I end;
 opt_type(oauth_access) ->
     fun acl:access_rules_validator/1;
-opt_type(_) -> [oauth_expire, oauth_access].
+opt_type(oauth_db_type) ->
+    fun(T) -> ejabberd_config:v_db(?MODULE, T) end;
+opt_type(_) -> [oauth_expire, oauth_access, oauth_db_type].
diff --git a/src/ejabberd_oauth_mnesia.erl b/src/ejabberd_oauth_mnesia.erl
new file mode 100644 (file)
index 0000000..a23f443
--- /dev/null
@@ -0,0 +1,65 @@
+%%%-------------------------------------------------------------------
+%%% File    : ejabberd_oauth_mnesia.erl
+%%% Author  : Alexey Shchepin <alexey@process-one.net>
+%%% Purpose : OAUTH2 mnesia backend
+%%% Created : 20 Jul 2016 by Alexey Shchepin <alexey@process-one.net>
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2016   ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License
+%%% along with this program; if not, write to the Free Software
+%%% Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
+%%% 02111-1307 USA
+%%%
+%%%-------------------------------------------------------------------
+
+-module(ejabberd_oauth_mnesia).
+
+-export([init/0,
+         store/1,
+         lookup/1,
+         clean/1]).
+
+-include("ejabberd_oauth.hrl").
+
+init() ->
+    mnesia:create_table(oauth_token,
+                        [{disc_copies, [node()]},
+                         {attributes,
+                          record_info(fields, oauth_token)}]),
+    mnesia:add_table_copy(oauth_token, node(), disc_copies),
+    ok.
+
+store(R) ->
+    mnesia:dirty_write(R).
+
+lookup(Token) ->
+    case catch mnesia:dirty_read(oauth_token, Token) of
+        [R] ->
+            R;
+        _ ->
+            false
+    end.
+
+clean(TS) ->
+    F = fun() ->
+               Ts = mnesia:select(
+                      oauth_token,
+                      [{#oauth_token{expire = '$1', _ = '_'},
+                        [{'<', '$1', TS}],
+                        ['$_']}]),
+               lists:foreach(fun mnesia:delete_object/1, Ts)
+        end,
+    mnesia:async_dirty(F).
+