%% API
-export([start_link/0, add_certfile/1, format_error/1, opt_type/1,
- get_certfile/1, try_certfile/1, route_registered/1]).
+ get_certfile/1, try_certfile/1, route_registered/1,
+ config_reloaded/0]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-include("jid.hrl").
-record(state, {validate = true :: boolean(),
- certs = #{}}).
--record(cert_state, {domains = [] :: [binary()]}).
+ paths = [] :: [file:filename()],
+ certs = #{} :: map(),
+ keys = [] :: [public_key:private_key()]}).
+-type state() :: #state{}.
-type cert() :: #'OTPCertificate'{}.
-type priv_key() :: public_key:private_key().
-type pub_key() :: #'RSAPublicKey'{} | {integer(), #'Dss-Parms'{}} | #'ECPoint'{}.
-spec try_certfile(filename:filename()) -> binary().
try_certfile(Path0) ->
Path = prep_path(Path0),
- case mk_cert_state(Path, false) of
- {ok, _} -> Path;
+ case load_certfile(Path) of
+ {ok, _, _} -> Path;
{error, _} -> erlang:error(badarg)
end.
format_error(not_der) ->
"failed to decode from DER format";
format_error(encrypted) ->
- "encrypted certificate found in the chain";
+ "encrypted certificate";
format_error({bad_cert, cert_expired}) ->
"certificate is no longer valid as its expiration date has passed";
format_error({bad_cert, invalid_issuer}) ->
"certificate issuer name does not match the name of the "
- "issuer certificate in the chain";
+ "issuer certificate";
format_error({bad_cert, invalid_signature}) ->
- "certificate was not signed by its issuer certificate in the chain";
+ "certificate was not signed by its issuer certificate";
format_error({bad_cert, name_not_permitted}) ->
"invalid Subject Alternative Name extension";
format_error({bad_cert, missing_basic_constraint}) ->
"certificate key is used in an invalid way according "
"to the key-usage extension";
format_error({bad_cert, selfsigned_peer}) ->
- "self-signed certificate in the chain";
+ "self-signed certificate";
format_error({bad_cert, unknown_sig_algo}) ->
"certificate is signed using unknown algorithm";
format_error({bad_cert, unknown_ca}) ->
start_link() ->
gen_server:start_link({local, ?MODULE}, ?MODULE, [], []).
+config_reloaded() ->
+ gen_server:cast(?MODULE, config_reloaded).
+
opt_type(ca_path) ->
fun(Path) -> iolist_to_binary(Path) end;
+opt_type(certfiles) ->
+ fun(CertList) ->
+ [binary_to_list(Path) || Path <- CertList]
+ end;
opt_type(_) ->
- [ca_path].
+ [ca_path, certfiles].
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
init([]) ->
+ application:load(fs),
+ application:set_env(fs, backwards_compatible, false),
+ ejabberd:start_app(fs),
process_flag(trap_exit, true),
ets:new(?MODULE, [named_table, public, bag]),
ejabberd_hooks:add(route_registered, ?MODULE, route_registered, 50),
+ ejabberd_hooks:add(config_reloaded, ?MODULE, config_reloaded, 30),
Validate = case os:type() of
{win32, _} -> false;
_ ->
true -> ok
end,
State = #state{validate = Validate},
- {ok, add_certfiles(State)}.
+ case filelib:ensure_dir(filename:join(certs_dir(), "foo")) of
+ ok ->
+ clean_dir(certs_dir()),
+ case add_certfiles(State) of
+ {ok, State1} ->
+ {ok, State1};
+ {error, Why} ->
+ {stop, Why}
+ end;
+ {error, Why} ->
+ ?CRITICAL_MSG("Failed to create directory ~s: ~s",
+ [certs_dir(), file:format_error(Why)]),
+ {stop, Why}
+ end.
handle_call({add_certfile, Path}, _, State) ->
{Result, NewState} = add_certfile(Path, State),
{reply, Result, NewState};
handle_call({route_registered, Host}, _, State) ->
- NewState = add_certfiles(Host, State),
- case get_certfile(Host) of
- {ok, _} -> ok;
- error ->
- ?WARNING_MSG("No certificate found matching '~s': strictly "
- "configured clients or servers will reject "
- "connections with this host", [Host])
- end,
- {reply, ok, NewState};
+ case add_certfiles(Host, State) of
+ {ok, NewState} ->
+ case get_certfile(Host) of
+ {ok, _} -> ok;
+ error ->
+ ?WARNING_MSG("No certificate found matching '~s': strictly "
+ "configured clients or servers will reject "
+ "connections with this host", [Host])
+ end,
+ {reply, ok, NewState};
+ {error, _} ->
+ {reply, ok, State}
+ end;
handle_call(_Request, _From, State) ->
Reply = ok,
{reply, Reply, State}.
+handle_cast(config_reloaded, State) ->
+ State1 = State#state{paths = [], certs = #{}, keys = []},
+ case add_certfiles(State1) of
+ {ok, State2} ->
+ {noreply, State2};
+ {error, _} ->
+ {noreply, State}
+ end;
handle_cast(_Msg, State) ->
{noreply, State}.
+handle_info({_, {fs, file_event}, {File, Events}}, State) ->
+ ?DEBUG("got FS events for ~s: ~p", [File, Events]),
+ Path = iolist_to_binary(File),
+ case lists:member(modified, Events) of
+ true ->
+ case lists:member(Path, State#state.paths) of
+ true ->
+ handle_cast(config_reloaded, State);
+ false ->
+ {noreply, State}
+ end;
+ false ->
+ {noreply, State}
+ end;
handle_info(_Info, State) ->
?WARNING_MSG("unexpected info: ~p", [_Info]),
{noreply, State}.
terminate(_Reason, _State) ->
- ejabberd_hooks:delete(route_registered, ?MODULE, route_registered, 50).
+ ejabberd_hooks:delete(route_registered, ?MODULE, route_registered, 50),
+ ejabberd_hooks:delete(config_reloaded, ?MODULE, config_reloaded, 30).
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
+-spec certfiles_from_config_options() -> [atom()].
+certfiles_from_config_options() ->
+ [c2s_certfile, s2s_certfile, domain_certfile].
+
+-spec get_certfiles_from_config_options(state()) -> [binary()].
+get_certfiles_from_config_options(State) ->
+ Global = case ejabberd_config:get_option(certfiles) of
+ undefined ->
+ [];
+ Paths ->
+ lists:flatmap(fun filelib:wildcard/1, Paths)
+ end,
+ Local = lists:flatmap(
+ fun(OptHost) ->
+ case ejabberd_config:get_option(OptHost) of
+ undefined -> [];
+ Path -> [Path]
+ end
+ end, [{Opt, Host}
+ || Opt <- certfiles_from_config_options(),
+ Host <- ejabberd_config:get_myhosts()]),
+ [iolist_to_binary(P) || P <- lists:usort(Local ++ Global)].
+
+-spec add_certfiles(state()) -> {ok, state()} | {error, bad_cert()}.
add_certfiles(State) ->
- lists:foldl(
- fun(Host, AccState) ->
- add_certfiles(Host, AccState)
- end, State, ejabberd_config:get_myhosts()).
+ Paths = get_certfiles_from_config_options(State),
+ State1 = lists:foldl(
+ fun(Path, Acc) ->
+ {_, NewAcc} = add_certfile(Path, Acc),
+ NewAcc
+ end, State, Paths),
+ case build_chain_and_check(State1) of
+ ok -> {ok, State1};
+ {error, _} = Err -> Err
+ end.
+-spec add_certfiles(binary(), state()) -> {ok, state()} | {error, bad_cert()}.
add_certfiles(Host, State) ->
- lists:foldl(
- fun(Opt, AccState) ->
- case ejabberd_config:get_option({Opt, Host}) of
- undefined -> AccState;
- Path ->
- {_, NewAccState} = add_certfile(Path, AccState),
- NewAccState
- end
- end, State, [c2s_certfile, s2s_certfile, domain_certfile]).
+ State1 = lists:foldl(
+ fun(Opt, AccState) ->
+ case ejabberd_config:get_option({Opt, Host}) of
+ undefined -> AccState;
+ Path ->
+ {_, NewAccState} = add_certfile(Path, AccState),
+ NewAccState
+ end
+ end, State, certfiles_from_config_options()),
+ if State /= State1 ->
+ case build_chain_and_check(State1) of
+ ok -> {ok, State1};
+ {error, _} = Err -> Err
+ end;
+ true ->
+ {ok, State}
+ end.
+-spec add_certfile(file:filename_all(), state()) -> {ok, state()} |
+ {{error, cert_error()}, state()}.
add_certfile(Path, State) ->
- case maps:get(Path, State#state.certs, undefined) of
- #cert_state{} ->
+ case lists:member(Path, State#state.paths) of
+ true ->
{ok, State};
- undefined ->
- case mk_cert_state(Path, State#state.validate) of
- {error, Reason} ->
- {{error, Reason}, State};
- {ok, CertState} ->
- NewCerts = maps:put(Path, CertState, State#state.certs),
- lists:foreach(
- fun(Domain) ->
- ets:insert(?MODULE, {Domain, Path})
- end, CertState#cert_state.domains),
- {ok, State#state{certs = NewCerts}}
+ false ->
+ case load_certfile(Path) of
+ {ok, Certs, Keys} ->
+ NewCerts = lists:foldl(
+ fun(Cert, Acc) ->
+ maps:put(Cert, Path, Acc)
+ end, State#state.certs, Certs),
+ {ok, State#state{paths = [Path|State#state.paths],
+ certs = NewCerts,
+ keys = Keys ++ State#state.keys}};
+ {error, Why} = Err ->
+ ?ERROR_MSG("failed to read certificate from ~s: ~s",
+ [Path, format_error(Why)]),
+ {Err, State}
end
end.
-mk_cert_state(Path, Validate) ->
- case check_certfile(Path, Validate) of
- {ok, Ds} ->
- {ok, #cert_state{domains = Ds}};
- {invalid, Ds, {bad_cert, _} = Why} ->
- ?WARNING_MSG("certificate from ~s is invalid: ~s",
- [Path, format_error(Why)]),
- {ok, #cert_state{domains = Ds}};
- {error, Why} = Err ->
- ?ERROR_MSG("failed to read certificate from ~s: ~s",
+-spec build_chain_and_check(state()) -> ok | {error, bad_cert()}.
+build_chain_and_check(State) ->
+ ?DEBUG("Rebuilding certificate chains from ~s",
+ [str:join(State#state.paths, <<", ">>)]),
+ CertPaths = get_cert_paths(maps:keys(State#state.certs)),
+ case match_cert_keys(CertPaths, State#state.keys) of
+ {ok, Chains} ->
+ CertFilesWithDomains = store_certs(Chains, []),
+ ets:delete_all_objects(?MODULE),
+ lists:foreach(
+ fun({Path, Domain}) ->
+ ets:insert(?MODULE, {Domain, Path})
+ end, CertFilesWithDomains),
+ Errors = validate(CertPaths, State#state.validate),
+ subscribe(State),
+ lists:foreach(
+ fun({Cert, Why}) ->
+ Path = maps:get(Cert, State#state.certs),
+ ?ERROR_MSG("Failed to validate certificate from ~s: ~s",
+ [Path, format_error(Why)])
+ end, Errors);
+ {error, Cert, Why} ->
+ Path = maps:get(Cert, State#state.certs),
+ ?ERROR_MSG("Failed to build certificate chain for ~s: ~s",
[Path, format_error(Why)]),
- Err
+ {error, Why}
end.
--spec check_certfile(filename:filename(), boolean())
- -> {ok, [binary()]} | {invalid, [binary()], bad_cert()} |
- {error, cert_error() | file:posix()}.
-check_certfile(Path, Validate) ->
+-spec store_certs([{[cert()], priv_key()}],
+ [{binary(), binary()}]) -> [{binary(), binary()}].
+store_certs([{Certs, Key}|Chains], Acc) ->
+ CertPEMs = public_key:pem_encode(
+ lists:map(
+ fun(Cert) ->
+ Type = element(1, Cert),
+ DER = public_key:pkix_encode(Type, Cert, otp),
+ {'Certificate', DER, not_encrypted}
+ end, Certs)),
+ KeyPEM = public_key:pem_encode(
+ [{element(1, Key),
+ public_key:der_encode(element(1, Key), Key),
+ not_encrypted}]),
+ PEMs = <<CertPEMs/binary, KeyPEM/binary>>,
+ Cert = hd(Certs),
+ Domains = xmpp_stream_pkix:get_cert_domains(Cert),
+ FileName = filename:join(certs_dir(), str:sha(PEMs)),
+ case file:write_file(FileName, PEMs) of
+ ok ->
+ file:change_mode(FileName, 8#600),
+ NewAcc = [{FileName, Domain} || Domain <- Domains] ++ Acc,
+ store_certs(Chains, NewAcc);
+ {error, Why} ->
+ ?ERROR_MSG("Failed to write to ~s: ~s",
+ [FileName, file:format_error(Why)]),
+ store_certs(Chains, [])
+ end;
+store_certs([], Acc) ->
+ Acc.
+
+-spec load_certfile(file:filename_all()) -> {ok, [cert()], [priv_key()]} |
+ {error, cert_error() | file:posix()}.
+load_certfile(Path) ->
try
{ok, Data} = file:read_file(Path),
- {ok, Certs, PrivKeys} = pem_decode(Data),
- CertPaths = get_cert_paths(Certs),
- Domains = get_domains(CertPaths),
- case match_cert_keys(CertPaths, PrivKeys) of
- {ok, _} ->
- case validate(CertPaths, Validate) of
- ok -> {ok, Domains};
- {error, Why} -> {invalid, Domains, Why}
- end;
- {error, Why} ->
- {invalid, Domains, Why}
- end
+ pem_decode(Data)
catch _:{badmatch, {error, _} = Err} ->
Err
end.
fun(#'OTPCertificate'{}) -> true;
(_) -> false
end, Objects) of
- {[], _} ->
+ {[], []} ->
{error, not_cert};
{Certs, PrivKeys} ->
{ok, Certs, PrivKeys}
{error, not_der}
end.
--spec validate([{path, [cert()]}], boolean()) -> ok | {error, bad_cert()}.
-validate([{path, Path}|Paths], true) ->
- case validate_path(Path) of
- ok ->
- validate(Paths, true);
- Err ->
- Err
- end;
+-spec validate([{path, [cert()]}], boolean()) -> [{cert(), bad_cert()}].
+validate(Paths, true) ->
+ lists:flatmap(
+ fun({path, Path}) ->
+ case validate_path(Path) of
+ ok ->
+ [];
+ {error, Cert, Reason} ->
+ [{Cert, Reason}]
+ end
+ end, Paths);
validate(_, _) ->
ok.
--spec validate_path([cert()]) -> ok | {error, bad_cert()}.
+-spec validate_path([cert()]) -> ok | {error, cert(), bad_cert()}.
validate_path([Cert|_] = Certs) ->
case find_local_issuer(Cert) of
{ok, IssuerCert} ->
try public_key:pkix_path_validation(IssuerCert, Certs, []) of
{ok, _} ->
ok;
- Err ->
- Err
+ {error, Reason} ->
+ {error, Cert, Reason}
catch error:function_clause ->
case erlang:get_stacktrace() of
[{public_key, pkix_sign_types, _, _}|_] ->
- {error, {bad_cert, unknown_sig_algo}};
+ {error, Cert, {bad_cert, unknown_sig_algo}};
ST ->
%% Bug in public_key application
erlang:raise(error, function_clause, ST)
end
end;
- {error, _} = Err ->
+ {error, Reason} ->
case public_key:pkix_is_self_signed(Cert) of
true ->
- {error, {bad_cert, selfsigned_peer}};
+ {error, Cert, {bad_cert, selfsigned_peer}};
false ->
- Err
+ {error, Cert, Reason}
end
end.
ca_dir() ->
ejabberd_config:get_option(ca_path, "/etc/ssl/certs").
+-spec certs_dir() -> string().
+certs_dir() ->
+ MnesiaDir = mnesia:system_info(directory),
+ filename:join(MnesiaDir, "certs").
+
+-spec clean_dir(file:filename_all()) -> ok.
+clean_dir(Dir) ->
+ ?DEBUG("Cleaning directory ~s", [Dir]),
+ Files = filelib:wildcard(filename:join(Dir, "*")),
+ lists:foreach(
+ fun(Path) ->
+ case filelib:is_file(Path) of
+ true ->
+ file:delete(Path);
+ false ->
+ ok
+ end
+ end, Files).
+
-spec check_ca_dir() -> ok.
check_ca_dir() ->
case filelib:wildcard(filename:join(ca_dir(), "*.0")) of
-spec match_cert_keys([{path, [cert()]}], [{pub_key(), priv_key()}],
[{cert(), priv_key()}])
- -> {ok, [{cert(), priv_key()}]} | {error, {bad_cert, missing_priv_key}}.
+ -> {ok, [{[cert()], priv_key()}]} | {error, cert(), {bad_cert, missing_priv_key}}.
match_cert_keys([{path, Certs}|CertPaths], KeyPairs, Result) ->
[Cert|_] = RevCerts = lists:reverse(Certs),
PubKey = pubkey_from_cert(Cert),
case lists:keyfind(PubKey, 1, KeyPairs) of
false ->
- {error, {bad_cert, missing_priv_key}};
+ {error, Cert, {bad_cert, missing_priv_key}};
{_, PrivKey} ->
match_cert_keys(CertPaths, KeyPairs, [{RevCerts, PrivKey}|Result])
end;
pubkey_from_privkey(#'ECPrivateKey'{publicKey = Key}) ->
#'ECPoint'{point = Key}.
--spec get_domains([{path, [cert()]}]) -> [binary()].
-get_domains(CertPaths) ->
- lists:usort(
- lists:flatmap(
- fun({path, Certs}) ->
- Cert = lists:last(Certs),
- xmpp_stream_pkix:get_cert_domains(Cert)
- end, CertPaths)).
-
-spec get_cert_paths([cert()]) -> [{path, [cert()]}].
get_cert_paths(Certs) ->
G = digraph:new([acyclic]),
short_name_hash(_) ->
"".
-endif.
+
+-spec subscribe(state()) -> ok.
+subscribe(State) ->
+ lists:foreach(
+ fun(Path) ->
+ Dir = filename:dirname(Path),
+ Name = list_to_atom(integer_to_list(erlang:phash2(Dir))),
+ case fs:start_link(Name, Dir) of
+ {ok, _} ->
+ ?DEBUG("Subscribed to FS events from ~s", [Dir]),
+ fs:subscribe(Name);
+ {error, _} ->
+ ok
+ end
+ end, State#state.paths).