]> granicus.if.org Git - ejabberd/commitdiff
Process 'Contact' headers more accurately (as per RFC3261)
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Fri, 30 May 2014 19:11:46 +0000 (23:11 +0400)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Fri, 30 May 2014 19:14:52 +0000 (23:14 +0400)
src/mod_sip_proxy.erl
src/mod_sip_registrar.erl

index b05c49061eee7f630bdbeaf0a4fa8ca1907fd1a7..185d72afe8a9b4f0913e02abc7639a6f3778c5d8 100644 (file)
@@ -66,15 +66,15 @@ wait_for_request({#sip{type = request} = Req, TrID}, State) ->
     Opts = State#state.opts,
     Req1 = prepare_request(State#state.host, Req),
     case connect(Req1, Opts) of
-       {ok, SIPSockets} ->
+       {ok, SIPSocketsWithURIs} ->
            NewState =
                lists:foldl(
-                 fun(_SIPSocket, {error, _} = Err) ->
+                 fun(_SIPSocketWithURI, {error, _} = Err) ->
                          Err;
-                    (SIPSocket, #state{tr_ids = TrIDs} = AccState) ->
+                    ({SIPSocket, URI}, #state{tr_ids = TrIDs} = AccState) ->
                          Req2 = add_record_route(SIPSocket, State#state.host, Req1),
                          Req3 = add_via(SIPSocket, State#state.host, Req2),
-                         case esip:request(SIPSocket, Req3,
+                         case esip:request(SIPSocket, Req3#sip{uri = URI},
                                            {?MODULE, route, [self()]}) of
                              {ok, ClientTrID} ->
                                  NewTrIDs = [ClientTrID|TrIDs],
@@ -83,7 +83,7 @@ wait_for_request({#sip{type = request} = Req, TrID}, State) ->
                                  cancel_pending_transactions(AccState),
                                  Err
                          end
-                 end, State, SIPSockets),
+                 end, State, SIPSocketsWithURIs),
            case NewState of
                {error, _} = Err ->
                    {Status, Reason} = esip:error_status(Err),
@@ -214,7 +214,7 @@ connect(#sip{hdrs = Hdrs} = Req, Opts) ->
        false ->
            case esip:connect(Req, Opts) of
                {ok, SIPSock} ->
-                   {ok, [SIPSock]};
+                   {ok, [{SIPSock, Req#sip.uri}]};
                {error, _} = Err ->
                    Err
            end
index 689efe48eaa254697f9f3dd3d6c4d943b152d8b7..5080cf4a7e749fb8231b1d1b3aa982e58af2e0d5 100644 (file)
 -include("esip.hrl").
 
 -define(CALL_TIMEOUT, timer:seconds(30)).
+-define(DEFAULT_EXPIRES, 3600).
 
 -record(binding, {socket = #sip_socket{},
                  call_id = <<"">> :: binary(),
                  cseq = 0 :: non_neg_integer(),
                  timestamp = now() :: erlang:timestamp(),
+                 contact :: {binary(), #uri{}, [{binary(), binary()}]},
                  tref = make_ref() :: reference(),
                  expires = 0 :: non_neg_integer()}).
 
@@ -50,20 +52,19 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) ->
     US = {LUser, LServer},
     CallID = esip:get_hdr('call-id', Hdrs),
     CSeq = esip:get_hdr('cseq', Hdrs),
-    Expires = esip:get_hdr('expires', Hdrs, 0),
+    Expires = esip:get_hdr('expires', Hdrs, ?DEFAULT_EXPIRES),
     case esip:get_hdrs('contact', Hdrs) of
         [<<"*">>] when Expires == 0 ->
-            case unregister_session(US, SIPSock, CallID, CSeq) of
-               ok ->
+            case unregister_session(US, CallID, CSeq) of
+               {ok, ContactsWithExpires} ->
                    ?INFO_MSG("unregister SIP session for user ~s@~s from ~s",
                              [LUser, LServer, inet_parse:ntoa(PeerIP)]),
-                   Contact = {<<"">>, #uri{user = LUser, host = LServer},
-                              [{<<"expires">>, <<"0">>}]},
+                   Cs = prepare_contacts_to_send(ContactsWithExpires),
                    mod_sip:make_response(
                      Req,
                      #sip{type = response,
                           status = 200,
-                          hdrs = [{'contact', [Contact]}]});
+                          hdrs = [{'contact', Cs}]});
                {error, Why} ->
                    {Status, Reason} = make_status(Why),
                    mod_sip:make_response(
@@ -72,51 +73,35 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) ->
                                reason = Reason})
            end;
         [{_, _URI, _Params}|_] = Contacts ->
-            ExpiresList = lists:map(
-                           fun({_, _, Params}) ->
-                                   case to_integer(
-                                          esip:get_param(
-                                            <<"expires">>, Params),
-                                          0, (1 bsl 32)-1) of
-                                       {ok, E} -> E;
-                                       _ -> Expires
-                                   end
-                           end, Contacts),
-           Expires1 = lists:max(ExpiresList),
-           Contact = {<<"">>, #uri{user = LUser, host = LServer},
-                      [{<<"expires">>, jlib:integer_to_binary(Expires1)}]},
+           ContactsWithExpires = make_contacts_with_expires(Contacts, Expires),
+           Expires1 = lists:max([E || {_, E} <- ContactsWithExpires]),
            MinExpires = min_expires(),
-            if Expires1 >= MinExpires ->
-                   case register_session(US, SIPSock, CallID, CSeq, Expires1) of
-                       ok ->
-                           ?INFO_MSG("register SIP session for user ~s@~s from ~s",
-                                     [LUser, LServer, inet_parse:ntoa(PeerIP)]),
-                           mod_sip:make_response(
-                             Req,
-                             #sip{type = response,
-                                  status = 200,
-                                  hdrs = [{'contact', [Contact]}]});
-                       {error, Why} ->
-                           {Status, Reason} = make_status(Why),
-                           mod_sip:make_response(
-                             Req, #sip{type = response,
-                                       status = Status,
-                                       reason = Reason})
-                   end;
-               Expires1 > 0, Expires1 < MinExpires ->
-                    mod_sip:make_response(
+           if Expires1 > 0, Expires1 < MinExpires ->
+                   mod_sip:make_response(
                      Req, #sip{type = response,
                                status = 423,
                                hdrs = [{'min-expires', MinExpires}]});
-               true ->
-                    case unregister_session(US, SIPSock, CallID, CSeq) of
-                       ok ->
-                           ?INFO_MSG("unregister SIP session for user ~s@~s from ~s",
-                                     [LUser, LServer, inet_parse:ntoa(PeerIP)]),
+              true ->
+                   case register_session(US, SIPSock, CallID, CSeq,
+                                         ContactsWithExpires) of
+                       {ok, Res} ->
+                           if Res == updated ->
+                                   ?INFO_MSG("register SIP session for user "
+                                             "~s@~s from ~s",
+                                             [LUser, LServer,
+                                              inet_parse:ntoa(PeerIP)]);
+                              Res == deleted ->
+                                   ?INFO_MSG("unregister SIP session for user "
+                                             "~s@~s from ~s",
+                                             [LUser, LServer,
+                                              inet_parse:ntoa(PeerIP)])
+                           end,
+                           Cs = prepare_contacts_to_send(ContactsWithExpires),
                            mod_sip:make_response(
                              Req,
-                             #sip{type = response, status = 200,
-                                  hdrs = [{'contact', [Contact]}]});
+                             #sip{type = response,
+                                  status = 200,
+                                  hdrs = [{'contact', Cs}]});
                        {error, Why} ->
                            {Status, Reason} = make_status(Why),
                            mod_sip:make_response(
@@ -128,22 +113,15 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) ->
        [] ->
            case mnesia:dirty_read(sip_session, US) of
                [#sip_session{bindings = Bindings}] ->
-                   case pop_previous_binding(SIPSock, Bindings) of
-                       {ok, #binding{expires = Expires1}, _} ->
-                           Contact = {<<"">>,
-                                      #uri{user = LUser, host = LServer},
-                                      [{<<"expires">>,
-                                        jlib:integer_to_binary(Expires1)}]},
-                           mod_sip:make_response(
-                             Req, #sip{type = response, status = 200,
-                                       hdrs = [{'contact', [Contact]}]});
-                       {error, notfound} ->
-                           {Status, Reason} = make_status(notfound),
-                           mod_sip:make_response(
-                             Req, #sip{type = response,
-                                       status = Status,
-                                       reason = Reason})
-                   end;
+                   ContactsWithExpires =
+                       lists:map(
+                         fun(#binding{contact = Contact, expires = Es}) ->
+                                 {Contact, Es}
+                         end, Bindings),
+                   Cs = prepare_contacts_to_send(ContactsWithExpires),
+                   mod_sip:make_response(
+                     Req, #sip{type = response, status = 200,
+                               hdrs = [{'contact', Cs}]});
                [] ->
                    {Status, Reason} = make_status(notfound),
                    mod_sip:make_response(
@@ -158,7 +136,11 @@ request(#sip{hdrs = Hdrs} = Req, SIPSock) ->
 find_sockets(U, S) ->
     case mnesia:dirty_read(sip_session, {U, S}) of
        [#sip_session{bindings = Bindings}] ->
-           [Binding#binding.socket || Binding <- Bindings];
+           lists:map(
+             fun(#binding{contact = {_, URI, _},
+                          socket = Socket}) ->
+                     {Socket, URI}
+             end, Bindings);
        [] ->
            []
     end.
@@ -176,8 +158,8 @@ init([]) ->
 handle_call({write, Session}, _From, State) ->
     Res = write_session(Session),
     {reply, Res, State};
-handle_call({delete, US, SIPSocket, CallID, CSeq}, _From, State) ->
-    Res = delete_session(US, SIPSocket, CallID, CSeq),
+handle_call({delete, US, CallID, CSeq}, _From, State) ->
+    Res = delete_session(US, CallID, CSeq),
     {reply, Res, State};
 handle_call(_Request, _From, State) ->
     Reply = ok,
@@ -189,8 +171,8 @@ handle_cast(_Msg, State) ->
 handle_info({write, Session}, State) ->
     write_session(Session),
     {noreply, State};
-handle_info({delete, US, SIPSocket, CallID, CSeq}, State) ->
-    delete_session(US, SIPSocket, CallID, CSeq),
+handle_info({delete, US, CallID, CSeq}, State) ->
+    delete_session(US, CallID, CSeq),
     {noreply, State};
 handle_info({timeout, TRef, US}, State) ->
     delete_expired_session(US, TRef),
@@ -208,70 +190,102 @@ code_change(_OldVsn, State, _Extra) ->
 %%%===================================================================
 %%% Internal functions
 %%%===================================================================
-register_session(US, SIPSocket, CallID, CSeq, Expires) ->
-    Session = #sip_session{us = US,
-                          bindings = [#binding{socket = SIPSocket,
-                                               call_id = CallID,
-                                               cseq = CSeq,
-                                               timestamp = now(),
-                                               expires = Expires}]},
+register_session(US, SIPSocket, CallID, CSeq, ContactsWithExpires) ->
+    Bindings = lists:map(
+                fun({Contact, Expires}) ->
+                        #binding{socket = SIPSocket,
+                                 call_id = CallID,
+                                 cseq = CSeq,
+                                 timestamp = now(),
+                                 contact = Contact,
+                                 expires = Expires}
+                end, ContactsWithExpires),
+    Session = #sip_session{us = US, bindings = Bindings},
     call({write, Session}).
 
-unregister_session(US, SIPSocket, CallID, CSeq) ->
-    Msg = {delete, US, SIPSocket, CallID, CSeq},
+unregister_session(US, CallID, CSeq) ->
+    Msg = {delete, US, CallID, CSeq},
     call(Msg).
 
-write_session(#sip_session{us = {U, S} = US,
-                          bindings = [#binding{socket = SIPSocket,
-                                               call_id = CallID,
-                                               expires = Expires,
-                                               cseq = CSeq} = Binding]}) ->
-    case mnesia:dirty_read(sip_session, US) of
-       [#sip_session{bindings = Bindings}] ->
-           case pop_previous_binding(SIPSocket, Bindings) of
-               {ok, #binding{call_id = CallID, cseq = PrevCSeq}, _}
-                 when PrevCSeq > CSeq ->
-                   {error, cseq_out_of_order};
-               {ok, #binding{tref = Tref}, Bindings1} ->
-                   erlang:cancel_timer(Tref),
-                   NewTRef = erlang:start_timer(Expires * 1000, self(), US),
-                   NewBindings = [Binding#binding{tref = NewTRef}|Bindings1],
-                   mnesia:dirty_write(
-                     #sip_session{us = US, bindings = NewBindings});
-               {error, notfound} ->
-                   MaxSessions = ejabberd_sm:get_max_user_sessions(U, S),
-                   if length(Bindings) < MaxSessions ->
-                           NewTRef = erlang:start_timer(Expires * 1000, self(), US),
-                           NewBindings = [Binding#binding{tref = NewTRef}|Bindings],
-                           mnesia:dirty_write(
-                             #sip_session{us = US, bindings = NewBindings});
-                      true ->
-                           {error, too_many_sessions}
+write_session(#sip_session{us = {U, S} = US, bindings = NewBindings}) ->
+    PrevBindings = case mnesia:dirty_read(sip_session, US) of
+                      [#sip_session{bindings = PrevBindings1}] ->
+                          PrevBindings1;
+                      [] ->
+                          []
+                  end,
+    Res = lists:foldl(
+           fun(_, {error, _} = Err) ->
+                   Err;
+              (#binding{call_id = CallID,
+                        expires = Expires,
+                        cseq = CSeq} = Binding, {Add, Del}) ->
+                   case find_binding(Binding, PrevBindings) of
+                       {ok, #binding{call_id = CallID, cseq = PrevCSeq}}
+                         when PrevCSeq > CSeq ->
+                           {error, cseq_out_of_order};
+                       {ok, PrevBinding} when Expires == 0 ->
+                           {Add, [PrevBinding|Del]};
+                       {ok, _} ->
+                           {[Binding|Add], Del};
+                       {error, notfound} when Expires == 0 ->
+                           {error, notfound};
+                       {error, notfound} ->
+                           {[Binding|Add], Del}
                    end
-           end;
-       [] ->
-           NewTRef = erlang:start_timer(Expires * 1000, self(), US),
-           NewBindings = [Binding#binding{tref = NewTRef}],
-           mnesia:dirty_write(#sip_session{us = US, bindings = NewBindings})
+           end, {[], []}, NewBindings),
+    MaxSessions = ejabberd_sm:get_max_user_sessions(U, S),
+    case Res of
+       {error, Why} ->
+           {error, Why};
+       {AddBindings, _} when length(AddBindings) > MaxSessions ->
+           {error, too_many_sessions};
+       {AddBindings, DelBindings} ->
+           lists:foreach(
+             fun(#binding{tref = TRef}) ->
+                     erlang:cancel_timer(TRef)
+             end, DelBindings),
+           Bindings = lists:map(
+                        fun(#binding{tref = TRef,
+                                     expires = Expires} = Binding) ->
+                                erlang:cancel_timer(TRef),
+                                NewTRef = erlang:start_timer(
+                                            Expires * 1000, self(), US),
+                                Binding#binding{tref = NewTRef}
+                        end, AddBindings),
+           case Bindings of
+               [] ->
+                   mnesia:dirty_delete(sip_session, US),
+                   {ok, deleted};
+               _ ->
+                   mnesia:dirty_write(
+                     #sip_session{us = US, bindings = Bindings}),
+                   {ok, updated}
+           end
     end.
 
-delete_session(US, SIPSocket, CallID, CSeq) ->
+delete_session(US, CallID, CSeq) ->
     case mnesia:dirty_read(sip_session, US) of
        [#sip_session{bindings = Bindings}] ->
-           case pop_previous_binding(SIPSocket, Bindings) of
-               {ok, #binding{call_id = CallID, cseq = PrevCSeq}, _}
-                 when PrevCSeq > CSeq ->
-                   {error, cseq_out_of_order};
-               {ok, #binding{tref = TRef}, []} ->
-                   erlang:cancel_timer(TRef),
-                   mnesia:dirty_delete(sip_session, US);
-               {ok, #binding{tref = TRef}, NewBindings} ->
-                   erlang:cancel_timer(TRef),
-                   mnesia:dirty_write(sip_session,
-                                      #sip_session{us = US,
-                                                   bindings = NewBindings});
-               {error, notfound} ->
-                   {error, notfound}
+           case lists:all(
+                  fun(B) when B#binding.call_id == CallID,
+                              B#binding.cseq > CSeq ->
+                          false;
+                     (_) ->
+                          true
+                  end, Bindings) of
+               true ->
+                   ContactsWithExpires =
+                       lists:map(
+                         fun(#binding{contact = Contact,
+                                      tref = TRef}) ->
+                                 erlang:cancel_timer(TRef),
+                                 {Contact, 0}
+                         end, Bindings),
+                   mnesia:dirty_delete(sip_session, US),
+                   {ok, ContactsWithExpires};
+               false ->
+                   {error, cseq_out_of_order}
            end;
        [] ->
            {error, notfound}
@@ -308,17 +322,6 @@ to_integer(Bin, Min, Max) ->
             error
     end.
 
-pop_previous_binding(#sip_socket{peer = Peer}, Bindings) ->
-    case lists:partition(
-          fun(#binding{socket = #sip_socket{peer = Peer1}}) ->
-                  Peer1 == Peer
-          end, Bindings) of
-       {[Binding], RestBindings} ->
-           {ok, Binding, RestBindings};
-       _ ->
-           {error, notfound}
-    end.
-
 call(Msg) ->
     case catch ?GEN_SERVER:call(?MODULE, Msg, ?CALL_TIMEOUT) of
        {'EXIT', {timeout, _}} ->
@@ -329,6 +332,47 @@ call(Msg) ->
            Reply
     end.
 
+make_contacts_with_expires(Contacts, Expires) ->
+    lists:map(
+      fun({Name, URI, Params}) ->
+             E1 = case to_integer(esip:get_param(<<"expires">>, Params),
+                                  0, (1 bsl 32)-1) of
+                      {ok, E} -> E;
+                      _ -> Expires
+                  end,
+             Params1 = lists:keydelete(<<"expires">>, 1, Params),
+             {{Name, URI, Params1}, E1}
+      end, Contacts).
+
+prepare_contacts_to_send(ContactsWithExpires) ->
+    lists:map(
+      fun({{Name, URI, Params}, Expires}) ->
+             Params1 = esip:set_param(<<"expires">>,
+                                      list_to_binary(
+                                        integer_to_list(Expires)),
+                                      Params),
+             {Name, URI, Params1}
+      end, ContactsWithExpires).
+
+find_binding(#binding{contact = {_, URI1, _}} = OrigBinding,
+            [#binding{contact = {_, URI2, _}} = Binding|Bindings]) ->
+    case cmp_uri(URI1, URI2) of
+       true ->
+           {ok, Binding};
+       false ->
+           find_binding(OrigBinding, Bindings)
+    end;
+find_binding(_, []) ->
+    {error, notfound}.
+
+%% TODO: this is *totally* wrong.
+%% Rewrite this using URI comparison rules
+cmp_uri(#uri{user = U, host = H, port = P},
+       #uri{user = U, host = H, port = P}) ->
+    true;
+cmp_uri(_, _) ->
+    false.
+
 make_status(notfound) ->
     {404, esip:reason(404)};
 make_status(cseq_out_of_order) ->