Implement packets reordering to avoid race conditions (EJAB-724).(thanks to Michael...
authorBadlop <badlop@process-one.net>
Tue, 16 Jun 2009 17:47:03 +0000 (17:47 +0000)
committerBadlop <badlop@process-one.net>
Tue, 16 Jun 2009 17:47:03 +0000 (17:47 +0000)
SVN Revision: 2243

src/web/ejabberd_http_bind.erl

index 8ab7627a79093795944939bf2bf553fed1aaf314..d3849f723182af55011d942635d2468fc8fdb0f7 100644 (file)
@@ -4,7 +4,7 @@
 %%% Purpose : Implements XMPP over BOSH (XEP-0205) (formerly known as 
 %%%           HTTP Binding)
 %%% Created : 21 Sep 2005 by Stefan Strigler <steve@zeank.in-berlin.de>
-%%% Id      : $Id: ejabberd_http_bind.erl 694 2008-07-15 13:27:03Z alexey $
+%%% Id      : $Id: ejabberd_http_bind.erl 720 2008-09-17 15:52:58Z mremond $
 %%%----------------------------------------------------------------------
 
 -module(ejabberd_http_bind).
@@ -58,7 +58,8 @@
                ctime = 0,
                timer,
                pause=0,
-               req_list = [], % list of requests
+               unprocessed_req_list = [], % list of request that have been delayed for proper reordering
+               req_list = [], % list of requests (cache)
                ip = ?NULL_PEER 
               }).
 
@@ -333,187 +334,57 @@ handle_sync_event(stop, _From, _StateName, StateData) ->
     Reply = ok,
     {stop, normal, Reply, StateData};
 
-handle_sync_event({http_put, Rid, Attrs, Payload, Hold, StreamTo, IP},
+%% HTTP PUT: Receive packets from the client
+handle_sync_event({http_put, Rid, Attrs, _Payload, Hold, _StreamTo, _IP}=Request,
                  _From, StateName, StateData) ->
-    Key = xml:get_attr_s("key", Attrs),
-    NewKey = xml:get_attr_s("newkey", Attrs),
-    %% check if Rid valid
-    RidAllow =  case StateData#state.rid of
-                    none -> 
-                        %% first request - nothing saved so far
-                        {true, 0};
-                    OldRid ->
-                        ?DEBUG("state.rid/cur rid: ~p/~p", 
-                               [OldRid, Rid]),
-                        if 
-                            (OldRid < Rid) and 
-                            (Rid =< (OldRid + Hold + 1)) ->
-                                case catch list_to_integer(
-                                       xml:get_attr_s("pause", Attrs)) of
-                                    {'EXIT', _} ->
-                                        {true, 0};
-                                    Pause1 when Pause1 =< ?MAX_PAUSE ->
-                                        ?DEBUG("got pause: ~p", [Pause1]),
-                                        {true, Pause1};
-                                    _ ->
-                                        {true, 0}
-                                end;
-                            (Rid =< OldRid) and 
-                            (Rid > OldRid - Hold - 1) ->
-                                repeat;
-                            true ->
-                                false
-                        end
-                end,
-    %% check if key valid
-    KeyAllow = case RidAllow of
-                  repeat -> 
-                      true;
-                  false ->
-                      false;
-                  {true, _} ->
-                      case StateData#state.key of
-                          "" ->
-                              true;
-                          OldKey ->
-                              NextKey = jlib:tolower(
-                                          hex(binary_to_list(
-                                                crypto:sha(Key)))),
-                              ?DEBUG("Key/OldKey/NextKey: ~s/~s/~s", 
-                                     [Key, OldKey, NextKey]),
-                              if
-                                  OldKey == NextKey ->
-                                      true;
-                                  true ->
-                                      ?DEBUG("wrong key: ~s",[Key]),
-                                      false
-                              end
-                      end
-              end,
-    {TMegSec, TSec, TMSec} = now(),
-    TNow = (TMegSec * 1000000 + TSec) * 1000000 + TMSec,
-    LastPoll = if 
-                  Payload == "" ->
-                      TNow;
-                  true ->
-                      0
-              end,
-    if
-       (Payload == "") and 
-        (Hold == 0) and
-       (TNow - StateData#state.last_poll < ?MIN_POLLING) ->
-           Reply = {error, polling_too_frequently},
-           {reply, Reply, StateName, StateData};
-       KeyAllow ->
-           case RidAllow of
-               false ->
-                   Reply = {error, not_exists},
-                   {reply, Reply, StateName, StateData};
-               repeat ->
-                   ?DEBUG("REPEATING ~p", [Rid]),
-                   [Out | _XS] = [El#hbr.out || 
-                                     El <- StateData#state.req_list, 
-                                     El#hbr.rid == Rid],
-                   case Out of 
-                       [[] | OutPacket] ->
-                           Reply = {repeat, OutPacket};
+    %% Check if Rid valid
+    RidAllow = 
+       case StateData#state.rid of
+           none -> 
+               %% First request - nothing saved so far
+               {true, 0};
+           OldRid ->
+               ?DEBUG("state.rid/cur rid: ~p/~p", [OldRid, Rid]),
+               if
+                   %% We did not miss any packet, we can process it immediately:
+                   Rid == OldRid + 1 ->
+                   case catch list_to_integer(
+                                xml:get_attr_s("pause", Attrs)) of
+                       {'EXIT', _} ->
+                           {true, 0};
+                       Pause1 when Pause1 =< ?MAX_PAUSE ->
+                           ?DEBUG("got pause: ~p", [Pause1]),
+                           {true, Pause1};
                        _ ->
-                           Reply = {repeat, Out}
-                   end,
-                   {reply, Reply, StateName, 
-                    StateData#state{input = "cancel", last_poll = LastPoll}};
-               {true, Pause} ->
-                   SaveKey = if 
-                                 NewKey == "" ->
-                                     Key;
-                                 true ->
-                                     NewKey
-                             end,
-                   ?DEBUG(" -- SaveKey: ~s~n", [SaveKey]),
-
-                   %% save request
-                   ReqList = [#hbr{rid=Rid,
-                                   key=StateData#state.key,
-                                   in=StateData#state.input,
-                                   out=StateData#state.output
-                                  } | 
-                              [El || El <- StateData#state.req_list, 
-                                     El#hbr.rid < Rid, 
-                                     El#hbr.rid > (Rid - 1 - Hold)]
-                             ],
-%%                 ?DEBUG("reqlist: ~p", [ReqList]),
-                    
-                    %% setup next timer
-                   if
-                       StateData#state.timer /= undefined ->
-                           cancel_timer(StateData#state.timer);
-                       true ->
-                           ok
-                   end,
-                   Timer = if
-                               Pause > 0 ->
-                                   erlang:start_timer(
-                                     Pause*1000, self(), []);
-                               true ->
-                                   erlang:start_timer(
-                                     ?MAX_INACTIVITY, self(), [])
-                           end,
-                   case StateData#state.waiting_input of
-                       false ->
-                           Input = Payload ++ [StateData#state.input],
-                           Reply = ok,
-                           {reply, Reply, StateName, 
-                            StateData#state{input = Input,
-                                            rid = Rid,
-                                            key = SaveKey,
-                                            ctime = TNow,
-                                            timer = Timer,
-                                             pause = Pause,
-                                            last_poll = LastPoll,
-                                            req_list = ReqList,
-                                            ip = IP
-                                           }};
-                       {Receiver, _Tag} ->
-                            SendPacket = 
-                                case StreamTo of
-                                    {To, ""} ->
-                                        ["<stream:stream to='", To, "' "
-                                         "xmlns='"++?NS_CLIENT++"' "
-                                         "xmlns:stream='"++?NS_STREAM++"'>"] 
-                                            ++ Payload;
-                                    {To, Version} ->
-                                        ["<stream:stream to='", To, "' "
-                                         "xmlns='"++?NS_CLIENT++"' "
-                                         "version='", Version, "' "
-                                         "xmlns:stream='"++?NS_STREAM++"'>"] 
-                                            ++ Payload;
-                                    _ ->
-                                        Payload
-                                end,
-                            ?DEBUG("really sending now: ~s", [SendPacket]),
-                           Receiver ! {tcp, StateData#state.socket,
-                                       list_to_binary(SendPacket)},
-                           Reply = ok,
-                           {reply, Reply, StateName,
-                            StateData#state{waiting_input = false,
-                                            last_receiver = Receiver,
-                                            input = "",
-                                            rid = Rid,
-                                            key = SaveKey,
-                                            ctime = TNow,
-                                            timer = Timer,
-                                             pause = Pause,
-                                            last_poll = LastPoll,
-                                            req_list = ReqList,
-                                            ip = IP
-                                           }}
-                   end
-           end;
-       true ->
-           Reply = {error, bad_key},
-           {reply, Reply, StateName, StateData}
+                           {true, 0}
+                   end;
+                   %% We have missed packets, we need to cached it to process it later on:
+                   (OldRid < Rid) and 
+                   (Rid =< (OldRid + Hold + 1)) ->
+                       buffer;
+                   (Rid =< OldRid) and 
+                   (Rid > OldRid - Hold - 1) ->
+                       repeat;
+                   true ->
+                       false
+               end
+       end,
+    
+    %% Check if Rid is in sequence or out of sequence:
+    case RidAllow of
+       buffer ->
+           ?DEBUG("Buffered request: ~p", [Request]),
+           %% Request is out of sequence:
+           PendingRequests = StateData#state.unprocessed_req_list,
+           %% In case an existing RID was already buffered:
+           Requests = lists:keydelete(Rid, 2, PendingRequests),
+           {reply, ok, StateName, StateData#state{unprocessed_req_list=[Request|Requests]}};
+       _ ->
+           %% Request is in sequence:
+           process_http_put(Request, StateName, StateData, RidAllow)
     end;
 
+%% HTTP GET: send packets to the client
 handle_sync_event({http_get, Rid, Wait, Hold}, From, StateName, StateData) ->
     %% setup timer
     if
@@ -695,6 +566,173 @@ terminate(_Reason, _StateName, StateData) ->
 %%% Internal functions
 %%%----------------------------------------------------------------------
 
+%% PUT / Get processing:
+process_http_put({http_put, Rid, Attrs, Payload, Hold, StreamTo, IP},
+                StateName, StateData, RidAllow) ->
+    %% Check if key valid
+    Key = xml:get_attr_s("key", Attrs),
+    NewKey = xml:get_attr_s("newkey", Attrs),
+    KeyAllow =
+       case RidAllow of
+           repeat -> 
+               true;
+           false ->
+               false;
+           {true, _} ->
+               case StateData#state.key of
+                   "" ->
+                       true;
+                   OldKey ->
+                       NextKey = jlib:tolower(
+                                   hex(binary_to_list(
+                                         crypto:sha(Key)))),
+                       ?DEBUG("Key/OldKey/NextKey: ~s/~s/~s", [Key, OldKey, NextKey]),
+                       if
+                           OldKey == NextKey ->
+                               true;
+                           true ->
+                               ?DEBUG("wrong key: ~s",[Key]),
+                               false
+                       end
+               end
+       end,
+    {TMegSec, TSec, TMSec} = now(),
+    TNow = (TMegSec * 1000000 + TSec) * 1000000 + TMSec,
+    LastPoll = if 
+                  Payload == "" ->
+                      TNow;
+                  true ->
+                      0
+              end,
+    if
+       (Payload == "") and 
+        (Hold == 0) and
+       (TNow - StateData#state.last_poll < ?MIN_POLLING) ->
+           Reply = {error, polling_too_frequently},
+           {reply, Reply, StateName, StateData};
+       KeyAllow ->
+           case RidAllow of
+               false ->
+                   Reply = {error, not_exists},
+                   {reply, Reply, StateName, StateData};
+               repeat ->
+                   ?DEBUG("REPEATING ~p", [Rid]),
+                   [Out | _XS] = [El#hbr.out || 
+                                     El <- StateData#state.req_list, 
+                                     El#hbr.rid == Rid],
+                   case Out of 
+                       [[] | OutPacket] ->
+                           Reply = {repeat, OutPacket};
+                       _ ->
+                           Reply = {repeat, Out}
+                   end,
+                   {reply, Reply, StateName, 
+                    StateData#state{input = "cancel", last_poll = LastPoll}};
+               {true, Pause} ->
+                   SaveKey = if 
+                                 NewKey == "" ->
+                                     Key;
+                                 true ->
+                                     NewKey
+                             end,
+                   ?DEBUG(" -- SaveKey: ~s~n", [SaveKey]),
+
+                   %% save request
+                   ReqList = [#hbr{rid=Rid,
+                                   key=StateData#state.key,
+                                   in=StateData#state.input,
+                                   out=StateData#state.output
+                                  } | 
+                              [El || El <- StateData#state.req_list, 
+                                     El#hbr.rid < Rid, 
+                                     El#hbr.rid > (Rid - 1 - Hold)]
+                             ],
+%%                 ?DEBUG("reqlist: ~p", [ReqList]),
+                    
+                    %% setup next timer
+                   if
+                       StateData#state.timer /= undefined ->
+                           cancel_timer(StateData#state.timer);
+                       true ->
+                           ok
+                   end,
+                   Timer = if
+                               Pause > 0 ->
+                                   erlang:start_timer(
+                                     Pause*1000, self(), []);
+                               true ->
+                                   erlang:start_timer(
+                                     ?MAX_INACTIVITY, self(), [])
+                           end,
+                   case StateData#state.waiting_input of
+                       false ->
+                           Input = Payload ++ [StateData#state.input],
+                           Reply = ok,
+                           process_buffered_request(Reply, StateName, 
+                                                    StateData#state{input = Input,
+                                                                    rid = Rid,
+                                                                    key = SaveKey,
+                                                                    ctime = TNow,
+                                                                    timer = Timer,
+                                                                    pause = Pause,
+                                                                    last_poll = LastPoll,
+                                                                    req_list = ReqList,
+                                                                    ip = IP
+                                                                   });
+                       {Receiver, _Tag} ->
+                            SendPacket = 
+                                case StreamTo of
+                                    {To, ""} ->
+                                        ["<stream:stream to='", To, "' "
+                                         "xmlns='"++?NS_CLIENT++"' "
+                                         "xmlns:stream='"++?NS_STREAM++"'>"] 
+                                            ++ Payload;
+                                    {To, Version} ->
+                                        ["<stream:stream to='", To, "' "
+                                         "xmlns='"++?NS_CLIENT++"' "
+                                         "version='", Version, "' "
+                                         "xmlns:stream='"++?NS_STREAM++"'>"] 
+                                            ++ Payload;
+                                    _ ->
+                                        Payload
+                                end,
+                            ?DEBUG("really sending now: ~s", [SendPacket]),
+                           Receiver ! {tcp, StateData#state.socket,
+                                       list_to_binary(SendPacket)},
+                           Reply = ok,
+                           process_buffered_request(Reply, StateName, 
+                                                    StateData#state{waiting_input = false,
+                                                                    last_receiver = Receiver,
+                                                                    input = "",
+                                                                    rid = Rid,
+                                                                    key = SaveKey,
+                                                                    ctime = TNow,
+                                                                    timer = Timer,
+                                                                    pause = Pause,
+                                                                    last_poll = LastPoll,
+                                                                    req_list = ReqList,
+                                                                    ip = IP
+                                                                   })
+                   end
+           end;
+       true ->
+           Reply = {error, bad_key},
+           {reply, Reply, StateName, StateData}
+    end.
+
+process_buffered_request(Reply, StateName, StateData) ->
+    Rid = StateData#state.rid,
+    Requests = StateData#state.unprocessed_req_list,
+    case lists:keysearch(Rid+1, 2, Requests) of
+       {value, Request} ->
+           ?DEBUG("Processing buffered request: ~p", [Request]),
+           NewRequests = Requests -- [Request],
+           handle_sync_event(Request, undefined, StateName,
+                             StateData#state{unprocessed_req_list=NewRequests});
+       _ ->
+           {reply, Reply, StateName, StateData}
+    end.
+
 handle_http_put(Sid, Rid, Attrs, Payload, StreamStart, IP) ->
     case http_put(Sid, Rid, Attrs, Payload, StreamStart, IP) of
         {error, not_exists} ->