]> granicus.if.org Git - ejabberd/commitdiff
Use new cache API in ejabberd_oauth
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Fri, 21 Apr 2017 06:02:10 +0000 (09:02 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Fri, 21 Apr 2017 06:02:10 +0000 (09:02 +0300)
src/ejabberd_oauth.erl
src/ejabberd_oauth_mnesia.erl
src/ejabberd_oauth_rest.erl
src/ejabberd_oauth_sql.erl

index 3a0b276d1661328ae63e505020afd06d234e2e5b..8527c9271a3a952bc905479f49986fdeb0a3f705 100644 (file)
@@ -28,6 +28,8 @@
 
 -behaviour(gen_server).
 
+-compile(export_all).
+
 %% gen_server callbacks
 -export([init/1, handle_call/3, handle_cast/2,
         handle_info/2, terminate/2, code_change/3]).
@@ -46,6 +48,7 @@
          check_token/2,
          scope_in_scope_list/2,
          process/2,
+        config_reloaded/0,
          opt_type/1]).
 
 -export([oauth_issue_token/3, oauth_list_tokens/0, oauth_revoke_token/1, oauth_list_scopes/0]).
@@ -140,8 +143,14 @@ oauth_revoke_token(Token) ->
 oauth_list_scopes() ->
     [ {Scope, string:join([atom_to_list(Cmd) || Cmd <- Cmds], ",")}   || {Scope, Cmds} <- dict:to_list(get_cmd_scopes())].
 
-
-
+config_reloaded() ->
+    DBMod = get_db_backend(),
+    case init_cache(DBMod) of
+       true ->
+           ets_cache:setopts(oauth_cache, cache_opts());
+       false ->
+           ok
+    end.
 
 start_link() ->
     gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
@@ -150,23 +159,13 @@ start_link() ->
 init([]) ->
     DBMod = get_db_backend(),
     DBMod:init(),
-    MaxSize =
-        ejabberd_config:get_option(
-          oauth_cache_size,
-          fun(I) when is_integer(I), I>0 -> I end,
-          1000),
-    LifeTime =
-        ejabberd_config:get_option(
-          oauth_cache_life_time,
-          fun(I) when is_integer(I), I>0 -> I end,
-          timer:hours(1) div 1000),
-    cache_tab:new(oauth_token,
-                 [{max_size, MaxSize}, {life_time, LifeTime}]),
+    init_cache(DBMod),
     Expire = expire(),
     application:set_env(oauth2, backend, ejabberd_oauth),
     application:set_env(oauth2, expiry_time, Expire),
     application:start(oauth2),
     ejabberd_commands:register_commands(get_commands_spec()),
+    ejabberd_hooks:add(config_reloaded, ?MODULE, config_reloaded, 50),
     erlang:send_after(expire() * 1000, self(), clean),
     {ok, ok}.
 
@@ -371,24 +370,59 @@ check_token(ScopeList, Token) ->
 
 
 store(R) ->
-    cache_tab:insert(
-      oauth_token, R#oauth_token.token, R,
-      fun() ->
-              DBMod = get_db_backend(),
-              DBMod:store(R)
-      end).
+    DBMod = get_db_backend(),
+    case DBMod:store(R) of
+       ok ->
+           ets_cache:delete(oauth_cache, R#oauth_token.token,
+                            ejabberd_cluster:get_nodes());
+       {error, _} = Err ->
+           Err
+    end.
 
 lookup(Token) ->
-    cache_tab:lookup(
-      oauth_token, Token,
-      fun() ->
-              DBMod = get_db_backend(),
-              case DBMod:lookup(Token) of
-                  #oauth_token{} = R -> {ok, R};
-                  _ -> error
-              end
-      end).
+    ets_cache:lookup(oauth_cache, Token,
+                    fun() ->
+                            DBMod = get_db_backend(),
+                            DBMod:lookup(Token)
+                    end).
+
+-spec init_cache(module()) -> boolean().
+init_cache(DBMod) ->
+    UseCache = use_cache(DBMod),
+    case UseCache of
+       true ->
+           ets_cache:new(oauth_cache, cache_opts());
+       false ->
+           ets_cache:delete(oauth_cache)
+    end,
+    UseCache.
+
+use_cache(DBMod) ->
+    case erlang:function_exported(DBMod, use_cache, 0) of
+       true -> DBMod:use_cache();
+       false ->
+           ejabberd_config:get_option(
+             oauth_use_cache, opt_type(oauth_use_cache),
+             ejabberd_config:use_cache(global))
+    end.
 
+cache_opts() ->
+    MaxSize = ejabberd_config:get_option(
+               oauth_cache_size,
+               opt_type(oauth_cache_size),
+               ejabberd_config:cache_size(global)),
+    CacheMissed = ejabberd_config:get_option(
+                   oauth_cache_missed,
+                   opt_type(oauth_cache_missed),
+                   ejabberd_config:cache_missed(global)),
+    LifeTime = case ejabberd_config:get_option(
+                     oauth_cache_life_time,
+                     opt_type(oauth_cache_life_time),
+                     ejabberd_config:cache_life_time(global)) of
+                  infinity -> infinity;
+                  I -> timer:seconds(I)
+              end,
+    [{max_size, MaxSize}, {life_time, LifeTime}, {cache_missed, CacheMissed}].
 
 expire() ->
     ejabberd_config:get_option(
@@ -746,8 +780,13 @@ opt_type(oauth_access) ->
     fun acl:access_rules_validator/1;
 opt_type(oauth_db_type) ->
     fun(T) -> ejabberd_config:v_db(?MODULE, T) end;
-opt_type(oauth_cache_life_time) ->
-    fun (I) when is_integer(I), I > 0 -> I end;
-opt_type(oauth_cache_size) ->
-    fun (I) when is_integer(I), I > 0 -> I end;
-opt_type(_) -> [oauth_expire, oauth_access, oauth_db_type].
+opt_type(O) when O == oauth_cache_life_time; O == oauth_cache_size ->
+    fun (I) when is_integer(I), I > 0 -> I;
+       (infinity) -> infinity
+    end;
+opt_type(O) when O == oauth_use_cache; O == oauth_cache_missed ->
+    fun (B) when is_boolean(B) -> B end;
+opt_type(_) ->
+    [oauth_expire, oauth_access, oauth_db_type,
+     oauth_cache_life_time, oauth_cache_size, oauth_use_cache,
+     oauth_cache_missed].
index c9ef6dcac094632025888a044cedc39628a44a9e..8a99979294bc622e5b2bc76a621b46038e46793f 100644 (file)
@@ -47,9 +47,9 @@ store(R) ->
 lookup(Token) ->
     case catch mnesia:dirty_read(oauth_token, Token) of
         [R] ->
-            R;
+            {ok, R};
         _ ->
-            false
+            error
     end.
 
 clean(TS) ->
index b9614eb096d427384b09c4d985abbab5a901d3e2..15e118a0bccc99e6972d9b141f76f2daf9b833dd 100644 (file)
@@ -58,7 +58,7 @@ store(R) ->
             ok;
         Err ->
             ?ERROR_MSG("failed to store oauth record ~p: ~p", [R, Err]),
-            {error, Err}
+            {error, db_failure}
     end.
 
 lookup(Token) ->
@@ -72,15 +72,15 @@ lookup(Token) ->
             US = {JID#jid.luser, JID#jid.lserver},
             Scope = proplists:get_value(<<"scope">>, Data, []),
             Expire = proplists:get_value(<<"expire">>, Data, 0),
-            #oauth_token{token = Token,
-                         us = US,
-                         scope = Scope,
-                         expire = Expire};
+            {ok, #oauth_token{token = Token,
+                             us = US,
+                             scope = Scope,
+                             expire = Expire}};
         {ok, 404, _Resp} ->
-            false;
+            error;
         Other ->
             ?ERROR_MSG("Unexpected response for oauth lookup: ~p", [Other]),
-            {error, rest_failed}
+           error
     end.
 
 clean(_TS) ->
index 10ca49844ff319eb8faf344f0c150d8982405254..5c4a9664165026f9d577a1f22da616ab6d319421 100644 (file)
@@ -37,6 +37,7 @@
 -include("ejabberd.hrl").
 -include("ejabberd_sql_pt.hrl").
 -include("jid.hrl").
+-include("logger.hrl").
 
 init() ->
     ok.
@@ -47,13 +48,20 @@ store(R) ->
     SJID = jid:encode({User, Server, <<"">>}),
     Scope = str:join(R#oauth_token.scope, <<" ">>),
     Expire = R#oauth_token.expire,
-    ?SQL_UPSERT(
-       ?MYNAME,
-       "oauth_token",
-       ["!token=%(Token)s",
-        "jid=%(SJID)s",
-        "scope=%(Scope)s",
-        "expire=%(Expire)d"]).
+    case ?SQL_UPSERT(
+           ?MYNAME,
+           "oauth_token",
+           ["!token=%(Token)s",
+            "jid=%(SJID)s",
+            "scope=%(Scope)s",
+            "expire=%(Expire)d"]) of
+       ok ->
+           ok;
+       Err ->
+           ?ERROR_MSG("Failed to write to SQL 'oauth_token' table: ~p",
+                      [Err]),
+           {error, db_failure}
+    end.
 
 lookup(Token) ->
     case ejabberd_sql:sql_query(
@@ -63,12 +71,12 @@ lookup(Token) ->
         {selected, [{SJID, Scope, Expire}]} ->
             JID = jid:decode(SJID),
             US = {JID#jid.luser, JID#jid.lserver},
-            #oauth_token{token = Token,
-                         us = US,
-                         scope = str:tokens(Scope, <<" ">>),
-                         expire = Expire};
+            {ok, #oauth_token{token = Token,
+                             us = US,
+                             scope = str:tokens(Scope, <<" ">>),
+                             expire = Expire}};
         _ ->
-            false
+            error
     end.
 
 clean(TS) ->