]> granicus.if.org Git - ejabberd/commitdiff
EJAB-940: Implements reliable ODBC transaction nesting.
authorGeoff Cant <geoff.cant@process-one.net>
Tue, 28 Jul 2009 13:46:28 +0000 (13:46 +0000)
committerGeoff Cant <geoff.cant@process-one.net>
Tue, 28 Jul 2009 13:46:28 +0000 (13:46 +0000)
SVN Revision: 2397

src/odbc/ejabberd_odbc.erl

index 0cb673cc19bab8295bc7aae56013a804c0eb3500..1b03485d88b5387f855ff209c412f22c57d49fd0 100644 (file)
@@ -52,6 +52,8 @@
 -record(state, {db_ref, db_type}).
 
 -define(STATE_KEY, ejabberd_odbc_state).
+-define(NESTING_KEY, ejabberd_odbc_nesting_level).
+-define(TOP_LEVEL_TXN, 0).
 -define(MAX_TRANSACTION_RESTARTS, 10).
 -define(PGSQL_PORT, 5432).
 -define(MYSQL_PORT, 3306).
@@ -70,8 +72,7 @@ start_link(Host, StartInterval) ->
     gen_server:start_link(ejabberd_odbc, [Host, StartInterval], []).
 
 sql_query(Host, Query) ->
-    Msg = {sql_query, Query},
-    sql_call(Host, Msg).
+    sql_call(Host, {sql_query, Query}).
 
 %% SQL transaction based on a list of queries
 %% This function automatically
@@ -84,34 +85,27 @@ sql_transaction(Host, Queries) when is_list(Queries) ->
        end,
     sql_transaction(Host, F);
 %% SQL transaction, based on a erlang anonymous function (F = fun)
-sql_transaction(Host, F) ->
-    Msg = {sql_transaction, F},
-    sql_call(Host, Msg).
+sql_transaction(Host, F) when is_function(F) ->
+    sql_call(Host, {sql_transaction, F}).
 
 %% SQL bloc, based on a erlang anonymous function (F = fun)
 sql_bloc(Host, F) ->
-    Msg = {sql_bloc, F},
-    sql_call(Host, Msg).
+    sql_call(Host, {sql_bloc, F}).
 
 sql_call(Host, Msg) ->
     case get(?STATE_KEY) of
         undefined ->
             gen_server:call(ejabberd_odbc_sup:get_random_pid(Host),
-                            Msg, ?TRANSACTION_TIMEOUT);
-        State ->
-            %% Query, Transaction or Bloc nested inside transaction
-            nested_op(Msg, State)
+                            {sql_cmd, Msg}, ?TRANSACTION_TIMEOUT);
+        _State ->
+            nested_op(Msg)
     end.
 
 
 %% This function is intended to be used from inside an sql_transaction:
 sql_query_t(Query) ->
-    State = get(?STATE_KEY),
-    QRes = sql_query_internal(State, Query),
+    QRes = sql_query_internal(Query),
     case QRes of
-       {error, "No SQL-driver information available."} ->
-           % workaround for odbc bug
-           {updated, 0};
        {error, Reason} ->
            throw({aborted, Reason});
        Rs when is_list(Rs) ->
@@ -184,8 +178,13 @@ init([Host, StartInterval]) ->
 %%          {stop, Reason, Reply, State}   | (terminate/2 is called)
 %%          {stop, Reason, State}            (terminate/2 is called)
 %%----------------------------------------------------------------------
-handle_call(Command, _From, State) ->
-    dispatch_sql_command(Command, State).
+handle_call({sql_cmd, Command}, _From, State) ->
+    put(?NESTING_KEY, ?TOP_LEVEL_TXN),
+    put(?STATE_KEY, State),
+    abort_on_driver_error(outer_op(Command));
+handle_call(Request, {Who, _Ref}, State) ->
+    ?WARNING_MSG("Unexpected call ~p from ~p.", [Request, Who]),
+    {reply, ok, State}.
 
 %%----------------------------------------------------------------------
 %% Func: handle_cast/2
@@ -233,110 +232,139 @@ terminate(_Reason, State) ->
 %%%----------------------------------------------------------------------
 %%% Internal functions
 %%%----------------------------------------------------------------------
-dispatch_sql_command({sql_query, Query}, State) ->
-    abort_on_driver_error(sql_query_internal(State, Query), State);
-dispatch_sql_command({sql_transaction, F}, State) ->
-    abort_on_driver_error(
-      execute_transaction(State, F, ?MAX_TRANSACTION_RESTARTS, ""), State);
-dispatch_sql_command({sql_bloc, F}, State) ->
-    abort_on_driver_error(execute_bloc(State, F), State);
-dispatch_sql_command(Request, State) ->
-    ?WARNING_MSG("Unexpected call ~p.", [Request]),
-    {reply, ok, State}.
 
-sql_query_internal(State, Query) ->
-    Nested = case get(?STATE_KEY) of
-                undefined -> put(?STATE_KEY, State), false;
-                _State -> true
-            end,
-    Result = case State#state.db_type of
-                odbc ->
-                    odbc:sql_query(State#state.db_ref, Query);
-                pgsql ->
-                    pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query));
-                mysql ->
-                    ?DEBUG("MySQL, Send query~n~p~n", [Query]),
-                    R = mysql_to_odbc(mysql_conn:fetch(State#state.db_ref, Query, self())),
-                    ?DEBUG("MySQL, Received result~n~p~n", [R]),
-                    R
-            end,
-    case Nested of
-       true -> Result;
-       false -> erase(?STATE_KEY), Result
+%% Only called by handle_call, only handles top level operations.
+%% @spec outer_op(Op) -> {error, Reason} | {aborted, Reason} | {atomic, Result}
+outer_op({sql_query, Query}) ->
+    sql_query_internal(Query);
+outer_op({sql_transaction, F}) ->
+    outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, "");
+outer_op({sql_bloc, F}) ->
+    execute_bloc(F).
+
+%% Called via sql_query/transaction/bloc from client code when inside a
+%% nested operation
+nested_op({sql_query, Query}) ->
+    %% XXX - use sql_query_t here insted? Most likely would break
+    %% callers who expect {error, _} tuples (sql_query_t turns
+    %% these into throws)
+    sql_query_internal(Query);
+nested_op({sql_transaction, F}) ->
+    NestingLevel = get(?NESTING_KEY),
+    if NestingLevel =:= ?TOP_LEVEL_TXN ->
+            %% First transaction inside a (series of) sql_blocs
+            outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, "");
+       true ->
+            %% Transaction inside a transaction
+            inner_transaction(F)
+    end;
+nested_op({sql_bloc, F}) ->
+    execute_bloc(F).
+
+%% Never retry nested transactions - only outer transactions
+inner_transaction(F) ->
+    PreviousNestingLevel = get(?NESTING_KEY),
+    case get(?NESTING_KEY) of
+        ?TOP_LEVEL_TXN ->
+            {backtrace, T} = process_info(self(), backtrace),
+            ?ERROR_MSG("inner transaction called at outer txn level. Trace: ~s", [T]),
+            erlang:exit(implementation_faulty);
+        _N -> ok
+    end,
+    put(?NESTING_KEY, PreviousNestingLevel + 1),
+    Result = (catch F()),
+    put(?NESTING_KEY, PreviousNestingLevel),
+    case Result of
+        {aborted, Reason} ->
+            {aborted, Reason};
+        {'EXIT', Reason} ->
+            {'EXIT', Reason};
+        {atomic, Res} ->
+            {atomic, Res};
+        Res ->
+            {atomic, Res}
     end.
 
-execute_transaction(State, _F, 0, Reason) ->
-    ?ERROR_MSG("SQL transaction restarts exceeded~n"
-              "** Restarts: ~p~n"
-              "** Last abort reason: ~p~n"
-              "** Stacktrace: ~p~n"
-              "** When State == ~p",
-              [?MAX_TRANSACTION_RESTARTS, Reason,
-               erlang:get_stacktrace(), State]),
-    sql_query_internal(State, "rollback;"),
-    {aborted, restarts_exceeded};
-execute_transaction(State, F, NRestarts, _Reason) ->
-    Nested = case get(?STATE_KEY) of
-                undefined ->
-                    put(?STATE_KEY, State),
-                    sql_query_internal(State, "begin;"),
-                    false;
-                _State ->
-                    true
-            end,
-    Result = case catch F() of
-                {aborted, Reason} ->
-                    execute_transaction(State, F, NRestarts - 1, Reason);
-                {'EXIT', Reason} ->
-                    sql_query_internal(State, "rollback;"),
-                    {aborted, Reason};
-                Res  ->
-                    {atomic, Res}
-                            end,
-    case Nested of
-       true -> Result;
-       false -> sql_query_internal(State, "commit;"), erase(?STATE_KEY), Result
+outer_transaction(F, NRestarts, _Reason) ->
+    PreviousNestingLevel = get(?NESTING_KEY),
+    case get(?NESTING_KEY) of
+        ?TOP_LEVEL_TXN ->
+            ok;
+        _N ->
+            {backtrace, T} = process_info(self(), backtrace),
+            ?ERROR_MSG("outer transaction called at inner txn level. Trace: ~s", [T]),
+            erlang:exit(implementation_faulty)
+    end,
+    sql_query_internal("begin;"),
+    put(?NESTING_KEY, PreviousNestingLevel + 1),
+    Result = (catch F()),
+    put(?NESTING_KEY, PreviousNestingLevel),
+    case Result of
+        {aborted, Reason} when NRestarts > 0 ->
+            %% Retry outer transaction upto NRestarts times.
+            sql_query_internal("rollback;"),
+            outer_transaction(F, NRestarts - 1, Reason);
+        {aborted, Reason} when NRestarts =:= 0 ->
+            %% Too many retries of outer transaction.
+            ?ERROR_MSG("SQL transaction restarts exceeded~n"
+                       "** Restarts: ~p~n"
+                       "** Last abort reason: ~p~n"
+                       "** Stacktrace: ~p~n"
+                       "** When State == ~p",
+                       [?MAX_TRANSACTION_RESTARTS, Reason,
+                        erlang:get_stacktrace(), get(?STATE_KEY)]),
+            sql_query_internal("rollback;"),
+            {aborted, Reason};
+        {'EXIT', Reason} ->
+            %% Abort sql transaction on EXIT from outer txn only.
+            sql_query_internal("rollback;"),
+            {aborted, Reason};
+        Res ->
+            %% Commit successful outer txn
+            sql_query_internal("commit;"),
+            {atomic, Res}
     end.
 
-execute_bloc(State, F) ->
-    Nested = case get(?STATE_KEY) of
-                undefined -> put(?STATE_KEY, State), false;
-                _State -> true
-            end,
-    Result = case catch F() of
-                {aborted, Reason} ->
-                    {aborted, Reason};
-                {'EXIT', Reason} ->
-                    {aborted, Reason};
-                Res ->
-                    {atomic, Res}
-            end,
-    case Nested of
-       true -> Result;
-       false -> erase(?STATE_KEY), Result
+execute_bloc(F) ->
+    %% We don't alter ?NESTING_KEY here as only SQL transactions alter txn nesting
+    case catch F() of
+        {aborted, Reason} ->
+            {aborted, Reason};
+        {'EXIT', Reason} ->
+            {aborted, Reason};
+        Res ->
+            {atomic, Res}
     end.
 
-nested_op(Op, State) ->
-    case dispatch_sql_command(Op, State) of
-       {reply, Res, NewState} ->
-           put(?STATE_KEY, NewState),
-           Res;
-       {stop, _Reason, Reply, NewState} ->
-           put(?STATE_KEY, NewState),
-           throw({aborted, Reply});
-       {noreply, NewState} ->
-           put(?STATE_KEY, NewState),
-           exit({bad_op_in_nested_txn, Op})
+sql_query_internal(Query) ->
+    State = get(?STATE_KEY),
+    Res = case State#state.db_type of
+              odbc ->
+                  odbc:sql_query(State#state.db_ref, Query);
+              pgsql ->
+                  pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query));
+              mysql ->
+                  ?DEBUG("MySQL, Send query~n~p~n", [Query]),
+                  R = mysql_to_odbc(mysql_conn:fetch(State#state.db_ref, Query, self())),
+                  ?INFO_MSG("MySQL, Received result~n~p~n", [R]),
+                  R
+          end,
+    case Res of
+       {error, "No SQL-driver information available."} ->
+           % workaround for odbc bug
+           {updated, 0};
+        _Else -> Res
     end.
 
-abort_on_driver_error({error, "query timed out"} = Reply, State) ->
+%% Generate the OTP callback return tuple depending on the driver result.
+abort_on_driver_error({error, "query timed out"} = Reply) ->
     %% mysql driver error
-    {stop, timeout, Reply, State};
-abort_on_driver_error({error, "Failed sending data on socket"++_} = Reply, State) ->
+    {stop, timeout, Reply, get(?STATE_KEY)};
+abort_on_driver_error({error, "Failed sending data on socket"++_} = Reply) ->
     %% mysql driver error
-    {stop, closed, Reply, State};
-abort_on_driver_error(Reply, State) ->
-    {reply, Reply, State}.
+    {stop, closed, Reply, get(?STATE_KEY)};
+abort_on_driver_error(Reply) ->
+    {reply, Reply, get(?STATE_KEY)}.
 
 
 %% == pure ODBC code