]> granicus.if.org Git - ejabberd/commitdiff
Return "Bad request" error when origin in websocket connection doesn't match
authorPaweł Chmielowski <pchmielowski@process-one.net>
Fri, 26 Apr 2019 13:29:43 +0000 (15:29 +0200)
committerPaweł Chmielowski <pchmielowski@process-one.net>
Fri, 26 Apr 2019 13:29:43 +0000 (15:29 +0200)
This also allow websocket_origin option to accept multiple values instead
of just single one.

src/ejabberd_websocket.erl

index 2b5a01460a6f52f6950b6866d1962f26c8816268..e954b42c2f7a7ff3e08417a6d81296d5493e8bbf 100644 (file)
@@ -42,7 +42,7 @@
 
 -author('ecestari@process-one.net').
 
--export([check/2, socket_handoff/5, opt_type/1]).
+-export([socket_handoff/5, opt_type/1]).
 
 -include("logger.hrl").
 
                           ?AC_ALLOW_HEADERS, ?AC_MAX_AGE]).
 -define(HEADER, [?CT_XML, ?AC_ALLOW_ORIGIN, ?AC_ALLOW_HEADERS]).
 
-check(_Path, Headers) ->
-    HeadersValidators = [{'Upgrade', <<"websocket">>, true},
-                       {'Connection', ignore, true}, {'Host', ignore, true},
-                       {<<"Sec-Websocket-Key">>, ignore, true},
-                       {<<"Sec-Websocket-Version">>, <<"13">>, true},
-                       {<<"Origin">>, get_origin(), false}],
-
-    F = fun ({Tag, Val, Required}) ->
-               case lists:keyfind(Tag, 1, Headers) of
-                 false -> Required; % header not found, keep in list if required
-                 {_, HVal} ->
-                     case Val of
-                       ignore -> false; % ignore value -> ok, remove from list
-                       _ ->
-                           % expected value -> ok, remove from list (false)
-                           % value is different, keep in list (true)
-                           str:to_lower(HVal) /= Val
-                      end
-                end
-        end,
-    case lists:filter(F, HeadersValidators) of
-      [] -> true;
-      _InvalidHeaders -> false
+is_valid_websocket_upgrade(_Path, Headers) ->
+    HeadersToValidate = [{'Upgrade', <<"websocket">>},
+                         {'Connection', ignore},
+                         {'Host', ignore},
+                         {<<"Sec-Websocket-Key">>, ignore},
+                         {<<"Sec-Websocket-Version">>, <<"13">>}],
+    Res = lists:all(
+        fun({Tag, Val}) ->
+            case lists:keyfind(Tag, 1, Headers) of
+                false ->
+                    false;
+                {_, _} when Val == ignore ->
+                    true;
+                {_, HVal} ->
+                    str:to_lower(HVal) == Val
+            end
+        end, HeadersToValidate),
+
+    case {Res, lists:keyfind(<<"Origin">>, 1, Headers), get_origin()} of
+        {false, _, _} ->
+            false;
+        {true, _, []} ->
+            true;
+        {true, {_, HVal}, Origins} ->
+            HValLow = str:to_lower(HVal),
+            case lists:any(fun(V) -> V == HValLow end, Origins) of
+                true ->
+                    true;
+                _ ->
+                    invalid_origin
+            end;
+        {true, false, _} ->
+            true
     end.
 
 socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path,
@@ -92,7 +102,7 @@ socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path,
                                   socket = Socket, sockmod = SockMod,
                                   data = Buf, opts = HOpts},
                _Opts, HandlerModule, InfoMsgFun) ->
-    case check(LocalPath, Headers) of
+    case is_valid_websocket_upgrade(LocalPath, Headers) of
         true ->
             WS = #ws{socket = Socket,
                      sockmod = SockMod,
@@ -107,8 +117,11 @@ socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path,
                      http_opts = HOpts},
 
             connect(WS, HandlerModule);
-        _ ->
-            {200, ?HEADER, InfoMsgFun()}
+        false ->
+            {200, ?HEADER, InfoMsgFun()};
+        invalid_origin ->
+            {403, ?HEADER, #xmlel{name = <<"h1">>,
+                                  children = [{xmlcdata, <<"403 Bad Request - Invalid origin">>}]}}
     end;
 socket_handoff(_, #request{method = 'OPTIONS'}, _, _, _) ->
     {200, ?OPTIONS_HEADER, []};
@@ -413,22 +426,27 @@ websocket_close(Socket, WsHandleLoopPid, SocketMode, _CloseCode) ->
     SocketMode:close(Socket).
 
 get_origin() ->
-    ejabberd_config:get_option(websocket_origin, ignore).
+    ejabberd_config:get_option(websocket_origin, []).
 
 opt_type(websocket_ping_interval) ->
     fun (I) when is_integer(I), I >= 0 -> I end;
 opt_type(websocket_timeout) ->
     fun (I) when is_integer(I), I > 0 -> I end;
 opt_type(websocket_origin) ->
-    %% Accept only values conforming to RFC6454 section 7.1
-    fun (<<"null">>) -> <<"null">>;
-       (null) -> <<"null">>;
-       (Origin) ->
-           URIs = [_|_] = lists:flatmap(
-                            fun(<<>>) -> [];
-                               (URI) -> [misc:try_url(URI)]
-                            end, re:split(Origin, "\\s")),
-           str:join(URIs, <<" ">>)
+    fun Verify(V) when is_binary(V) ->
+        Verify([V]);
+        Verify([]) ->
+            [];
+        Verify([<<"null">> | R]) ->
+            [<<"null">> | Verify(R)];
+        Verify([null | R]) ->
+            [<<"null">> | Verify(R)];
+        Verify([V | R]) when is_binary(V) ->
+           URIs = [_|_] = lists:filtermap(
+                            fun(<<>>) -> false;
+                               (URI) -> {true, misc:try_url(URI)}
+                            end, re:split(V, "\\s+")),
+           [str:join(URIs, <<" ">>) | Verify(R)]
     end;
 opt_type(_) ->
     [websocket_ping_interval, websocket_timeout, websocket_origin].