From bcfe50f817b6365b2cada08e05cc8f59f5d00980 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Pawe=C5=82=20Chmielowski?= Date: Fri, 26 Apr 2019 15:29:43 +0200 Subject: [PATCH] Return "Bad request" error when origin in websocket connection doesn't match This also allow websocket_origin option to accept multiple values instead of just single one. --- src/ejabberd_websocket.erl | 92 +++++++++++++++++++++++--------------- 1 file changed, 55 insertions(+), 37 deletions(-) diff --git a/src/ejabberd_websocket.erl b/src/ejabberd_websocket.erl index 2b5a01460..e954b42c2 100644 --- a/src/ejabberd_websocket.erl +++ b/src/ejabberd_websocket.erl @@ -42,7 +42,7 @@ -author('ecestari@process-one.net'). --export([check/2, socket_handoff/5, opt_type/1]). +-export([socket_handoff/5, opt_type/1]). -include("logger.hrl"). @@ -62,29 +62,39 @@ ?AC_ALLOW_HEADERS, ?AC_MAX_AGE]). -define(HEADER, [?CT_XML, ?AC_ALLOW_ORIGIN, ?AC_ALLOW_HEADERS]). -check(_Path, Headers) -> - HeadersValidators = [{'Upgrade', <<"websocket">>, true}, - {'Connection', ignore, true}, {'Host', ignore, true}, - {<<"Sec-Websocket-Key">>, ignore, true}, - {<<"Sec-Websocket-Version">>, <<"13">>, true}, - {<<"Origin">>, get_origin(), false}], - - F = fun ({Tag, Val, Required}) -> - case lists:keyfind(Tag, 1, Headers) of - false -> Required; % header not found, keep in list if required - {_, HVal} -> - case Val of - ignore -> false; % ignore value -> ok, remove from list - _ -> - % expected value -> ok, remove from list (false) - % value is different, keep in list (true) - str:to_lower(HVal) /= Val - end - end - end, - case lists:filter(F, HeadersValidators) of - [] -> true; - _InvalidHeaders -> false +is_valid_websocket_upgrade(_Path, Headers) -> + HeadersToValidate = [{'Upgrade', <<"websocket">>}, + {'Connection', ignore}, + {'Host', ignore}, + {<<"Sec-Websocket-Key">>, ignore}, + {<<"Sec-Websocket-Version">>, <<"13">>}], + Res = lists:all( + fun({Tag, Val}) -> + case lists:keyfind(Tag, 1, Headers) of + false -> + false; + {_, _} when Val == ignore -> + true; + {_, HVal} -> + str:to_lower(HVal) == Val + end + end, HeadersToValidate), + + case {Res, lists:keyfind(<<"Origin">>, 1, Headers), get_origin()} of + {false, _, _} -> + false; + {true, _, []} -> + true; + {true, {_, HVal}, Origins} -> + HValLow = str:to_lower(HVal), + case lists:any(fun(V) -> V == HValLow end, Origins) of + true -> + true; + _ -> + invalid_origin + end; + {true, false, _} -> + true end. socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path, @@ -92,7 +102,7 @@ socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path, socket = Socket, sockmod = SockMod, data = Buf, opts = HOpts}, _Opts, HandlerModule, InfoMsgFun) -> - case check(LocalPath, Headers) of + case is_valid_websocket_upgrade(LocalPath, Headers) of true -> WS = #ws{socket = Socket, sockmod = SockMod, @@ -107,8 +117,11 @@ socket_handoff(LocalPath, #request{method = 'GET', ip = IP, q = Q, path = Path, http_opts = HOpts}, connect(WS, HandlerModule); - _ -> - {200, ?HEADER, InfoMsgFun()} + false -> + {200, ?HEADER, InfoMsgFun()}; + invalid_origin -> + {403, ?HEADER, #xmlel{name = <<"h1">>, + children = [{xmlcdata, <<"403 Bad Request - Invalid origin">>}]}} end; socket_handoff(_, #request{method = 'OPTIONS'}, _, _, _) -> {200, ?OPTIONS_HEADER, []}; @@ -413,22 +426,27 @@ websocket_close(Socket, WsHandleLoopPid, SocketMode, _CloseCode) -> SocketMode:close(Socket). get_origin() -> - ejabberd_config:get_option(websocket_origin, ignore). + ejabberd_config:get_option(websocket_origin, []). opt_type(websocket_ping_interval) -> fun (I) when is_integer(I), I >= 0 -> I end; opt_type(websocket_timeout) -> fun (I) when is_integer(I), I > 0 -> I end; opt_type(websocket_origin) -> - %% Accept only values conforming to RFC6454 section 7.1 - fun (<<"null">>) -> <<"null">>; - (null) -> <<"null">>; - (Origin) -> - URIs = [_|_] = lists:flatmap( - fun(<<>>) -> []; - (URI) -> [misc:try_url(URI)] - end, re:split(Origin, "\\s")), - str:join(URIs, <<" ">>) + fun Verify(V) when is_binary(V) -> + Verify([V]); + Verify([]) -> + []; + Verify([<<"null">> | R]) -> + [<<"null">> | Verify(R)]; + Verify([null | R]) -> + [<<"null">> | Verify(R)]; + Verify([V | R]) when is_binary(V) -> + URIs = [_|_] = lists:filtermap( + fun(<<>>) -> false; + (URI) -> {true, misc:try_url(URI)} + end, re:split(V, "\\s+")), + [str:join(URIs, <<" ">>) | Verify(R)] end; opt_type(_) -> [websocket_ping_interval, websocket_timeout, websocket_origin]. -- 2.40.0