]> granicus.if.org Git - ejabberd/commitdiff
Implement X-OAUTH2 authentication for mod_mqtt
authorEvgeny Khramtsov <ekhramtsov@process-one.net>
Fri, 11 Oct 2019 13:46:16 +0000 (16:46 +0300)
committerEvgeny Khramtsov <ekhramtsov@process-one.net>
Fri, 11 Oct 2019 13:46:16 +0000 (16:46 +0300)
This will only work for MQTT 5.0 connections.
A client MUST set "Authentication Method" property of CONNECT
packet to "X-OAUTH2" and MUST set the token in "Authentication Data"
property of the same CONNECT packet.

The server responses as usual with CONNACK.

src/mod_mqtt_session.erl

index 8c0ced1f4d384cc1c25dee0d3a51a897391a48b2..6508a706c66d6fb94b9115fa884350404e108a7b 100644 (file)
@@ -677,12 +677,17 @@ set_will_properties(State, _) ->
 -spec get_connack_properties(state(), connect()) -> properties().
 get_connack_properties(#state{session_expiry = SessExp, jid = JID},
                        #connect{client_id = ClientID,
-                                keep_alive = KeepAlive}) ->
+                                keep_alive = KeepAlive,
+                               properties = Props}) ->
     Props1 = case ClientID of
                  <<>> -> #{assigned_client_identifier => JID#jid.lresource};
                  _ -> #{}
              end,
-    Props1#{session_expiry_interval => SessExp div 1000,
+    Props2 = case maps:find(authentication_method, Props) of
+                {ok, Method} -> Props1#{authentication_method => Method};
+                error -> Props1
+            end,
+    Props2#{session_expiry_interval => SessExp div 1000,
             shared_subscription_available => false,
             topic_alias_maximum => topic_alias_maximum(JID#jid.lserver),
             server_keep_alive => KeepAlive}.
@@ -1169,24 +1174,41 @@ parse_credentials(JID, ClientID) ->
     end.
 
 -spec authenticate(connect(), peername()) -> {ok, jid:jid()} | {error, reason_code()}.
-authenticate(#connect{password = Pass} = Pkt, IP) ->
+authenticate(Pkt, IP) ->
+    case authenticate(Pkt) of
+       {ok, JID, AuthModule} ->
+           ?INFO_MSG("Accepted MQTT authentication for ~ts by ~s backend from ~s",
+                     [jid:encode(JID),
+                      ejabberd_auth:backend_type(AuthModule),
+                      ejabberd_config:may_hide_data(misc:ip_to_list(IP))]),
+           {ok, JID};
+       {error, _} = Err ->
+           Err
+    end.
+
+-spec authenticate(connect()) -> {ok, jid:jid(), module()} | {error, reason_code()}.
+authenticate(#connect{password = Pass, properties = Props} = Pkt) ->
     case parse_credentials(Pkt) of
        {ok, #jid{luser = LUser, lserver = LServer} = JID} ->
-           case ejabberd_auth:check_password_with_authmodule(
-                   LUser, <<>>, LServer, Pass) of
-               {true, AuthModule} ->
-                    ?INFO_MSG(
-                       "Accepted MQTT authentication for ~ts "
-                       "by ~ts backend from ~ts",
-                       [jid:encode(JID),
-                        ejabberd_auth:backend_type(AuthModule),
-                        ejabberd_config:may_hide_data(misc:ip_to_list(IP))]),
-                    {ok, JID};
-                false ->
-                    {error, 'not-authorized'}
-            end;
-        {error, _} = Err ->
-            Err
+           case maps:find(authentication_method, Props) of
+               {ok, <<"X-OAUTH2">>} ->
+                   Token = maps:get(authentication_data, Props, <<>>),
+                   case ejabberd_oauth:check_token(
+                          LUser, LServer, [<<"sasl_auth">>], Token) of
+                       true -> {ok, JID, ejabberd_oauth};
+                       _ -> {error, 'not-authorized'}
+                   end;
+               {ok, _} ->
+                   {error, 'bad-authentication-method'};
+               error ->
+                   case ejabberd_auth:check_password_with_authmodule(
+                          LUser, <<>>, LServer, Pass) of
+                       {true, AuthModule} -> {ok, JID, AuthModule};
+                       false -> {error, 'not-authorized'}
+                   end
+           end;
+       {error, _} = Err ->
+           Err
     end.
 
 %%%===================================================================