]> granicus.if.org Git - ejabberd/commitdiff
Speedup certificate chains creation and validation
authorEvgeniy Khramtsov <ekhramtsov@process-one.net>
Thu, 7 Dec 2017 11:32:12 +0000 (14:32 +0300)
committerEvgeniy Khramtsov <ekhramtsov@process-one.net>
Thu, 7 Dec 2017 11:32:12 +0000 (14:32 +0300)
src/ejabberd_pkix.erl

index a67df1288e948c09ef9df41d400399de7e4fecb9..037fc9e9e70e8a00d6da997cd692d3aaa4ec4f0f 100644 (file)
@@ -40,6 +40,7 @@
                notify = false :: boolean(),
                paths = [] :: [file:filename()],
                certs = #{} :: map(),
+               graph :: digraph:graph(),
                keys = [] :: [public_key:private_key()]}).
 
 -type state() :: #state{}.
@@ -54,6 +55,8 @@
 -type cert_error() :: not_cert | not_der | not_pem | encrypted.
 -export_type([cert_error/0]).
 
+-define(CA_CACHE, ca_cache).
+
 %%%===================================================================
 %%% API
 %%%===================================================================
@@ -143,6 +146,10 @@ start_link() ->
     gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
 
 config_reloaded() ->
+    case use_cache() of
+       true -> init_cache();
+       false -> delete_cache()
+    end,
     gen_server:call(?MODULE, config_reloaded, 60000).
 
 opt_type(ca_path) ->
@@ -182,7 +189,9 @@ init([]) ->
     if Validate -> check_ca();
        true -> ok
     end,
-    State = #state{validate = Validate, notify = Notify},
+    G = digraph:new([acyclic]),
+    init_cache(),
+    State = #state{validate = Validate, notify = Notify, graph = G},
     case filelib:ensure_dir(filename:join(certs_dir(), "foo")) of
        ok ->
            clean_dir(certs_dir()),
@@ -201,11 +210,15 @@ init([]) ->
 handle_call({add_certfile, Path}, _, State) ->
     case add_certfile(Path, State) of
        {ok, State1} ->
-           case build_chain_and_check(State1) of
-               {ok, State2} ->
-                   {reply, ok, State2};
-               Err ->
-                   {reply, Err, State}
+           if State /= State1 ->
+                   case build_chain_and_check(State1) of
+                       {ok, State2} ->
+                           {reply, ok, State2};
+                       Err ->
+                           {reply, Err, State1}
+                   end;
+              true ->
+                   {reply, ok, State1}
            end;
        {Err, State1} ->
            {reply, Err, State1}
@@ -297,6 +310,7 @@ get_certfiles_from_config_options(_State) ->
 
 -spec add_certfiles(state()) -> {ok, state()} | {error, bad_cert()}.
 add_certfiles(State) ->
+    ?DEBUG("Reading certificates", []),
     Paths = get_certfiles_from_config_options(State),
     State1 = lists:foldl(
               fun(Path, Acc) ->
@@ -353,18 +367,21 @@ add_certfile(Path, State) ->
 
 -spec build_chain_and_check(state()) -> ok | {error, bad_cert()}.
 build_chain_and_check(State) ->
-    ?DEBUG("Rebuilding certificate chains from ~s",
-          [str:join(State#state.paths, <<", ">>)]),
-    CertPaths = get_cert_paths(maps:keys(State#state.certs)),
+    ?DEBUG("Building certificates graph", []),
+    CertPaths = get_cert_paths(maps:keys(State#state.certs), State#state.graph),
+    ?DEBUG("Finding matched certificate keys", []),
     case match_cert_keys(CertPaths, State#state.keys) of
        {ok, Chains} ->
+           ?DEBUG("Storing certificate chains", []),
            CertFilesWithDomains = store_certs(Chains, []),
            ets:delete_all_objects(?MODULE),
            lists:foreach(
              fun({Path, Domain}) ->
                      ets:insert(?MODULE, {Domain, Path})
              end, CertFilesWithDomains),
+           ?DEBUG("Validating certificates", []),
            Errors = validate(CertPaths, State#state.validate),
+           ?DEBUG("Subscribing to file events", []),
            subscribe(State),
            lists:foreach(
              fun({Cert, Why}) ->
@@ -485,21 +502,43 @@ decode_certs(PemEntries) ->
 
 -spec validate([{path, [cert()]}], boolean()) -> [{cert(), bad_cert()}].
 validate(Paths, true) ->
-    lists:flatmap(
+    {ok, Re} = re:compile("^[a-f0-9]+\\.[0-9]+$", [unicode]),
+    Hashes = case file:list_dir(ca_dir()) of
+                {ok, Files} ->
+                    lists:foldl(
+                      fun(File, Acc) ->
+                              try re:run(File, Re) of
+                                  {match, _} ->
+                                      [Hash|_] = string:tokens(File, "."),
+                                      Path = filename:join(ca_dir(), File),
+                                      dict:append(Hash, Path, Acc);
+                                  nomatch ->
+                                      Acc
+                              catch _:badarg ->
+                                      ?ERROR_MSG("Regexp failure on ~w", [File]),
+                                      Acc
+                              end
+                      end, dict:new(), Files);
+                {error, Why} ->
+                    ?ERROR_MSG("Failed to list directory ~s: ~s",
+                              [ca_dir(), file:format_error(Why)]),
+                    dict:new()
+            end,
+    lists:filtermap(
       fun({path, Path}) ->
-             case validate_path(Path) of
+             case validate_path(Path, Hashes) of
                  ok ->
-                     [];
+                     false;
                  {error, Cert, Reason} ->
-                     [{Cert, Reason}]
+                     {true, {Cert, Reason}}
              end
       end, Paths);
 validate(_, _) ->
     [].
 
--spec validate_path([cert()]) -> ok | {error, cert(), bad_cert()}.
-validate_path([Cert|_] = Certs) ->
-    case find_local_issuer(Cert) of
+-spec validate_path([cert()], dict:dict()) -> ok | {error, cert(), bad_cert()}.
+validate_path([Cert|_] = Certs, Cache) ->
+    case find_local_issuer(Cert, Cache) of
        {ok, IssuerCert} ->
            try public_key:pkix_path_validation(IssuerCert, Certs, []) of
                {ok, _} ->
@@ -570,65 +609,86 @@ check_ca() ->
            ok
     end.
 
--spec find_local_issuer(cert()) -> {ok, cert()} | {error, {bad_cert, unknown_ca}}.
-find_local_issuer(Cert) ->
-    case find_issuer_in_dir(Cert, ca_dir()) of
+-spec find_local_issuer(cert(), dict:dict()) -> {ok, cert()} |
+                                               {error, {bad_cert, unknown_ca}}.
+find_local_issuer(Cert, Hashes) ->
+    case find_issuer_in_dir(Cert, Hashes) of
        {ok, IssuerCert} ->
            {ok, IssuerCert};
-       {error, _} = Err ->
+       {error, Reason} ->
            case ca_file() of
-               undefined -> Err;
+               undefined -> {error, Reason};
                CAFile -> find_issuer_in_file(Cert, CAFile)
            end
     end.
 
--spec find_issuer_in_dir(cert(), file:filename_all())
-      -> {ok, cert()} | {error, {bad_cert, unknown_ca}}.
-find_issuer_in_dir(Cert, CADir) ->
+-spec find_issuer_in_dir(cert(), dict:dict())
+      -> {{ok, cert()} | {error, {bad_cert, unknown_ca}}, dict:dict()}.
+find_issuer_in_dir(Cert, Cache) ->
     {ok, {_, IssuerID}} = public_key:pkix_issuer_id(Cert, self),
     Hash = short_name_hash(IssuerID),
-    filelib:fold_files(
-      CADir, Hash ++ "\\.[0-9]+", false,
-      fun(_, {ok, IssuerCert}) ->
-             {ok, IssuerCert};
-        (CertFile, Acc) ->
-             try
-                 {ok, Data} = file:read_file(CertFile),
-                 {ok, [IssuerCert|_], _} = pem_decode(Data),
-                 case public_key:pkix_is_issuer(Cert, IssuerCert) of
-                     true ->
-                         {ok, IssuerCert};
-                     false ->
-                         Acc
-                 end
-             catch _:{badmatch, {error, Why}} ->
-                     ?ERROR_MSG("failed to read CA certificate from \"~s\": ~s",
-                                [CertFile, format_error(Why)]),
-                     Acc
+    Files = case dict:find(Hash, Cache) of
+               {ok, L} -> L;
+               error -> []
+           end,
+    lists:foldl(
+      fun(_, {ok, _IssuerCert} = Acc) ->
+             Acc;
+        (Path, Err) ->
+             case read_ca_file(Path) of
+                 {ok, [IssuerCert|_]} ->
+                     case public_key:pkix_is_issuer(Cert, IssuerCert) of
+                         true ->
+                             {ok, IssuerCert};
+                         false ->
+                             Err
+                     end;
+                 error ->
+                     Err
              end
-      end, {error, {bad_cert, unknown_ca}}).
+      end, {error, {bad_cert, unknown_ca}}, Files).
 
 -spec find_issuer_in_file(cert(), file:filename_all() | undefined)
       -> {ok, cert()} | {error, {bad_cert, unknown_ca}}.
 find_issuer_in_file(_Cert, undefined) ->
     {error, {bad_cert, unknown_ca}};
 find_issuer_in_file(Cert, CAFile) ->
+    case read_ca_file(CAFile) of
+       {ok, IssuerCerts} ->
+           lists:foldl(
+             fun(_, {ok, _} = Res) ->
+                     Res;
+                (IssuerCert, Err) ->
+                     case public_key:pkix_is_issuer(Cert, IssuerCert) of
+                         true -> {ok, IssuerCert};
+                         false -> Err
+                     end
+             end, {error, {bad_cert, unknown_ca}}, IssuerCerts);
+       error ->
+           {error, {bad_cert, unknown_ca}}
+    end.
+
+-spec read_ca_file(file:filename_all()) -> {ok, [cert()]} | error.
+read_ca_file(Path) ->
+    case use_cache() of
+       true ->
+           ets_cache:lookup(?CA_CACHE, Path,
+                            fun() -> do_read_ca_file(Path) end);
+       false ->
+           do_read_ca_file(Path)
+    end.
+
+-spec do_read_ca_file(file:filename_all()) -> {ok, [cert()]} | error.
+do_read_ca_file(Path) ->
     try
-       {ok, Data} = file:read_file(CAFile),
+       {ok, Data} = file:read_file(Path),
        {ok, IssuerCerts, _} = pem_decode(Data),
-       lists:foldl(
-         fun(_, {ok, _} = Res) ->
-                 Res;
-            (IssuerCert, Err) ->
-                 case public_key:pkix_is_issuer(Cert, IssuerCert) of
-                     true -> {ok, IssuerCert};
-                     false -> Err
-                 end
-         end, {error, {bad_cert, unknown_ca}}, IssuerCerts)
+       {ok, IssuerCerts}
     catch _:{badmatch, {error, Why}} ->
-           ?ERROR_MSG("failed to read CA certificates from \"~s\": ~s",
-                      [CAFile, format_error(Why)]),
-           {error, {bad_cert, unknown_ca}}
+           ?ERROR_MSG("Failed to read CA certificate "
+                      "from \"~s\": ~s",
+                      [Path, format_error(Why)]),
+           error
     end.
 
 -spec match_cert_keys([{path, [cert()]}], [priv_key()])
@@ -680,13 +740,22 @@ pubkey_from_privkey(#'DSAPrivateKey'{p = P, q = Q, g = G, y = Y}) ->
 pubkey_from_privkey(#'ECPrivateKey'{publicKey = Key}) ->
     #'ECPoint'{point = Key}.
 
--spec get_cert_paths([cert()]) -> [{path, [cert()]}].
-get_cert_paths(Certs) ->
-    G = digraph:new([acyclic]),
-    lists:foreach(
-      fun(Cert) ->
-             digraph:add_vertex(G, Cert)
-      end, Certs),
+-spec get_cert_paths([cert()], digraph:graph()) -> [{path, [cert()]}].
+get_cert_paths(Certs, G) ->
+    {NewCerts, OldCerts} =
+       lists:partition(
+         fun(Cert) ->
+                 case digraph:vertex(G, Cert) of
+                     false ->
+                         digraph:add_vertex(G, Cert),
+                         true;
+                     {_, _} ->
+                         false
+                 end
+         end, Certs),
+    CertPairs = [{C1, C2} || C1 <- NewCerts, C2 <- OldCerts] ++
+               [{C1, C2} || C1 <- OldCerts, C2 <- NewCerts] ++
+               [{C1, C2} || C1 <- NewCerts, C2 <- NewCerts],
     lists:foreach(
       fun({Cert1, Cert2}) when Cert1 /= Cert2 ->
              case public_key:pkix_is_self_signed(Cert1) of
@@ -702,18 +771,16 @@ get_cert_paths(Certs) ->
              end;
         (_) ->
              ok
-      end, [{Cert1, Cert2} || Cert1 <- Certs, Cert2 <- Certs]),
-    Paths = lists:flatmap(
-             fun(Cert) ->
-                     case digraph:in_degree(G, Cert) of
-                         0 ->
-                             get_cert_path(G, [Cert]);
-                         _ ->
-                             []
-                     end
-             end, Certs),
-    digraph:delete(G),
-    Paths.
+      end, CertPairs),
+    lists:flatmap(
+      fun(Cert) ->
+             case digraph:in_degree(G, Cert) of
+                 0 ->
+                     get_cert_path(G, [Cert]);
+                 _ ->
+                     []
+             end
+      end, Certs).
 
 get_cert_path(G, [Root|_] = Acc) ->
     case digraph:out_edges(G, Root) of
@@ -783,3 +850,25 @@ wildcard(Path) when is_binary(Path) ->
     wildcard(binary_to_list(Path));
 wildcard(Path) ->
     filelib:wildcard(Path).
+
+-spec use_cache() -> boolean().
+use_cache() ->
+    ejabberd_config:use_cache(global).
+
+-spec init_cache() -> ok.
+init_cache() ->
+    ets_cache:new(?CA_CACHE, cache_opts()).
+
+-spec delete_cache() -> ok.
+delete_cache() ->
+    ets_cache:delete(?CA_CACHE).
+
+-spec cache_opts() -> [proplists:property()].
+cache_opts() ->
+    MaxSize = ejabberd_config:cache_size(global),
+    CacheMissed = ejabberd_config:cache_missed(global),
+    LifeTime = case ejabberd_config:cache_life_time(global) of
+                   infinity -> infinity;
+                   I -> timer:seconds(I)
+               end,
+    [{max_size, MaxSize}, {cache_missed, CacheMissed}, {life_time, LifeTime}].