From 533a4eec96a9810d2e1fdc0200f1bb136e0e88c2 Mon Sep 17 00:00:00 2001
From: Evgeny Khramtsov <ekhramtsov@process-one.net>
Date: Tue, 10 Sep 2019 16:02:51 +0300
Subject: [PATCH] Replicate Mnesia 'bosh' table when nodes are joined

---
 src/mod_bosh_mnesia.erl | 145 +++++++++++++++++++++++++++++++++-------
 1 file changed, 119 insertions(+), 26 deletions(-)

diff --git a/src/mod_bosh_mnesia.erl b/src/mod_bosh_mnesia.erl
index fd3135c31..c84b01704 100644
--- a/src/mod_bosh_mnesia.erl
+++ b/src/mod_bosh_mnesia.erl
@@ -33,12 +33,16 @@
 	 terminate/2, code_change/3, start_link/0]).
 
 -include("logger.hrl").
+-include_lib("stdlib/include/ms_transform.hrl").
 
--record(bosh, {sid = <<"">>      :: binary() | '_',
-               timestamp = erlang:timestamp() :: erlang:timestamp() | '_',
-               pid = self()      :: pid() | '$1'}).
+-define(CALL_TIMEOUT, timer:minutes(10)).
 
--record(state, {}).
+-record(bosh, {sid = <<"">>      :: binary(),
+               timestamp = erlang:timestamp() :: erlang:timestamp(),
+               pid = self()      :: pid()}).
+
+-record(state, {nodes = #{} :: #{node() => {pid(), reference()}}}).
+-type state() :: #state{}.
 
 %%%===================================================================
 %%% API
@@ -49,6 +53,7 @@ init() ->
 	    transient, 5000, worker, [?MODULE]},
     case supervisor:start_child(ejabberd_backend_sup, Spec) of
 	{ok, _Pid} -> ok;
+	{error, {already_started, _}} -> ok;
 	Err -> Err
     end.
 
@@ -59,28 +64,21 @@ start_link() ->
 use_cache() ->
     false.
 
+-spec open_session(binary(), pid()) -> ok.
 open_session(SID, Pid) ->
     Session = #bosh{sid = SID, timestamp = erlang:timestamp(), pid = Pid},
-    lists:foreach(
-      fun(Node) when Node == node() ->
-	      gen_server:call(?MODULE, {write, Session});
-	 (Node) ->
-	      cluster_send({?MODULE, Node}, {write, Session})
-      end, ejabberd_cluster:get_nodes()).
+    gen_server:call(?MODULE, {write, Session}, ?CALL_TIMEOUT).
 
+-spec close_session(binary()) -> ok.
 close_session(SID) ->
     case mnesia:dirty_read(bosh, SID) of
 	[Session] ->
-	    lists:foreach(
-	      fun(Node) when Node == node() ->
-		      gen_server:call(?MODULE, {delete, Session});
-		 (Node) ->
-		      cluster_send({?MODULE, Node}, {delete, Session})
-	      end, ejabberd_cluster:get_nodes());
+	    gen_server:call(?MODULE, {delete, Session}, ?CALL_TIMEOUT);
 	[] ->
 	    ok
     end.
 
+-spec find_session(binary()) -> {ok, pid()} | {error, notfound}.
 find_session(SID) ->
     case mnesia:dirty_read(bosh, SID) of
         [#bosh{pid = Pid}] ->
@@ -92,30 +90,90 @@ find_session(SID) ->
 %%%===================================================================
 %%% gen_server callbacks
 %%%===================================================================
+-spec init([]) -> {ok, state()}.
 init([]) ->
     setup_database(),
+    multicast({join, node(), self()}),
+    mnesia:subscribe(system),
     {ok, #state{}}.
 
-handle_call({write, Session}, _From, State) ->
-    Res = write_session(Session),
-    {reply, Res, State};
-handle_call({delete, Session}, _From, State) ->
-    Res = delete_session(Session),
-    {reply, Res, State};
+-spec handle_call(_, _, state()) -> {reply, ok, state()} | {noreply, state()}.
+handle_call({write, Session} = Msg, _From, State) ->
+    write_session(Session),
+    multicast(Msg),
+    {reply, ok, State};
+handle_call({delete, Session} = Msg, _From, State) ->
+    delete_session(Session),
+    multicast(Msg),
+    {reply, ok, State};
 handle_call(Request, From, State) ->
     ?WARNING_MSG("Unexpected call from ~p: ~p", [From, Request]),
     {noreply, State}.
 
+-spec handle_cast(_, state()) -> {noreply, state()}.
 handle_cast(Msg, State) ->
     ?WARNING_MSG("Unexpected cast: ~p", [Msg]),
     {noreply, State}.
 
+-spec handle_info(_, state()) -> {noreply, state()}.
 handle_info({write, Session}, State) ->
     write_session(Session),
     {noreply, State};
 handle_info({delete, Session}, State) ->
     delete_session(Session),
     {noreply, State};
+handle_info({join, Node, Pid}, State) ->
+    ejabberd_cluster:send(Pid, {joined, node(), self()}),
+    case maps:find(Node, State#state.nodes) of
+	{ok, {Pid, _}} ->
+	    ok;
+	_ ->
+	    ejabberd_cluster:send(Pid, {join, node(), self()})
+    end,
+    {noreply, State};
+handle_info({joined, Node, Pid}, State) ->
+    case maps:find(Node, State#state.nodes) of
+	{ok, {Pid, _}} ->
+	    {noreply, State};
+	Ret ->
+	    MRef = erlang:monitor(process, {?MODULE, Node}),
+	    Nodes = maps:put(Node, {Pid, MRef}, State#state.nodes),
+	    case Ret of
+		error -> ejabberd_cluster:send(Pid, {first, self()});
+		_ -> ok
+	    end,
+	    {noreply, State#state{nodes = Nodes}}
+    end;
+handle_info({first, From}, State) ->
+    ejabberd_cluster:send(From, {replica, node(), first_session()}),
+    {noreply, State};
+handle_info({next, From, Key}, State) ->
+    ejabberd_cluster:send(From, {replica, node(), next_session(Key)}),
+    {noreply, State};
+handle_info({replica, _From, '$end_of_table'}, State) ->
+    {noreply, State};
+handle_info({replica, From, Session}, State) ->
+    write_session(Session),
+    ejabberd_cluster:send(From, {next, self(), Session#bosh.sid}),
+    {noreply, State};
+handle_info({'DOWN', _, process, {?MODULE, _}, _Info}, State) ->
+    {noreply, State};
+handle_info({mnesia_system_event, {mnesia_down, Node}}, State) ->
+    Sessions =
+	ets:select(
+	  bosh,
+	  ets:fun2ms(
+	    fun(#bosh{pid = Pid} = S) when node(Pid) == Node ->
+		    S
+	    end)),
+    lists:foreach(
+      fun(S) ->
+	      mnesia:dirty_delete_object(S)
+      end, Sessions),
+    Nodes = maps:remove(Node, State#state.nodes),
+    {noreply, State#state{nodes = Nodes}};
+handle_info({mnesia_system_event, _}, State) ->
+    {noreply, State};
 handle_info(Info, State) ->
     ?WARNING_MSG("Unexpected info: ~p", [Info]),
     {noreply, State}.
@@ -129,22 +187,24 @@ code_change(_OldVsn, State, _Extra) ->
 %%%===================================================================
 %%% Internal functions
 %%%===================================================================
+-spec write_session(#bosh{}) -> ok.
 write_session(#bosh{pid = Pid1, sid = SID, timestamp = T1} = S1) ->
     case mnesia:dirty_read(bosh, SID) of
 	[#bosh{pid = Pid2, timestamp = T2} = S2] ->
 	    if Pid1 == Pid2 ->
 		    mnesia:dirty_write(S1);
 	       T1 < T2 ->
-		    cluster_send(Pid2, replaced),
+		    ejabberd_cluster:send(Pid2, replaced),
 		    mnesia:dirty_write(S1);
 	       true ->
-		    cluster_send(Pid1, replaced),
+		    ejabberd_cluster:send(Pid1, replaced),
 		    mnesia:dirty_write(S2)
 	    end;
 	[] ->
 	    mnesia:dirty_write(S1)
     end.
 
+-spec delete_session(#bosh{}) -> ok.
 delete_session(#bosh{sid = SID, pid = Pid1}) ->
     case mnesia:dirty_read(bosh, SID) of
 	[#bosh{pid = Pid2}] ->
@@ -157,8 +217,14 @@ delete_session(#bosh{sid = SID, pid = Pid1}) ->
 	    ok
     end.
 
-cluster_send(NodePid, Msg) ->
-    erlang:send(NodePid, Msg, [noconnect, nosuspend]).
+-spec multicast(_) -> ok.
+multicast(Msg) ->
+    lists:foreach(
+      fun(Node) when Node /= node() ->
+	      ejabberd_cluster:send({?MODULE, Node}, Msg);
+	 (_) ->
+	      ok
+      end, ejabberd_cluster:get_nodes()).
 
 setup_database() ->
     case catch mnesia:table_info(bosh, attributes) of
@@ -170,3 +236,30 @@ setup_database() ->
     ejabberd_mnesia:create(?MODULE, bosh,
 			[{ram_copies, [node()]}, {local_content, true},
 			 {attributes, record_info(fields, bosh)}]).
+
+-spec first_session() -> #bosh{} | '$end_of_table'.
+first_session() ->
+    case mnesia:dirty_first(bosh) of
+	'$end_of_table' ->
+	    '$end_of_table';
+	First ->
+	    read_session(First)
+    end.
+
+-spec next_session(binary()) -> #bosh{} | '$end_of_table'.
+next_session(Prev) ->
+    case mnesia:dirty_next(bosh, Prev) of
+	'$end_of_table' ->
+	    '$end_of_table';
+	Next ->
+	    read_session(Next)
+    end.
+
+-spec read_session(binary()) -> #bosh{} | '$end_of_table'.
+read_session(Key) ->
+    case mnesia:dirty_read(bosh, Key) of
+	[#bosh{pid = Pid} = Session] when node(Pid) == node() ->
+	    Session;
+	_ ->
+	    next_session(Key)
+    end.
-- 
2.40.0