Update mod_roster and ejabberd_auth_odbc SQL queries to the new API
authorAlexey Shchepin <alexey@process-one.net>
Mon, 15 Feb 2016 18:02:22 +0000 (21:02 +0300)
committerAlexey Shchepin <alexey@process-one.net>
Tue, 1 Mar 2016 21:12:49 +0000 (00:12 +0300)
src/ejabberd_auth_odbc.erl
src/mod_roster.erl
src/odbc_queries.erl

index b8b4594b630a4685b364e57a9c7771b5642e7dd5..60812781757874caeb9cab4f3d2d6c17c0048256 100644 (file)
@@ -72,22 +72,18 @@ check_password(User, Server, Password) ->
        (LUser == <<>>) or (LServer == <<>>) ->
             false;
        true ->
-            Username = ejabberd_odbc:escape(LUser),
             case is_scrammed() of
                 true ->
-                    try odbc_queries:get_password_scram(LServer, Username) of
-                        {selected, [<<"password">>, <<"serverkey">>,
-                                    <<"salt">>, <<"iterationcount">>],
-                         [[StoredKey, ServerKey, Salt, IterationCount]]} ->
+                    try odbc_queries:get_password_scram(LServer, LUser) of
+                        {selected,
+                         [{StoredKey, ServerKey, Salt, IterationCount}]} ->
                             Scram =
                                 #scram{storedkey = StoredKey,
                                        serverkey = ServerKey,
                                        salt = Salt,
-                                       iterationcount = binary_to_integer(
-                                                          IterationCount)},
+                                       iterationcount = IterationCount},
                             is_password_scram_valid(Password, Scram);
-                        {selected, [<<"password">>, <<"serverkey">>,
-                                    <<"salt">>, <<"iterationcount">>], []} ->
+                        {selected, []} ->
                             false; %% Account does not exist
                         {error, _Error} ->
                             false %% Typical error is that table doesn't exist
@@ -96,12 +92,12 @@ check_password(User, Server, Password) ->
                             false %% Typical error is database not accessible
                     end;
                 false ->
-                    try odbc_queries:get_password(LServer, Username) of
-                        {selected, [<<"password">>], [[Password]]} ->
+                    try odbc_queries:get_password(LServer, LUser) of
+                        {selected, [{Password}]} ->
                             Password /= <<"">>;
-                        {selected, [<<"password">>], [[_Password2]]} ->
+                        {selected, [{_Password2}]} ->
                             false; %% Password is not correct
-                        {selected, [<<"password">>], []} ->
+                        {selected, []} ->
                             false; %% Account does not exist
                         {error, _Error} ->
                             false %% Typical error is that table doesn't exist
@@ -124,10 +120,9 @@ check_password(User, Server, Password, Digest,
        true ->
             case is_scrammed() of
                 false ->
-                    Username = ejabberd_odbc:escape(LUser),
-                    try odbc_queries:get_password(LServer, Username) of
+                    try odbc_queries:get_password(LServer, LUser) of
                         %% Account exists, check if password is valid
-                        {selected, [<<"password">>], [[Passwd]]} ->
+                        {selected, [{Passwd}]} ->
                             DigRes = if Digest /= <<"">> ->
                                              Digest == DigestGen(Passwd);
                                         true -> false
@@ -135,7 +130,7 @@ check_password(User, Server, Password, Digest,
                             if DigRes -> true;
                                true -> (Passwd == Password) and (Password /= <<"">>)
                             end;
-                        {selected, [<<"password">>], []} ->
+                        {selected, []} ->
                             false; %% Account does not exist
                         {error, _Error} ->
                             false %% Typical error is that table doesn't exist
@@ -267,24 +262,22 @@ get_password(User, Server) ->
        (LUser == <<>>) or (LServer == <<>>) ->
             false;
        true ->
-            Username = ejabberd_odbc:escape(LUser),
             case is_scrammed() of
                 true ->
                     case catch odbc_queries:get_password_scram(
-                                 LServer, Username) of
-                        {selected, [<<"password">>, <<"serverkey">>,
-                                    <<"salt">>, <<"iterationcount">>],
-                         [[StoredKey, ServerKey, Salt, IterationCount]]} ->
+                                 LServer, LUser) of
+                        {selected,
+                         [{StoredKey, ServerKey, Salt, IterationCount}]} ->
                             {jlib:decode_base64(StoredKey),
                              jlib:decode_base64(ServerKey),
                              jlib:decode_base64(Salt),
-                             binary_to_integer(IterationCount)};
+                             IterationCount};
                         _ -> false
                     end;
                 false ->
-                    case catch odbc_queries:get_password(LServer, Username)
+                    case catch odbc_queries:get_password(LServer, LUser)
                         of
-                        {selected, [<<"password">>], [[Password]]} -> Password;
+                        {selected, [{Password}]} -> Password;
                         _ -> false
                     end
             end
@@ -300,9 +293,8 @@ get_password_s(User, Server) ->
        true ->
             case is_scrammed() of
                 false ->
-                    Username = ejabberd_odbc:escape(LUser),
-                    case catch odbc_queries:get_password(LServer, Username) of
-                        {selected, [<<"password">>], [[Password]]} -> Password;
+                    case catch odbc_queries:get_password(LServer, LUser) of
+                        {selected, [{Password}]} -> Password;
                         _ -> <<"">>
                     end;
                 true -> <<"">>
@@ -311,15 +303,17 @@ get_password_s(User, Server) ->
 
 %% @spec (User, Server) -> true | false | {error, Error}
 is_user_exists(User, Server) ->
-    case jid:nodeprep(User) of
-      error -> false;
-      LUser ->
-         Username = ejabberd_odbc:escape(LUser),
-         LServer = jid:nameprep(Server),
-         try odbc_queries:get_password(LServer, Username) of
-           {selected, [<<"password">>], [[_Password]]} ->
+    LServer = jid:nameprep(Server),
+    LUser = jid:nodeprep(User),
+    if (LUser == error) or (LServer == error) ->
+            false;
+       (LUser == <<>>) or (LServer == <<>>) ->
+            false;
+       true ->
+         try odbc_queries:get_password(LServer, LUser) of
+           {selected, [{_Password}]} ->
                true; %% Account exists
-           {selected, [<<"password">>], []} ->
+           {selected, []} ->
                false; %% Account does not exist
            {error, Error} -> {error, Error}
          catch
index f68300763a3f48aedc401fabc8a374ec7327083a..1e5e3b70c22ec61aff484eabb6fe4ea8a9d3dd56 100644 (file)
@@ -203,11 +203,9 @@ read_roster_version(LUser, LServer, mnesia) ->
       [] -> error
     end;
 read_roster_version(LUser, LServer, odbc) ->
-    Username = ejabberd_odbc:escape(LUser),
-    case odbc_queries:get_roster_version(LServer, Username)
-       of
-      {selected, [<<"version">>], [[Version]]} -> Version;
-      {selected, [<<"version">>], []} -> error
+    case odbc_queries:get_roster_version(LServer, LUser) of
+      {selected, [{Version}]} -> Version;
+      {selected, []} -> error
     end;
 read_roster_version(LServer, LUser, riak) ->
     case ejabberd_riak:get(roster_version, roster_version_schema(),
@@ -369,46 +367,37 @@ get_roster(LUser, LServer, riak) ->
         _Err -> []
     end;
 get_roster(LUser, LServer, odbc) ->
-    Username = ejabberd_odbc:escape(LUser),
-    case catch odbc_queries:get_roster(LServer, Username) of
-      {selected,
-       [<<"username">>, <<"jid">>, <<"nick">>,
-       <<"subscription">>, <<"ask">>, <<"askmessage">>,
-       <<"server">>, <<"subscribe">>, <<"type">>],
-       Items}
-         when is_list(Items) ->
-         JIDGroups = case catch
-                            odbc_queries:get_roster_jid_groups(LServer,
-                                                               Username)
-                         of
-                       {selected, [<<"jid">>, <<"grp">>], JGrps}
-                           when is_list(JGrps) ->
-                           JGrps;
-                       _ -> []
-                     end,
-         GroupsDict = lists:foldl(fun ([J, G], Acc) ->
-                                          dict:append(J, G, Acc)
-                                  end,
-                                  dict:new(), JIDGroups),
-         RItems = lists:flatmap(fun (I) ->
-                                        case raw_to_record(LServer, I) of
-                                          %% Bad JID in database:
-                                          error -> [];
-                                          R ->
-                                              SJID =
-                                                  jid:to_string(R#roster.jid),
-                                              Groups = case dict:find(SJID,
-                                                                      GroupsDict)
-                                                           of
-                                                         {ok, Gs} -> Gs;
-                                                         error -> []
-                                                       end,
-                                              [R#roster{groups = Groups}]
-                                        end
-                                end,
-                                Items),
-         RItems;
-      _ -> []
+    case catch odbc_queries:get_roster(LServer, LUser) of
+        {selected, Items} when is_list(Items) ->
+            JIDGroups = case catch odbc_queries:get_roster_jid_groups(
+                                     LServer, LUser) of
+                            {selected, JGrps}
+                            when is_list(JGrps) ->
+                                JGrps;
+                            _ -> []
+                        end,
+            GroupsDict = lists:foldl(fun({J, G}, Acc) ->
+                                             dict:append(J, G, Acc)
+                                     end,
+                                     dict:new(), JIDGroups),
+            RItems =
+                lists:flatmap(
+                  fun(I) ->
+                          case raw_to_record(LServer, I) of
+                              %% Bad JID in database:
+                              error -> [];
+                              R ->
+                                  SJID = jid:to_string(R#roster.jid),
+                                  Groups = case dict:find(SJID, GroupsDict) of
+                                               {ok, Gs} -> Gs;
+                                               error -> []
+                                           end,
+                                  [R#roster{groups = Groups}]
+                          end
+                  end,
+                  Items),
+            RItems;
+        _ -> []
     end.
 
 set_roster(#roster{us = {LUser, LServer}, jid = LJID} = Item) ->
@@ -460,14 +449,8 @@ get_roster_by_jid_t(LUser, LServer, LJID, mnesia) ->
                   xs = []}
     end;
 get_roster_by_jid_t(LUser, LServer, LJID, odbc) ->
-    Username = ejabberd_odbc:escape(LUser),
-    SJID = ejabberd_odbc:escape(jid:to_string(LJID)),
-    {selected,
-     [<<"username">>, <<"jid">>, <<"nick">>,
-      <<"subscription">>, <<"ask">>, <<"askmessage">>,
-      <<"server">>, <<"subscribe">>, <<"type">>],
-     Res} =
-       odbc_queries:get_roster_by_jid(LServer, Username, SJID),
+    {selected, Res} =
+       odbc_queries:get_roster_by_jid(LServer, LUser, jid:to_string(LJID)),
     case Res of
       [] ->
          #roster{usj = {LUser, LServer, LJID},
@@ -750,30 +733,18 @@ get_roster_by_jid_with_groups_t(LUser, LServer, LJID,
     end;
 get_roster_by_jid_with_groups_t(LUser, LServer, LJID,
                                odbc) ->
-    Username = ejabberd_odbc:escape(LUser),
-    SJID = ejabberd_odbc:escape(jid:to_string(LJID)),
-    case odbc_queries:get_roster_by_jid(LServer, Username,
-                                       SJID)
-       of
-      {selected,
-       [<<"username">>, <<"jid">>, <<"nick">>,
-       <<"subscription">>, <<"ask">>, <<"askmessage">>,
-       <<"server">>, <<"subscribe">>, <<"type">>],
-       [I]} ->
-         R = raw_to_record(LServer, I),
-         Groups = case odbc_queries:get_roster_groups(LServer,
-                                                      Username, SJID)
-                      of
-                    {selected, [<<"grp">>], JGrps} when is_list(JGrps) ->
-                        [JGrp || [JGrp] <- JGrps];
-                    _ -> []
-                  end,
-         R#roster{groups = Groups};
-      {selected,
-       [<<"username">>, <<"jid">>, <<"nick">>,
-       <<"subscription">>, <<"ask">>, <<"askmessage">>,
-       <<"server">>, <<"subscribe">>, <<"type">>],
-       []} ->
+    SJID = jid:to_string(LJID),
+    case odbc_queries:get_roster_by_jid(LServer, LUser, SJID) of
+      {selected, [I]} ->
+            R = raw_to_record(LServer, I),
+            Groups =
+                case odbc_queries:get_roster_groups(LServer, LUser, SJID) of
+                    {selected, JGrps} when is_list(JGrps) ->
+                        [JGrp || {JGrp} <- JGrps];
+                    _ -> []
+                end,
+            R#roster{groups = Groups};
+      {selected, []} ->
          #roster{usj = {LUser, LServer, LJID},
                  us = {LUser, LServer}, jid = LJID}
     end;
@@ -995,8 +966,7 @@ remove_user(LUser, LServer, mnesia) ->
        end,
     mnesia:transaction(F);
 remove_user(LUser, LServer, odbc) ->
-    Username = ejabberd_odbc:escape(LUser),
-    odbc_queries:del_user_roster_t(LServer, Username),
+    odbc_queries:del_user_roster_t(LServer, LUser),
     ok;
 remove_user(LUser, LServer, riak) ->
     {atomic, ejabberd_riak:delete_by_index(roster, <<"us">>, {LUser, LServer})}.
@@ -1243,12 +1213,9 @@ read_subscription_and_groups(LUser, LServer, LJID,
     end;
 read_subscription_and_groups(LUser, LServer, LJID,
                             odbc) ->
-    Username = ejabberd_odbc:escape(LUser),
-    SJID = ejabberd_odbc:escape(jid:to_string(LJID)),
-    case catch odbc_queries:get_subscription(LServer,
-                                            Username, SJID)
-       of
-      {selected, [<<"subscription">>], [[SSubscription]]} ->
+    SJID = jid:to_string(LJID),
+    case catch odbc_queries:get_subscription(LServer, LUser, SJID) of
+      {selected, [{SSubscription}]} ->
          Subscription = case SSubscription of
                           <<"B">> -> both;
                           <<"T">> -> to;
@@ -1256,11 +1223,11 @@ read_subscription_and_groups(LUser, LServer, LJID,
                           _ -> none
                         end,
          Groups = case catch
-                         odbc_queries:get_rostergroup_by_jid(LServer, Username,
+                         odbc_queries:get_rostergroup_by_jid(LServer, LUser,
                                                              SJID)
                       of
-                    {selected, [<<"grp">>], JGrps} when is_list(JGrps) ->
-                        [JGrp || [JGrp] <- JGrps];
+                    {selected, JGrps} when is_list(JGrps) ->
+                        [JGrp || {JGrp} <- JGrps];
                     _ -> []
                   end,
          {Subscription, Groups};
@@ -1297,6 +1264,12 @@ get_jid_info(_, User, Server, JID) ->
 raw_to_record(LServer,
              [User, SJID, Nick, SSubscription, SAsk, SAskMessage,
               _SServer, _SSubscribe, _SType]) ->
+    raw_to_record(LServer,
+                  {User, SJID, Nick, SSubscription, SAsk, SAskMessage,
+                   _SServer, _SSubscribe, _SType});
+raw_to_record(LServer,
+             {User, SJID, Nick, SSubscription, SAsk, SAskMessage,
+              _SServer, _SSubscribe, _SType}) ->
     case jid:from_string(SJID) of
       error -> error;
       JID ->
index ee8fa16906b6f589cfaa1cb3115aeef5789d050c..b6c9a36c0ebd4827e00ca1723660f2561129d6d8 100644 (file)
@@ -139,16 +139,17 @@ del_last(LServer, Username) ->
                            [<<"delete from last where username='">>, Username,
                             <<"'">>]).
 
-get_password(LServer, Username) ->
-    ejabberd_odbc:sql_query(LServer,
-                           [<<"select password from users where username='">>,
-                            Username, <<"';">>]).
+get_password(LServer, LUser) ->
+    ejabberd_odbc:sql_query(
+      LServer,
+      ?SQL("select @(password)s from users where username=%(LUser)s")).
 
-get_password_scram(LServer, Username) ->
+get_password_scram(LServer, LUser) ->
     ejabberd_odbc:sql_query(
       LServer,
-      [<<"select password, serverkey, salt, iterationcount from users where "
-        "username='">>, Username, <<"';">>]).
+      ?SQL("select @(password)s, @(serverkey)s, @(salt)s, @(iterationcount)d"
+           " from users"
+           " where username=%(LUser)s")).
 
 set_password_t(LServer, Username, Pass) ->
     ejabberd_odbc:sql_transaction(LServer,
@@ -311,46 +312,46 @@ del_spool_msg(LServer, LUser) ->
       LServer,
       ?SQL("delete from spool where username=%(LUser)s")).
 
-get_roster(LServer, Username) ->
-    ejabberd_odbc:sql_query(LServer,
-                           [<<"select username, jid, nick, subscription, "
-                              "ask, askmessage, server, subscribe, "
-                              "type from rosterusers where username='">>,
-                            Username, <<"'">>]).
-
-get_roster_jid_groups(LServer, Username) ->
-    ejabberd_odbc:sql_query(LServer,
-                           [<<"select jid, grp from rostergroups where "
-                              "username='">>,
-                            Username, <<"'">>]).
-
-get_roster_groups(_LServer, Username, SJID) ->
-    ejabberd_odbc:sql_query_t([<<"select grp from rostergroups where username='">>,
-                              Username, <<"' and jid='">>, SJID, <<"';">>]).
+get_roster(LServer, LUser) ->
+    ejabberd_odbc:sql_query(
+      LServer,
+      ?SQL("select @(username)s, @(jid)s, @(nick)s, @(subscription)s, "
+           "@(ask)s, @(askmessage)s, @(server)s, @(subscribe)s, "
+           "@(type)s from rosterusers where username=%(LUser)s")).
 
-del_user_roster_t(LServer, Username) ->
-    ejabberd_odbc:sql_transaction(LServer,
-                                 fun () ->
-                                         ejabberd_odbc:sql_query_t([<<"delete from rosterusers       where "
-                                                                      "username='">>,
-                                                                    Username,
-                                                                    <<"';">>]),
-                                         ejabberd_odbc:sql_query_t([<<"delete from rostergroups       where "
-                                                                      "username='">>,
-                                                                    Username,
-                                                                    <<"';">>])
-                                 end).
+get_roster_jid_groups(LServer, LUser) ->
+    ejabberd_odbc:sql_query(
+      LServer,
+      ?SQL("select @(jid)s, @(grp)s from rostergroups where "
+           "username=%(LUser)s")).
 
-get_roster_by_jid(_LServer, Username, SJID) ->
-    ejabberd_odbc:sql_query_t([<<"select username, jid, nick, subscription, "
-                                "ask, askmessage, server, subscribe, "
-                                "type from rosterusers where username='">>,
-                              Username, <<"' and jid='">>, SJID, <<"';">>]).
+get_roster_groups(_LServer, LUser, SJID) ->
+    ejabberd_odbc:sql_query_t(
+      ?SQL("select @(grp)s from rostergroups"
+           " where username=%(LUser)s and jid=%(SJID)s")).
 
-get_rostergroup_by_jid(LServer, Username, SJID) ->
-    ejabberd_odbc:sql_query(LServer,
-                           [<<"select grp from rostergroups where username='">>,
-                            Username, <<"' and jid='">>, SJID, <<"'">>]).
+del_user_roster_t(LServer, LUser) ->
+    ejabberd_odbc:sql_transaction(
+      LServer,
+      fun () ->
+              ejabberd_odbc:sql_query_t(
+                ?SQL("delete from rosterusers where username=%(LUser)s")),
+              ejabberd_odbc:sql_query_t(
+                ?SQL("delete from rostergroups where username=%(LUser)s"))
+      end).
+
+get_roster_by_jid(_LServer, LUser, SJID) ->
+    ejabberd_odbc:sql_query_t(
+      ?SQL("select @(username)s, @(jid)s, @(nick)s, @(subscription)s,"
+           " @(ask)s, @(askmessage)s, @(server)s, @(subscribe)s,"
+           " @(type)s from rosterusers"
+           " where username=%(LUser)s and jid=%(SJID)s")).
+
+get_rostergroup_by_jid(LServer, LUser, SJID) ->
+    ejabberd_odbc:sql_query(
+      LServer,
+      ?SQL("select @(grp)s from rostergroups"
+           " where username=%(LUser)s and jid=%(SJID)s")).
 
 del_roster(_LServer, Username, SJID) ->
     ejabberd_odbc:sql_query_t([<<"delete from rosterusers       where "
@@ -421,11 +422,11 @@ roster_subscribe(_LServer, Username, SJID, ItemVals) ->
             [<<"username='">>, Username, <<"' and jid='">>, SJID,
              <<"'">>]).
 
-get_subscription(LServer, Username, SJID) ->
-    ejabberd_odbc:sql_query(LServer,
-                           [<<"select subscription from rosterusers "
-                              "where username='">>,
-                            Username, <<"' and jid='">>, SJID, <<"'">>]).
+get_subscription(LServer, LUser, SJID) ->
+    ejabberd_odbc:sql_query(
+      LServer,
+      ?SQL("select @(subscription)s from rosterusers "
+           "where username=%(LUser)s and jid=%(SJID)s")).
 
 set_private_data(_LServer, Username, LXMLNS, SData) ->
     update_t(<<"private_storage">>,
@@ -639,10 +640,10 @@ count_records_where(LServer, Table, WhereClause) ->
                             WhereClause, <<";">>]).
 
 get_roster_version(LServer, LUser) ->
-    ejabberd_odbc:sql_query(LServer,
-                           [<<"select version from roster_version where "
-                              "username = '">>,
-                            LUser, <<"'">>]).
+    ejabberd_odbc:sql_query(
+      LServer,
+      ?SQL("select @(version)s from roster_version"
+           " where username = %(LUser)s")).
 
 set_roster_version(LUser, Version) ->
     update_t(<<"roster_version">>,