-record(state, {db_ref, db_type}).
-define(STATE_KEY, ejabberd_odbc_state).
+-define(NESTING_KEY, ejabberd_odbc_nesting_level).
+-define(TOP_LEVEL_TXN, 0).
-define(MAX_TRANSACTION_RESTARTS, 10).
-define(PGSQL_PORT, 5432).
-define(MYSQL_PORT, 3306).
gen_server:start_link(ejabberd_odbc, [Host, StartInterval], []).
sql_query(Host, Query) ->
- Msg = {sql_query, Query},
- sql_call(Host, Msg).
+ sql_call(Host, {sql_query, Query}).
%% SQL transaction based on a list of queries
%% This function automatically
end,
sql_transaction(Host, F);
%% SQL transaction, based on a erlang anonymous function (F = fun)
-sql_transaction(Host, F) ->
- Msg = {sql_transaction, F},
- sql_call(Host, Msg).
+sql_transaction(Host, F) when is_function(F) ->
+ sql_call(Host, {sql_transaction, F}).
%% SQL bloc, based on a erlang anonymous function (F = fun)
sql_bloc(Host, F) ->
- Msg = {sql_bloc, F},
- sql_call(Host, Msg).
+ sql_call(Host, {sql_bloc, F}).
sql_call(Host, Msg) ->
case get(?STATE_KEY) of
undefined ->
gen_server:call(ejabberd_odbc_sup:get_random_pid(Host),
- Msg, ?TRANSACTION_TIMEOUT);
- State ->
- %% Query, Transaction or Bloc nested inside transaction
- nested_op(Msg, State)
+ {sql_cmd, Msg}, ?TRANSACTION_TIMEOUT);
+ _State ->
+ nested_op(Msg)
end.
%% This function is intended to be used from inside an sql_transaction:
sql_query_t(Query) ->
- State = get(?STATE_KEY),
- QRes = sql_query_internal(State, Query),
+ QRes = sql_query_internal(Query),
case QRes of
- {error, "No SQL-driver information available."} ->
- % workaround for odbc bug
- {updated, 0};
{error, Reason} ->
throw({aborted, Reason});
Rs when is_list(Rs) ->
%% {stop, Reason, Reply, State} | (terminate/2 is called)
%% {stop, Reason, State} (terminate/2 is called)
%%----------------------------------------------------------------------
-handle_call(Command, _From, State) ->
- dispatch_sql_command(Command, State).
+handle_call({sql_cmd, Command}, _From, State) ->
+ put(?NESTING_KEY, ?TOP_LEVEL_TXN),
+ put(?STATE_KEY, State),
+ abort_on_driver_error(outer_op(Command));
+handle_call(Request, {Who, _Ref}, State) ->
+ ?WARNING_MSG("Unexpected call ~p from ~p.", [Request, Who]),
+ {reply, ok, State}.
%%----------------------------------------------------------------------
%% Func: handle_cast/2
%%%----------------------------------------------------------------------
%%% Internal functions
%%%----------------------------------------------------------------------
-dispatch_sql_command({sql_query, Query}, State) ->
- abort_on_driver_error(sql_query_internal(State, Query), State);
-dispatch_sql_command({sql_transaction, F}, State) ->
- abort_on_driver_error(
- execute_transaction(State, F, ?MAX_TRANSACTION_RESTARTS, ""), State);
-dispatch_sql_command({sql_bloc, F}, State) ->
- abort_on_driver_error(execute_bloc(State, F), State);
-dispatch_sql_command(Request, State) ->
- ?WARNING_MSG("Unexpected call ~p.", [Request]),
- {reply, ok, State}.
-sql_query_internal(State, Query) ->
- Nested = case get(?STATE_KEY) of
- undefined -> put(?STATE_KEY, State), false;
- _State -> true
- end,
- Result = case State#state.db_type of
- odbc ->
- odbc:sql_query(State#state.db_ref, Query);
- pgsql ->
- pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query));
- mysql ->
- ?DEBUG("MySQL, Send query~n~p~n", [Query]),
- R = mysql_to_odbc(mysql_conn:fetch(State#state.db_ref, Query, self())),
- ?DEBUG("MySQL, Received result~n~p~n", [R]),
- R
- end,
- case Nested of
- true -> Result;
- false -> erase(?STATE_KEY), Result
+%% Only called by handle_call, only handles top level operations.
+%% @spec outer_op(Op) -> {error, Reason} | {aborted, Reason} | {atomic, Result}
+outer_op({sql_query, Query}) ->
+ sql_query_internal(Query);
+outer_op({sql_transaction, F}) ->
+ outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, "");
+outer_op({sql_bloc, F}) ->
+ execute_bloc(F).
+
+%% Called via sql_query/transaction/bloc from client code when inside a
+%% nested operation
+nested_op({sql_query, Query}) ->
+ %% XXX - use sql_query_t here insted? Most likely would break
+ %% callers who expect {error, _} tuples (sql_query_t turns
+ %% these into throws)
+ sql_query_internal(Query);
+nested_op({sql_transaction, F}) ->
+ NestingLevel = get(?NESTING_KEY),
+ if NestingLevel =:= ?TOP_LEVEL_TXN ->
+ %% First transaction inside a (series of) sql_blocs
+ outer_transaction(F, ?MAX_TRANSACTION_RESTARTS, "");
+ true ->
+ %% Transaction inside a transaction
+ inner_transaction(F)
+ end;
+nested_op({sql_bloc, F}) ->
+ execute_bloc(F).
+
+%% Never retry nested transactions - only outer transactions
+inner_transaction(F) ->
+ PreviousNestingLevel = get(?NESTING_KEY),
+ case get(?NESTING_KEY) of
+ ?TOP_LEVEL_TXN ->
+ {backtrace, T} = process_info(self(), backtrace),
+ ?ERROR_MSG("inner transaction called at outer txn level. Trace: ~s", [T]),
+ erlang:exit(implementation_faulty);
+ _N -> ok
+ end,
+ put(?NESTING_KEY, PreviousNestingLevel + 1),
+ Result = (catch F()),
+ put(?NESTING_KEY, PreviousNestingLevel),
+ case Result of
+ {aborted, Reason} ->
+ {aborted, Reason};
+ {'EXIT', Reason} ->
+ {'EXIT', Reason};
+ {atomic, Res} ->
+ {atomic, Res};
+ Res ->
+ {atomic, Res}
end.
-execute_transaction(State, _F, 0, Reason) ->
- ?ERROR_MSG("SQL transaction restarts exceeded~n"
- "** Restarts: ~p~n"
- "** Last abort reason: ~p~n"
- "** Stacktrace: ~p~n"
- "** When State == ~p",
- [?MAX_TRANSACTION_RESTARTS, Reason,
- erlang:get_stacktrace(), State]),
- sql_query_internal(State, "rollback;"),
- {aborted, restarts_exceeded};
-execute_transaction(State, F, NRestarts, _Reason) ->
- Nested = case get(?STATE_KEY) of
- undefined ->
- put(?STATE_KEY, State),
- sql_query_internal(State, "begin;"),
- false;
- _State ->
- true
- end,
- Result = case catch F() of
- {aborted, Reason} ->
- execute_transaction(State, F, NRestarts - 1, Reason);
- {'EXIT', Reason} ->
- sql_query_internal(State, "rollback;"),
- {aborted, Reason};
- Res ->
- {atomic, Res}
- end,
- case Nested of
- true -> Result;
- false -> sql_query_internal(State, "commit;"), erase(?STATE_KEY), Result
+outer_transaction(F, NRestarts, _Reason) ->
+ PreviousNestingLevel = get(?NESTING_KEY),
+ case get(?NESTING_KEY) of
+ ?TOP_LEVEL_TXN ->
+ ok;
+ _N ->
+ {backtrace, T} = process_info(self(), backtrace),
+ ?ERROR_MSG("outer transaction called at inner txn level. Trace: ~s", [T]),
+ erlang:exit(implementation_faulty)
+ end,
+ sql_query_internal("begin;"),
+ put(?NESTING_KEY, PreviousNestingLevel + 1),
+ Result = (catch F()),
+ put(?NESTING_KEY, PreviousNestingLevel),
+ case Result of
+ {aborted, Reason} when NRestarts > 0 ->
+ %% Retry outer transaction upto NRestarts times.
+ sql_query_internal("rollback;"),
+ outer_transaction(F, NRestarts - 1, Reason);
+ {aborted, Reason} when NRestarts =:= 0 ->
+ %% Too many retries of outer transaction.
+ ?ERROR_MSG("SQL transaction restarts exceeded~n"
+ "** Restarts: ~p~n"
+ "** Last abort reason: ~p~n"
+ "** Stacktrace: ~p~n"
+ "** When State == ~p",
+ [?MAX_TRANSACTION_RESTARTS, Reason,
+ erlang:get_stacktrace(), get(?STATE_KEY)]),
+ sql_query_internal("rollback;"),
+ {aborted, Reason};
+ {'EXIT', Reason} ->
+ %% Abort sql transaction on EXIT from outer txn only.
+ sql_query_internal("rollback;"),
+ {aborted, Reason};
+ Res ->
+ %% Commit successful outer txn
+ sql_query_internal("commit;"),
+ {atomic, Res}
end.
-execute_bloc(State, F) ->
- Nested = case get(?STATE_KEY) of
- undefined -> put(?STATE_KEY, State), false;
- _State -> true
- end,
- Result = case catch F() of
- {aborted, Reason} ->
- {aborted, Reason};
- {'EXIT', Reason} ->
- {aborted, Reason};
- Res ->
- {atomic, Res}
- end,
- case Nested of
- true -> Result;
- false -> erase(?STATE_KEY), Result
+execute_bloc(F) ->
+ %% We don't alter ?NESTING_KEY here as only SQL transactions alter txn nesting
+ case catch F() of
+ {aborted, Reason} ->
+ {aborted, Reason};
+ {'EXIT', Reason} ->
+ {aborted, Reason};
+ Res ->
+ {atomic, Res}
end.
-nested_op(Op, State) ->
- case dispatch_sql_command(Op, State) of
- {reply, Res, NewState} ->
- put(?STATE_KEY, NewState),
- Res;
- {stop, _Reason, Reply, NewState} ->
- put(?STATE_KEY, NewState),
- throw({aborted, Reply});
- {noreply, NewState} ->
- put(?STATE_KEY, NewState),
- exit({bad_op_in_nested_txn, Op})
+sql_query_internal(Query) ->
+ State = get(?STATE_KEY),
+ Res = case State#state.db_type of
+ odbc ->
+ odbc:sql_query(State#state.db_ref, Query);
+ pgsql ->
+ pgsql_to_odbc(pgsql:squery(State#state.db_ref, Query));
+ mysql ->
+ ?DEBUG("MySQL, Send query~n~p~n", [Query]),
+ R = mysql_to_odbc(mysql_conn:fetch(State#state.db_ref, Query, self())),
+ ?INFO_MSG("MySQL, Received result~n~p~n", [R]),
+ R
+ end,
+ case Res of
+ {error, "No SQL-driver information available."} ->
+ % workaround for odbc bug
+ {updated, 0};
+ _Else -> Res
end.
-abort_on_driver_error({error, "query timed out"} = Reply, State) ->
+%% Generate the OTP callback return tuple depending on the driver result.
+abort_on_driver_error({error, "query timed out"} = Reply) ->
%% mysql driver error
- {stop, timeout, Reply, State};
-abort_on_driver_error({error, "Failed sending data on socket"++_} = Reply, State) ->
+ {stop, timeout, Reply, get(?STATE_KEY)};
+abort_on_driver_error({error, "Failed sending data on socket"++_} = Reply) ->
%% mysql driver error
- {stop, closed, Reply, State};
-abort_on_driver_error(Reply, State) ->
- {reply, Reply, State}.
+ {stop, closed, Reply, get(?STATE_KEY)};
+abort_on_driver_error(Reply) ->
+ {reply, Reply, get(?STATE_KEY)}.
%% == pure ODBC code