]> granicus.if.org Git - ejabberd/commitdiff
Add Redis as mod_proxy65 RAM backend
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Fri, 31 Mar 2017 16:07:56 +0000 (19:07 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Fri, 31 Mar 2017 16:07:56 +0000 (19:07 +0300)
src/ejabberd_redis.erl
src/mod_proxy65_redis.erl [new file with mode: 0644]

index 1d8f32c28e42710733198f5613e86d0984bdd435..dbd55e914cbe3891d618fc6445e3228d1f3fd796 100644 (file)
 -behaviour(gen_server).
 -behaviour(ejabberd_config).
 
+-compile({no_auto_import, [get/1, put/2]}).
+
 %% API
 -export([start_link/0, q/1, qp/1, config_reloaded/0, opt_type/1]).
+%% Commands
+-export([multi/1, get/1, set/2, del/1, sadd/2, srem/2, smembers/1, scard/1]).
 
 %% gen_server callbacks
 -export([init/1, handle_call/3, handle_cast/2, handle_info/2,
 
 -define(SERVER, ?MODULE).
 -define(PROCNAME, 'ejabberd_redis_client').
+-define(TR_STACK, redis_transaction_stack).
 
 -include("logger.hrl").
 -include("ejabberd.hrl").
 
 -record(state, {connection :: {pid(), reference()} | undefined}).
 
+-type redis_error() :: {error, binary() | atom()}.
+
 %%%===================================================================
 %%% API
 %%%===================================================================
@@ -58,6 +65,28 @@ qp(Pipeline) ->
     catch _:Reason -> {error, Reason}
     end.
 
+-spec multi(fun(() -> any())) -> {ok, list()} | redis_error().
+multi(F) ->
+    case erlang:get(?TR_STACK) of
+       undefined ->
+           erlang:put(?TR_STACK, []),
+           try F() of
+               _ ->
+                   Stack = erlang:get(?TR_STACK),
+                   erlang:erase(?TR_STACK),
+                   Command = [["MULTI"]|lists:reverse([["EXEC"]|Stack])],
+                   case qp(Command) of
+                       {error, _} = Err -> Err;
+                       Result -> get_result(Result)
+                   end
+           catch E:R ->
+                   erlang:erase(?TR_STACK),
+                   erlang:raise(E, R, erlang:get_stacktrace())
+           end;
+       _ ->
+           {error, nested_transaction}
+    end.
+
 config_reloaded() ->
     case is_redis_configured() of
        true ->
@@ -66,6 +95,93 @@ config_reloaded() ->
            ?MODULE ! disconnect
     end.
 
+get(Key) ->
+    case erlang:get(?TR_STACK) of
+       undefined ->
+           q([<<"GET">>, Key]);
+       _ ->
+           {error, transaction_unsupported}
+    end.
+
+-spec set(iodata(), iodata()) -> ok | redis_error() | queued.
+set(Key, Val) ->
+    Cmd = [<<"SET">>, Key, Val],
+    case erlang:get(?TR_STACK) of
+       undefined ->
+           case q(Cmd) of
+               {ok, <<"OK">>} -> ok;
+               {error, _} = Err -> Err
+           end;
+       Stack ->
+           erlang:put(?TR_STACK, [Cmd|Stack]),
+           queued
+    end.
+
+-spec del(list()) -> {ok, non_neg_integer()} | redis_error() | queued.
+del(Keys) ->
+    Cmd = [<<"DEL">>|Keys],
+    case erlang:get(?TR_STACK) of
+       undefined ->
+           case q(Cmd) of
+               {ok, N} -> {ok, binary_to_integer(N)};
+               {error, _} = Err -> Err
+           end;
+       Stack ->
+           erlang:put(?TR_STACK, [Cmd|Stack]),
+           queued
+    end.
+
+-spec sadd(iodata(), list()) -> {ok, non_neg_integer()} | redis_error() | queued.
+sadd(Set, Members) ->
+    Cmd = [<<"SADD">>, Set|Members],
+    case erlang:get(?TR_STACK) of
+       undefined ->
+           case q(Cmd) of
+               {ok, N} -> {ok, binary_to_integer(N)};
+               {error, _} = Err -> Err
+           end;
+       Stack ->
+           erlang:put(?TR_STACK, [Cmd|Stack]),
+           queued
+    end.
+
+-spec srem(iodata(), list()) -> {ok, non_neg_integer()} | redis_error() | queued.
+srem(Set, Members) ->
+    Cmd = [<<"SREM">>, Set|Members],
+    case erlang:get(?TR_STACK) of
+       undefined ->
+           case q(Cmd) of
+               {ok, N} -> {ok, binary_to_integer(N)};
+               {error, _} = Err -> Err
+           end;
+       Stack ->
+           erlang:put(?TR_STACK, [Cmd|Stack]),
+           queued
+    end.
+
+-spec smembers(iodata()) -> {ok, [binary()]} | redis_error().
+smembers(Set) ->
+    case erlang:get(?TR_STACK) of
+       undefined ->
+           q([<<"SMEMBERS">>, Set]);
+       _ ->
+           {error, transaction_unsupported}
+    end.
+
+-spec scard(iodata()) -> {ok, non_neg_integer()} | redis_error().
+scard(Set) ->
+    case erlang:get(?TR_STACK) of
+       undefined ->
+           case q([<<"SCARD">>, Set]) of
+               {ok, N} ->
+                   {ok, binary_to_integer(N)};
+               {error, _} = Err ->
+                   Err
+           end;
+       _ ->
+           {error, transaction_unsupported}
+    end.
+
 %%%===================================================================
 %%% gen_server callbacks
 %%%===================================================================
@@ -202,6 +318,13 @@ connect() ->
            {error, Reason}
     end.
 
+get_result([{error, _} = Err|_]) ->
+    Err;
+get_result([{ok, _} = OK]) ->
+    OK;
+get_result([_|T]) ->
+    get_result(T).
+
 opt_type(redis_connect_timeout) ->
     fun (I) when is_integer(I), I > 0 -> I end;
 opt_type(redis_db) ->
diff --git a/src/mod_proxy65_redis.erl b/src/mod_proxy65_redis.erl
new file mode 100644 (file)
index 0000000..4086574
--- /dev/null
@@ -0,0 +1,190 @@
+%%%-------------------------------------------------------------------
+%%% @author Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%% Created : 31 Mar 2017 by Evgeny Khramtsov <ekhramtsov@process-one.net>
+%%%
+%%%
+%%% ejabberd, Copyright (C) 2002-2017   ProcessOne
+%%%
+%%% This program is free software; you can redistribute it and/or
+%%% modify it under the terms of the GNU General Public License as
+%%% published by the Free Software Foundation; either version 2 of the
+%%% License, or (at your option) any later version.
+%%%
+%%% This program is distributed in the hope that it will be useful,
+%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
+%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+%%% General Public License for more details.
+%%%
+%%% You should have received a copy of the GNU General Public License along
+%%% with this program; if not, write to the Free Software Foundation, Inc.,
+%%% 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+%%%
+%%%-------------------------------------------------------------------
+-module(mod_proxy65_redis).
+-behaviour(mod_proxy65).
+
+%% API
+-export([init/0, register_stream/2, unregister_stream/1, activate_stream/4]).
+
+-include("ejabberd.hrl").
+-include("logger.hrl").
+
+-record(proxy65, {pid_t :: pid(),
+                 pid_i :: pid() | undefined,
+                 jid_i :: binary() | undefined}).
+
+%%%===================================================================
+%%% API
+%%%===================================================================
+init() ->
+    ?INFO_MSG("Cleaning Redis 'proxy65' table...", []),
+    NodeKey = node_key(),
+    case ejabberd_redis:smembers(NodeKey) of
+       {ok, SIDs} ->
+           SIDKeys = [sid_key(S) || S <- SIDs],
+           JIDs = lists:flatmap(
+                    fun(SIDKey) ->
+                            case ejabberd_redis:get(SIDKey) of
+                                {ok, Val} ->
+                                    try binary_to_term(Val) of
+                                        #proxy65{jid_i = J} when is_binary(J) ->
+                                            [jid_key(J)];
+                                        _ ->
+                                            []
+                                    catch _:badarg ->
+                                            []
+                                    end;
+                                _ ->
+                                    []
+                            end
+                    end, SIDKeys),
+           ejabberd_redis:multi(
+             fun() ->
+                     if SIDs /= [] ->
+                             ejabberd_redis:del(SIDKeys),
+                             if JIDs /= [] ->
+                                     ejabberd_redis:del(JIDs);
+                                true ->
+                                     ok
+                             end;
+                        true ->
+                             ok
+                     end,
+                     ejabberd_redis:del([NodeKey])
+             end),
+           ok;
+       Err ->
+           ?ERROR_MSG("redis request failed: ~p", [Err]),
+           {error, db_failure}
+    end.
+
+register_stream(SID, Pid) ->
+    SIDKey = sid_key(SID),
+    try
+       {ok, Val} = ejabberd_redis:get(SIDKey),
+       try binary_to_term(Val) of
+           #proxy65{pid_i = undefined} = R ->
+               NewVal = term_to_binary(R#proxy65{pid_i = Pid}),
+               ok = ejabberd_redis:set(SIDKey, NewVal);
+           _ ->
+               {error, conflict}
+       catch _:badarg when Val == undefined ->
+               NewVal = term_to_binary(#proxy65{pid_t = Pid}),
+               {ok, _} = ejabberd_redis:multi(
+                           fun() ->
+                                   ejabberd_redis:set(SIDKey, NewVal),
+                                   ejabberd_redis:sadd(node_key(), [SID])
+                           end),
+               ok;
+             _:badarg ->
+               ?ERROR_MSG("malformed data in redis (key = '~s'): ~p",
+                          [SIDKey, Val]),
+               {error, db_failure}
+       end
+    catch _:{badmatch, Err} ->
+           ?ERROR_MSG("redis request failed: ~p", [Err]),
+           {error, db_failure}
+    end.
+
+unregister_stream(SID) ->
+    SIDKey = sid_key(SID),
+    NodeKey = node_key(),
+    try
+       {ok, Val} = ejabberd_redis:get(SIDKey),
+       try binary_to_term(Val) of
+           #proxy65{jid_i = JID} when is_binary(JID) ->
+               JIDKey = jid_key(JID),
+               {ok, _} = ejabberd_redis:multi(
+                           fun() ->
+                                   ejabberd_redis:del([SIDKey]),
+                                   ejabberd_redis:srem(JIDKey, [SID]),
+                                   ejabberd_redis:srem(NodeKey, [SID])
+                           end),
+               ok;
+           _ ->
+               {ok, _} = ejabberd_redis:multi(
+                           fun() ->
+                                   ejabberd_redis:del([SIDKey]),
+                                   ejabberd_redis:srem(NodeKey, [SID])
+                           end),
+               ok
+       catch _:badarg when Val == undefined ->
+               ok;
+             _:badarg ->
+               ?ERROR_MSG("malformed data in redis (key = '~s'): ~p",
+                          [SIDKey, Val]),
+               {error, db_failure}
+       end
+    catch _:{badmatch, Err} ->
+           ?ERROR_MSG("redis request failed: ~p", [Err]),
+           {error, db_failure}
+    end.
+
+activate_stream(SID, IJID, MaxConnections, _Node) ->
+    SIDKey = sid_key(SID),
+    JIDKey = jid_key(IJID),
+    try
+       {ok, Val} = ejabberd_redis:get(SIDKey),
+       try binary_to_term(Val) of
+           #proxy65{pid_t = TPid, pid_i = IPid,
+                    jid_i = undefined} = R when is_pid(IPid) ->
+               {ok, Num} = ejabberd_redis:scard(JIDKey),
+               if Num >= MaxConnections ->
+                       {error, {limit, IPid, TPid}};
+                  true ->
+                       NewVal = term_to_binary(R#proxy65{jid_i = IJID}),
+                       {ok, _} = ejabberd_redis:multi(
+                                   fun() ->
+                                           ejabberd_redis:sadd(JIDKey, [SID]),
+                                           ejabberd_redis:set(SIDKey, NewVal)
+                                   end),
+                       {ok, IPid, TPid}
+               end;
+           #proxy65{jid_i = JID} when is_binary(JID) ->
+               {error, conflict};
+           _ ->
+               {error, notfound}
+       catch _:badarg when Val == undefined ->
+               {error, notfound};
+             _:badarg ->
+               ?ERROR_MSG("malformed data in redis (key = '~s'): ~p",
+                          [SIDKey, Val]),
+               {error, db_failure}
+       end
+    catch _:{badmatch, Err} ->
+           ?ERROR_MSG("redis request failed: ~p", [Err]),
+           {error, db_failure}
+    end.
+
+%%%===================================================================
+%%% Internal functions
+%%%===================================================================
+sid_key(SID) ->
+    <<"ejabberd:proxy65:sid:", SID/binary>>.
+
+jid_key(JID) ->
+    <<"ejabberd:proxy65:initiator:", JID/binary>>.
+
+node_key() ->
+    Node = erlang:atom_to_binary(node(), latin1),
+    <<"ejabberd:proxy65:node:", Node/binary>>.