]> granicus.if.org Git - ejabberd/commitdiff
Use cache in front of Redis/SQL RAM backends
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Fri, 14 Apr 2017 10:57:52 +0000 (13:57 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Fri, 14 Apr 2017 10:57:52 +0000 (13:57 +0300)
26 files changed:
include/bosh.hrl
include/ejabberd_router.hrl
include/ejabberd_sm.hrl
include/mod_carboncopy.hrl
rebar.config
src/ejabberd_cluster.erl
src/ejabberd_config.erl
src/ejabberd_redis.erl
src/ejabberd_redis_sup.erl
src/ejabberd_router.erl
src/ejabberd_router_mnesia.erl
src/ejabberd_router_redis.erl
src/ejabberd_router_sql.erl
src/ejabberd_sm.erl
src/ejabberd_sm_mnesia.erl
src/ejabberd_sm_redis.erl
src/ejabberd_sm_sql.erl
src/mod_bosh.erl
src/mod_bosh_mnesia.erl
src/mod_bosh_redis.erl
src/mod_bosh_sql.erl
src/mod_carboncopy.erl
src/mod_carboncopy_mnesia.erl
src/mod_carboncopy_redis.erl
src/mod_carboncopy_sql.erl
src/randoms.erl

index 3f9095e58481159e147760af8ef331391a690f8d..d95784c08408a7c7c429074d7023b1eef0c0024a 100644 (file)
@@ -47,3 +47,5 @@
 
 -define(HEADER(CType),
        [CType, ?AC_ALLOW_ORIGIN, ?AC_ALLOW_HEADERS]).
+
+-define(BOSH_CACHE, bosh_cache).
index f22bd723ba76af6366f722cbff8f8b35f49633d1..04ea6e304a83d07a740ffb9870a7eaedf545f73e 100644 (file)
@@ -1,3 +1,5 @@
+-define(ROUTES_CACHE, routes_cache).
+
 -type local_hint() :: integer() | {apply, atom(), atom()}.
 
 -record(route, {domain :: binary() | '_',
index 71cfc9ee9a18cecdb1b842a63a6e192c40f7761a..377c98a4dbefc5318ce26dc98231118f3b0e3394 100644 (file)
@@ -21,6 +21,8 @@
 -ifndef(EJABBERD_SM_HRL).
 -define(EJABBERD_SM_HRL, true).
 
+-define(SM_CACHE, sm_cache).
+
 -record(session, {sid, usr, us, priority, info = []}).
 -record(session_counter, {vhost, count}).
 -type sid() :: {erlang:timestamp(), pid()}.
index b58a5044e2b32dc3280e066279977244351071fb..1da76ffbc02b88691ad496f6c6e091b793ef64ae 100644 (file)
@@ -22,3 +22,5 @@
 -record(carboncopy, {us       :: {binary(), binary()} | matchspec_atom(), 
                     resource :: binary() | matchspec_atom(),
                     version  :: binary() | matchspec_atom()}).
+
+-define(CARBONCOPY_CACHE, carboncopy_cache).
index 03cf910b1d4c4ef872161b57d8e967ab3f410d10..05d5d29a21e5349b1c576ed0d4f27676e712527d 100644 (file)
@@ -20,7 +20,7 @@
 
 {deps, [{lager, ".*", {git, "https://github.com/basho/lager", {tag, "3.2.1"}}},
         {p1_utils, ".*", {git, "https://github.com/processone/p1_utils", {tag, "1.0.8"}}},
-        {cache_tab, ".*", {git, "https://github.com/processone/cache_tab", {tag, "1.0.7"}}},
+        {cache_tab, ".*", {git, "https://github.com/processone/cache_tab", "35cc9904fde"}},
         {fast_tls, ".*", {git, "https://github.com/processone/fast_tls", {tag, "1.0.11"}}},
         {stringprep, ".*", {git, "https://github.com/processone/stringprep", {tag, "1.0.8"}}},
         {fast_xml, ".*", {git, "https://github.com/processone/fast_xml", {tag, "1.1.21"}}},
index a331a008438fbf14f8b91016af8933a83e6324c9..aeae294b0e9566cffdeff1174bfd8c105f5d79e9 100644 (file)
@@ -26,7 +26,8 @@
 -module(ejabberd_cluster).
 
 %% API
--export([get_nodes/0, call/4, multicall/3, multicall/4]).
+-export([get_nodes/0, call/4, multicall/3, multicall/4,
+        eval_everywhere/3, eval_everywhere/4]).
 -export([join/1, leave/1, get_known_nodes/0]).
 -export([node_id/0, get_node_by_id/1]).
 
@@ -59,6 +60,18 @@ multicall(Module, Function, Args) ->
 multicall(Nodes, Module, Function, Args) ->
     rpc:multicall(Nodes, Module, Function, Args, 5000).
 
+-spec eval_everywhere(module(), atom(), [any()]) -> ok.
+
+eval_everywhere(Module, Function, Args) ->
+    eval_everywhere(get_nodes(), Module, Function, Args),
+    ok.
+
+-spec eval_everywhere([node()], module(), atom(), [any()]) -> ok.
+
+eval_everywhere(Nodes, Module, Function, Args) ->
+    rpc:eval_everywhere(Nodes, Module, Function, Args),
+    ok.
+
 -spec join(node()) -> ok | {error, any()}.
 
 join(Node) ->
index 03b893ac36b0deb5765c44778b4899931a69ab89..1f357c2940444043b9c8d9eb13ec9dc25d7659cb 100644 (file)
@@ -37,7 +37,8 @@
         env_binary_to_list/2, opt_type/1, may_hide_data/1,
         is_elixir_enabled/0, v_dbs/1, v_dbs_mods/1,
         default_db/1, default_db/2, default_ram_db/1, default_ram_db/2,
-        default_queue_type/1, queue_dir/0, fsm_limit_opts/1]).
+        default_queue_type/1, queue_dir/0, fsm_limit_opts/1,
+        use_cache/1, cache_size/1, cache_missed/1, cache_life_time/1]).
 
 -export([start/2]).
 
@@ -1460,9 +1461,24 @@ opt_type(queue_dir) ->
     fun iolist_to_binary/1;
 opt_type(queue_type) ->
     fun(ram) -> ram; (file) -> file end;
+opt_type(use_cache) ->
+    fun(B) when is_boolean(B) -> B end;
+opt_type(cache_size) ->
+    fun(I) when is_integer(I), I>0 -> I;
+       (infinity) -> infinity;
+       (unlimited) -> infinity
+    end;
+opt_type(cache_missed) ->
+    fun(B) when is_boolean(B) -> B end;
+opt_type(cache_life_time) ->
+    fun(I) when is_integer(I), I>0 -> I;
+       (infinity) -> infinity;
+       (unlimited) -> infinity
+    end;
 opt_type(_) ->
     [hide_sensitive_log_data, hosts, language, max_fsm_queue,
-     default_db, default_ram_db, queue_type, queue_dir, loglevel].
+     default_db, default_ram_db, queue_type, queue_dir, loglevel,
+     use_cache, cache_size, cache_missed, cache_life_time].
 
 -spec may_hide_data(any()) -> any().
 may_hide_data(Data) ->
@@ -1499,3 +1515,20 @@ queue_dir() ->
 -spec default_queue_type(binary()) -> ram | file.
 default_queue_type(Host) ->
     get_option({queue_type, Host}, opt_type(queue_type), ram).
+
+-spec use_cache(binary() | global) -> boolean().
+use_cache(Host) ->
+    get_option({use_cache, Host}, opt_type(use_cache), true).
+
+-spec cache_size(binary() | global) -> pos_integer() | infinity.
+cache_size(Host) ->
+    get_option({cache_size, Host}, opt_type(cache_size), 1000).
+
+-spec cache_missed(binary() | global) -> boolean().
+cache_missed(Host) ->
+    get_option({cache_missed, Host}, opt_type(cache_missed), true).
+
+-spec cache_life_time(binary() | global) -> pos_integer() | infinity.
+%% NOTE: the integer value returned is in *seconds*
+cache_life_time(Host) ->
+    get_option({cache_life_time, Host}, opt_type(cache_life_time), 3600).
index e7cc74d98cf6716dbfc1819eecede809ea4b4944..bd85f0ee56ff320cd2828904a2c333b6f41c77c1 100644 (file)
 -compile({no_auto_import, [get/1, put/2]}).
 
 %% API
--export([start_link/1, get_proc/1, q/1, qp/1, format_error/1]).
+-export([start_link/1, get_proc/1, get_connection/1, q/1, qp/1, format_error/1]).
 %% Commands
 -export([multi/1, get/1, set/2, del/1,
         sadd/2, srem/2, smembers/1, sismember/2, scard/1,
-        hget/2, hset/3, hdel/2, hlen/1, hgetall/1, hkeys/1]).
+        hget/2, hset/3, hdel/2, hlen/1, hgetall/1, hkeys/1,
+        subscribe/1, publish/2]).
 
 %% gen_server callbacks
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
 
 -record(state, {connection :: pid() | undefined,
                num :: pos_integer(),
+               subscriptions = #{} :: map(),
                pending_q :: p1_queue:queue()}).
 
--type redis_error() :: {error, binary() | timeout | disconnected | overloaded}.
+-type error_reason() :: binary() | timeout | disconnected | overloaded.
+-type redis_error() :: {error, error_reason()}.
 -type redis_reply() :: binary() | [binary()].
 -type redis_command() :: [binary()].
 -type redis_pipeline() :: [redis_command()].
 -type state() :: #state{}.
 
+-export_type([error_reason/0]).
+
 %%%===================================================================
 %%% API
 %%%===================================================================
@@ -79,11 +84,11 @@ get_connection(I) ->
 
 -spec q(redis_command()) -> {ok, redis_reply()} | redis_error().
 q(Command) ->
-    call(get_worker(), {q, Command}, ?MAX_RETRIES).
+    call(get_rnd_id(), {q, Command}, ?MAX_RETRIES).
 
 -spec qp(redis_pipeline()) -> {ok, [redis_reply()]} | redis_error().
 qp(Pipeline) ->
-    call(get_worker(), {qp, Pipeline}, ?MAX_RETRIES).
+    call(get_rnd_id(), {qp, Pipeline}, ?MAX_RETRIES).
 
 -spec multi(fun(() -> any())) -> {ok, [redis_reply()]} | redis_error().
 multi(F) ->
@@ -288,6 +293,30 @@ hkeys(Key) ->
            erlang:error(transaction_unsupported)
     end.
 
+-spec subscribe([binary()]) -> ok | redis_error().
+subscribe(Channels) ->
+    try ?GEN_SERVER:call(get_proc(1), {subscribe, self(), Channels}, ?CALL_TIMEOUT)
+    catch exit:{Why, {?GEN_SERVER, call, _}} ->
+           Reason = case Why of
+                        timeout -> timeout;
+                        _ -> disconnected
+                    end,
+           {error, Reason}
+    end.
+
+-spec publish(iodata(), iodata()) -> {ok, non_neg_integer()} | redis_error() | queued.
+publish(Channel, Data) ->
+    Cmd = [<<"PUBLISH">>, Channel, Data],
+    case erlang:get(?TR_STACK) of
+       undefined ->
+           case q(Cmd) of
+               {ok, N} -> {ok, binary_to_integer(N)};
+               {error, _} = Err -> Err
+           end;
+       Stack ->
+           tr_enq(Cmd, Stack)
+    end.
+
 %%%===================================================================
 %%% gen_server callbacks
 %%%===================================================================
@@ -315,6 +344,15 @@ handle_call(connect, From, #state{connection = Pid} = State) ->
            self() ! connect,
            handle_call(connect, From, State#state{connection = undefined})
     end;
+handle_call({subscribe, Caller, Channels}, _From,
+           #state{connection = Pid, subscriptions = Subs} = State) ->
+    Subs1 = lists:foldl(
+             fun(Channel, Acc) ->
+                     Callers = maps:get(Channel, Acc, []) -- [Caller],
+                     maps:put(Channel, [Caller|Callers], Acc)
+             end, Subs, Channels),
+    eredis_subscribe(Pid, Channels),
+    {reply, ok, State#state{subscriptions = Subs1}};
 handle_call(Request, _From, State) ->
     ?WARNING_MSG("unexepected call: ~p", [Request]),
     {noreply, State}.
@@ -326,6 +364,7 @@ handle_info(connect, #state{connection = undefined} = State) ->
     NewState = case connect(State) of
                   {ok, Connection} ->
                       Q1 = flush_queue(State#state.pending_q),
+                      re_subscribe(Connection, State#state.subscriptions),
                       State#state{connection = Connection, pending_q = Q1};
                   {error, _} ->
                       State
@@ -342,6 +381,31 @@ handle_info({'EXIT', Pid, _}, State) ->
        _ ->
            {noreply, State}
     end;
+handle_info({subscribed, Channel, Pid}, State) ->
+    case State#state.connection of
+       Pid ->
+           case maps:is_key(Channel, State#state.subscriptions) of
+               true -> eredis_sub:ack_message(Pid);
+               false ->
+                   ?WARNING_MSG("got subscription ack for unknown channel ~s",
+                                [Channel])
+           end;
+       _ ->
+           ok
+    end,
+    {noreply, State};
+handle_info({message, Channel, Data, Pid}, State) ->
+    case State#state.connection of
+       Pid ->
+           lists:foreach(
+             fun(Subscriber) ->
+                     erlang:send(Subscriber, {redis_message, Channel, Data})
+             end, maps:get(Channel, State#state.subscriptions, [])),
+           eredis_sub:ack_message(Pid);
+       _ ->
+           ok
+    end,
+    {noreply, State};
 handle_info(Info, State) ->
     ?WARNING_MSG("unexpected info = ~p", [Info]),
     {noreply, State}.
@@ -377,8 +441,7 @@ connect(#state{num = Num}) ->
                      redis_connect_timeout,
                      fun(I) when is_integer(I), I>0 -> I end,
                      1)),
-    try case eredis:start_link(Server, Port, DB, Pass,
-                              no_reconnect, ConnTimeout) of
+    try case do_connect(Num, Server, Port, Pass, DB, ConnTimeout) of
            {ok, Client} ->
                ?DEBUG("Connection #~p established to Redis at ~s:~p",
                       [Num, Server, Port]),
@@ -397,12 +460,24 @@ connect(#state{num = Num}) ->
            {error, Reason}
     end.
 
--spec call({atom(), atom()}, {q, redis_command()}, integer()) ->
+do_connect(1, Server, Port, Pass, _DB, _ConnTimeout) ->
+    %% First connection in the pool is always a subscriber
+    Res = eredis_sub:start_link(Server, Port, Pass, no_reconnect, infinity, drop),
+    case Res of
+       {ok, Pid} -> eredis_sub:controlling_process(Pid);
+       _ -> ok
+    end,
+    Res;
+do_connect(_, Server, Port, Pass, DB, ConnTimeout) ->
+    eredis:start_link(Server, Port, DB, Pass, no_reconnect, ConnTimeout).
+
+-spec call(pos_integer(), {q, redis_command()}, integer()) ->
                  {ok, redis_reply()} | redis_error();
-         ({atom(), atom()}, {qp, redis_pipeline()}, integer()) ->
+         (pos_integer(), {qp, redis_pipeline()}, integer()) ->
                  {ok, [redis_reply()]} | redis_error().
-call({Conn, Parent}, {F, Cmd}, Retries) ->
+call(I, {F, Cmd}, Retries) ->
     ?DEBUG("redis query: ~p", [Cmd]),
+    Conn = get_connection(I),
     Res = try eredis:F(Conn, Cmd, ?CALL_TIMEOUT) of
              {error, Reason} when is_atom(Reason) ->
                  try exit(whereis(Conn), kill) catch _:_ -> ok end,
@@ -414,8 +489,8 @@ call({Conn, Parent}, {F, Cmd}, Retries) ->
          end,
     case Res of
        {error, disconnected} when Retries > 0 ->
-           try ?GEN_SERVER:call(Parent, connect, ?CALL_TIMEOUT) of
-               ok -> call({Conn, Parent}, {F, Cmd}, Retries-1);
+           try ?GEN_SERVER:call(get_proc(I), connect, ?CALL_TIMEOUT) of
+               ok -> call(I, {F, Cmd}, Retries-1);
                {error, _} = Err -> Err
            catch exit:{Why, {?GEN_SERVER, call, _}} ->
                    Reason1 = case Why of
@@ -439,11 +514,9 @@ log_error(Cmd, Reason) ->
               "** response = ~s",
               [Cmd, format_error(Reason)]).
 
--spec get_worker() -> {atom(), atom()}.
-get_worker() ->
-    Time = p1_time_compat:system_time(),
-    I = erlang:phash2(Time, ejabberd_redis_sup:get_pool_size()) + 1,
-    {get_connection(I), get_proc(I)}.
+-spec get_rnd_id() -> pos_integer().
+get_rnd_id() ->
+    randoms:uniform(2, ejabberd_redis_sup:get_pool_size()).
 
 -spec get_result([{error, atom() | binary()} | {ok, iodata()}]) ->
                        {ok, [redis_reply()]} | {error, binary()}.
@@ -531,3 +604,13 @@ clean_queue(Q, CurrTime) ->
        true ->
            Q1
     end.
+
+re_subscribe(Pid, Subs) ->
+    case maps:keys(Subs) of
+       [] -> ok;
+       Channels -> eredis_subscribe(Pid, Channels)
+    end.
+
+eredis_subscribe(Pid, Channels) ->
+    ?DEBUG("redis query: ~p", [[<<"SUBSCRIBE">>|Channels]]),
+    eredis_sub:subscribe(Pid, Channels).
index 23330f87cde028969e2462b6547365d90d4df113..7e2953c11126ab473cc3286f6c323137b3e1f327 100644 (file)
@@ -136,7 +136,7 @@ get_pool_size() ->
     ejabberd_config:get_option(
       redis_pool_size,
       fun(N) when is_integer(N), N >= 1 -> N end,
-      ?DEFAULT_POOL_SIZE).
+      ?DEFAULT_POOL_SIZE) + 1.
 
 iolist_to_list(IOList) ->
     binary_to_list(iolist_to_binary(IOList)).
index 7474f9a6715dadd3a47fb1f854da227cff0e1f55..30654a03b953b56116eb964213b3c4490bc410ad 100644 (file)
@@ -49,7 +49,8 @@
         get_all_routes/0,
         is_my_route/1,
         is_my_host/1,
-        find_routes/0,
+        clean_cache/1,
+        config_reloaded/0,
         get_backend/0]).
 
 -export([start_link/0]).
 -callback register_route(binary(), binary(), local_hint(),
                         undefined | pos_integer(), pid()) -> ok | {error, term()}.
 -callback unregister_route(binary(), undefined | pos_integer(), pid()) -> ok | {error, term()}.
--callback find_routes(binary()) -> [#route{}].
--callback find_routes() -> [#route{}].
--callback host_of_route(binary()) -> {ok, binary()} | error.
--callback is_my_route(binary()) -> boolean().
--callback is_my_host(binary()) -> boolean().
--callback get_all_routes() -> [binary()].
+-callback find_routes(binary()) -> {ok, [#route{}]} | {error, any()}.
+-callback get_all_routes() -> {ok, [binary()]} | {error, any()}.
 
 -record(state, {}).
 
@@ -159,7 +156,8 @@ register_route(Domain, ServerHost, LocalHint, Pid) ->
            case Mod:register_route(LDomain, LServerHost, LocalHint,
                                    get_component_number(LDomain), Pid) of
                ok ->
-                   ?DEBUG("Route registered: ~s", [LDomain]);
+                   ?DEBUG("Route registered: ~s", [LDomain]),
+                   delete_cache(Mod, LDomain);
                {error, Err} ->
                    ?ERROR_MSG("Failed to register route ~s: ~p",
                               [LDomain, Err])
@@ -186,7 +184,8 @@ unregister_route(Domain, Pid) ->
            case Mod:unregister_route(
                   LDomain, get_component_number(LDomain), Pid) of
                ok ->
-                   ?DEBUG("Route unregistered: ~s", [LDomain]);
+                   ?DEBUG("Route unregistered: ~s", [LDomain]),
+                   delete_cache(Mod, LDomain);
                {error, Err} ->
                    ?ERROR_MSG("Failed to unregister route ~s: ~p",
                               [LDomain, Err])
@@ -199,15 +198,55 @@ unregister_routes(Domains) ->
                  end,
                  Domains).
 
--spec get_all_routes() -> [binary()].
-get_all_routes() ->
+-spec find_routes(binary()) -> [#route{}].
+find_routes(Domain) ->
     Mod = get_backend(),
-    Mod:get_all_routes().
+    case use_cache(Mod) of
+       true ->
+           case ets_cache:lookup(
+                  ?ROUTES_CACHE, {route, Domain},
+                  fun() ->
+                          case Mod:find_routes(Domain) of
+                              {ok, Rs} when Rs /= [] ->
+                                  {ok, Rs};
+                              _ ->
+                                  error
+                          end
+                  end) of
+               {ok, Rs} -> Rs;
+               error -> []
+           end;
+       false ->
+           case Mod:find_routes(Domain) of
+               {ok, Rs} -> Rs;
+               _ -> []
+           end
+    end.
 
--spec find_routes() -> [#route{}].
-find_routes() ->
+-spec get_all_routes() -> [binary()].
+get_all_routes() ->
     Mod = get_backend(),
-    Mod:find_routes().
+    case use_cache(Mod) of
+       true ->
+           case ets_cache:lookup(
+                  ?ROUTES_CACHE, routes,
+                  fun() ->
+                          case Mod:get_all_routes() of
+                              {ok, Rs} when Rs /= [] ->
+                                  {ok, Rs};
+                              _ ->
+                                  error
+                          end
+                  end) of
+               {ok, Rs} -> Rs;
+               error -> []
+           end;
+       false ->
+           case Mod:get_all_routes() of
+               {ok, Rs} -> Rs;
+               _ -> []
+           end
+    end.
 
 -spec host_of_route(binary()) -> binary().
 host_of_route(Domain) ->
@@ -215,10 +254,11 @@ host_of_route(Domain) ->
        error ->
            erlang:error({invalid_domain, Domain});
        LDomain ->
-           Mod = get_backend(),
-           case Mod:host_of_route(LDomain) of
-               {ok, ServerHost} -> ServerHost;
-               error -> erlang:error({unregistered_route, Domain})
+           case find_routes(LDomain) of
+               [#route{server_host = ServerHost}|_] ->
+                   ServerHost;
+               _ ->
+                   erlang:error({unregistered_route, Domain})
            end
     end.
 
@@ -228,8 +268,7 @@ is_my_route(Domain) ->
        error ->
            erlang:error({invalid_domain, Domain});
        LDomain ->
-           Mod = get_backend(),
-           Mod:is_my_route(LDomain)
+           lists:member(LDomain, get_all_routes())
     end.
 
 -spec is_my_host(binary()) -> boolean().
@@ -238,8 +277,10 @@ is_my_host(Domain) ->
        error ->
            erlang:error({invalid_domain, Domain});
        LDomain ->
-           Mod = get_backend(),
-           Mod:is_my_host(LDomain)
+           case find_routes(LDomain) of
+               [#route{server_host = LDomain}|_] -> true;
+               _ -> false
+           end
     end.
 
 -spec process_iq(iq()) -> any().
@@ -250,12 +291,20 @@ process_iq(#iq{to = To} = IQ) ->
            ejabberd_sm:process_iq(IQ)
     end.
 
+-spec config_reloaded() -> ok.
+config_reloaded() ->
+    Mod = get_backend(),
+    init_cache(Mod).
+
 %%====================================================================
 %% gen_server callbacks
 %%====================================================================
 init([]) ->
+    ejabberd_hooks:add(config_reloaded, ?MODULE, config_reloaded, 50),
     Mod = get_backend(),
+    init_cache(Mod),
     Mod:init(),
+    clean_cache(),
     {ok, #state{}}.
 
 handle_call(_Request, _From, State) ->
@@ -273,7 +322,7 @@ handle_info(Info, State) ->
     {noreply, State}.
 
 terminate(_Reason, _State) ->
-    ok.
+    ejabberd_hooks:add(config_reloaded, ?MODULE, config_reloaded, 50).
 
 code_change(_OldVsn, State, _Extra) ->
     {ok, State}.
@@ -290,8 +339,7 @@ do_route(OrigPacket) ->
        Packet ->
            To = xmpp:get_to(Packet),
            LDstDomain = To#jid.lserver,
-           Mod = get_backend(),
-           case Mod:find_routes(LDstDomain) of
+           case find_routes(LDstDomain) of
                [] ->
                    ejabberd_s2s:route(Packet);
                [Route] ->
@@ -366,6 +414,80 @@ get_backend() ->
             end,
     list_to_atom("ejabberd_router_" ++ atom_to_list(DBType)).
 
+-spec cache_nodes(module()) -> [node()].
+cache_nodes(Mod) ->
+    case erlang:function_exported(Mod, cache_nodes, 0) of
+       true -> Mod:cache_nodes();
+       false -> ejabberd_cluster:get_nodes()
+    end.
+
+-spec use_cache(module()) -> boolean().
+use_cache(Mod) ->
+    case erlang:function_exported(Mod, use_cache, 0) of
+       true -> Mod:use_cache();
+       false ->
+           ejabberd_config:get_option(
+             router_use_cache, opt_type(router_use_cache),
+             ejabberd_config:use_cache(global))
+    end.
+
+-spec delete_cache(module(), binary()) -> ok.
+delete_cache(Mod, Domain) ->
+    case use_cache(Mod) of
+       true ->
+           ets_cache:delete(?ROUTES_CACHE, {route, Domain}, cache_nodes(Mod)),
+           ets_cache:delete(?ROUTES_CACHE, routes, cache_nodes(Mod));
+       false ->
+           ok
+    end.
+
+-spec init_cache(module()) -> ok.
+init_cache(Mod) ->
+    case use_cache(Mod) of
+       true ->
+           ets_cache:new(?ROUTES_CACHE, cache_opts());
+       false ->
+           ets_cache:delete(?ROUTES_CACHE)
+    end.
+
+-spec cache_opts() -> [proplists:property()].
+cache_opts() ->
+    MaxSize = ejabberd_config:get_option(
+               router_cache_size,
+               opt_type(router_cache_size),
+               ejabberd_config:cache_size(global)),
+    CacheMissed = ejabberd_config:get_option(
+                   router_cache_missed,
+                   opt_type(router_cache_missed),
+                   ejabberd_config:cache_missed(global)),
+    LifeTime = case ejabberd_config:get_option(
+                     router_cache_life_time,
+                     opt_type(router_cache_life_time),
+                     ejabberd_config:cache_life_time(global)) of
+                  infinity -> infinity;
+                  I -> timer:seconds(I)
+              end,
+    [{max_size, MaxSize}, {cache_missed, CacheMissed}, {life_time, LifeTime}].
+
+-spec clean_cache(node()) -> ok.
+clean_cache(Node) ->
+    ets_cache:filter(
+      ?ROUTES_CACHE,
+      fun(_, error) ->
+             false;
+        (routes, _) ->
+             false;
+        ({route, _}, {ok, Rs}) ->
+             not lists:any(
+                   fun(#route{pid = Pid}) ->
+                           node(Pid) == Node
+                   end, Rs)
+      end).
+
+-spec clean_cache() -> ok.
+clean_cache() ->
+    ejabberd_cluster:eval_everywhere(?MODULE, clean_cache, [node()]).
+
 opt_type(domain_balancing) ->
     fun (random) -> random;
        (source) -> source;
@@ -376,6 +498,14 @@ opt_type(domain_balancing) ->
 opt_type(domain_balancing_component_number) ->
     fun (N) when is_integer(N), N > 1 -> N end;
 opt_type(router_db_type) -> fun(T) -> ejabberd_config:v_db(?MODULE, T) end;
+opt_type(O) when O == router_use_cache; O == router_cache_missed ->
+    fun(B) when is_boolean(B) -> B end;
+opt_type(O) when O == router_cache_size; O == router_cache_life_time ->
+    fun(I) when is_integer(I), I>0 -> I;
+       (unlimited) -> infinity;
+       (infinity) -> infinity
+    end;
 opt_type(_) ->
     [domain_balancing, domain_balancing_component_number,
-     router_db_type].
+     router_db_type, router_use_cache, router_cache_size,
+     router_cache_missed, router_cache_life_time].
index e3b550a751231de59baa03e64a1d91eb7854e898..d8664fee9def629c42bdfb557b69578070729872 100644 (file)
@@ -25,8 +25,7 @@
 
 %% API
 -export([init/0, register_route/5, unregister_route/3, find_routes/1,
-        host_of_route/1, is_my_route/1, is_my_host/1, get_all_routes/0,
-        find_routes/0]).
+        get_all_routes/0, use_cache/0]).
 %% gen_server callbacks
 -export([init/1, handle_cast/2, handle_call/3, handle_info/2,
         terminate/2, code_change/3, start_link/0]).
@@ -54,6 +53,9 @@ init() ->
 start_link() ->
     gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
 
+use_cache() ->
+    false.
+
 register_route(Domain, ServerHost, LocalHint, undefined, Pid) ->
     F = fun () ->
                mnesia:write(#route{domain = Domain,
@@ -124,37 +126,15 @@ unregister_route(Domain, _, Pid) ->
     transaction(F).
 
 find_routes(Domain) ->
-    mnesia:dirty_read(route, Domain).
-
-host_of_route(Domain) ->
-    case mnesia:dirty_read(route, Domain) of
-       [#route{server_host = ServerHost}|_] ->
-           {ok, ServerHost};
-       [] ->
-           error
-    end.
-
-is_my_route(Domain) ->
-    mnesia:dirty_read(route, Domain) /= [].
-
-is_my_host(Domain) ->
-    case mnesia:dirty_read(route, Domain) of
-       [#route{server_host = Host}|_] ->
-           Host == Domain;
-       [] ->
-           false
-    end.
+    {ok, mnesia:dirty_read(route, Domain)}.
 
 get_all_routes() ->
-    mnesia:dirty_select(
-      route,
-      ets:fun2ms(
-       fun(#route{domain = Domain, server_host = ServerHost})
-             when Domain /= ServerHost -> Domain
-       end)).
-
-find_routes() ->
-    ets:tab2list(route).
+    {ok, mnesia:dirty_select(
+          route,
+          ets:fun2ms(
+            fun(#route{domain = Domain, server_host = ServerHost})
+                  when Domain /= ServerHost -> Domain
+            end))}.
 
 %%%===================================================================
 %%% gen_server callbacks
@@ -227,7 +207,7 @@ transaction(F) ->
            ok;
        {aborted, Reason} ->
            ?ERROR_MSG("Mnesia transaction failed: ~p", [Reason]),
-           {error, Reason}
+           {error, db_failure}
     end.
 
 -spec update_tables() -> ok.
index 58c1fca4ad905790d9023b55ef3a5997743812e9..2b02a7595143cb7e46f2b0fe5c2a3d6f46e9d1f1 100644 (file)
 %%%-------------------------------------------------------------------
 -module(ejabberd_router_redis).
 -behaviour(ejabberd_router).
+-behaviour(gen_server).
 
 %% API
 -export([init/0, register_route/5, unregister_route/3, find_routes/1,
-        host_of_route/1, is_my_route/1, is_my_host/1, get_all_routes/0,
-        find_routes/0]).
+        get_all_routes/0]).
+%% gen_server callbacks
+-export([init/1, handle_cast/2, handle_call/3, handle_info/2,
+        terminate/2, code_change/3, start_link/0]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
 -include("ejabberd_router.hrl").
 
--define(ROUTES_KEY, "ejabberd:routes").
+-record(state, {}).
+
+-define(ROUTES_KEY, <<"ejabberd:routes">>).
 
 %%%===================================================================
 %%% API
 %%%===================================================================
 init() ->
-    clean_table().
+    Spec = {?MODULE, {?MODULE, start_link, []},
+           transient, 5000, worker, [?MODULE]},
+    case supervisor:start_child(ejabberd_backend_sup, Spec) of
+       {ok, _Pid} -> ok;
+       Err -> Err
+    end.
+
+-spec start_link() -> {ok, pid()} | {error, any()}.
+start_link() ->
+    gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
 
 register_route(Domain, ServerHost, LocalHint, _, Pid) ->
     DomKey = domain_key(Domain),
@@ -83,47 +97,48 @@ find_routes(Domain) ->
     DomKey = domain_key(Domain),
     case ejabberd_redis:hgetall(DomKey) of
        {ok, Vals} ->
-           decode_routes(Domain, Vals);
-       {error, _} ->
-           []
-    end.
-
-host_of_route(Domain) ->
-    DomKey = domain_key(Domain),
-    case ejabberd_redis:hgetall(DomKey) of
-       {ok, [{_Pid, Data}|_]} ->
-           {ServerHost, _} = binary_to_term(Data),
-           {ok, ServerHost};
+           {ok, decode_routes(Domain, Vals)};
        _ ->
-           error
-    end.
-
-is_my_route(Domain) ->
-    case ejabberd_redis:sismember(?ROUTES_KEY, Domain) of
-       {ok, Bool} ->
-           Bool;
-       {error, _} ->
-           false
+           {error, db_failure}
     end.
 
-is_my_host(Domain) ->
-    {ok, Domain} == host_of_route(Domain).
-
 get_all_routes() ->
     case ejabberd_redis:smembers(?ROUTES_KEY) of
        {ok, Routes} ->
-           Routes;
-       {error, _} ->
-           []
+           {ok, Routes};
+       _ ->
+           {error, db_failure}
     end.
 
-find_routes() ->
-    lists:flatmap(fun find_routes/1, get_all_routes()).
+%%%===================================================================
+%%% gen_server callbacks
+%%%===================================================================
+init([]) ->
+    clean_table(),
+    {ok, #state{}}.
+
+handle_call(_Request, _From, State) ->
+    Reply = ok,
+    {reply, Reply, State}.
+
+handle_cast(_Msg, State) ->
+    {noreply, State}.
+
+handle_info(Info, State) ->
+    ?ERROR_MSG("unexpected info: ~p", [Info]),
+    {noreply, State}.
+
+terminate(_Reason, _State) ->
+    ok.
+
+code_change(_OldVsn, State, _Extra) ->
+    {ok, State}.
 
 %%%===================================================================
 %%% Internal functions
 %%%===================================================================
 clean_table() ->
+    ?INFO_MSG("Cleaning Redis route entries...", []),
     lists:foreach(
       fun(#route{domain = Domain, pid = Pid}) when node(Pid) == node() ->
              unregister_route(Domain, undefined, Pid);
@@ -131,6 +146,20 @@ clean_table() ->
              ok
       end, find_routes()).
 
+find_routes() ->
+    case get_all_routes() of
+       {ok, Domains} ->
+           lists:flatmap(
+             fun(Domain) ->
+                     case find_routes(Domain) of
+                         {ok, Routes} -> Routes;
+                         {error, _} -> []
+                     end
+             end, Domains);
+       {error, _} ->
+           []
+    end.
+
 domain_key(Domain) ->
     <<"ejabberd:route:", Domain/binary>>.
 
index 0747d03960b0195d6b2d6f4d67bf0cf8b512c968..b354eb212cb48ea19eaf113dfe2a610a55cfe29c 100644 (file)
@@ -27,8 +27,7 @@
 
 %% API
 -export([init/0, register_route/5, unregister_route/3, find_routes/1,
-        host_of_route/1, is_my_route/1, is_my_host/1, get_all_routes/0,
-        find_routes/0]).
+        get_all_routes/0]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
@@ -64,80 +63,47 @@ register_route(Domain, ServerHost, LocalHint, _, Pid) ->
            ok;
        Err ->
            ?ERROR_MSG("failed to update 'route' table: ~p", [Err]),
-           {error, Err}
+           {error, db_failure}
     end.
 
 unregister_route(Domain, _, Pid) ->
     PidS = misc:encode_pid(Pid),
     Node = erlang:atom_to_binary(node(Pid), latin1),
-    ejabberd_sql:sql_query(
-      ?MYNAME,
-      ?SQL("delete from route where domain=%(Domain)s "
-          "and pid=%(PidS)s and node=%(Node)s")),
-    %% TODO: return meaningful error
-    ok.
-
-find_routes(Domain) ->
     case ejabberd_sql:sql_query(
           ?MYNAME,
-          ?SQL("select @(server_host)s, @(node)s, @(pid)s, @(local_hint)s "
-               "from route where domain=%(Domain)s")) of
-       {selected, Rows} ->
-           lists:flatmap(
-             fun(Row) ->
-                     row_to_route(Domain, Row)
-             end, Rows);
+          ?SQL("delete from route where domain=%(Domain)s "
+               "and pid=%(PidS)s and node=%(Node)s")) of
+       {updated, _} ->
+           ok;
        Err ->
-           ?ERROR_MSG("failed to select from 'route' table: ~p", [Err]),
-           {error, Err}
+           ?ERROR_MSG("failed to delete from 'route' table: ~p", [Err]),
+           {error, db_failure}
     end.
 
-host_of_route(Domain) ->
+find_routes(Domain) ->
     case ejabberd_sql:sql_query(
           ?MYNAME,
-          ?SQL("select @(server_host)s from route where domain=%(Domain)s")) of
-       {selected, [{ServerHost}|_]} ->
-           {ok, ServerHost};
-       {selected, []} ->
-           error;
+          ?SQL("select @(server_host)s, @(node)s, @(pid)s, @(local_hint)s "
+               "from route where domain=%(Domain)s")) of
+       {selected, Rows} ->
+           {ok, lists:flatmap(
+                  fun(Row) ->
+                          row_to_route(Domain, Row)
+                  end, Rows)};
        Err ->
            ?ERROR_MSG("failed to select from 'route' table: ~p", [Err]),
-           error
+           {error, db_failure}
     end.
 
-is_my_route(Domain) ->
-    case host_of_route(Domain) of
-       {ok, _} -> true;
-       _ -> false
-    end.
-
-is_my_host(Domain) ->
-    {ok, Domain} == host_of_route(Domain).
-
 get_all_routes() ->
     case ejabberd_sql:sql_query(
           ?MYNAME,
           ?SQL("select @(domain)s from route where domain <> server_host")) of
        {selected, Domains} ->
-           [Domain || {Domain} <- Domains];
-       Err ->
-           ?ERROR_MSG("failed to select from 'route' table: ~p", [Err]),
-           []
-    end.
-
-find_routes() ->
-    case ejabberd_sql:sql_query(
-          ?MYNAME,
-          ?SQL("select @(domain)s, @(server_host)s, @(node)s, @(pid)s, "
-               "@(local_hint)s from route")) of
-       {selected, Rows} ->
-           lists:flatmap(
-             fun({Domain, ServerHost, Node, Pid, LocalHint}) ->
-                     row_to_route(Domain, {ServerHost, Node, Pid, LocalHint})
-             end, Rows);
+           {ok, [Domain || {Domain} <- Domains]};
        Err ->
            ?ERROR_MSG("failed to select from 'route' table: ~p", [Err]),
-           []
+           {error, db_failure}
     end.
 
 %%%===================================================================
index 7c63292fc77c857d397a11f203dc1a9199b1a086..1cd911e11ea70760b5c44a8830403ef6e00fb810 100644 (file)
@@ -76,7 +76,9 @@
         c2s_handle_info/2,
         host_up/1,
         host_down/1,
-        make_sid/0
+        make_sid/0,
+        clean_cache/1,
+        config_reloaded/0
        ]).
 
 -export([init/1, handle_call/3, handle_cast/2,
 -include("ejabberd_sm.hrl").
 
 -callback init() -> ok | {error, any()}.
--callback set_session(#session{}) -> ok.
--callback delete_session(binary(), binary(), binary(), sid()) ->
-    {ok, #session{}} | {error, notfound}.
+-callback set_session(#session{}) -> ok | {error, any()}.
+-callback delete_session(#session{}) -> ok | {error, any()}.
 -callback get_sessions() -> [#session{}].
 -callback get_sessions(binary()) -> [#session{}].
--callback get_sessions(binary(), binary()) -> [#session{}].
--callback get_sessions(binary(), binary(), binary()) -> [#session{}].
+-callback get_sessions(binary(), binary()) -> {ok, [#session{}]} | {error, any()}.
+-callback use_cache(binary()) -> boolean().
+-callback cache_nodes(binary()) -> [node()].
+
+-optional_callbacks([use_cache/1, cache_nodes/1]).
 
 -record(state, {}).
 
@@ -158,9 +162,12 @@ close_session(SID, User, Server, Resource) ->
     LServer = jid:nameprep(Server),
     LResource = jid:resourceprep(Resource),
     Mod = get_sm_backend(LServer),
-    Info = case Mod:delete_session(LUser, LServer, LResource, SID) of
-              {ok, #session{info = I}} -> I;
-              {error, notfound} -> []
+    Info = case get_sessions(Mod, LUser, LServer, LResource) of
+              [#session{info = I} = Session|_] ->
+                  delete_session(Mod, Session),
+                  I;
+              [] ->
+                  []
           end,
     JID = jid:make(User, Server, Resource),
     ejabberd_hooks:run(sm_remove_connection_hook,
@@ -196,14 +203,14 @@ get_user_resources(User, Server) ->
     LUser = jid:nodeprep(User),
     LServer = jid:nameprep(Server),
     Mod = get_sm_backend(LServer),
-    Ss = online(Mod:get_sessions(LUser, LServer)),
+    Ss = online(get_sessions(Mod, LUser, LServer)),
     [element(3, S#session.usr) || S <- clean_session_list(Ss)].
 
 -spec get_user_present_resources(binary(), binary()) -> [tuple()].
 
 get_user_present_resources(LUser, LServer) ->
     Mod = get_sm_backend(LServer),
-    Ss = online(Mod:get_sessions(LUser, LServer)),
+    Ss = online(get_sessions(Mod, LUser, LServer)),
     [{S#session.priority, element(3, S#session.usr)}
      || S <- clean_session_list(Ss), is_integer(S#session.priority)].
 
@@ -214,7 +221,7 @@ get_user_ip(User, Server, Resource) ->
     LServer = jid:nameprep(Server),
     LResource = jid:resourceprep(Resource),
     Mod = get_sm_backend(LServer),
-    case online(Mod:get_sessions(LUser, LServer, LResource)) of
+    case online(get_sessions(Mod, LUser, LServer, LResource)) of
        [] ->
            undefined;
        Ss ->
@@ -227,7 +234,7 @@ get_user_info(User, Server) ->
     LUser = jid:nodeprep(User),
     LServer = jid:nameprep(Server),
     Mod = get_sm_backend(LServer),
-    Ss = online(Mod:get_sessions(LUser, LServer)),
+    Ss = online(get_sessions(Mod, LUser, LServer)),
     [{LResource, [{node, node(Pid)}|Info]}
      || #session{usr = {_, _, LResource},
                 info = Info,
@@ -240,7 +247,7 @@ get_user_info(User, Server, Resource) ->
     LServer = jid:nameprep(Server),
     LResource = jid:resourceprep(Resource),
     Mod = get_sm_backend(LServer),
-    case online(Mod:get_sessions(LUser, LServer, LResource)) of
+    case online(get_sessions(Mod, LUser, LServer, LResource)) of
        [] ->
            offline;
        Ss ->
@@ -288,7 +295,7 @@ get_session_pid(User, Server, Resource) ->
     LServer = jid:nameprep(Server),
     LResource = jid:resourceprep(Resource),
     Mod = get_sm_backend(LServer),
-    case online(Mod:get_sessions(LUser, LServer, LResource)) of
+    case online(get_sessions(Mod, LUser, LServer, LResource)) of
        [#session{sid = {_, Pid}}] -> Pid;
        _ -> none
     end.
@@ -309,7 +316,7 @@ get_offline_info(Time, User, Server, Resource) ->
     LServer = jid:nameprep(Server),
     LResource = jid:resourceprep(Resource),
     Mod = get_sm_backend(LServer),
-    case Mod:get_sessions(LUser, LServer, LResource) of
+    case get_sessions(Mod, LUser, LServer, LResource) of
        [#session{sid = {Time, _}, info = Info}] ->
            case proplists:get_bool(offline, Info) of
                true ->
@@ -326,7 +333,7 @@ get_offline_info(Time, User, Server, Resource) ->
 dirty_get_sessions_list() ->
     lists:flatmap(
       fun(Mod) ->
-             [S#session.usr || S <- online(Mod:get_sessions())]
+             [S#session.usr || S <- online(get_sessions(Mod))]
       end, get_sm_backends()).
 
 -spec dirty_get_my_sessions_list() -> [#session{}].
@@ -334,7 +341,7 @@ dirty_get_sessions_list() ->
 dirty_get_my_sessions_list() ->
     lists:flatmap(
       fun(Mod) ->
-             [S || S <- online(Mod:get_sessions()),
+             [S || S <- online(get_sessions(Mod)),
                    node(element(2, S#session.sid)) == node()]
       end, get_sm_backends()).
 
@@ -343,14 +350,14 @@ dirty_get_my_sessions_list() ->
 get_vh_session_list(Server) ->
     LServer = jid:nameprep(Server),
     Mod = get_sm_backend(LServer),
-    [S#session.usr || S <- online(Mod:get_sessions(LServer))].
+    [S#session.usr || S <- online(get_sessions(Mod, LServer))].
 
 -spec get_all_pids() -> [pid()].
 
 get_all_pids() ->
     lists:flatmap(
       fun(Mod) ->
-             [element(2, S#session.sid) || S <- online(Mod:get_sessions())]
+             [element(2, S#session.sid) || S <- online(get_sessions(Mod))]
       end, get_sm_backends()).
 
 -spec get_vh_session_number(binary()) -> non_neg_integer().
@@ -358,7 +365,7 @@ get_all_pids() ->
 get_vh_session_number(Server) ->
     LServer = jid:nameprep(Server),
     Mod = get_sm_backend(LServer),
-    length(online(Mod:get_sessions(LServer))).
+    length(online(get_sessions(Mod, LServer))).
 
 -spec register_iq_handler(binary(), binary(), atom(), atom(), list()) -> ok.
 
@@ -387,16 +394,23 @@ c2s_handle_info(#{lang := Lang} = State, {exit, Reason}) ->
 c2s_handle_info(State, _) ->
     State.
 
+-spec config_reloaded() -> ok.
+config_reloaded() ->
+    init_cache().
+
 %%====================================================================
 %% gen_server callbacks
 %%====================================================================
 
 init([]) ->
     process_flag(trap_exit, true),
+    init_cache(),
     lists:foreach(fun(Mod) -> Mod:init() end, get_sm_backends()),
+    clean_cache(),
     ets:new(sm_iqtable, [named_table, public, {read_concurrency, true}]),
     ejabberd_hooks:add(host_up, ?MODULE, host_up, 50),
     ejabberd_hooks:add(host_down, ?MODULE, host_down, 60),
+    ejabberd_hooks:add(config_reloaded, ?MODULE, config_reloaded, 50),
     lists:foreach(fun host_up/1, ?MYHOSTS),
     ejabberd_commands:register_commands(get_commands_spec()),
     {ok, #state{}}.
@@ -432,6 +446,7 @@ terminate(_Reason, _State) ->
     lists:foreach(fun host_down/1, ?MYHOSTS),
     ejabberd_hooks:delete(host_up, ?MODULE, host_up, 50),
     ejabberd_hooks:delete(host_down, ?MODULE, host_down, 60),
+    ejabberd_hooks:delete(config_reloaded, ?MODULE, config_reloaded, 50),
     ejabberd_commands:unregister_commands(get_commands_spec()),
     ok.
 
@@ -460,7 +475,7 @@ host_down(Host) ->
              ejabberd_c2s:send(Pid, xmpp:serr_system_shutdown());
         (_) ->
              ok
-      end, Mod:get_sessions(Host)),
+      end, get_sessions(Mod, Host)),
     ejabberd_hooks:delete(c2s_handle_info, Host,
                          ejabberd_sm, c2s_handle_info, 50),
     ejabberd_hooks:delete(roster_in_subscription, Host,
@@ -472,7 +487,7 @@ host_down(Host) ->
     ejabberd_c2s:host_down(Host).
 
 -spec set_session(sid(), binary(), binary(), binary(),
-                  prio(), info()) -> ok.
+                  prio(), info()) -> ok | {error, any()}.
 
 set_session(SID, User, Server, Resource, Priority, Info) ->
     LUser = jid:nodeprep(User),
@@ -481,8 +496,69 @@ set_session(SID, User, Server, Resource, Priority, Info) ->
     US = {LUser, LServer},
     USR = {LUser, LServer, LResource},
     Mod = get_sm_backend(LServer),
-    Mod:set_session(#session{sid = SID, usr = USR, us = US,
-                            priority = Priority, info = Info}).
+    case Mod:set_session(#session{sid = SID, usr = USR, us = US,
+                                 priority = Priority, info = Info}) of
+       ok ->
+           case use_cache(Mod, LServer) of
+               true ->
+                   ets_cache:delete(?SM_CACHE, {LUser, LServer},
+                                    cache_nodes(Mod, LServer));
+               false ->
+                   ok
+           end;
+       {error, _} = Err ->
+           Err
+    end.
+
+-spec get_sessions(module()) -> [#session{}].
+get_sessions(Mod) ->
+    Mod:get_sessions().
+
+-spec get_sessions(module(), binary()) -> [#session{}].
+get_sessions(Mod, LServer) ->
+    Mod:get_sessions(LServer).
+
+-spec get_sessions(module(), binary(), binary()) -> [#session{}].
+get_sessions(Mod, LUser, LServer) ->
+    case use_cache(Mod, LServer) of
+       true ->
+           case ets_cache:lookup(
+                  ?SM_CACHE, {LUser, LServer},
+                  fun() ->
+                          case Mod:get_sessions(LUser, LServer) of
+                              {ok, Ss} when Ss /= [] ->
+                                  {ok, Ss};
+                              _ ->
+                                  error
+                          end
+                  end) of
+               {ok, Sessions} ->
+                   Sessions;
+               error ->
+                   []
+           end;
+       false ->
+           case Mod:get_sessions(LUser, LServer) of
+               {ok, Ss} -> Ss;
+               _ -> []
+           end
+    end.
+
+-spec get_sessions(module(), binary(), binary(), binary()) -> [#session{}].
+get_sessions(Mod, LUser, LServer, LResource) ->
+    Sessions = get_sessions(Mod, LUser, LServer),
+    [S || S <- Sessions, element(3, S#session.usr) == LResource].
+
+-spec delete_session(module(), #session{}) -> ok.
+delete_session(Mod, #session{usr = {LUser, LServer, _}} = Session) ->
+    Mod:delete_session(Session),
+    case use_cache(Mod, LServer) of
+       true ->
+           ets_cache:delete(?SM_CACHE, {LUser, LServer},
+                            cache_nodes(Mod, LServer));
+       false ->
+           ok
+    end.
 
 -spec online([#session{}]) -> [#session{}].
 
@@ -505,7 +581,7 @@ do_route(To, Term) ->
     ?DEBUG("broadcasting ~p to ~s", [Term, jid:encode(To)]),
     {U, S, R} = jid:tolower(To),
     Mod = get_sm_backend(S),
-    case online(Mod:get_sessions(U, S, R)) of
+    case online(get_sessions(Mod, U, S, R)) of
        [] ->
            ?DEBUG("dropping broadcast to unavailable resourse: ~p", [Term]);
        Ss ->
@@ -541,7 +617,7 @@ do_route(#presence{from = From, to = To, type = T, status = Status} = Packet)
                      ejabberd_c2s:route(Pid, {route, Packet1});
                 (_) ->
                      ok
-             end, online(Mod:get_sessions(LUser, LServer)));
+             end, online(get_sessions(Mod, LUser, LServer)));
        false ->
            ok
     end;
@@ -570,7 +646,7 @@ do_route(Packet) ->
     To = xmpp:get_to(Packet),
     {LUser, LServer, LResource} = jid:tolower(To),
     Mod = get_sm_backend(LServer),
-    case online(Mod:get_sessions(LUser, LServer, LResource)) of
+    case online(get_sessions(Mod, LUser, LServer, LResource)) of
        [] ->
            case Packet of
                #message{type = T} when T == chat; T == normal ->
@@ -618,7 +694,7 @@ route_message(#message{to = To, type = Type} = Packet) ->
                                          (P >= 0) and (Type == headline) ->
                                LResource = jid:resourceprep(R),
                                Mod = get_sm_backend(LServer),
-                               case online(Mod:get_sessions(LUser, LServer,
+                               case online(get_sessions(Mod, LUser, LServer,
                                                             LResource)) of
                                  [] ->
                                      ok; % Race condition
@@ -689,10 +765,10 @@ check_for_sessions_to_replace(User, Server, Resource) ->
 -spec check_existing_resources(binary(), binary(), binary()) -> ok.
 check_existing_resources(LUser, LServer, LResource) ->
     Mod = get_sm_backend(LServer),
-    Ss = Mod:get_sessions(LUser, LServer, LResource),
+    Ss = get_sessions(Mod, LUser, LServer, LResource),
     {OnlineSs, OfflineSs} = lists:partition(fun is_online/1, Ss),
-    lists:foreach(fun(#session{sid = S}) ->
-                         Mod:delete_session(LUser, LServer, LResource, S)
+    lists:foreach(fun(S) ->
+                         delete_session(Mod, S)
                  end, OfflineSs),
     if OnlineSs == [] -> ok;
        true ->
@@ -716,12 +792,12 @@ get_resource_sessions(User, Server, Resource) ->
     LServer = jid:nameprep(Server),
     LResource = jid:resourceprep(Resource),
     Mod = get_sm_backend(LServer),
-    [S#session.sid || S <- online(Mod:get_sessions(LUser, LServer, LResource))].
+    [S#session.sid || S <- online(get_sessions(Mod, LUser, LServer, LResource))].
 
 -spec check_max_sessions(binary(), binary()) -> ok | replaced.
 check_max_sessions(LUser, LServer) ->
     Mod = get_sm_backend(LServer),
-    Ss = Mod:get_sessions(LUser, LServer),
+    Ss = get_sessions(Mod, LUser, LServer),
     {OnlineSs, OfflineSs} = lists:partition(fun is_online/1, Ss),
     MaxSessions = get_max_user_sessions(LUser, LServer),
     if length(OnlineSs) =< MaxSessions -> ok;
@@ -731,8 +807,7 @@ check_max_sessions(LUser, LServer) ->
     end,
     if length(OfflineSs) =< MaxSessions -> ok;
        true ->
-           #session{sid = SID, usr = {_, _, R}} = lists:min(OfflineSs),
-           Mod:delete_session(LUser, LServer, R, SID)
+           delete_session(Mod, lists:min(OfflineSs))
     end.
 
 %% Get the user_max_session setting
@@ -779,7 +854,7 @@ process_iq(#iq{}) ->
 
 force_update_presence({LUser, LServer}) ->
     Mod = get_sm_backend(LServer),
-    Ss = online(Mod:get_sessions(LUser, LServer)),
+    Ss = online(get_sessions(Mod, LUser, LServer)),
     lists:foreach(fun (#session{sid = {_, Pid}}) ->
                          ejabberd_c2s:route(Pid, force_update_presence)
                  end,
@@ -811,6 +886,80 @@ get_vh_by_backend(Mod) ->
              get_sm_backend(Host) == Mod
       end, ?MYHOSTS).
 
+%%--------------------------------------------------------------------
+%%% Cache stuff
+%%--------------------------------------------------------------------
+-spec init_cache() -> ok.
+init_cache() ->
+    case use_cache() of
+       true ->
+           ets_cache:new(?SM_CACHE, cache_opts());
+       false ->
+           ets_cache:delete(?SM_CACHE)
+    end.
+
+-spec cache_opts() -> [proplists:property()].
+cache_opts() ->
+    MaxSize = ejabberd_config:get_option(
+               sm_cache_size,
+               opt_type(sm_cache_size),
+               ejabberd_config:cache_size(global)),
+    CacheMissed = ejabberd_config:get_option(
+                   sm_cache_missed,
+                   opt_type(sm_cache_missed),
+                   ejabberd_config:cache_missed(global)),
+    LifeTime = case ejabberd_config:get_option(
+                     sm_cache_life_time,
+                     opt_type(sm_cache_life_time),
+                     ejabberd_config:cache_life_time(global)) of
+                  infinity -> infinity;
+                  I -> timer:seconds(I)
+              end,
+    [{max_size, MaxSize}, {cache_missed, CacheMissed}, {life_time, LifeTime}].
+
+-spec clean_cache(node()) -> ok.
+clean_cache(Node) ->
+    ets_cache:filter(
+      ?SM_CACHE,
+      fun(_, error) ->
+             false;
+        (_, {ok, Ss}) ->
+             not lists:any(
+                   fun(#session{sid = {_, Pid}}) ->
+                           node(Pid) == Node
+                   end, Ss)
+      end).
+
+-spec clean_cache() -> ok.
+clean_cache() ->
+    ejabberd_cluster:eval_everywhere(?MODULE, clean_cache, [node()]).
+
+-spec use_cache(module(), binary()) -> boolean().
+use_cache(Mod, LServer) ->
+    case erlang:function_exported(Mod, use_cache, 1) of
+       true -> Mod:use_cache(LServer);
+       false ->
+           ejabberd_config:get_option(
+             {sm_use_cache, LServer},
+             ejabberd_sm:opt_type(sm_use_cache),
+             ejabberd_config:use_cache(LServer))
+    end.
+
+-spec use_cache() -> boolean().
+use_cache() ->
+    lists:any(
+      fun(Host) ->
+             Mod = get_sm_backend(Host),
+             use_cache(Mod, Host)
+      end, ?MYHOSTS).
+
+-spec cache_nodes(module(), binary()) -> [node()].
+cache_nodes(Mod, LServer) ->
+    case erlang:function_exported(Mod, cache_nodes, 1) of
+       true -> Mod:cache_nodes(LServer);
+       false -> ejabberd_cluster:get_nodes()
+    end.
+
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 %%% ejabberd commands
 
@@ -869,4 +1018,13 @@ make_sid() ->
     {p1_time_compat:unique_timestamp(), self()}.
 
 opt_type(sm_db_type) -> fun(T) -> ejabberd_config:v_db(?MODULE, T) end;
-opt_type(_) -> [sm_db_type].
+opt_type(O) when O == sm_use_cache; O == sm_cache_missed ->
+    fun(B) when is_boolean(B) -> B end;
+opt_type(O) when O == sm_cache_size; O == sm_cache_life_time ->
+    fun(I) when is_integer(I), I>0 -> I;
+       (unlimited) -> infinity;
+       (infinity) -> infinity
+    end;
+opt_type(_) ->
+    [sm_db_type, sm_use_cache, sm_cache_size, sm_cache_missed,
+     sm_cache_life_time].
index 35fc42e9d87d4fa0a5d7c17e416ad8ba4d88e68a..99e53fa12e704dedd9f193b6e2d4cff54aa00280 100644 (file)
 
 %% API
 -export([init/0,
+        use_cache/1,
         set_session/1,
-        delete_session/4,
+        delete_session/1,
         get_sessions/0,
         get_sessions/1,
-        get_sessions/2,
-        get_sessions/3]).
+        get_sessions/2]).
 
 %% gen_server callbacks
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
@@ -62,20 +62,17 @@ init() ->
 start_link() ->
     gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
 
+-spec use_cache(binary()) -> boolean().
+use_cache(_LServer) ->
+    false.
+
 -spec set_session(#session{}) -> ok.
 set_session(Session) ->
     mnesia:dirty_write(Session).
 
--spec delete_session(binary(), binary(), binary(), sid()) ->
-                           {ok, #session{}} | {error, notfound}.
-delete_session(_LUser, _LServer, _LResource, SID) ->
-    case mnesia:dirty_read(session, SID) of
-       [Session] ->
-           mnesia:dirty_delete(session, SID),
-           {ok, Session};
-       [] ->
-           {error, notfound}
-    end.
+-spec delete_session(#session{}) -> ok.
+delete_session(#session{sid = SID}) ->
+    mnesia:dirty_delete(session, SID).
 
 -spec get_sessions() -> [#session{}].
 get_sessions() ->
@@ -87,13 +84,9 @@ get_sessions(LServer) ->
                        [{#session{usr = '$1', _ = '_'},
                          [{'==', {element, 2, '$1'}, LServer}], ['$_']}]).
 
--spec get_sessions(binary(), binary()) -> [#session{}].
+-spec get_sessions(binary(), binary()) -> {ok, [#session{}]}.
 get_sessions(LUser, LServer) ->
-    mnesia:dirty_index_read(session, {LUser, LServer}, #session.us).
-
--spec get_sessions(binary(), binary(), binary()) -> [#session{}].
-get_sessions(LUser, LServer, LResource) ->
-    mnesia:dirty_index_read(session, {LUser, LServer, LResource}, #session.usr).
+    {ok, mnesia:dirty_index_read(session, {LUser, LServer}, #session.us)}.
 
 %%%===================================================================
 %%% gen_server callbacks
index 4854bf8a6ec0fdb8975451c12da230782c77ac04..4314f8d278ee872976593e02a88bdae21a44d06c 100644 (file)
 %%%----------------------------------------------------------------------
 
 -module(ejabberd_sm_redis).
-
+-ifndef(GEN_SERVER).
+-define(GEN_SERVER, p1_server).
+-endif.
+-behaviour(?GEN_SERVER).
 -behaviour(ejabberd_config).
 
 -behaviour(ejabberd_sm).
 
--export([init/0, set_session/1, delete_session/4,
+-export([init/0, set_session/1, delete_session/1,
         get_sessions/0, get_sessions/1, get_sessions/2,
-        get_sessions/3, opt_type/1]).
+        cache_nodes/1, opt_type/1]).
+%% gen_server callbacks
+-export([init/1, handle_cast/2, handle_call/3, handle_info/2,
+        terminate/2, code_change/3, start_link/0]).
 
 -include("ejabberd.hrl").
 -include("ejabberd_sm.hrl").
 -include("logger.hrl").
 
+-define(SM_KEY, <<"ejabberd:sm">>).
+-record(state, {}).
+
 %%%===================================================================
 %%% API
 %%%===================================================================
 -spec init() -> ok | {error, any()}.
 init() ->
-    clean_table().
+    Spec = {?MODULE, {?MODULE, start_link, []},
+           transient, 5000, worker, [?MODULE]},
+    case supervisor:start_child(ejabberd_backend_sup, Spec) of
+       {ok, _Pid} -> ok;
+       Err -> Err
+    end.
+
+-spec start_link() -> {ok, pid()} | {error, any()}.
+start_link() ->
+    ?GEN_SERVER:start_link({local, ?MODULE}, ?MODULE, [], []).
 
--spec set_session(#session{}) -> ok.
+-spec cache_nodes(binary()) -> [node()].
+cache_nodes(_LServer) ->
+    [node()].
+
+-spec set_session(#session{}) -> ok | {error, ejabberd_redis:error_reason()}.
 set_session(Session) ->
     T = term_to_binary(Session),
     USKey = us_to_key(Session#session.us),
     SIDKey = sid_to_key(Session#session.sid),
     ServKey = server_to_key(element(2, Session#session.us)),
     USSIDKey = us_sid_to_key(Session#session.us, Session#session.sid),
-    ejabberd_redis:multi(
-      fun() ->
-             ejabberd_redis:hset(USKey, SIDKey, T),
-             ejabberd_redis:hset(ServKey, USSIDKey, T)
-      end),
-    ok.
+    case ejabberd_redis:multi(
+          fun() ->
+                  ejabberd_redis:hset(USKey, SIDKey, T),
+                  ejabberd_redis:hset(ServKey, USSIDKey, T),
+                  ejabberd_redis:publish(
+                    ?SM_KEY, term_to_binary({delete, Session#session.us}))
+          end) of
+       {ok, _} ->
+           ok;
+       Err ->
+           Err
+    end.
 
--spec delete_session(binary(), binary(), binary(), sid()) ->
-                           {ok, #session{}} | {error, notfound}.
-delete_session(LUser, LServer, _LResource, SID) ->
-    USKey = us_to_key({LUser, LServer}),
-    case ejabberd_redis:hgetall(USKey) of
-       {ok, Vals} ->
-           Ss = decode_session_list(Vals),
-           case lists:keyfind(SID, #session.sid, Ss) of
-               false ->
-                   {error, notfound};
-               Session ->
-                   SIDKey = sid_to_key(SID),
-                   ServKey = server_to_key(element(2, Session#session.us)),
-                   USSIDKey = us_sid_to_key(Session#session.us, SID),
-                   ejabberd_redis:multi(
-                     fun() ->
-                             ejabberd_redis:hdel(USKey, [SIDKey]),
-                             ejabberd_redis:hdel(ServKey, [USSIDKey])
-                     end),
-                   {ok, Session}
-           end;
-       {error, _} ->
-           {error, notfound}
+-spec delete_session(#session{}) -> ok | {error, ejabberd_redis:error_reason()}.
+delete_session(#session{sid = SID} = Session) ->
+    USKey = us_to_key(Session#session.us),
+    SIDKey = sid_to_key(SID),
+    ServKey = server_to_key(element(2, Session#session.us)),
+    USSIDKey = us_sid_to_key(Session#session.us, SID),
+    case ejabberd_redis:multi(
+          fun() ->
+                  ejabberd_redis:hdel(USKey, [SIDKey]),
+                  ejabberd_redis:hdel(ServKey, [USSIDKey]),
+                  ejabberd_redis:publish(
+                    ?SM_KEY,
+                    term_to_binary({delete, Session#session.us}))
+          end) of
+       {ok, _} ->
+           ok;
+       Err ->
+           Err
     end.
 
 -spec get_sessions() -> [#session{}].
@@ -99,27 +122,49 @@ get_sessions(LServer) ->
            []
     end.
 
--spec get_sessions(binary(), binary()) -> [#session{}].
+-spec get_sessions(binary(), binary()) -> {ok, [#session{}]} |
+                                         {error, ejabberd_redis:error_reason()}.
 get_sessions(LUser, LServer) ->
     USKey = us_to_key({LUser, LServer}),
     case ejabberd_redis:hgetall(USKey) of
        {ok, Vals} ->
-           decode_session_list(Vals);
-       {error, _} ->
-           []
+           {ok, decode_session_list(Vals)};
+       Err ->
+           Err
     end.
 
--spec get_sessions(binary(), binary(), binary()) ->
-    [#session{}].
-get_sessions(LUser, LServer, LResource) ->
-    USKey = us_to_key({LUser, LServer}),
-    case ejabberd_redis:hgetall(USKey) of
-       {ok, Vals} ->
-           [S || S <- decode_session_list(Vals),
-                 element(3, S#session.usr) == LResource];
-       {error, _} ->
-           []
-    end.
+%%%===================================================================
+%%% gen_server callbacks
+%%%===================================================================
+init([]) ->
+    ejabberd_redis:subscribe([?SM_KEY]),
+    clean_table(),
+    {ok, #state{}}.
+
+handle_call(_Request, _From, State) ->
+    Reply = ok,
+    {reply, Reply, State}.
+
+handle_cast(_Msg, State) ->
+    {noreply, State}.
+
+handle_info({redis_message, ?SM_KEY, Data}, State) ->
+    case binary_to_term(Data) of
+       {delete, Key} ->
+           ets_cache:delete(?SM_CACHE, Key);
+       Msg ->
+           ?WARNING_MSG("unexpected redis message: ~p", [Msg])
+    end,
+    {noreply, State};
+handle_info(Info, State) ->
+    ?ERROR_MSG("unexpected info: ~p", [Info]),
+    {noreply, State}.
+
+terminate(_Reason, _State) ->
+    ok.
+
+code_change(_OldVsn, State, _Extra) ->
+    {ok, State}.
 
 %%%===================================================================
 %%% Internal functions
index 04f03f750c70711f329fb758b8afdcaf2e0746b4..3fda1f5e522604bc22df23b4a4e24e7f8b5c1403 100644 (file)
 %% API
 -export([init/0,
         set_session/1,
-        delete_session/4,
+        delete_session/1,
         get_sessions/0,
         get_sessions/1,
-        get_sessions/2,
-        get_sessions/3]).
+        get_sessions/2]).
 
 -include("ejabberd.hrl").
 -include("ejabberd_sm.hrl").
@@ -81,30 +80,21 @@ set_session(#session{sid = {Now, Pid}, usr = {U, LServer, R},
        ok ->
            ok;
        Err ->
-           ?ERROR_MSG("failed to update 'sm' table: ~p", [Err])
+           ?ERROR_MSG("failed to update 'sm' table: ~p", [Err]),
+           {error, db_failure}
     end.
 
-delete_session(_LUser, LServer, _LResource, {Now, Pid}) ->
+delete_session(#session{usr = {_, LServer, _}, sid = {Now, Pid}}) ->
     TS = now_to_timestamp(Now),
     PidS = list_to_binary(erlang:pid_to_list(Pid)),
     case ejabberd_sql:sql_query(
           LServer,
-          ?SQL("select @(usec)d, @(pid)s, @(node)s, @(username)s,"
-                " @(resource)s, @(priority)s, @(info)s "
-                "from sm where usec=%(TS)d and pid=%(PidS)s")) of
-       {selected, [Row]} ->
-            ejabberd_sql:sql_query(
-              LServer,
-              ?SQL("delete from sm"
-                   " where usec=%(TS)d and pid=%(PidS)s")),
-           try {ok, row_to_session(LServer, Row)}
-           catch _:{bad_node, _} -> {error, notfound}
-           end;
-       {selected, []} ->
-           {error, notfound};
+          ?SQL("delete from sm where usec=%(TS)d and pid=%(PidS)s")) of
+       {updated, _} ->
+           ok;
        Err ->
            ?ERROR_MSG("failed to delete from 'sm' table: ~p", [Err]),
-           {error, notfound}
+           {error, db_failure}
     end.
 
 get_sessions() ->
@@ -137,33 +127,15 @@ get_sessions(LUser, LServer) ->
                 " @(resource)s, @(priority)s, @(info)s from sm"
                 " where username=%(LUser)s")) of
        {selected, Rows} ->
-           lists:flatmap(
-             fun(Row) ->
-                     try [row_to_session(LServer, Row)]
-                     catch _:{bad_node, _} -> []
-                     end
-             end, Rows);
-       Err ->
-           ?ERROR_MSG("failed to select from 'sm' table: ~p", [Err]),
-           []
-    end.
-
-get_sessions(LUser, LServer, LResource) ->
-    case ejabberd_sql:sql_query(
-          LServer,
-           ?SQL("select @(usec)d, @(pid)s, @(node)s, @(username)s,"
-                " @(resource)s, @(priority)s, @(info)s from sm"
-                " where username=%(LUser)s and resource=%(LResource)s")) of
-       {selected, Rows} ->
-           lists:flatmap(
-             fun(Row) ->
-                     try [row_to_session(LServer, Row)]
-                     catch _:{bad_node, _} -> []
-                     end
-             end, Rows);
+           {ok, lists:flatmap(
+                  fun(Row) ->
+                          try [row_to_session(LServer, Row)]
+                          catch _:{bad_node, _} -> []
+                          end
+                  end, Rows)};
        Err ->
            ?ERROR_MSG("failed to select from 'sm' table: ~p", [Err]),
-           []
+           {error, db_failure}
     end.
 
 %%%===================================================================
index 57c81953747b586533aa2b0b45c464ac584a9064..c2bf7600e829e328cbcef489ef5340226a5ad6ec 100644 (file)
@@ -34,7 +34,7 @@
 
 -export([start_link/0]).
 -export([start/2, stop/1, reload/3, process/2, open_session/2,
-        close_session/1, find_session/1]).
+        close_session/1, find_session/1, clean_cache/1]).
 
 -export([depends/2, mod_opt_type/1]).
 
 -include("bosh.hrl").
 
 -callback init() -> any().
--callback open_session(binary(), pid()) -> any().
--callback close_session(binary()) -> any().
--callback find_session(binary()) -> {ok, pid()} | error.
+-callback open_session(binary(), pid()) -> ok | {error, any()}.
+-callback close_session(binary()) -> ok | {error, any()}.
+-callback find_session(binary()) -> {ok, pid()} | {error, any()}.
+-callback use_cache() -> boolean().
+-callback cache_nodes() -> [node()].
+
+-optional_callbacks([use_cache/0, cache_nodes/0]).
 
 %%%----------------------------------------------------------------------
 %%% API
@@ -76,22 +80,48 @@ process(_Path, _Request) ->
      #xmlel{name = <<"h1">>, attrs = [],
            children = [{xmlcdata, <<"400 Bad Request">>}]}}.
 
+-spec open_session(binary(), pid()) -> ok | {error, any()}.
 open_session(SID, Pid) ->
     Mod = gen_mod:ram_db_mod(global, ?MODULE),
-    Mod:open_session(SID, Pid).
+    case Mod:open_session(SID, Pid) of
+       ok ->
+           delete_cache(Mod, SID);
+       {error, _} = Err ->
+           Err
+    end.
 
+-spec close_session(binary()) -> ok.
 close_session(SID) ->
     Mod = gen_mod:ram_db_mod(global, ?MODULE),
-    Mod:close_session(SID).
+    Mod:close_session(SID),
+    delete_cache(Mod, SID).
 
+-spec find_session(binary()) -> {ok, pid()} | error.
 find_session(SID) ->
     Mod = gen_mod:ram_db_mod(global, ?MODULE),
-    Mod:find_session(SID).
+    case use_cache(Mod) of
+       true ->
+           ets_cache:lookup(
+             ?BOSH_CACHE, SID,
+             fun() ->
+                     case Mod:find_session(SID) of
+                         {ok, Pid} -> {ok, Pid};
+                         {error, _} -> error
+                     end
+             end);
+       false ->
+           case Mod:find_session(SID) of
+               {ok, Pid} -> {ok, Pid};
+               {error, _} -> error
+           end
+    end.
 
 start(Host, Opts) ->
     start_jiffy(Opts),
     Mod = gen_mod:ram_db_mod(global, ?MODULE),
+    init_cache(Mod),
     Mod:init(),
+    clean_cache(),
     TmpSup = gen_mod:get_module_proc(Host, ?MODULE),
     TmpSupSpec = {TmpSup,
                  {ejabberd_tmp_sup, start_link, [TmpSup, ejabberd_bosh]},
@@ -106,6 +136,7 @@ stop(Host) ->
 reload(_Host, NewOpts, _OldOpts) ->
     start_jiffy(NewOpts),
     Mod = gen_mod:ram_db_mod(global, ?MODULE),
+    init_cache(Mod),
     Mod:init(),
     ok.
 
@@ -160,9 +191,87 @@ mod_opt_type(ram_db_type) ->
     fun(T) -> ejabberd_config:v_db(?MODULE, T) end;
 mod_opt_type(queue_type) ->
     fun(ram) -> ram; (file) -> file end;
+mod_opt_type(O) when O == use_cache; O == cache_missed ->
+    fun(B) when is_boolean(B) -> B end;
+mod_opt_type(O) when O == cache_size; O == cache_life_time ->
+    fun(I) when is_integer(I), I>0 -> I;
+       (unlimited) -> infinity;
+       (infinity) -> infinity
+    end;
 mod_opt_type(_) ->
     [json, max_concat, max_inactivity, max_pause, prebind, ram_db_type,
-     queue_type].
+     queue_type, use_cache, cache_size, cache_missed, cache_life_time].
+
+%%%----------------------------------------------------------------------
+%%% Cache stuff
+%%%----------------------------------------------------------------------
+-spec init_cache(module()) -> ok.
+init_cache(Mod) ->
+    case use_cache(Mod) of
+       true ->
+           ets_cache:new(?BOSH_CACHE, cache_opts());
+       false ->
+           ets_cache:delete(?BOSH_CACHE)
+    end.
+
+-spec use_cache(module()) -> boolean().
+use_cache(Mod) ->
+    case erlang:function_exported(Mod, use_cache, 0) of
+       true -> Mod:use_cache();
+       false ->
+           gen_mod:get_module_opt(
+             global, ?MODULE, use_cache, mod_opt_type(use_cache),
+             ejabberd_config:use_cache(global))
+    end.
+
+-spec cache_nodes(module()) -> [node()].
+cache_nodes(Mod) ->
+    case erlang:function_exported(Mod, cache_nodes, 0) of
+       true -> Mod:cache_nodes();
+       false -> ejabberd_cluster:get_nodes()
+    end.
+
+-spec delete_cache(module(), binary()) -> ok.
+delete_cache(Mod, SID) ->
+    case use_cache(Mod) of
+       true ->
+           ets_cache:delete(?BOSH_CACHE, SID, cache_nodes(Mod));
+       false ->
+           ok
+    end.
+
+-spec cache_opts() -> [proplists:property()].
+cache_opts() ->
+    MaxSize = gen_mod:get_module_opt(
+               global, ?MODULE, cache_size,
+               mod_opt_type(cache_size),
+               ejabberd_config:cache_size(global)),
+    CacheMissed = gen_mod:get_module_opt(
+                   global, ?MODULE, cache_missed,
+                   mod_opt_type(cache_missed),
+                   ejabberd_config:cache_missed(global)),
+    LifeTime = case gen_mod:get_module_opt(
+                     global, ?MODULE, cache_life_time,
+                     mod_opt_type(cache_life_time),
+                     ejabberd_config:cache_life_time(global)) of
+                  infinity -> infinity;
+                  I -> timer:seconds(I)
+              end,
+    [{max_size, MaxSize}, {cache_missed, CacheMissed}, {life_time, LifeTime}].
+
+-spec clean_cache(node()) -> ok.
+clean_cache(Node) ->
+    ets_cache:filter(
+      ?BOSH_CACHE,
+      fun(_, error) ->
+             false;
+        (_, {ok, Pid}) ->
+             node(Pid) /= Node
+      end).
+
+-spec clean_cache() -> ok.
+clean_cache() ->
+    ejabberd_cluster:eval_everywhere(?MODULE, clean_cache, [node()]).
 
 %%%----------------------------------------------------------------------
 %%% Help Web Page
index b96d88d145399fd992a0f97ff502f36f0ee7c254..5954cbe4914bc87a95226ae2d40bc2b24e44e640 100644 (file)
@@ -25,7 +25,8 @@
 -behaviour(mod_bosh).
 
 %% mod_bosh API
--export([init/0, open_session/2, close_session/1, find_session/1]).
+-export([init/0, open_session/2, close_session/1, find_session/1,
+        use_cache/0]).
 
 %% gen_server callbacks
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
@@ -55,6 +56,9 @@ init() ->
 start_link() ->
     gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
 
+use_cache() ->
+    false.
+
 open_session(SID, Pid) ->
     Session = #bosh{sid = SID, timestamp = p1_time_compat:timestamp(), pid = Pid},
     lists:foreach(
@@ -82,7 +86,7 @@ find_session(SID) ->
         [#bosh{pid = Pid}] ->
             {ok, Pid};
         [] ->
-            error
+            {error, notfound}
     end.
 
 %%%===================================================================
index 156df368b80fe1f71b11c7673d409b167882ee2d..194d220a1d6e610a79de149fed5e7481f849fc8a 100644 (file)
@@ -8,24 +8,45 @@
 %%%-------------------------------------------------------------------
 -module(mod_bosh_redis).
 -behaviour(mod_bosh).
+-behaviour(gen_server).
 
 %% API
--export([init/0, open_session/2, close_session/1, find_session/1]).
+-export([init/0, open_session/2, close_session/1, find_session/1,
+        cache_nodes/0]).
+%% gen_server callbacks
+-export([init/1, handle_cast/2, handle_call/3, handle_info/2,
+        terminate/2, code_change/3, start_link/0]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
+-include("bosh.hrl").
 
--define(BOSH_KEY, "ejabberd:bosh").
+-record(state, {}).
+
+-define(BOSH_KEY, <<"ejabberd:bosh">>).
 
 %%%===================================================================
 %%% API
 %%%===================================================================
 init() ->
-    clean_table().
+    Spec = {?MODULE, {?MODULE, start_link, []},
+           transient, 5000, worker, [?MODULE]},
+    case supervisor:start_child(ejabberd_backend_sup, Spec) of
+       {ok, _Pid} -> ok;
+       Err -> Err
+    end.
+
+-spec start_link() -> {ok, pid()} | {error, any()}.
+start_link() ->
+    gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
 
 open_session(SID, Pid) ->
     PidBin = term_to_binary(Pid),
-    case ejabberd_redis:hset(?BOSH_KEY, SID, PidBin) of
+    case ejabberd_redis:multi(
+          fun() ->
+                  ejabberd_redis:hset(?BOSH_KEY, SID, PidBin),
+                  ejabberd_redis:publish(?BOSH_KEY, SID)
+          end) of
        {ok, _} ->
            ok;
        {error, _} ->
@@ -33,23 +54,63 @@ open_session(SID, Pid) ->
     end.
 
 close_session(SID) ->
-    ejabberd_redis:hdel(?BOSH_KEY, [SID]),
-    ok.
+    case ejabberd_redis:multi(
+          fun() ->
+                  ejabberd_redis:hdel(?BOSH_KEY, [SID]),
+                  ejabberd_redis:publish(?BOSH_KEY, SID)
+          end) of
+       {ok, _} ->
+           ok;
+       {error, _} ->
+           {error, db_failure}
+    end.
 
 find_session(SID) ->
     case ejabberd_redis:hget(?BOSH_KEY, SID) of
-       {ok, Pid} when is_binary(Pid) ->
+       {ok, undefined} ->
+           {error, notfound};
+       {ok, Pid} ->
            try
                {ok, binary_to_term(Pid)}
            catch _:badarg ->
                    ?ERROR_MSG("malformed data in redis (key = '~s'): ~p",
                               [SID, Pid]),
-                   error
+                   {error, db_failure}
            end;
-       _ ->
-           error
+       {error, _} ->
+           {error, db_failure}
     end.
 
+cache_nodes() ->
+    [node()].
+
+%%%===================================================================
+%%% gen_server callbacks
+%%%===================================================================
+init([]) ->
+    clean_table(),
+    {ok, #state{}}.
+
+handle_call(_Request, _From, State) ->
+    Reply = ok,
+    {reply, Reply, State}.
+
+handle_cast(_Msg, State) ->
+    {noreply, State}.
+
+handle_info({redis_message, ?BOSH_KEY, SID}, State) ->
+    ets_cache:delete(?BOSH_CACHE, SID),
+    {noreply, State};
+handle_info(Info, State) ->
+    ?ERROR_MSG("unexpected info: ~p", [Info]),
+    {noreply, State}.
+
+terminate(_Reason, _State) ->
+    ok.
+
+code_change(_OldVsn, State, _Extra) ->
+    {ok, State}.
+
 %%%===================================================================
 %%% Internal functions
 %%%===================================================================
index 0171ad8f4584994d490be0eaad5c21783fb3c7c0..73e42868d86db35ac4a9999a95fa5b1b2032f69f 100644 (file)
@@ -44,13 +44,18 @@ open_session(SID, Pid) ->
            ok;
        Err ->
            ?ERROR_MSG("failed to update 'bosh' table: ~p", [Err]),
-           {error, Err}
+           {error, db_failure}
     end.
 
 close_session(SID) ->
-    %% TODO: report errors
-    ejabberd_sql:sql_query(
-      ?MYNAME, ?SQL("delete from bosh where sid=%(SID)s")).
+    case ejabberd_sql:sql_query(
+          ?MYNAME, ?SQL("delete from bosh where sid=%(SID)s")) of
+       {updated, _} ->
+           ok;
+       Err ->
+           ?ERROR_MSG("failed to delete from 'bosh' table: ~p", [Err]),
+           {error, db_failure}
+    end.
 
 find_session(SID) ->
     case ejabberd_sql:sql_query(
@@ -58,13 +63,13 @@ find_session(SID) ->
           ?SQL("select @(pid)s, @(node)s from bosh where sid=%(SID)s")) of
        {selected, [{Pid, Node}]} ->
            try {ok, misc:decode_pid(Pid, Node)}
-           catch _:{bad_node, _} -> error
+           catch _:{bad_node, _} -> {error, notfound}
            end;
        {selected, []} ->
-           error;
+           {error, notfound};
        Err ->
            ?ERROR_MSG("failed to select 'bosh' table: ~p", [Err]),
-           error
+           {error, db_failure}
     end.
 
 %%%===================================================================
index a7ae37f454dff677aca4d48827ddbd5a1b39d922..91c18aabf329b820582e3abe31c9508eecc472bc 100644 (file)
 
 -export([user_send_packet/1, user_receive_packet/1,
         iq_handler/1, remove_connection/4, disco_features/5,
-        is_carbon_copy/1, mod_opt_type/1, depends/2]).
+        is_carbon_copy/1, mod_opt_type/1, depends/2, clean_cache/1]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
 -include("xmpp.hrl").
+-include("mod_carboncopy.hrl").
 
 -type direction() :: sent | received.
 
 -callback init(binary(), gen_mod:opts()) -> any().
 -callback enable(binary(), binary(), binary(), binary()) -> ok | {error, any()}.
 -callback disable(binary(), binary(), binary()) -> ok | {error, any()}.
--callback list(binary(), binary()) -> [{binary(), binary()}].
+-callback list(binary(), binary()) -> [{binary(), binary(), node()}].
+-callback use_cache(binary()) -> boolean().
+-callback cache_nodes(binary()) -> [node()].
+
+-optional_callbacks([use_cache/1, cache_nodes/1]).
 
 -spec is_carbon_copy(stanza()) -> boolean().
 is_carbon_copy(#message{meta = #{carbon_copy := true}}) ->
@@ -59,7 +64,9 @@ start(Host, Opts) ->
     IQDisc = gen_mod:get_opt(iqdisc, Opts,fun gen_iq_handler:check_type/1, one_queue),
     ejabberd_hooks:add(disco_local_features, Host, ?MODULE, disco_features, 50),
     Mod = gen_mod:ram_db_mod(Host, ?MODULE),
+    init_cache(Mod, Host, Opts),
     Mod:init(Host, Opts),
+    clean_cache(),
     ejabberd_hooks:add(unset_presence_hook,Host, ?MODULE, remove_connection, 10),
     %% why priority 89: to define clearly that we must run BEFORE mod_logdb hook (90)
     ejabberd_hooks:add(user_send_packet,Host, ?MODULE, user_send_packet, 89),
@@ -82,6 +89,12 @@ reload(Host, NewOpts, OldOpts) ->
        true ->
            ok
     end,
+    case use_cache(NewMod, Host) of
+       true ->
+           ets_cache:new(?CARBONCOPY_CACHE, cache_opts(Host, NewOpts));
+       false ->
+           ok
+    end,
     case gen_mod:is_equal_opt(iqdisc, NewOpts, OldOpts,
                              fun gen_iq_handler:check_type/1,
                              one_queue) of
@@ -247,13 +260,20 @@ build_forward_packet(JID, #message{type = T} = Msg, Sender, Dest, Direction) ->
 enable(Host, U, R, CC)->
     ?DEBUG("enabling for ~p", [U]),
     Mod = gen_mod:ram_db_mod(Host, ?MODULE),
-    Mod:enable(U, Host, R, CC).
+    case Mod:enable(U, Host, R, CC) of
+       ok ->
+           delete_cache(Mod, U, Host);
+       {error, _} = Err ->
+           Err
+    end.
 
 -spec disable(binary(), binary(), binary()) -> ok | {error, any()}.
 disable(Host, U, R)->
     ?DEBUG("disabling for ~p", [U]),
     Mod = gen_mod:ram_db_mod(Host, ?MODULE),
-    Mod:disable(U, Host, R).
+    Res = Mod:disable(U, Host, R),
+    delete_cache(Mod, U, Host),
+    Res.
 
 -spec complete_packet(jid(), message(), direction()) -> message().
 complete_packet(From, #message{from = undefined} = Msg, sent) ->
@@ -276,15 +296,106 @@ is_muc_pm(#jid{lresource = <<>>}, _Packet) ->
 is_muc_pm(_To, Packet) ->
     xmpp:has_subtag(Packet, #muc_user{}).
 
--spec list(binary(), binary()) -> [{binary(), binary()}].
-%% list {resource, cc_version} with carbons enabled for given user and host
+-spec list(binary(), binary()) -> [{Resource :: binary(), Namespace :: binary()}].
 list(User, Server) ->
     Mod = gen_mod:ram_db_mod(Server, ?MODULE),
-    Mod:list(User, Server).
+    case use_cache(Mod, Server) of
+       true ->
+           case ets_cache:lookup(
+                  ?CARBONCOPY_CACHE, {User, Server},
+                  fun() ->
+                          case Mod:list(User, Server) of
+                              {ok, L} when L /= [] -> {ok, L};
+                              _ -> error
+                          end
+                  end) of
+               {ok, L} -> [{Resource, NS} || {Resource, NS, _} <- L];
+               error -> []
+           end;
+       false ->
+           case Mod:list(User, Server) of
+               {ok, L} -> [{Resource, NS} || {Resource, NS, _} <- L];
+               error -> []
+           end
+    end.
+
+-spec init_cache(module(), binary(), gen_mod:opts()) -> ok.
+init_cache(Mod, Host, Opts) ->
+    case use_cache(Mod, Host) of
+       true ->
+           ets_cache:new(?CARBONCOPY_CACHE, cache_opts(Host, Opts));
+       false ->
+           ets_cache:delete(?CARBONCOPY_CACHE)
+    end.
+
+-spec cache_opts(binary(), gen_mod:opts()) -> [proplists:property()].
+cache_opts(Host, Opts) ->
+    MaxSize = gen_mod:get_opt(
+               cache_size, Opts, mod_opt_type(cache_size),
+               ejabberd_config:cache_size(Host)),
+    CacheMissed = gen_mod:get_opt(
+                   cache_missed, Opts, mod_opt_type(cache_missed),
+                   ejabberd_config:cache_missed(Host)),
+    LifeTime = case gen_mod:get_opt(
+                     cache_life_time, Opts, mod_opt_type(cache_life_time),
+                     ejabberd_config:cache_life_time(Host)) of
+                  infinity -> infinity;
+                  I -> timer:seconds(I)
+              end,
+    [{max_size, MaxSize}, {cache_missed, CacheMissed}, {life_time, LifeTime}].
+
+-spec use_cache(module(), binary()) -> boolean().
+use_cache(Mod, Host) ->
+    case erlang:function_exported(Mod, use_cache, 1) of
+       true -> Mod:use_cache(Host);
+       false ->
+           gen_mod:get_module_opt(
+             Host, ?MODULE, use_cache, mod_opt_type(use_cache),
+             ejabberd_config:use_cache(Host))
+    end.
+
+-spec cache_nodes(module(), binary()) -> [node()].
+cache_nodes(Mod, Host) ->
+    case erlang:function_exported(Mod, cache_nodes, 1) of
+       true -> Mod:cache_nodes(Host);
+       false -> ejabberd_cluster:get_nodes()
+    end.
+
+-spec clean_cache(node()) -> ok.
+clean_cache(Node) ->
+    ets_cache:filter(
+      ?CARBONCOPY_CACHE,
+      fun(_, error) ->
+             false;
+        (_, {ok, L}) ->
+             not lists:any(fun({_, _, N}) -> N == Node end, L)
+      end).
+
+-spec clean_cache() -> ok.
+clean_cache() ->
+    ejabberd_cluster:eval_everywhere(?MODULE, clean_cache, [node()]).
+
+-spec delete_cache(module(), binary(), binary()) -> ok.
+delete_cache(Mod, User, Server) ->
+    case use_cache(Mod, Server) of
+       true ->
+           ets_cache:delete(?CARBONCOPY_CACHE, {User, Server},
+                            cache_nodes(Mod, Server));
+       false ->
+           ok
+    end.
 
 depends(_Host, _Opts) ->
     [].
 
 mod_opt_type(iqdisc) -> fun gen_iq_handler:check_type/1;
 mod_opt_type(ram_db_type) -> fun(T) -> ejabberd_config:v_db(?MODULE, T) end;
-mod_opt_type(_) -> [ram_db_type, iqdisc].
+mod_opt_type(O) when O == use_cache; O == cache_missed ->
+    fun(B) when is_boolean(B) -> B end;
+mod_opt_type(O) when O == cache_size; O == cache_life_time ->
+    fun(I) when is_integer(I), I>0 -> I;
+       (unlimited) -> infinity;
+       (infinity) -> infinity
+    end;
+mod_opt_type(_) ->
+    [ram_db_type, iqdisc, use_cache, cache_size, cache_missed, cache_life_time].
index 9c6a2ffaf138bd6f4334ffedb3998d3411476174..62355165ea2c06701c64a6b6727aaff3670198b1 100644 (file)
@@ -27,7 +27,7 @@
 -behaviour(mod_carboncopy).
 
 %% API
--export([init/2, enable/4, disable/3, list/2]).
+-export([init/2, enable/4, disable/3, list/2, use_cache/1]).
 
 -include("mod_carboncopy.hrl").
 
@@ -53,31 +53,26 @@ init(_Host, _Opts) ->
     mnesia:add_table_copy(carboncopy, node(), ram_copies).
 
 enable(LUser, LServer, LResource, NS) ->
-    try mnesia:dirty_write(
-         #carboncopy{us = {LUser, LServer},
-                     resource = LResource,
-                     version = NS}) of
-       ok -> ok
-    catch _:Error ->
-           {error, Error}
-    end.
+    mnesia:dirty_write(
+      #carboncopy{us = {LUser, LServer},
+                 resource = LResource,
+                 version = NS}).
 
 disable(LUser, LServer, LResource) ->
     ToDelete = mnesia:dirty_match_object(
                 #carboncopy{us = {LUser, LServer},
                             resource = LResource,
                             version = '_'}),
-    try lists:foreach(fun mnesia:dirty_delete_object/1, ToDelete) of
-       ok -> ok
-    catch _:Error ->
-           {error, Error}
-    end.
+    lists:foreach(fun mnesia:dirty_delete_object/1, ToDelete).
 
 list(LUser, LServer) ->
-    mnesia:dirty_select(
-      carboncopy,
-      [{#carboncopy{us = {LUser, LServer}, resource = '$2', version = '$3'},
-       [], [{{'$2','$3'}}]}]).
+    {ok, mnesia:dirty_select(
+          carboncopy,
+          [{#carboncopy{us = {LUser, LServer}, resource = '$2', version = '$3'},
+            [], [{{'$2','$3', node()}}]}])}.
+
+use_cache(_LServer) ->
+    false.
 
 %%%===================================================================
 %%% Internal functions
index 8ed33468b8bd283eef6c7c597ab7aa32f7bc5ea0..b72755f4ea8cc6cf8e7274248ba94405066d6030 100644 (file)
 %%%-------------------------------------------------------------------
 -module(mod_carboncopy_redis).
 -behaviour(mod_carboncopy).
+-behaviour(gen_server).
 
 %% API
--export([init/2, enable/4, disable/3, list/2]).
+-export([init/2, enable/4, disable/3, list/2, cache_nodes/1]).
+%% gen_server callbacks
+-export([init/1, handle_cast/2, handle_call/3, handle_info/2,
+        terminate/2, code_change/3, start_link/0]).
 
 -include("ejabberd.hrl").
 -include("logger.hrl").
+-include("mod_carboncopy.hrl").
+
+-define(CARBONCOPY_KEY, <<"ejabberd:carboncopy">>).
+
+-record(state, {}).
 
 %%%===================================================================
 %%% API
 %%%===================================================================
 init(_Host, _Opts) ->
-    clean_table().
+    Spec = {?MODULE, {?MODULE, start_link, []},
+           transient, 5000, worker, [?MODULE]},
+    case supervisor:start_child(ejabberd_backend_sup, Spec) of
+       {ok, _Pid} -> ok;
+       Err -> Err
+    end.
+
+-spec start_link() -> {ok, pid()} | {error, any()}.
+start_link() ->
+    gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
+
+cache_nodes(_LServer) ->
+    [node()].
 
 enable(LUser, LServer, LResource, NS) ->
     USKey = us_key(LUser, LServer),
     NodeKey = node_key(),
     JID = jid:encode({LUser, LServer, LResource}),
+    Data = term_to_binary({NS, node()}),
     case ejabberd_redis:multi(
           fun() ->
-                  ejabberd_redis:hset(USKey, LResource, NS),
-                  ejabberd_redis:sadd(NodeKey, [JID])
+                  ejabberd_redis:hset(USKey, LResource, Data),
+                  ejabberd_redis:sadd(NodeKey, [JID]),
+                  ejabberd_redis:publish(
+                    ?CARBONCOPY_KEY,
+                    term_to_binary({delete, {LUser, LServer}}))
           end) of
        {ok, _} ->
            ok;
@@ -57,7 +82,10 @@ disable(LUser, LServer, LResource) ->
     case ejabberd_redis:multi(
           fun() ->
                   ejabberd_redis:hdel(USKey, [LResource]),
-                  ejabberd_redis:srem(NodeKey, [JID])
+                  ejabberd_redis:srem(NodeKey, [JID]),
+                  ejabberd_redis:publish(
+                    ?CARBONCOPY_KEY,
+                    term_to_binary({delete, {LUser, LServer}}))
           end) of
        {ok, _} ->
            ok;
@@ -68,12 +96,56 @@ disable(LUser, LServer, LResource) ->
 list(LUser, LServer) ->
     USKey = us_key(LUser, LServer),
     case ejabberd_redis:hgetall(USKey) of
-       {ok, Vals} ->
-           Vals;
+       {ok, Pairs} ->
+           {ok, lists:flatmap(
+                  fun({Resource, Data}) ->
+                          try
+                              {NS, Node} = binary_to_term(Data),
+                              [{Resource, NS, Node}]
+                          catch _:_ ->
+                                  ?ERROR_MSG("invalid term stored in Redis "
+                                             "(key = ~s): ~p",
+                                             [USKey, Data]),
+                                  []
+                          end
+                  end, Pairs)};
        {error, _} ->
-           []
+           {error, db_failure}
     end.
 
+%%%===================================================================
+%%% gen_server callbacks
+%%%===================================================================
+init([]) ->
+    ejabberd_redis:subscribe([?CARBONCOPY_KEY]),
+    clean_table(),
+    {ok, #state{}}.
+
+handle_call(_Request, _From, State) ->
+    Reply = ok,
+    {reply, Reply, State}.
+
+handle_cast(_Msg, State) ->
+    {noreply, State}.
+
+handle_info({redis_message, ?CARBONCOPY_KEY, Data}, State) ->
+    case binary_to_term(Data) of
+       {delete, Key} ->
+           ets_cache:delete(?CARBONCOPY_CACHE, Key);
+       Msg ->
+           ?WARNING_MSG("unexpected redis message: ~p", [Msg])
+    end,
+    {noreply, State};
+handle_info(Info, State) ->
+    ?ERROR_MSG("unexpected info: ~p", [Info]),
+    {noreply, State}.
+
+terminate(_Reason, _State) ->
+    ok.
+
+code_change(_OldVsn, State, _Extra) ->
+    {ok, State}.
+
 %%%===================================================================
 %%% Internal functions
 %%%===================================================================
index 2770d40aa88a754297acd064b0792464ff547d5a..41d9a0632f41db04ad8f5718adf6d49f49b0971d 100644 (file)
@@ -49,7 +49,7 @@ enable(LUser, LServer, LResource, NS) ->
            ok;
        Err ->
            ?ERROR_MSG("failed to update 'carboncopy' table: ~p", [Err]),
-           Err
+           {error, db_failure}
     end.
 
 disable(LUser, LServer, LResource) ->
@@ -61,19 +61,20 @@ disable(LUser, LServer, LResource) ->
            ok;
        Err ->
            ?ERROR_MSG("failed to delete from 'carboncopy' table: ~p", [Err]),
-           Err
+           {error, db_failure}
     end.
 
 list(LUser, LServer) ->
     case ejabberd_sql:sql_query(
           LServer,
-          ?SQL("select @(resource)s, @(namespace)s from carboncopy "
+          ?SQL("select @(resource)s, @(namespace)s, @(node)s from carboncopy "
                "where username=%(LUser)s")) of
        {selected, Rows} ->
-           Rows;
+           {ok, [{Resource, NS, binary_to_atom(Node, latin1)}
+                 || {Resource, NS, Node} <- Rows]};
        Err ->
            ?ERROR_MSG("failed to select from 'carboncopy' table: ~p", [Err]),
-           []
+           {error, db_failure}
     end.
 
 %%%===================================================================
@@ -89,5 +90,5 @@ clean_table(LServer) ->
            ok;
        Err ->
            ?ERROR_MSG("failed to clean 'carboncopy' table: ~p", [Err]),
-           Err
+           {error, db_failure}
     end.
index a5e33becdbb3041c88ec7e51cf00a255a78e7be2..ea21b4a1d230b7cd606bc5d1111fd9bebbdbd63b 100644 (file)
@@ -27,7 +27,7 @@
 
 -author('alexey@process-one.net').
 
--export([get_string/0, uniform/0, uniform/1, bytes/1]).
+-export([get_string/0, uniform/0, uniform/1, uniform/2, bytes/1]).
 
 -define(THRESHOLD, 16#10000000000000000).
 
@@ -41,6 +41,9 @@ uniform() ->
 uniform(N) ->
     crypto:rand_uniform(1, N+1).
 
+uniform(N, M) ->
+    crypto:rand_uniform(N, M+1).
+
 -ifdef(STRONG_RAND_BYTES).
 bytes(N) ->
     crypto:strong_rand_bytes(N).