]> granicus.if.org Git - pdns/commitdiff
dnsdist: Add Lua bindings to be able to rotate `DNSCrypt` keys
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 3 Jul 2017 18:42:17 +0000 (20:42 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 3 Jul 2017 18:42:17 +0000 (20:42 +0200)
pdns/README-dnsdist.md
pdns/dnscrypt.hh
pdns/dnsdist-console.cc
pdns/dnsdist-lua2.cc
regression-tests.dnsdist/dnscrypt.py
regression-tests.dnsdist/test_DNSCrypt.py

index 237f4f0001091db23c8affdf460e8b43ac7fd833..2d58f92bc9c16f4788358b251a195a51f875aede 100644 (file)
@@ -1682,6 +1682,7 @@ instantiate a server with additional parameters
     * `addDNSCryptBind("127.0.0.1:8443", "provider name", "/path/to/resolver.cert", "/path/to/resolver.key", [, {doTCP=true, reusePort=false, tcpFastOpenSize=0, interface=\"\"}]):` listen to incoming DNSCrypt queries on 127.0.0.1 port 8443, with a provider name of "provider name", using a resolver certificate and associated key stored respectively in the `resolver.cert` and `resolver.key` files. The fifth parameter is the same optional table than the one described in `addLocal()`, except that TCP is always enabled
     * `generateDNSCryptProviderKeys("/path/to/providerPublic.key", "/path/to/providerPrivate.key"):` generate a new provider keypair
     * `generateDNSCryptCertificate("/path/to/providerPrivate.key", "/path/to/resolver.cert", "/path/to/resolver.key", serial, validFrom, validUntil):` generate a new resolver private key and related certificate, valid from the `validFrom` UNIX timestamp until the `validUntil` one, signed with the provider private key
+    * `getDNSCryptBind(n)`: return the `DNSCryptContext` object corresponding to the bind `n`
     * `printDNSCryptProviderFingerprint("/path/to/providerPublic.key")`: display the fingerprint of the provided resolver public key
     * `showDNSCryptBinds():`: display the currently configured DNSCrypt binds
  * BPFFilter related:
@@ -1693,6 +1694,23 @@ instantiate a server with additional parameters
     * member `getStats()`: print the block tables
     * member `unblock(ComboAddress)`: unblock this address
     * member `unblockQName(DNSName [, qtype=255])`: remove this qname from the block list
+ * DNSCryptCert related:
+    * member `getClientMagic`: return this certificate's client magic value, as a string
+    * member `getEsVersion()`: return the cryptographic construction to use with this certificate, as a string
+    * member `getMagic()`: return the certificate magic number, as a string
+    * member `getProtocolMinorVersion()`: return this certificate's minor version, as a string
+    * member `getResolverPublicKey()`: return the public key corresponding to this certificate, as a string
+    * member `getSerial()`: return the certificate serial number
+    * member `getSignature()`: return this certificate's signature, as a string
+    * member `getTSEnd()`: return the date the certificate is valid from, as a Unix timestamp
+    * member `getTSStart()`: return the date the certificate is valid until (inclusive), as a Unix timestamp
+ * DNSCryptContext related:
+    * member `generateAndLoadInMemoryCertificate(path/to/provider/private/key/file, serial, begin, end)`: generate a new resolver key and the associated certificate in-memory, sign it with the provided provider key, and use the new certificate
+    * member `getCurrentCertificate()`: return the current certificate as a `DnsCryptCert` object
+    * member `getOldCertificate()`: return the previous certificate as a `DnsCryptCert` object
+    * member `getProviderName()`: return the provider name
+    * member `hasOldCertificate()`: return a boolean indicating if the context has a previous certificate, from a certificate rotation
+    * member `loadNewCertificate(path/to/certificate, path/to/key)`: load a new certificate and the corresponding private key, and use it
  * DNSDistProtoBufMessage related:
     * member `setBytes(bytes)`: set the size of the query
     * member `setEDNSSubnet(Netmask)`: set the EDNS Subnet
index f889d9cac513be1e5dd6042b2784780876eb20e7..dff4e9e281f84c51a850974e182c70987fe76bb6 100644 (file)
@@ -164,7 +164,7 @@ public:
   void setNewCertificate(const DnsCryptCert& newCert, const DnsCryptPrivateKey& newKey);
   const DnsCryptCert& getCurrentCertificate() const { return cert; };
   const DnsCryptCert& getOldCertificate() const { return oldCert; };
-  bool hadOldCertificate() const { return hasOldCert; };
+  bool hasOldCertificate() const { return hasOldCert; };
   const std::string& getProviderName() const { return providerName; }
   int encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DnsCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen) const;
 
index 6a1e15465cfbbd93fd86d3ad277695dee10c68b2..9c2d3a0ae1698bef3d247aeac70e35e8ef0f86e7 100644 (file)
@@ -313,7 +313,8 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "firstAvailable", false, "", "picks the server with the lowest `order` that has not exceeded its QPS limit" },
   { "fixupCase", true, "bool", "if set (default to no), rewrite the first qname of the question part of the answer to match the one from the query. It is only useful when you have a downstream server that messes up the case of the question qname in the answer" },
   { "generateDNSCryptCertificate", true, "\"/path/to/providerPrivate.key\", \"/path/to/resolver.cert\", \"/path/to/resolver.key\", serial, validFrom, validUntil", "generate a new resolver private key and related certificate, valid from the `validFrom` timestamp until the `validUntil` one, signed with the provider private key" },
-  { "generateDNSCryptProviderKeys", true, "\"/path/to/providerPublic.key\", \"/path/to/providerPrivate.key\"", "generate a new provider keypair"},
+  { "generateDNSCryptProviderKeys", true, "\"/path/to/providerPublic.key\", \"/path/to/providerPrivate.key\"", "generate a new provider keypair" },
+  { "getDNSCryptBind", true, "n", "return the `DNSCryptContext` object corresponding to the bind `n`" },
   { "getPoolServers", true, "pool", "return servers part of this pool" },
   { "getQueryCounters", true, "[max=10]", "show current buffer of query counters, limited by 'max' if provided" },
   { "getResponseRing", true, "", "return the current content of the response ring" },
index 20360e972570c331a626a93088ecd486b4467cbb..eb8848ab716d352d67dea99804acf59b0b3307f7 100644 (file)
@@ -197,7 +197,34 @@ map<ComboAddress,int> exceedRespByterate(int rate, int seconds)
                   });
 }
 
+#ifdef HAVE_DNSCRYPT
+static bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, DnsCryptCert& certOut, DnsCryptPrivateKey& keyOut)
+{
+  bool success = false;
+  unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE];
+  sodium_mlock(providerPrivateKey, sizeof(providerPrivateKey));
+  sodium_memzero(providerPrivateKey, sizeof(providerPrivateKey));
+
+  try {
+    ifstream providerKStream(providerPrivateKeyFile);
+    providerKStream.read((char*) providerPrivateKey, sizeof(providerPrivateKey));
+    if (providerKStream.fail()) {
+      providerKStream.close();
+      throw std::runtime_error("Invalid DNSCrypt provider key file " + providerPrivateKeyFile);
+    }
+
+    DnsCryptContext::generateCertificate(serial, begin, end, providerPrivateKey, keyOut, certOut);
+    success = true;
+  }
+  catch(const std::exception& e) {
+    errlog(e.what());
+  }
 
+  sodium_memzero(providerPrivateKey, sizeof(providerPrivateKey));
+  sodium_munlock(providerPrivateKey, sizeof(providerPrivateKey));
+  return success;
+}
+#endif /* HAVE_DNSCRYPT */
 
 void moreLua(bool client)
 {
@@ -548,7 +575,7 @@ void moreLua(bool client)
 
       for (const auto& local : g_dnsCryptLocals) {
         const DnsCryptContext& ctx = std::get<1>(local);
-        bool const hasOldCert = ctx.hadOldCertificate();
+        bool const hasOldCert = ctx.hasOldCertificate();
         const DnsCryptCert& cert = ctx.getCurrentCertificate();
         const DnsCryptCert& oldCert = ctx.getOldCertificate();
 
@@ -562,6 +589,53 @@ void moreLua(bool client)
 #endif
     });
 
+  g_lua.writeFunction("getDNSCryptBind", [client](size_t idx) {
+      setLuaNoSideEffect();
+#ifdef HAVE_DNSCRYPT
+      DnsCryptContext* ret = nullptr;
+      if (idx < g_dnsCryptLocals.size()) {
+        ret = &(std::get<1>(g_dnsCryptLocals.at(idx)));
+      }
+      return ret;
+#else
+      g_outputBuffer="Error: DNSCrypt support is not enabled.\n";
+#endif
+    });
+
+#ifdef HAVE_DNSCRYPT
+    /* DnsCryptContext bindings */
+    g_lua.registerFunction<std::string(DnsCryptContext::*)()>("getProviderName", [](const DnsCryptContext& ctx) { return ctx.getProviderName(); });
+    g_lua.registerFunction<DnsCryptCert(DnsCryptContext::*)()>("getCurrentCertificate", [](const DnsCryptContext& ctx) { return ctx.getCurrentCertificate(); });
+    g_lua.registerFunction<DnsCryptCert(DnsCryptContext::*)()>("getOldCertificate", [](const DnsCryptContext& ctx) { return ctx.getOldCertificate(); });
+    g_lua.registerFunction("hasOldCertificate", &DnsCryptContext::hasOldCertificate);
+    g_lua.registerFunction("loadNewCertificate", &DnsCryptContext::loadNewCertificate);
+    g_lua.registerFunction<void(DnsCryptContext::*)(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end)>("generateAndLoadInMemoryCertificate", [](DnsCryptContext& ctx, const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end) {
+        DnsCryptPrivateKey privateKey;
+        DnsCryptCert cert;
+
+        try {
+          if (generateDNSCryptCertificate(providerPrivateKeyFile, serial, begin, end, cert, privateKey)) {
+            ctx.setNewCertificate(cert, privateKey);
+          }
+        }
+        catch(const std::exception& e) {
+          errlog(e.what());
+          g_outputBuffer="Error: "+string(e.what())+"\n";
+        }
+    });
+
+    /* DnsCryptCert */
+    g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getMagic", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.magic), sizeof(cert.magic)); });
+    g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getEsVersion", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.esVersion), sizeof(cert.esVersion)); });
+    g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getProtocolMinorVersion", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.protocolMinorVersion), sizeof(cert.protocolMinorVersion)); });
+    g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getSignature", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.signature), sizeof(cert.signature)); });
+    g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getResolverPublicKey", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.signedData.resolverPK), sizeof(cert.signedData.resolverPK)); });
+    g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getClientMagic", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.signedData.clientMagic), sizeof(cert.signedData.clientMagic)); });
+    g_lua.registerFunction<uint32_t(DnsCryptCert::*)()>("getSerial", [](const DnsCryptCert& cert) { return cert.signedData.serial; });
+    g_lua.registerFunction<uint32_t(DnsCryptCert::*)()>("getTSStart", [](const DnsCryptCert& cert) { return cert.signedData.tsStart; });
+    g_lua.registerFunction<uint32_t(DnsCryptCert::*)()>("getTSEnd", [](const DnsCryptCert& cert) { return cert.signedData.tsEnd; });
+#endif
+
     g_lua.writeFunction("generateDNSCryptProviderKeys", [](const std::string& publicKeyFile, const std::string privateKeyFile) {
         setLuaNoSideEffect();
 #ifdef HAVE_DNSCRYPT
@@ -621,32 +695,19 @@ void moreLua(bool client)
     g_lua.writeFunction("generateDNSCryptCertificate", [](const std::string& providerPrivateKeyFile, const std::string& certificateFile, const std::string privateKeyFile, uint32_t serial, time_t begin, time_t end) {
         setLuaNoSideEffect();
 #ifdef HAVE_DNSCRYPT
-        unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE];
-        sodium_mlock(providerPrivateKey, sizeof(providerPrivateKey));
-        sodium_memzero(providerPrivateKey, sizeof(providerPrivateKey));
+        DnsCryptPrivateKey privateKey;
+        DnsCryptCert cert;
 
         try {
-          DnsCryptPrivateKey privateKey;
-          DnsCryptCert cert;
-          ifstream providerKStream(providerPrivateKeyFile);
-          providerKStream.read((char*) providerPrivateKey, sizeof(providerPrivateKey));
-          if (providerKStream.fail()) {
-            providerKStream.close();
-            throw std::runtime_error("Invalid DNSCrypt provider key file " + providerPrivateKeyFile);
+          if (generateDNSCryptCertificate(providerPrivateKeyFile, serial, begin, end, cert, privateKey)) {
+            privateKey.saveToFile(privateKeyFile);
+            DnsCryptContext::saveCertFromFile(cert, certificateFile);
           }
-
-          DnsCryptContext::generateCertificate(serial, begin, end, providerPrivateKey, privateKey, cert);
-
-          privateKey.saveToFile(privateKeyFile);
-          DnsCryptContext::saveCertFromFile(cert, certificateFile);
         }
-        catch(std::exception& e) {
+        catch(const std::exception& e) {
           errlog(e.what());
           g_outputBuffer="Error: "+string(e.what())+"\n";
         }
-
-        sodium_memzero(providerPrivateKey, sizeof(providerPrivateKey));
-        sodium_munlock(providerPrivateKey, sizeof(providerPrivateKey));
 #else
       g_outputBuffer="Error: DNSCrypt support is not enabled.\n";
 #endif
index 5426a5fc41c09d5650689f595ad971f6618e313e..3ed8cb5d6ff10ac0e931b6637050d4523d72884e 100644 (file)
@@ -39,7 +39,7 @@ class DNSCryptResolverCertificate(object):
 
         resolverPK = orig[0:32]
         clientMagic = orig[32:40]
-        serial = struct.unpack_from("I", orig[40:44])
+        serial = struct.unpack_from("I", orig[40:44])[0]
         validFrom = struct.unpack_from("!I", orig[44:48])[0]
         validUntil = struct.unpack_from("!I", orig[48:52])[0]
         return DNSCryptResolverCertificate(serial, validFrom, validUntil, resolverPK, clientMagic)
@@ -109,7 +109,18 @@ class DNSCryptClient(object):
 
         return False
 
-    def _getResolverCertificates(self):
+    def clearExpiredResolverCertificates(self):
+        newCerts = []
+
+        for cert in self._resolverCertificates:
+            if cert.isValid():
+                newCerts.append(cert)
+
+        self._resolverCertificates = newCerts
+
+    def refreshResolverCertificates(self):
+        self.clearExpiredResolverCertificates()
+
         query = dns.message.make_query(self._providerName, dns.rdatatype.TXT, dns.rdataclass.IN)
         data = self._sendQuery(query.to_wire())
 
@@ -129,7 +140,7 @@ class DNSCryptClient(object):
             if cert.isValid():
                 self._resolverCertificates.append(cert)
 
-    def _getResolverCertificate(self):
+    def getResolverCertificate(self):
         certs = self._resolverCertificates
         result = None
         for cert in certs:
@@ -191,10 +202,10 @@ class DNSCryptClient(object):
     def query(self, queryContent, tcp=False):
 
         if not self._hasValidResolverCertificate():
-            self._getResolverCertificates()
+            self.refreshResolverCertificates()
 
         nonce = self._generateNonce()
-        resolverCert = self._getResolverCertificate()
+        resolverCert = self.getResolverCertificate()
         if resolverCert is None:
             raise Exception("No valid certificate found")
         encryptedQuery = self._encryptQuery(queryContent, resolverCert, nonce, tcp)
index c20bdcdc324888f56db8ac7919f108126f5763b1..58abd84e4bd2568f2893a8aff15fe2d48101463a 100644 (file)
@@ -1,11 +1,12 @@
 #!/usr/bin/env python
+import base64
 import time
 import dns
 import dns.message
 from dnsdisttests import DNSDistTest
 import dnscrypt
 
-class TestDNSCrypt(DNSDistTest):
+class DNSCryptTest(DNSDistTest):
     """
     dnsdist is configured to accept DNSCrypt queries on 127.0.0.1:_dnsDistPortDNSCrypt.
     The provider's keys have been generated with:
@@ -15,21 +16,46 @@ class TestDNSCrypt(DNSDistTest):
 
     _dnsDistPort = 5340
     _dnsDistPortDNSCrypt = 8443
-    _config_template = """
-    generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d)
-    addDNSCryptBind("127.0.0.1:%d", "%s", "DNSCryptResolver.cert", "DNSCryptResolver.key")
-    newServer{address="127.0.0.1:%s"}
-    """
+
+    _consoleKey = DNSDistTest.generateConsoleKey()
+    _consoleKeyB64 = base64.b64encode(_consoleKey)
 
     _providerFingerprint = 'E1D7:2108:9A59:BF8D:F101:16FA:ED5E:EA6A:9F6C:C78F:7F91:AF6B:027E:62F4:69C3:B1AA'
     _providerName = "2.provider.name"
     _resolverCertificateSerial = 42
+
     # valid from 60s ago until 2h from now
     _resolverCertificateValidFrom = time.time() - 60
     _resolverCertificateValidUntil = time.time() + 7200
-    _config_params = ['_resolverCertificateSerial', '_resolverCertificateValidFrom', '_resolverCertificateValidUntil', '_dnsDistPortDNSCrypt', '_providerName', '_testServerPort']
+
     _dnsdistStartupDelay = 10
 
+    def doDNSCryptQuery(self, client, query, response, tcp):
+        self._toResponderQueue.put(response)
+        data = client.query(query.to_wire(), tcp=tcp)
+        receivedResponse = dns.message.from_wire(data)
+        receivedQuery = None
+        if not self._fromResponderQueue.empty():
+            receivedQuery = self._fromResponderQueue.get(query)
+
+        self.assertTrue(receivedQuery)
+        self.assertTrue(receivedResponse)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+
+
+class TestDNSCrypt(DNSCryptTest):
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%s")
+    generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d)
+    addDNSCryptBind("127.0.0.1:%d", "%s", "DNSCryptResolver.cert", "DNSCryptResolver.key")
+    newServer{address="127.0.0.1:%s"}
+    """
+
+    _config_params = ['_consoleKeyB64', '_consolePort', '_resolverCertificateSerial', '_resolverCertificateValidFrom', '_resolverCertificateValidUntil', '_dnsDistPortDNSCrypt', '_providerName', '_testServerPort']
+
     def testSimpleA(self):
         """
         DNSCrypt: encrypted A query
@@ -45,31 +71,8 @@ class TestDNSCrypt(DNSDistTest):
                                     '192.2.0.1')
         response.answer.append(rrset)
 
-        self._toResponderQueue.put(response)
-        data = client.query(query.to_wire())
-        receivedResponse = dns.message.from_wire(data)
-        receivedQuery = None
-        if not self._fromResponderQueue.empty():
-            receivedQuery = self._fromResponderQueue.get(query)
-
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
-
-        self._toResponderQueue.put(response)
-        data = client.query(query.to_wire(), tcp=True)
-        receivedResponse = dns.message.from_wire(data)
-        receivedQuery = None
-        if not self._fromResponderQueue.empty():
-            receivedQuery = self._fromResponderQueue.get(query)
-
-        self.assertTrue(receivedQuery)
-        self.assertTrue(receivedResponse)
-        receivedQuery.id = query.id
-        self.assertEquals(query, receivedQuery)
-        self.assertEquals(response, receivedResponse)
+        self.doDNSCryptQuery(client, query, response, False)
+        self.doDNSCryptQuery(client, query, response, True)
 
     def testResponseLargerThanPaddedQuery(self):
         """
@@ -107,14 +110,66 @@ class TestDNSCrypt(DNSDistTest):
         self.assertTrue(len(receivedResponse.authority) == 0)
         self.assertTrue(len(receivedResponse.additional) == 0)
 
-class TestDNSCryptWithCache(DNSDistTest):
-    _dnsDistPortDNSCrypt = 8443
-    _providerFingerprint = 'E1D7:2108:9A59:BF8D:F101:16FA:ED5E:EA6A:9F6C:C78F:7F91:AF6B:027E:62F4:69C3:B1AA'
-    _providerName = "2.provider.name"
-    _resolverCertificateSerial = 42
-    # valid from 60s ago until 2h from now
-    _resolverCertificateValidFrom = time.time() - 60
-    _resolverCertificateValidUntil = time.time() + 7200
+    def testCertRotation(self):
+        """
+        DNSCrypt: certificate rotation
+        """
+        client = dnscrypt.DNSCryptClient(self._providerName, self._providerFingerprint, "127.0.0.1", 8443)
+        client.refreshResolverCertificates()
+
+        cert = client.getResolverCertificate()
+        self.assertTrue(cert)
+        self.assertEquals(cert.serial, self._resolverCertificateSerial)
+
+        name = 'rotation.dnscrypt.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.2.0.1')
+        response.answer.append(rrset)
+
+        self.doDNSCryptQuery(client, query, response, False)
+        self.doDNSCryptQuery(client, query, response, True)
+
+        # generate a new certificate
+        self.sendConsoleCommand("generateDNSCryptCertificate('DNSCryptProviderPrivate.key', 'DNSCryptResolver.cert.2', 'DNSCryptResolver.key.2', {!s}, {:.0f}, {:.0f})".format(self._resolverCertificateSerial + 1, self._resolverCertificateValidFrom, self._resolverCertificateValidUntil))
+        # switch to that new certificate
+        self.sendConsoleCommand("getDNSCryptBind(0):loadNewCertificate('DNSCryptResolver.cert.2', 'DNSCryptResolver.key.2')")
+
+        # we should still be able to send queries with the previous certificate
+        self.doDNSCryptQuery(client, query, response, False)
+        self.doDNSCryptQuery(client, query, response, True)
+        cert = client.getResolverCertificate()
+        self.assertTrue(cert)
+        self.assertEquals(cert.serial, self._resolverCertificateSerial)
+
+        # but refreshing should get us the new one
+        client.refreshResolverCertificates()
+        cert = client.getResolverCertificate()
+        self.assertTrue(cert)
+        self.assertEquals(cert.serial, self._resolverCertificateSerial + 1)
+
+        # generate a third certificate, this time in memory
+        self.sendConsoleCommand("getDNSCryptBind(0):generateAndLoadInMemoryCertificate('DNSCryptProviderPrivate.key', {!s}, {:.0f}, {:.0f})".format(self._resolverCertificateSerial + 2, self._resolverCertificateValidFrom, self._resolverCertificateValidUntil))
+
+        # we should still be able to send queries with the previous certificate
+        self.doDNSCryptQuery(client, query, response, False)
+        self.doDNSCryptQuery(client, query, response, True)
+        cert = client.getResolverCertificate()
+        self.assertTrue(cert)
+        self.assertEquals(cert.serial, self._resolverCertificateSerial + 1)
+
+        # but refreshing should get us the new one
+        client.refreshResolverCertificates()
+        cert = client.getResolverCertificate()
+        self.assertTrue(cert)
+        self.assertEquals(cert.serial, self._resolverCertificateSerial + 2)
+
+class TestDNSCryptWithCache(DNSCryptTest):
+
     _config_params = ['_resolverCertificateSerial', '_resolverCertificateValidFrom', '_resolverCertificateValidUntil', '_dnsDistPortDNSCrypt', '_providerName', '_testServerPort']
     _config_template = """
     generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d)
@@ -169,3 +224,58 @@ class TestDNSCryptWithCache(DNSDistTest):
         for key in self._responsesCounter:
             total += self._responsesCounter[key]
         self.assertEquals(total, misses)
+
+class TestDNSCryptAutomaticRotation(DNSCryptTest):
+    _config_template = """
+    setKey("%s")
+    controlSocket("127.0.0.1:%s")
+    generateDNSCryptCertificate("DNSCryptProviderPrivate.key", "DNSCryptResolver.cert", "DNSCryptResolver.key", %d, %d, %d)
+    addDNSCryptBind("127.0.0.1:%d", "%s", "DNSCryptResolver.cert", "DNSCryptResolver.key")
+    newServer{address="127.0.0.1:%s"}
+
+    local last = 0
+    serial = %d
+    function maintenance()
+      local now = os.time()
+      if ((now - last) > 2) then
+        serial = serial + 1
+        getDNSCryptBind(0):generateAndLoadInMemoryCertificate('DNSCryptProviderPrivate.key', serial, now - 60, now + 120)
+        last = now
+      end
+    end
+    """
+
+    _config_params = ['_consoleKeyB64', '_consolePort', '_resolverCertificateSerial', '_resolverCertificateValidFrom', '_resolverCertificateValidUntil', '_dnsDistPortDNSCrypt', '_providerName', '_testServerPort', '_resolverCertificateSerial']
+
+    def testCertRotation(self):
+        """
+        DNSCrypt: automatic certificate rotation
+        """
+        client = dnscrypt.DNSCryptClient(self._providerName, self._providerFingerprint, "127.0.0.1", 8443)
+
+        client.refreshResolverCertificates()
+        cert = client.getResolverCertificate()
+        self.assertTrue(cert)
+        firstSerial = cert.serial
+        self.assertGreaterEqual(cert.serial, self._resolverCertificateSerial)
+
+        time.sleep(3)
+
+        client.refreshResolverCertificates()
+        cert = client.getResolverCertificate()
+        self.assertTrue(cert)
+        secondSerial = cert.serial
+        self.assertGreater(cert.serial, firstSerial)
+
+        name = 'automatic-rotation.dnscrypt.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.2.0.1')
+        response.answer.append(rrset)
+
+        self.doDNSCryptQuery(client, query, response, False)
+        self.doDNSCryptQuery(client, query, response, True)