]> granicus.if.org Git - pdns/commitdiff
dnsdist: Add support for DNSCrypt's xchacha20, n active certs
authorRemi Gacogne <remi.gacogne@powerdns.com>
Sun, 15 Oct 2017 20:28:32 +0000 (22:28 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 8 Mar 2018 09:05:09 +0000 (10:05 +0100)
12 files changed:
m4/pdns_check_libsodium.m4
pdns/dnscrypt.cc
pdns/dnscrypt.hh
pdns/dnsdist-dnscrypt.cc
pdns/dnsdist-lua-bindings.cc
pdns/dnsdist-lua.cc
pdns/dnsdist-tcp.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/test-dnscrypt_cc.cc
regression-tests.dnsdist/dnscrypt.py
regression-tests.dnsdist/test_DNSCrypt.py

index 30e961e89e8178bfbe3c96bb15f5949859b7401e..03daf00582d430c44d4cfb1aa9a0bbf29b8db5d7 100644 (file)
@@ -15,7 +15,7 @@ AC_DEFUN([PDNS_CHECK_LIBSODIUM], [
         save_LIBS=$LIBS
         CFLAGS="$LIBSODIUM_CFLAGS $CFLAGS"
         LIBS="$LIBSODIUM_LIBS $LIBS"
-        AC_CHECK_FUNCS([crypto_box_easy_afternm])
+        AC_CHECK_FUNCS([crypto_box_easy_afternm crypto_box_curve25519xchacha20poly1305_easy])
         CFLAGS=$save_CFLAGS
         LIBS=$save_LIBS
       ], [ : ])
index 62ec89706a504c01d0d5414744847bc6f4a2df2b..c03284585bc2b1058781523d3df1b1be9a8f7255 100644 (file)
 #include "dolog.hh"
 #include "dnscrypt.hh"
 #include "dnswriter.hh"
+#include "lock.hh"
 
-DnsCryptPrivateKey::DnsCryptPrivateKey()
+DNSCryptPrivateKey::DNSCryptPrivateKey()
 {
   sodium_memzero(key, sizeof(key));
   sodium_mlock(key, sizeof(key));
 }
 
-void DnsCryptPrivateKey::loadFromFile(const std::string& keyFile)
+void DNSCryptPrivateKey::loadFromFile(const std::string& keyFile)
 {
   ifstream file(keyFile);
   sodium_memzero(key, sizeof(key));
@@ -47,54 +48,96 @@ void DnsCryptPrivateKey::loadFromFile(const std::string& keyFile)
   file.close();
 }
 
-void DnsCryptPrivateKey::saveToFile(const std::string& keyFile) const
+void DNSCryptPrivateKey::saveToFile(const std::string& keyFile) const
 {
   ofstream file(keyFile);
   file.write((char*) key, sizeof(key));
   file.close();
 }
 
-DnsCryptPrivateKey::~DnsCryptPrivateKey()
+DNSCryptPrivateKey::~DNSCryptPrivateKey()
 {
   sodium_munlock(key, sizeof(key));
 }
 
+DNSCryptExchangeVersion DNSCryptQuery::getVersion() const
+{
+  if (d_pair == nullptr) {
+    throw std::runtime_error("Unable to determine the version of a DNSCrypt query if there is not associated cert");
+  }
+
+  return DNSCryptContext::getExchangeVersion(d_pair->cert);
+}
+
 #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM
-DnsCryptQuery::~DnsCryptQuery()
+DNSCryptQuery::~DNSCryptQuery()
 {
-  if (sharedKeyComputed) {
-    sodium_munlock(sharedKey, sizeof(sharedKey));
+  if (d_sharedKeyComputed) {
+    sodium_munlock(d_sharedKey, sizeof(d_sharedKey));
   }
 }
 
-int DnsCryptQuery::computeSharedKey(const DnsCryptPrivateKey& privateKey)
+int DNSCryptQuery::computeSharedKey()
 {
+  assert(d_pair != nullptr);
+
   int res = 0;
 
-  if (sharedKeyComputed) {
+  if (d_sharedKeyComputed) {
     return res;
   }
 
-  sodium_mlock(sharedKey, sizeof(sharedKey));
-  res = crypto_box_beforenm(sharedKey,
-                            header.clientPK,
-                            privateKey.key);
+  const DNSCryptExchangeVersion version = DNSCryptContext::getExchangeVersion(d_pair->cert);
+
+  sodium_mlock(d_sharedKey, sizeof(d_sharedKey));
+
+  if (version == DNSCryptExchangeVersion::VERSION1) {
+    res = crypto_box_beforenm(d_sharedKey,
+                              d_header.clientPK,
+                              d_pair->privateKey.key);
+  }
+  else if (version == DNSCryptExchangeVersion::VERSION2) {
+#ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY
+    res = crypto_box_curve25519xchacha20poly1305_beforenm(d_sharedKey,
+                                                          d_header.clientPK,
+                                                          d_pair->privateKey.key);
+#else /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */
+    res = -1;
+#endif /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */
+  }
+  else {
+    res = -1;
+  }
 
   if (res != 0) {
-    sodium_munlock(sharedKey, sizeof(sharedKey));
+    sodium_munlock(d_sharedKey, sizeof(d_sharedKey));
     return res;
   }
 
-  sharedKeyComputed = true;
+  d_sharedKeyComputed = true;
   return res;
 }
 #else
-DnsCryptQuery::~DnsCryptQuery()
+DNSCryptQuery::~DNSCryptQuery()
 {
 }
 #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */
 
-void DnsCryptContext::generateProviderKeys(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE], unsigned char privateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE])
+DNSCryptContext::DNSCryptContext(const std::string& pName, const std::string& certFile, const std::string& keyFile): providerName(pName)
+{
+  pthread_rwlock_init(&d_lock, 0);
+
+  loadNewCertificate(certFile, keyFile);
+}
+
+DNSCryptContext::DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey): providerName(pName)
+{
+  pthread_rwlock_init(&d_lock, 0);
+
+  addNewCertificate(certificate, pKey);
+}
+
+void DNSCryptContext::generateProviderKeys(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE], unsigned char privateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE])
 {
   int res = crypto_sign_ed25519_keypair(publicKey, privateKey);
 
@@ -103,7 +146,7 @@ void DnsCryptContext::generateProviderKeys(unsigned char publicKey[DNSCRYPT_PROV
   }
 }
 
-std::string DnsCryptContext::getProviderFingerprint(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE])
+std::string DNSCryptContext::getProviderFingerprint(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE])
 {
   boost::format fmt("%02X%02X");
   ostringstream ret;
@@ -119,12 +162,51 @@ std::string DnsCryptContext::getProviderFingerprint(unsigned char publicKey[DNSC
   return ret.str();
 }
 
-void DnsCryptContext::generateCertificate(uint32_t serial, time_t begin, time_t end, const unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE], DnsCryptPrivateKey& privateKey, DnsCryptCert& cert)
+void DNSCryptContext::setExchangeVersion(const DNSCryptExchangeVersion& version,  unsigned char esVersion[sizeof(DNSCryptCert::esVersion)])
+{
+  esVersion[0] = 0x00;
+
+  if (version == DNSCryptExchangeVersion::VERSION1) {
+    esVersion[1] = { 0x01 };
+  }
+  else if (version == DNSCryptExchangeVersion::VERSION2) {
+    esVersion[1] = { 0x02 };
+  }
+  else {
+    throw std::runtime_error("Unknown DNSCrypt exchange version");
+  }
+}
+
+DNSCryptExchangeVersion DNSCryptContext::getExchangeVersion(const unsigned char esVersion[sizeof(DNSCryptCert::esVersion)])
+{
+  if (esVersion[0] != 0x00) {
+    throw std::runtime_error("Unknown DNSCrypt exchange version");
+  }
+
+  if (esVersion[1] == 0x01) {
+    return DNSCryptExchangeVersion::VERSION1;
+  }
+  else if (esVersion[1] == 0x02) {
+    return DNSCryptExchangeVersion::VERSION2;
+  }
+
+  throw std::runtime_error("Unknown DNSCrypt exchange version");
+}
+
+DNSCryptExchangeVersion DNSCryptContext::getExchangeVersion(const DNSCryptCert& cert)
+{
+  return getExchangeVersion(cert.esVersion);
+}
+
+
+void DNSCryptContext::generateCertificate(uint32_t serial, time_t begin, time_t end, const DNSCryptExchangeVersion& version, const unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE], DNSCryptPrivateKey& privateKey, DNSCryptCert& cert)
 {
   unsigned char magic[DNSCRYPT_CERT_MAGIC_SIZE] = DNSCRYPT_CERT_MAGIC_VALUE;
-  unsigned char esVersion[] = DNSCRYPT_CERT_ES_VERSION_VALUE;
   unsigned char protocolMinorVersion[] = DNSCRYPT_CERT_PROTOCOL_MINOR_VERSION_VALUE;
   unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE];
+  unsigned char esVersion[sizeof(DNSCryptCert::esVersion)];
+  setExchangeVersion(version, esVersion);
+
   generateResolverKeyPair(privateKey, pubK);
 
   memcpy(cert.magic, magic, sizeof(magic));
@@ -145,14 +227,14 @@ void DnsCryptContext::generateCertificate(uint32_t serial, time_t begin, time_t
                                 providerPrivateKey);
 
   if (res == 0) {
-    assert(signatureSize == sizeof(DnsCryptCertSignedData) + DNSCRYPT_SIGNATURE_SIZE);
+    assert(signatureSize == sizeof(DNSCryptCertSignedData) + DNSCRYPT_SIGNATURE_SIZE);
   }
   else {
     throw std::runtime_error("Error generating DNSCrypt certificate");
   }
 }
 
-void DnsCryptContext::loadCertFromFile(const std::string&filename, DnsCryptCert& dest)
+void DNSCryptContext::loadCertFromFile(const std::string&filename, DNSCryptCert& dest)
 {
   ifstream file(filename);
   file.read((char *) &dest, sizeof(dest));
@@ -163,14 +245,14 @@ void DnsCryptContext::loadCertFromFile(const std::string&filename, DnsCryptCert&
   file.close();
 }
 
-void DnsCryptContext::saveCertFromFile(const DnsCryptCert& cert, const std::string&filename)
+void DNSCryptContext::saveCertFromFile(const DNSCryptCert& cert, const std::string&filename)
 {
   ofstream file(filename);
   file.write((char *) &cert, sizeof(cert));
   file.close();
 }
 
-void DnsCryptContext::generateResolverKeyPair(DnsCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE])
+void DNSCryptContext::generateResolverKeyPair(DNSCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE])
 {
   int res = crypto_box_keypair(pubK, privK.key);
 
@@ -179,7 +261,7 @@ void DnsCryptContext::generateResolverKeyPair(DnsCryptPrivateKey& privK, unsigne
   }
 }
 
-void DnsCryptContext::computePublicKeyFromPrivate(const DnsCryptPrivateKey& privK, unsigned char* pubK)
+void DNSCryptContext::computePublicKeyFromPrivate(const DNSCryptPrivateKey& privK, unsigned char* pubK)
 {
   int res = crypto_scalarmult_base(pubK,
                                    privK.key);
@@ -189,10 +271,10 @@ void DnsCryptContext::computePublicKeyFromPrivate(const DnsCryptPrivateKey& priv
   }
 }
 
-std::string DnsCryptContext::certificateDateToStr(uint32_t date)
+std::string DNSCryptContext::certificateDateToStr(uint32_t date)
 {
   char buf[20];
-  time_t tdate = (time_t) ntohl(date);
+  time_t tdate = static_cast<time_t>(ntohl(date));
   struct tm date_tm;
 
   localtime_r(&tdate, &date_tm);
@@ -201,150 +283,233 @@ std::string DnsCryptContext::certificateDateToStr(uint32_t date)
   return string(buf);
 }
 
-void DnsCryptContext::setNewCertificate(const DnsCryptCert& newCert, const DnsCryptPrivateKey& newKey)
+void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active)
 {
-  // XXX TODO: this could use a lock
-  oldPrivateKey = privateKey;
-  oldCert = cert;
-  hasOldCert = true;
-  privateKey = newKey;
-  cert = newCert;
+  WriteLock w(&d_lock);
+
+  for (auto pair : certs) {
+    if (pair->cert.getSerial() == newCert.getSerial()) {
+      throw std::runtime_error("Error adding a new certificate: we already have a certificate with the same serial");
+    }
+  }
+
+  auto pair = std::make_shared<DNSCryptCertificatePair>();
+  pair->cert = newCert;
+  pair->privateKey = newKey;
+  computePublicKeyFromPrivate(pair->privateKey, pair->publicKey);
+  pair->active = active;
+  certs.push_back(pair);
 }
 
-void DnsCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile)
+void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active)
 {
-  DnsCryptCert newCert;
-  DnsCryptPrivateKey newPrivateKey;
+  DNSCryptCert newCert;
+  DNSCryptPrivateKey newPrivateKey;
 
   loadCertFromFile(certFile, newCert);
   newPrivateKey.loadFromFile(keyFile);
-  setNewCertificate(newCert, newPrivateKey);
+
+  addNewCertificate(newCert, newPrivateKey, active);
+}
+
+void DNSCryptContext::markActive(uint32_t serial)
+{
+  WriteLock w(&d_lock);
+
+  for (auto pair : certs) {
+    if (pair->active == false && pair->cert.getSerial() == serial) {
+      pair->active = true;
+      return;
+    }
+  }
+  throw std::runtime_error("No inactive certificate found with this serial");
+}
+
+void DNSCryptContext::markInactive(uint32_t serial)
+{
+  WriteLock w(&d_lock);
+
+  for (auto pair : certs) {
+    if (pair->active == true && pair->cert.getSerial() == serial) {
+      pair->active = false;
+      return;
+    }
+  }
+  throw std::runtime_error("No active certificate found with this serial");
 }
 
-void DnsCryptContext::parsePlaintextQuery(const char * packet, uint16_t packetSize, std::shared_ptr<DnsCryptQuery> query) const
+void DNSCryptContext::removeInactiveCertificate(uint32_t serial)
 {
+  WriteLock w(&d_lock);
+
+  for (auto it = certs.begin(); it != certs.end(); ) {
+    if ((*it)->active == false && (*it)->cert.getSerial() == serial) {
+      it = certs.erase(it);
+      return;
+    } else {
+      it++;
+    }
+  }
+  throw std::runtime_error("No inactive certificate found with this serial");
+}
+
+bool DNSCryptQuery::parsePlaintextQuery(const char * packet, uint16_t packetSize)
+{
+  assert(d_ctx != nullptr);
+
   if (packetSize < sizeof(dnsheader)) {
-    return;
+    return false;
   }
 
-  struct dnsheader * dh = (struct dnsheader *) packet;
+  const struct dnsheader * dh = reinterpret_cast<const struct dnsheader *>(packet);
   if (dh->qr || ntohs(dh->qdcount) != 1 || dh->ancount != 0 || dh->nscount != 0 || dh->opcode != Opcode::Query)
-    return;
+    return false;
 
   unsigned int consumed;
   uint16_t qtype, qclass;
   DNSName qname(packet, packetSize, sizeof(dnsheader), false, &qtype, &qclass, &consumed);
   if ((packetSize - sizeof(dnsheader)) < (consumed + sizeof(qtype) + sizeof(qclass)))
-    return;
+    return false;
 
   if (qtype != QType::TXT || qclass != QClass::IN)
-    return;
+    return false;
 
-  if (qname != DNSName(providerName))
-    return;
+  if (qname != d_ctx->getProviderName())
+    return false;
+
+  d_qname = qname;
+  d_id = dh->id;
+  d_valid = true;
 
-  query->qname = qname;
-  query->id = dh->id;
-  query->valid = true;
+  return true;
 }
 
-void DnsCryptContext::getCertificateResponse(const std::shared_ptr<DnsCryptQuery> query, vector<uint8_t>& response) const
+void DNSCryptContext::getCertificateResponse(time_t now, const DNSName& qname, uint16_t qid, std::vector<uint8_t>& response)
 {
-  DNSPacketWriter pw(response, query->qname, QType::TXT, QClass::IN, Opcode::Query);
+  DNSPacketWriter pw(response, qname, QType::TXT, QClass::IN, Opcode::Query);
   struct dnsheader * dh = pw.getHeader();
-  dh->id = query->id;
+  dh->id = qid;
   dh->qr = true;
   dh->rcode = RCode::NoError;
-  pw.startRecord(query->qname, QType::TXT, (DNSCRYPT_CERTIFICATE_RESPONSE_TTL), QClass::IN, DNSResourceRecord::ANSWER, true);
-  std::string scert;
-  uint8_t certSize = sizeof(cert);
-  scert.assign((const char*) &certSize, sizeof(certSize));
-  scert.append((const char*) &cert, certSize);
 
-  pw.xfrBlob(scert);
-  pw.commit();
-}
+  ReadLock r(&d_lock);
+  for (const auto pair : certs) {
+    if (!pair->active || !pair->cert.isValid(now)) {
+      continue;
+    }
 
-bool DnsCryptContext::magicMatchesPublicKey(std::shared_ptr<DnsCryptQuery> query) const
-{
-  const unsigned char* magic = query->header.clientMagic;
+    pw.startRecord(qname, QType::TXT, (DNSCRYPT_CERTIFICATE_RESPONSE_TTL), QClass::IN, DNSResourceRecord::ANSWER, true);
+    std::string scert;
+    uint8_t certSize = sizeof(pair->cert);
+    scert.assign((const char*) &certSize, sizeof(certSize));
+    scert.append((const char*) &pair->cert, certSize);
 
-  if (memcmp(magic, cert.signedData.clientMagic, DNSCRYPT_CLIENT_MAGIC_SIZE) == 0) {
-    return true;
+    pw.xfrBlob(scert);
+    pw.commit();
   }
+}
 
-  if (hasOldCert == true &&
-      memcmp(magic, oldCert.signedData.clientMagic, DNSCRYPT_CLIENT_MAGIC_SIZE) == 0) {
-    query->useOldCert = true;
-    return true;
+bool DNSCryptContext::magicMatchesAPublicKey(DNSCryptQuery& query, time_t now)
+{
+  const unsigned char* magic = query.getClientMagic();
+
+  ReadLock r(&d_lock);
+  for (const auto& pair : certs) {
+    if (pair->cert.isValid(now) && memcmp(magic, pair->cert.signedData.clientMagic, DNSCRYPT_CLIENT_MAGIC_SIZE) == 0) {
+      query.setCertificatePair(pair);
+      return true;
+    }
   }
 
   return false;
 }
 
-void DnsCryptContext::isQueryEncrypted(const char * packet, uint16_t packetSize, std::shared_ptr<DnsCryptQuery> query, bool tcp) const
+bool DNSCryptQuery::isEncryptedQuery(const char * packet, uint16_t packetSize, bool tcp, time_t now)
 {
-  query->encrypted = false;
+  assert(d_ctx != nullptr);
 
-  if (packetSize < sizeof(DnsCryptQueryHeader)) {
-    return;
+  d_encrypted = false;
+
+  if (packetSize < sizeof(DNSCryptQueryHeader)) {
+    return false;
   }
 
-  if (!tcp && packetSize < DnsCryptQuery::minUDPLength) {
-    return;
+  if (!tcp && packetSize < DNSCryptQuery::s_minUDPLength) {
+    return false;
   }
 
-  struct DnsCryptQueryHeader* header = (struct DnsCryptQueryHeader*) packet;
+  const struct DNSCryptQueryHeader* header = reinterpret_cast<const struct DNSCryptQueryHeader*>(packet);
 
-  query->header = *(header);
+  d_header = *header;
 
-  if (!magicMatchesPublicKey(query)) {
-    return;
+  if (!d_ctx->magicMatchesAPublicKey(*this, now)) {
+    return false;
   }
 
-  query->encrypted = true;
+  d_encrypted = true;
+
+  return true;
 }
 
-void DnsCryptContext::getDecryptedQuery(std::shared_ptr<DnsCryptQuery> query, bool tcp, char* packet, uint16_t packetSize, uint16_t* decryptedQueryLen) const
+void DNSCryptQuery::getDecrypted(bool tcp, char* packet, uint16_t packetSize, uint16_t* decryptedQueryLen)
 {
-  assert(decryptedQueryLen != NULL);
-  assert(query->encrypted);
-  assert(query->valid == false);
+  assert(decryptedQueryLen != nullptr);
+  assert(d_encrypted);
+  assert(d_pair != nullptr);
+  assert(d_valid == false);
 
 #ifdef DNSCRYPT_STRICT_PADDING_LENGTH
-  if (tcp && ((packetSize - sizeof(DnsCryptQueryHeader)) % DNSCRYPT_PADDED_BLOCK_SIZE) != 0) {
-    vinfolog("Dropping encrypted query with invalid size of %d (should be a multiple of %d)", (packetSize - sizeof(DnsCryptQueryHeader)), DNSCRYPT_PADDED_BLOCK_SIZE);
+  if (tcp && ((packetSize - sizeof(DNSCryptQueryHeader)) % DNSCRYPT_PADDED_BLOCK_SIZE) != 0) {
+    vinfolog("Dropping encrypted query with invalid size of %d (should be a multiple of %d)", (packetSize - sizeof(DNSCryptQueryHeader)), DNSCRYPT_PADDED_BLOCK_SIZE);
     return;
   }
 #endif
 
   unsigned char nonce[DNSCRYPT_NONCE_SIZE];
-  static_assert(sizeof(nonce) == (2* sizeof(query->header.clientNonce)), "Nonce should be larger than clientNonce (half)");
-  static_assert(sizeof(query->header.clientPK) == DNSCRYPT_PUBLIC_KEY_SIZE, "Client Publick key size is not right");
-  static_assert(sizeof(privateKey.key) == DNSCRYPT_PRIVATE_KEY_SIZE, "Private key size is not right");
+  static_assert(sizeof(nonce) == (2* sizeof(d_header.clientNonce)), "Nonce should be larger than clientNonce (half)");
+  static_assert(sizeof(d_header.clientPK) == DNSCRYPT_PUBLIC_KEY_SIZE, "Client Publick key size is not right");
+  static_assert(sizeof(d_pair->privateKey.key) == DNSCRYPT_PRIVATE_KEY_SIZE, "Private key size is not right");
 
-  memcpy(nonce, &query->header.clientNonce, sizeof(query->header.clientNonce));
-  memset(nonce + sizeof(query->header.clientNonce), 0, sizeof(nonce) - sizeof(query->header.clientNonce));
+  memcpy(nonce, &d_header.clientNonce, sizeof(d_header.clientNonce));
+  memset(nonce + sizeof(d_header.clientNonce), 0, sizeof(nonce) - sizeof(d_header.clientNonce));
 
 #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM
-  int res = query->computeSharedKey(query->useOldCert ? oldPrivateKey : privateKey);
+  int res = computeSharedKey();
   if (res != 0) {
     vinfolog("Dropping encrypted query we can't compute the shared key for");
     return;
   }
 
-  res = crypto_box_open_easy_afternm((unsigned char*) packet,
-                                     (unsigned char*) packet + sizeof(DnsCryptQueryHeader),
-                                     packetSize - sizeof(DnsCryptQueryHeader),
-                                     nonce,
-                                     query->sharedKey);
-#else
-  int res = crypto_box_open_easy((unsigned char*) packet,
-                                 (unsigned char*) packet + sizeof(DnsCryptQueryHeader),
-                                 packetSize - sizeof(DnsCryptQueryHeader),
+  const DNSCryptExchangeVersion version = getVersion();
+
+  if (version == DNSCryptExchangeVersion::VERSION1) {
+    res = crypto_box_open_easy_afternm(reinterpret_cast<unsigned char*>(packet),
+                                       reinterpret_cast<unsigned char*>(packet + sizeof(DNSCryptQueryHeader)),
+                                       packetSize - sizeof(DNSCryptQueryHeader),
+                                       nonce,
+                                       d_sharedKey);
+  }
+  else if (version == DNSCryptExchangeVersion::VERSION2) {
+#ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY
+    res = crypto_box_curve25519xchacha20poly1305_open_easy_afternm(reinterpret_cast<unsigned char*>(packet),
+                                                                   reinterpret_cast<unsigned char*>(packet + sizeof(DNSCryptQueryHeader)),
+                                                                   packetSize - sizeof(DNSCryptQueryHeader),
+                                                                   nonce,
+                                                                   d_sharedKey);
+#else /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */
+    res = -1;
+#endif /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */
+  } else {
+    res = -1;
+  }
+
+#else /* HAVE_CRYPTO_BOX_EASY_AFTERNM */
+  int res = crypto_box_open_easy(reinterpret_cast<unsigned char*>(packet),
+                                 reinterpret_cast<unsigned char*>(packet + sizeof(DNSCryptQueryHeader)),
+                                 packetSize - sizeof(DNSCryptQueryHeader),
                                  nonce,
-                                 query->header.clientPK,
-                                 query->useOldCert ? oldPrivateKey.key : privateKey.key);
+                                 d_header.clientPK,
+                                 d_pair->privateKey.key);
 #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */
 
   if (res != 0) {
@@ -352,14 +517,14 @@ void DnsCryptContext::getDecryptedQuery(std::shared_ptr<DnsCryptQuery> query, bo
     return;
   }
 
-  *decryptedQueryLen = packetSize - sizeof(DnsCryptQueryHeader) - DNSCRYPT_MAC_SIZE;
+  *decryptedQueryLen = packetSize - sizeof(DNSCryptQueryHeader) - DNSCRYPT_MAC_SIZE;
   uint16_t pos = *decryptedQueryLen;
   assert(pos < packetSize);
-  query->paddedLen = *decryptedQueryLen;
+  d_paddedLen = *decryptedQueryLen;
 
   while(pos > 0 && packet[pos - 1] == 0) pos--;
 
-  if (pos == 0 || ((uint8_t) packet[pos - 1]) != 0x80) {
+  if (pos == 0 || static_cast<uint8_t>(packet[pos - 1]) != 0x80) {
     vinfolog("Dropping encrypted query with invalid padding value");
     return;
   }
@@ -374,32 +539,35 @@ void DnsCryptContext::getDecryptedQuery(std::shared_ptr<DnsCryptQuery> query, bo
     return;
   }
 
-  query->len = pos;
+  d_len = pos;
+  d_valid = true;
+}
 
-  query->valid = true;
+void DNSCryptQuery::getCertificateResponse(time_t now, std::vector<uint8_t>& response) const
+{
+  assert(d_ctx != nullptr);
+  d_ctx->getCertificateResponse(now, d_qname, d_id, response);
 }
 
-void DnsCryptContext::parsePacket(char* packet, uint16_t packetSize, std::shared_ptr<DnsCryptQuery> query, bool tcp, uint16_t* decryptedQueryLen) const
+void DNSCryptQuery::parsePacket(char* packet, uint16_t packetSize, bool tcp, uint16_t* decryptedQueryLen, time_t now)
 {
-  assert(packet != NULL);
-  assert(decryptedQueryLen != NULL);
+  assert(packet != nullptr);
+  assert(decryptedQueryLen != nullptr);
 
-  query->valid = false;
+  d_valid = false;
 
   /* might be a plaintext certificate request or an authenticated request */
-  isQueryEncrypted(packet, packetSize, query, tcp);
-
-  if (query->encrypted) {
-    getDecryptedQuery(query, tcp, packet, packetSize, decryptedQueryLen);
+  if (isEncryptedQuery(packet, packetSize, tcp, now)) {
+    getDecrypted(tcp, packet, packetSize, decryptedQueryLen);
   }
   else {
-    parsePlaintextQuery(packet, packetSize, query);
+    parsePlaintextQuery(packet, packetSize);
   }
 }
 
-void DnsCryptContext::fillServerNonce(unsigned char* nonce) const
+void DNSCryptQuery::fillServerNonce(unsigned char* nonce) const
 {
-  uint32_t* dest = (uint32_t*) nonce;
+  uint32_t* dest = reinterpret_cast<uint32_t*>(nonce);
   static const size_t nonceSize = DNSCRYPT_NONCE_SIZE / 2;
 
   for (size_t pos = 0; pos < (nonceSize / sizeof(*dest)); pos++)
@@ -413,44 +581,47 @@ void DnsCryptContext::fillServerNonce(unsigned char* nonce) const
    "The length of <resolver-response-pad> must be between 0 and 256 bytes,
    and must be constant for a given (<resolver-sk>, <client-nonce>) tuple."
 */
-uint16_t DnsCryptContext::computePaddingSize(uint16_t unpaddedLen, size_t maxLen, const unsigned char* clientNonce) const
+uint16_t DNSCryptQuery::computePaddingSize(uint16_t unpaddedLen, size_t maxLen) const
 {
-  size_t paddedLen = 0;
+  size_t paddedSize = 0;
   uint16_t result = 0;
   uint32_t rnd = 0;
-  assert(clientNonce != NULL);
+  assert(d_header.clientNonce);
+  assert(d_pair != nullptr);
+
   unsigned char nonce[DNSCRYPT_NONCE_SIZE];
-  memcpy(nonce, clientNonce, (DNSCRYPT_NONCE_SIZE / 2));
-  memcpy(&(nonce[DNSCRYPT_NONCE_SIZE / 2]), clientNonce, (DNSCRYPT_NONCE_SIZE / 2));
-  crypto_stream((unsigned char*) &rnd, sizeof(rnd), nonce, privateKey.key);
+  memcpy(nonce, d_header.clientNonce, (DNSCRYPT_NONCE_SIZE / 2));
+  memcpy(&(nonce[DNSCRYPT_NONCE_SIZE / 2]), d_header.clientNonce, (DNSCRYPT_NONCE_SIZE / 2));
+  crypto_stream((unsigned char*) &rnd, sizeof(rnd), nonce, d_pair->privateKey.key);
 
-  paddedLen = unpaddedLen + rnd % (maxLen - unpaddedLen + 1);
-  paddedLen += DNSCRYPT_PADDED_BLOCK_SIZE - (paddedLen % DNSCRYPT_PADDED_BLOCK_SIZE);
+  paddedSize = unpaddedLen + rnd % (maxLen - unpaddedLen + 1);
+  paddedSize += DNSCRYPT_PADDED_BLOCK_SIZE - (paddedSize % DNSCRYPT_PADDED_BLOCK_SIZE);
 
-  if (paddedLen > maxLen)
-    paddedLen = maxLen;
+  if (paddedSize > maxLen)
+    paddedSize = maxLen;
 
-  result = paddedLen - unpaddedLen;
+  result = paddedSize - unpaddedLen;
 
   return result;
 }
 
-int DnsCryptContext::encryptResponse(char* response, uint16_t responseLen, uint16_t responseSize, const std::shared_ptr<DnsCryptQuery> query, bool tcp, uint16_t* encryptedResponseLen) const
+int DNSCryptQuery::encryptResponse(char* response, uint16_t responseLen, uint16_t responseSize, bool tcp, uint16_t* encryptedResponseLen)
 {
-  struct DnsCryptResponseHeader header;
-  assert(response != NULL);
+  struct DNSCryptResponseHeader responseHeader;
+  assert(response != nullptr);
   assert(responseLen > 0);
   assert(responseSize >= responseLen);
-  assert(encryptedResponseLen != NULL);
-  assert(query->encrypted == true);
+  assert(encryptedResponseLen != nullptr);
+  assert(d_encrypted == true);
+  assert(d_pair != nullptr);
 
-  if (!tcp && query->paddedLen < responseLen) {
-    struct dnsheader* dh = (struct dnsheader*) response;
+  if (!tcp && d_paddedLen < responseLen) {
+    struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(response);
     size_t questionSize = 0;
 
     if (responseLen > sizeof(dnsheader)) {
       unsigned int consumed = 0;
-      DNSName qname(response, responseLen, sizeof(dnsheader), false, 0, 0, &consumed);
+      DNSName tempQName(response, responseLen, sizeof(dnsheader), false, 0, 0, &consumed);
       if (consumed > 0) {
         questionSize = consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE;
       }
@@ -458,31 +629,31 @@ int DnsCryptContext::encryptResponse(char* response, uint16_t responseLen, uint1
 
     responseLen = sizeof(dnsheader) + questionSize;
 
-    if (responseLen > query->paddedLen) {
-      responseLen = query->paddedLen;
+    if (responseLen > d_paddedLen) {
+      responseLen = d_paddedLen;
     }
     dh->ancount = dh->arcount = dh->nscount = 0;
     dh->tc = 1;
   }
 
-  size_t requiredSize = sizeof(header) + DNSCRYPT_MAC_SIZE + responseLen;
+  size_t requiredSize = sizeof(responseHeader) + DNSCRYPT_MAC_SIZE + responseLen;
   size_t maxSize = (responseSize > (requiredSize + DNSCRYPT_MAX_RESPONSE_PADDING_SIZE)) ? (requiredSize + DNSCRYPT_MAX_RESPONSE_PADDING_SIZE) : responseSize;
-  uint16_t paddingSize = computePaddingSize(requiredSize, maxSize, query->header.clientNonce);
+  uint16_t paddingSize = computePaddingSize(requiredSize, maxSize);
   requiredSize += paddingSize;
 
   if (requiredSize > responseSize)
     return ENOBUFS;
 
-  memcpy(&header.nonce, &query->header.clientNonce, sizeof query->header.clientNonce);
-  fillServerNonce(&(header.nonce[sizeof(query->header.clientNonce)]));
+  memcpy(&responseHeader.nonce, &d_header.clientNonce, sizeof d_header.clientNonce);
+  fillServerNonce(&(responseHeader.nonce[sizeof(d_header.clientNonce)]));
 
   /* moving the existing response after the header + MAC */
-  memmove(response + sizeof(header) + DNSCRYPT_MAC_SIZE, response, responseLen);
+  memmove(response + sizeof(responseHeader) + DNSCRYPT_MAC_SIZE, response, responseLen);
 
   uint16_t pos = 0;
   /* copying header */
-  memcpy(response + pos, &header, sizeof(header));
-  pos += sizeof(header);
+  memcpy(response + pos, &responseHeader, sizeof(responseHeader));
+  pos += sizeof(responseHeader);
   /* setting MAC bytes to 0 */
   memset(response + pos, 0, DNSCRYPT_MAC_SIZE);
   pos += DNSCRYPT_MAC_SIZE;
@@ -490,30 +661,48 @@ int DnsCryptContext::encryptResponse(char* response, uint16_t responseLen, uint1
   /* skipping response */
   pos += responseLen;
   /* padding */
-  response[pos] = (uint8_t) 0x80;
+  response[pos] = static_cast<uint8_t>(0x80);
   pos++;
   memset(response + pos, 0, paddingSize - 1);
   pos += (paddingSize - 1);
 
   /* encrypting */
 #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM
-  int res = query->computeSharedKey(query->useOldCert ? oldPrivateKey : privateKey);
+  int res = computeSharedKey();
   if (res != 0) {
     return res;
   }
 
-  res = crypto_box_easy_afternm((unsigned char*) (response + sizeof(header)),
-                                (unsigned char*) (response + toEncryptPos),
-                                responseLen + paddingSize,
-                                header.nonce,
-                                query->sharedKey);
+  const DNSCryptExchangeVersion version = getVersion();
+
+  if (version == DNSCryptExchangeVersion::VERSION1) {
+    res = crypto_box_easy_afternm(reinterpret_cast<unsigned char*>(response + sizeof(responseHeader)),
+                                  reinterpret_cast<unsigned char*>(response + toEncryptPos),
+                                  responseLen + paddingSize,
+                                  responseHeader.nonce,
+                                  d_sharedKey);
+  }
+  else if (version == DNSCryptExchangeVersion::VERSION2) {
+#ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY
+    res = crypto_box_curve25519xchacha20poly1305_easy_afternm(reinterpret_cast<unsigned char*>(response + sizeof(responseHeader)),
+                                                              reinterpret_cast<unsigned char*>(response + toEncryptPos),
+                                                              responseLen + paddingSize,
+                                                              responseHeader.nonce,
+                                                              d_sharedKey);
+#else /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */
+    res = -1;
+#endif /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */
+  }
+  else {
+    res = -1;
+  }
 #else
-  int res = crypto_box_easy((unsigned char*) (response + sizeof(header)),
-                            (unsigned char*) (response + toEncryptPos),
+  int res = crypto_box_easy(reinterpret_cast<unsigned char*>(response + sizeof(responseHeader)),
+                            reinterpret_cast<unsigned char*>(response + toEncryptPos),
                             responseLen + paddingSize,
-                            header.nonce,
-                            query->header.clientPK,
-                            query->useOldCert ? oldPrivateKey.key : privateKey.key);
+                            responseHeader.nonce,
+                            d_header.clientPK,
+                            d_pair->privateKey.key);
 #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */
 
   if (res == 0) {
@@ -524,34 +713,36 @@ int DnsCryptContext::encryptResponse(char* response, uint16_t responseLen, uint1
   return res;
 }
 
-int DnsCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DnsCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen) const
+int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen, const std::shared_ptr<DNSCryptCert> cert) const
 {
-  assert(query != NULL);
+  assert(query != nullptr);
   assert(queryLen > 0);
   assert(querySize >= queryLen);
-  assert(encryptedResponseLen != NULL);
+  assert(encryptedResponseLen != nullptr);
+  assert(cert != nullptr);
+
   unsigned char nonce[DNSCRYPT_NONCE_SIZE];
-  size_t requiredSize = sizeof(DnsCryptQueryHeader) + DNSCRYPT_MAC_SIZE + queryLen;
+  size_t requiredSize = sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE + queryLen;
   /* this is not optimal, we should compute a random padding size, multiple of DNSCRYPT_PADDED_BLOCK_SIZE,
      DNSCRYPT_PADDED_BLOCK_SIZE <= padding size <= 4096? */
   uint16_t paddingSize = DNSCRYPT_PADDED_BLOCK_SIZE - (queryLen % DNSCRYPT_PADDED_BLOCK_SIZE);
   requiredSize += paddingSize;
 
-  if (!tcp && requiredSize < DnsCryptQuery::minUDPLength) {
-    paddingSize += (DnsCryptQuery::minUDPLength - requiredSize);
-    requiredSize = DnsCryptQuery::minUDPLength;
+  if (!tcp && requiredSize < DNSCryptQuery::s_minUDPLength) {
+    paddingSize += (DNSCryptQuery::s_minUDPLength - requiredSize);
+    requiredSize = DNSCryptQuery::s_minUDPLength;
   }
 
   if (requiredSize > querySize)
     return ENOBUFS;
 
   /* moving the existing query after the header + MAC */
-  memmove(query + sizeof(DnsCryptQueryHeader) + DNSCRYPT_MAC_SIZE, query, queryLen);
+  memmove(query + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE, query, queryLen);
 
   size_t pos = 0;
   /* client magic */
-  memcpy(query + pos, cert.signedData.clientMagic, sizeof(cert.signedData.clientMagic));
-  pos += sizeof(cert.signedData.clientMagic);
+  memcpy(query + pos, cert->signedData.clientMagic, sizeof(cert->signedData.clientMagic));
+  pos += sizeof(cert->signedData.clientMagic);
 
   /* client PK */
   memcpy(query + pos, clientPublicKey, DNSCRYPT_PUBLIC_KEY_SIZE);
@@ -570,7 +761,7 @@ int DnsCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t query
   pos += queryLen;
 
   /* padding */
-  query[pos] = (uint8_t) 0x80;
+  query[pos] = static_cast<uint8_t>(0x80);
   pos++;
   memset(query + pos, 0, paddingSize - 1);
   pos += paddingSize - 1;
@@ -578,12 +769,30 @@ int DnsCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t query
   memcpy(nonce, clientNonce, DNSCRYPT_NONCE_SIZE / 2);
   memset(nonce + (DNSCRYPT_NONCE_SIZE / 2), 0, DNSCRYPT_NONCE_SIZE / 2);
 
-  int res = crypto_box_easy((unsigned char*) query + encryptedPos,
-                            (unsigned char*) query + encryptedPos + DNSCRYPT_MAC_SIZE,
-                            queryLen + paddingSize,
-                            nonce,
-                            cert.signedData.resolverPK,
-                            clientPrivateKey.key);
+  const DNSCryptExchangeVersion version = getExchangeVersion(*cert);
+  int res = -1;
+
+  if (version == DNSCryptExchangeVersion::VERSION1) {
+    res = crypto_box_easy(reinterpret_cast<unsigned char*>(query + encryptedPos),
+                          reinterpret_cast<unsigned char*>(query + encryptedPos + DNSCRYPT_MAC_SIZE),
+                          queryLen + paddingSize,
+                          nonce,
+                          cert->signedData.resolverPK,
+                          clientPrivateKey.key);
+  }
+  else if (version == DNSCryptExchangeVersion::VERSION2) {
+#ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY
+    res = crypto_box_curve25519xchacha20poly1305_easy(reinterpret_cast<unsigned char*>(query + encryptedPos),
+                                                      reinterpret_cast<unsigned char*>(query + encryptedPos + DNSCRYPT_MAC_SIZE),
+                                                      queryLen + paddingSize,
+                                                      nonce,
+                                                      cert->signedData.resolverPK,
+                                                      clientPrivateKey.key);
+#endif /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */
+  }
+  else {
+    throw std::runtime_error("Unknown DNSCrypt exchange version");
+  }
 
   if (res == 0) {
     assert(pos == requiredSize);
@@ -593,7 +802,7 @@ int DnsCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t query
   return res;
 }
 
-bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, DnsCryptCert& certOut, DnsCryptPrivateKey& keyOut)
+bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, DNSCryptExchangeVersion version, DNSCryptCert& certOut, DNSCryptPrivateKey& keyOut)
 {
   bool success = false;
   unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE];
@@ -608,7 +817,7 @@ bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint
       throw std::runtime_error("Invalid DNSCrypt provider key file " + providerPrivateKeyFile);
     }
 
-    DnsCryptContext::generateCertificate(serial, begin, end, providerPrivateKey, keyOut, certOut);
+    DNSCryptContext::generateCertificate(serial, begin, end, version, providerPrivateKey, keyOut, certOut);
     success = true;
   }
   catch(const std::exception& e) {
index 49f1186c38c86a35414ff46ad5ffd183fc253af6..aad89cd8c87acca7b66ec0f2e27d0be1cad3421e 100644 (file)
 #include <memory>
 #include <string>
 #include <vector>
+#include <arpa/inet.h>
+
 #include <sodium.h>
 
 #include "dnsname.hh"
 
 #define DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE (crypto_sign_ed25519_PUBLICKEYBYTES)
 #define DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE (crypto_sign_ed25519_SECRETKEYBYTES)
+#define DNSCRYPT_SIGNATURE_SIZE (crypto_sign_ed25519_BYTES)
+
 #define DNSCRYPT_PUBLIC_KEY_SIZE (crypto_box_curve25519xsalsa20poly1305_PUBLICKEYBYTES)
 #define DNSCRYPT_PRIVATE_KEY_SIZE (crypto_box_curve25519xsalsa20poly1305_SECRETKEYBYTES)
 #define DNSCRYPT_NONCE_SIZE (crypto_box_curve25519xsalsa20poly1305_NONCEBYTES)
 #define DNSCRYPT_BEFORENM_SIZE (crypto_box_curve25519xsalsa20poly1305_BEFORENMBYTES)
-#define DNSCRYPT_SIGNATURE_SIZE (crypto_sign_ed25519_BYTES)
 #define DNSCRYPT_MAC_SIZE (crypto_box_curve25519xsalsa20poly1305_MACBYTES)
+
+#ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY
+static_assert(crypto_box_curve25519xsalsa20poly1305_PUBLICKEYBYTES == crypto_box_curve25519xchacha20poly1305_PUBLICKEYBYTES, "DNSCrypt public key size should be the same for all exchange versions");
+static_assert(crypto_box_curve25519xsalsa20poly1305_SECRETKEYBYTES == crypto_box_curve25519xchacha20poly1305_SECRETKEYBYTES, "DNSCrypt private key size should be the same for all exchange versions");
+static_assert(crypto_box_curve25519xchacha20poly1305_NONCEBYTES == crypto_box_curve25519xsalsa20poly1305_NONCEBYTES, "DNSCrypt nonce size should be the same for all exchange versions");
+static_assert(crypto_box_curve25519xsalsa20poly1305_MACBYTES == crypto_box_curve25519xchacha20poly1305_MACBYTES, "DNSCrypt MAC size should be the same for all exchange versions");
+static_assert(crypto_box_curve25519xchacha20poly1305_BEFORENMBYTES == crypto_box_curve25519xsalsa20poly1305_BEFORENMBYTES, "DNSCrypt BEFORENM size should be the same for all exchange versions");
+#endif /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */
+
 #define DNSCRYPT_CERT_MAGIC_SIZE (4)
 #define DNSCRYPT_CERT_MAGIC_VALUE { 0x44, 0x4e, 0x53, 0x43 }
-#define DNSCRYPT_CERT_ES_VERSION_VALUE { 0x00, 0x01 }
 #define DNSCRYPT_CERT_PROTOCOL_MINOR_VERSION_VALUE { 0x00, 0x00 }
 #define DNSCRYPT_CLIENT_MAGIC_SIZE (8)
 #define DNSCRYPT_RESOLVER_MAGIC { 0x72, 0x36, 0x66, 0x6e, 0x76, 0x57, 0x6a, 0x38 }
 /* "The client must check for new certificates every hour", so let's use one hour TTL */
 #define DNSCRYPT_CERTIFICATE_RESPONSE_TTL (3600)
 
-static_assert(DNSCRYPT_CLIENT_MAGIC_SIZE <= DNSCRYPT_PUBLIC_KEY_SIZE, "Dnscrypt Client Nonce size should be smaller or equal to public key size.");
+static_assert(DNSCRYPT_CLIENT_MAGIC_SIZE <= DNSCRYPT_PUBLIC_KEY_SIZE, "DNSCrypt Client Nonce size should be smaller or equal to public key size.");
+
+#define DNSCRYPT_CERT_ES_VERSION1_VALUE { 0x00, 0x01 }
+#define DNSCRYPT_CERT_ES_VERSION2_VALUE { 0x00, 0x02 }
 
-class DnsCryptContext;
+class DNSCryptContext;
 
-struct DnsCryptCertSignedData
+struct DNSCryptCertSignedData
 {
   unsigned char resolverPK[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE];
   unsigned char clientMagic[DNSCRYPT_CLIENT_MAGIC_SIZE];
@@ -67,128 +81,184 @@ struct DnsCryptCertSignedData
   uint32_t tsEnd;
 };
 
-struct DnsCryptCert
+class DNSCryptCert
 {
+public:
+  uint32_t getSerial() const
+  {
+    return signedData.serial;
+  }
+  uint32_t getTSStart() const
+  {
+    return signedData.tsStart;
+  }
+  uint32_t getTSEnd() const
+  {
+    return signedData.tsEnd;
+  }
+  bool isValid(time_t now) const
+  {
+    return ntohl(getTSStart()) <= now && now <= ntohl(getTSEnd());
+  }
   unsigned char magic[DNSCRYPT_CERT_MAGIC_SIZE];
   unsigned char esVersion[2];
   unsigned char protocolMinorVersion[2];
   unsigned char signature[DNSCRYPT_SIGNATURE_SIZE];
-  struct DnsCryptCertSignedData signedData;
+  struct DNSCryptCertSignedData signedData;
 };
 
-static_assert((sizeof(DnsCryptCertSignedData) + DNSCRYPT_SIGNATURE_SIZE) == 116, "Dnscrypt cert signed data size + signature size should be 116!");
-static_assert(sizeof(DnsCryptCert) == 124, "Dnscrypt cert size should be 124!");
+static_assert((sizeof(DNSCryptCertSignedData) + DNSCRYPT_SIGNATURE_SIZE) == 116, "Dnscrypt cert signed data size + signature size should be 116!");
+static_assert(sizeof(DNSCryptCert) == 124, "Dnscrypt cert size should be 124!");
 
-struct DnsCryptQueryHeader
+struct DNSCryptQueryHeader
 {
   unsigned char clientMagic[DNSCRYPT_CLIENT_MAGIC_SIZE];
   unsigned char clientPK[DNSCRYPT_PUBLIC_KEY_SIZE];
   unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2];
 };
 
-static_assert(sizeof(DnsCryptQueryHeader) == 52, "Dnscrypt query header size should be 52!");
+static_assert(sizeof(DNSCryptQueryHeader) == 52, "Dnscrypt query header size should be 52!");
 
-struct DnsCryptResponseHeader
+struct DNSCryptResponseHeader
 {
   const unsigned char resolverMagic[DNSCRYPT_RESOLVER_MAGIC_SIZE] = DNSCRYPT_RESOLVER_MAGIC;
   unsigned char nonce[DNSCRYPT_NONCE_SIZE];
 };
 
-class DnsCryptPrivateKey
+typedef enum {
+  VERSION1,
+  VERSION2
+} DNSCryptExchangeVersion;
+
+class DNSCryptPrivateKey
 {
 public:
-  DnsCryptPrivateKey();
-  ~DnsCryptPrivateKey();
+  DNSCryptPrivateKey();
+  ~DNSCryptPrivateKey();
   void loadFromFile(const std::string& keyFile);
   void saveToFile(const std::string& keyFile) const;
 
   unsigned char key[DNSCRYPT_PRIVATE_KEY_SIZE];
 };
 
-class DnsCryptQuery
+struct DNSCryptCertificatePair
+{
+  unsigned char publicKey[DNSCRYPT_PUBLIC_KEY_SIZE];
+  DNSCryptCert cert;
+  DNSCryptPrivateKey privateKey;
+  bool active;
+};
+
+class DNSCryptQuery
 {
 public:
-  DnsCryptQuery()
+  DNSCryptQuery(std::shared_ptr<DNSCryptContext> ctx): d_ctx(ctx)
+  {
+  }
+  ~DNSCryptQuery();
+
+  bool isValid() const
+  {
+    return d_valid;
+  }
+
+  const DNSName& getQName() const
   {
+    return d_qname;
   }
-  ~DnsCryptQuery();
+
+  uint16_t getID() const
+  {
+    return d_id;
+  }
+
+  const unsigned char* getClientMagic() const
+  {
+    return d_header.clientMagic;
+  }
+
+  bool isEncrypted() const
+  {
+    return d_encrypted;
+  }
+
+  void setCertificatePair(std::shared_ptr<DNSCryptCertificatePair> pair)
+  {
+    d_pair = pair;
+  }
+
+  void parsePacket(char* packet, uint16_t packetSize, bool tcp, uint16_t* decryptedQueryLen, time_t now);
+  void getDecrypted(bool tcp, char* packet, uint16_t packetSize, uint16_t* decryptedQueryLen);
+  void getCertificateResponse(time_t now, std::vector<uint8_t>& response) const;
+  int encryptResponse(char* response, uint16_t responseLen, uint16_t responseSize, bool tcp, uint16_t* encryptedResponseLen);
+
+  static const size_t s_minUDPLength = 256;
+
+private:
+  DNSCryptExchangeVersion getVersion() const;
 #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM
-  int computeSharedKey(const DnsCryptPrivateKey& privateKey);
+  int computeSharedKey();
 #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */
+  void fillServerNonce(unsigned char* dest) const;
+  uint16_t computePaddingSize(uint16_t unpaddedLen, size_t maxLen) const;
+  bool parsePlaintextQuery(const char * packet, uint16_t packetSize);
+  bool isEncryptedQuery(const char * packet, uint16_t packetSize, bool tcp, time_t now);
 
-  static const size_t minUDPLength = 256;
-
-  DnsCryptQueryHeader header;
+  DNSCryptQueryHeader d_header;
 #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM
-  unsigned char sharedKey[crypto_box_BEFORENMBYTES];
+  unsigned char d_sharedKey[crypto_box_BEFORENMBYTES];
 #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */
-  DNSName qname;
-  DnsCryptContext* ctx;
-  uint16_t id{0};
-  uint16_t len{0};
-  uint16_t paddedLen;
-  bool useOldCert{false};
-  bool encrypted{false};
-  bool valid{false};
+  DNSName d_qname;
+  std::shared_ptr<DNSCryptContext> d_ctx{nullptr};
+  std::shared_ptr<DNSCryptCertificatePair> d_pair{nullptr};
+  uint16_t d_id{0};
+  uint16_t d_len{0};
+  uint16_t d_paddedLen;
+  bool d_encrypted{false};
+  bool d_valid{false};
+
 #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM
-  bool sharedKeyComputed{false};
+  bool d_sharedKeyComputed{false};
 #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */
 };
 
-class DnsCryptContext
+class DNSCryptContext
 {
 public:
   static void generateProviderKeys(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE], unsigned char privateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]);
   static std::string getProviderFingerprint(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]);
-  static void generateCertificate(uint32_t serial, time_t begin, time_t end, const unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE], DnsCryptPrivateKey& privateKey, DnsCryptCert& cert);
-  static void saveCertFromFile(const DnsCryptCert& cert, const std::string&filename);
+  static void generateCertificate(uint32_t serial, time_t begin, time_t end, const DNSCryptExchangeVersion& version, const unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE], DNSCryptPrivateKey& privateKey, DNSCryptCert& cert);
+  static void saveCertFromFile(const DNSCryptCert& cert, const std::string&filename);
   static std::string certificateDateToStr(uint32_t date);
-  static void generateResolverKeyPair(DnsCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]);
+  static void generateResolverKeyPair(DNSCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]);
+  static void setExchangeVersion(const DNSCryptExchangeVersion& version,  unsigned char esVersion[sizeof(DNSCryptCert::esVersion)]);
+  static DNSCryptExchangeVersion getExchangeVersion(const unsigned char esVersion[sizeof(DNSCryptCert::esVersion)]);
+  static DNSCryptExchangeVersion getExchangeVersion(const DNSCryptCert& cert);
 
-  DnsCryptContext(const std::string& pName, const std::string& certFile, const std::string& keyFile): providerName(pName)
-  {
-    loadCertFromFile(certFile, cert);
-    privateKey.loadFromFile(keyFile);
-    computePublicKeyFromPrivate(privateKey, publicKey);
-  }
+  DNSCryptContext(const std::string& pName, const std::string& certFile, const std::string& keyFile);
+  DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey);
 
-  DnsCryptContext(const std::string& pName, const DnsCryptCert& certificate, const DnsCryptPrivateKey& pKey): providerName(pName), cert(certificate), privateKey(pKey)
-  {
-    computePublicKeyFromPrivate(privateKey, publicKey);
-  }
-
-  void parsePacket(char* packet, uint16_t packetSize, std::shared_ptr<DnsCryptQuery> query, bool tcp, uint16_t* decryptedQueryLen) const;
-  int encryptResponse(char* response, uint16_t responseLen, uint16_t responseSize, const std::shared_ptr<DnsCryptQuery> query, bool tcp, uint16_t* encryptedResponseLen) const;
-  void getCertificateResponse(const std::shared_ptr<DnsCryptQuery> query, std::vector<uint8_t>& response) const;
-  void loadNewCertificate(const std::string& certFile, const std::string& keyFile);
-  void setNewCertificate(const DnsCryptCert& newCert, const DnsCryptPrivateKey& newKey);
-  const DnsCryptCert& getCurrentCertificate() const { return cert; };
-  const DnsCryptCert& getOldCertificate() const { return oldCert; };
-  bool hasOldCertificate() const { return hasOldCert; };
-  const std::string& getProviderName() const { return providerName; }
-  int encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DnsCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen) const;
+  void loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active=true);
+  void addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active=true);
+  void markActive(uint32_t serial);
+  void markInactive(uint32_t serial);
+  void removeInactiveCertificate(uint32_t serial);
+  std::vector<std::shared_ptr<DNSCryptCertificatePair>> getCertificates() { return certs; };
+  const DNSName& getProviderName() const { return providerName; }
 
+  int encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen, const std::shared_ptr<DNSCryptCert> cert) const;
+  bool magicMatchesAPublicKey(DNSCryptQuery& query, time_t now);
+  void getCertificateResponse(time_t now, const DNSName& qname, uint16_t qid, std::vector<uint8_t>& response);
 
 private:
-  static void computePublicKeyFromPrivate(const DnsCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]);
-  static void loadCertFromFile(const std::string&filename, DnsCryptCert& dest);
+  static void computePublicKeyFromPrivate(const DNSCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]);
+  static void loadCertFromFile(const std::string&filename, DNSCryptCert& dest);
 
-  void parsePlaintextQuery(const char * packet, uint16_t packetSize, std::shared_ptr<DnsCryptQuery> query) const;
-  bool magicMatchesPublicKey(std::shared_ptr<DnsCryptQuery> query) const;
-  void isQueryEncrypted(const char * packet, uint16_t packetSize, std::shared_ptr<DnsCryptQuery> query, bool tcp) const;
-  void getDecryptedQuery(std::shared_ptr<DnsCryptQuery> query, bool tcp, char* packet, uint16_t packetSize, uint16_t* decryptedQueryLen) const;
-  void fillServerNonce(unsigned char* dest) const;
-  uint16_t computePaddingSize(uint16_t unpaddedLen, size_t maxLen, const unsigned char* clientNonce) const;
-
-  std::string providerName;
-  DnsCryptCert cert;
-  DnsCryptCert oldCert;
-  DnsCryptPrivateKey privateKey;
-  unsigned char publicKey[DNSCRYPT_PUBLIC_KEY_SIZE];
-  DnsCryptPrivateKey oldPrivateKey;
-  bool hasOldCert{false};
+  pthread_rwlock_t d_lock;
+  std::vector<std::shared_ptr<DNSCryptCertificatePair>> certs;
+  DNSName providerName;
 };
 
-bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, DnsCryptCert& certOut, DnsCryptPrivateKey& keyOut);
+bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, DNSCryptExchangeVersion version, DNSCryptCert& certOut, DNSCryptPrivateKey& keyOut);
 
 #endif
index 5a4c4885db8edc1e8bbadfab7382a365134ada83..bc8c75a234fe5ddc7706ddd66882cc8b36e4ba0b 100644 (file)
 #include "dnscrypt.hh"
 
 #ifdef HAVE_DNSCRYPT
-int handleDnsCryptQuery(DnsCryptContext* ctx, char* packet, uint16_t len, std::shared_ptr<DnsCryptQuery>& query, uint16_t* decryptedQueryLen, bool tcp, std::vector<uint8_t>& response)
+int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr<DNSCryptQuery> query, uint16_t* decryptedQueryLen, bool tcp, time_t now, std::vector<uint8_t>& response)
 {
-  query->ctx = ctx;
+  query->parsePacket(packet, len, tcp, decryptedQueryLen, now);
 
-  ctx->parsePacket(packet, len, query, tcp, decryptedQueryLen);
-
-  if (query->valid == false) {
+  if (query->isValid() == false) {
     vinfolog("Dropping DNSCrypt invalid query");
     return false;
   }
 
-  if (query->encrypted == false) {
-    ctx->getCertificateResponse(query, response);
+  if (query->isEncrypted() == false) {
+    query->getCertificateResponse(now, response);
 
     return false;
   }
 
-  if(*decryptedQueryLen < (int)sizeof(struct dnsheader)) {
+  if(*decryptedQueryLen < static_cast<uint16_t>(sizeof(struct dnsheader))) {
     g_stats.nonCompliantQueries++;
     return false;
   }
index bb1415c7a4871ef34c5196920a9755f57b0ced1f..292bc48c2ed06b4d6a635662fa0ac629348964fe 100644 (file)
@@ -286,37 +286,113 @@ void setupLuaBindings(bool client)
     });
 
 #ifdef HAVE_DNSCRYPT
-  /* DnsCryptContext bindings */
-  g_lua.registerFunction<std::string(DnsCryptContext::*)()>("getProviderName", [](const DnsCryptContext& ctx) { return ctx.getProviderName(); });
-  g_lua.registerFunction<DnsCryptCert(DnsCryptContext::*)()>("getCurrentCertificate", [](const DnsCryptContext& ctx) { return ctx.getCurrentCertificate(); });
-  g_lua.registerFunction<DnsCryptCert(DnsCryptContext::*)()>("getOldCertificate", [](const DnsCryptContext& ctx) { return ctx.getOldCertificate(); });
-  g_lua.registerFunction("hasOldCertificate", &DnsCryptContext::hasOldCertificate);
-  g_lua.registerFunction("loadNewCertificate", &DnsCryptContext::loadNewCertificate);
-  g_lua.registerFunction<void(DnsCryptContext::*)(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end)>("generateAndLoadInMemoryCertificate", [](DnsCryptContext& ctx, const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end) {
-      DnsCryptPrivateKey privateKey;
-      DnsCryptCert cert;
-
-      try {
-        if (generateDNSCryptCertificate(providerPrivateKeyFile, serial, begin, end, cert, privateKey)) {
-          ctx.setNewCertificate(cert, privateKey);
+    /* DNSCryptContext bindings */
+    g_lua.registerFunction<std::string(DNSCryptContext::*)()>("getProviderName", [](const DNSCryptContext& ctx) { return ctx.getProviderName().toStringNoDot(); });
+    g_lua.registerFunction("addNewCertificate", &DNSCryptContext::addNewCertificate);
+    g_lua.registerFunction("markActive", &DNSCryptContext::markActive);
+    g_lua.registerFunction("markInactive", &DNSCryptContext::markInactive);
+    g_lua.registerFunction("removeInactiveCertificate", &DNSCryptContext::removeInactiveCertificate);
+    g_lua.registerFunction<void(std::shared_ptr<DNSCryptContext>::*)(const std::string& certFile, const std::string& keyFile, boost::optional<bool> active)>("loadNewCertificate", [](std::shared_ptr<DNSCryptContext> ctx, const std::string& certFile, const std::string& keyFile, boost::optional<bool> active) {
+
+      if (ctx == nullptr) {
+        throw std::runtime_error("DNSCryptContext::loadNewCertificate() called on a nil value");
+      }
+
+      ctx->loadNewCertificate(certFile, keyFile, active ? *active : true);
+    });
+    g_lua.registerFunction<void(std::shared_ptr<DNSCryptContext>::*)(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, boost::optional<bool> active)>("addNewCertificate", [](std::shared_ptr<DNSCryptContext> ctx, const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, boost::optional<bool> active) {
+
+      if (ctx == nullptr) {
+        throw std::runtime_error("DNSCryptContext::addNewCertificate() called on a nil value");
+      }
+
+      ctx->addNewCertificate(newCert, newKey, active ? *active : true);
+    });
+    g_lua.registerFunction<std::map<int, std::shared_ptr<DNSCryptCertificatePair>>(std::shared_ptr<DNSCryptContext>::*)()>("getCertificatePairs", [](std::shared_ptr<DNSCryptContext> ctx) {
+      std::map<int, std::shared_ptr<DNSCryptCertificatePair>> result;
+
+      if (ctx != nullptr) {
+        size_t idx = 1;
+        for (auto pair : ctx->getCertificates()) {
+          result[idx++] = pair;
         }
       }
-      catch(const std::exception& e) {
-        errlog(e.what());
-        g_outputBuffer="Error: "+string(e.what())+"\n";
+
+      return result;
+    });
+    g_lua.registerFunction<std::shared_ptr<DNSCryptCertificatePair>(std::shared_ptr<DNSCryptContext>::*)(size_t idx)>("getCertificatePair", [](std::shared_ptr<DNSCryptContext> ctx, size_t idx) {
+
+      if (ctx == nullptr) {
+        throw std::runtime_error("DNSCryptContext::getCertificatePair() called on a nil value");
+      }
+
+      std::shared_ptr<DNSCryptCertificatePair> result = nullptr;
+      auto certs = ctx->getCertificates();
+      if (idx < certs.size()) {
+        result = certs.at(idx);
       }
+
+      return result;
     });
 
-  /* DnsCryptCert */
-  g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getMagic", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.magic), sizeof(cert.magic)); });
-  g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getEsVersion", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.esVersion), sizeof(cert.esVersion)); });
-  g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getProtocolMinorVersion", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.protocolMinorVersion), sizeof(cert.protocolMinorVersion)); });
-  g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getSignature", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.signature), sizeof(cert.signature)); });
-  g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getResolverPublicKey", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.signedData.resolverPK), sizeof(cert.signedData.resolverPK)); });
-  g_lua.registerFunction<std::string(DnsCryptCert::*)()>("getClientMagic", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.signedData.clientMagic), sizeof(cert.signedData.clientMagic)); });
-  g_lua.registerFunction<uint32_t(DnsCryptCert::*)()>("getSerial", [](const DnsCryptCert& cert) { return cert.signedData.serial; });
-  g_lua.registerFunction<uint32_t(DnsCryptCert::*)()>("getTSStart", [](const DnsCryptCert& cert) { return ntohl(cert.signedData.tsStart); });
-  g_lua.registerFunction<uint32_t(DnsCryptCert::*)()>("getTSEnd", [](const DnsCryptCert& cert) { return ntohl(cert.signedData.tsEnd); });
+    g_lua.registerFunction<std::string(std::shared_ptr<DNSCryptContext>::*)()>("printCertificates", [](const std::shared_ptr<DNSCryptContext> ctx) {
+      ostringstream ret;
+
+      if (ctx != nullptr) {
+        size_t idx = 1;
+        boost::format fmt("%1$-3d %|5t|%2$-8d %|10t|%3$-2d %|20t|%4$-21.21s %|41t|%5$-21.21s");
+        ret << (fmt % "#" % "Serial" % "Version" % "From" % "To" ) << endl;
+
+        for (auto pair : ctx->getCertificates()) {
+          const auto cert = pair->cert;
+          const DNSCryptExchangeVersion version = DNSCryptContext::getExchangeVersion(cert);
+
+          ret << (fmt % idx % cert.getSerial() % (version == DNSCryptExchangeVersion::VERSION1 ? 1 : 2) % DNSCryptContext::certificateDateToStr(cert.getTSStart()) % DNSCryptContext::certificateDateToStr(cert.getTSEnd())) << endl;
+        }
+      }
+
+      return ret.str();
+    });
+
+    g_lua.registerFunction<void(DNSCryptContext::*)(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, boost::optional<DNSCryptExchangeVersion> version)>("generateAndLoadInMemoryCertificate", [](DNSCryptContext& ctx, const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, boost::optional<DNSCryptExchangeVersion> version) {
+        DNSCryptPrivateKey privateKey;
+        DNSCryptCert cert;
+
+        try {
+          if (generateDNSCryptCertificate(providerPrivateKeyFile, serial, begin, end, version ? *version : DNSCryptExchangeVersion::VERSION1, cert, privateKey)) {
+            ctx.addNewCertificate(cert, privateKey);
+          }
+        }
+        catch(const std::exception& e) {
+          errlog(e.what());
+          g_outputBuffer="Error: "+string(e.what())+"\n";
+        }
+    });
+
+    /* DNSCryptCertificatePair */
+    g_lua.registerFunction<const DNSCryptCert(std::shared_ptr<DNSCryptCertificatePair>::*)()>("getCertificate", [](const std::shared_ptr<DNSCryptCertificatePair> pair) {
+      if (pair == nullptr) {
+        throw std::runtime_error("DNSCryptCertificatePair::getCertificate() called on a nil value");
+      }
+      return pair->cert;
+    });
+    g_lua.registerFunction<bool(std::shared_ptr<DNSCryptCertificatePair>::*)()>("isActive", [](const std::shared_ptr<DNSCryptCertificatePair> pair) {
+      if (pair == nullptr) {
+        throw std::runtime_error("DNSCryptCertificatePair::isActive() called on a nil value");
+      }
+      return pair->active;
+    });
+
+    /* DNSCryptCert */
+    g_lua.registerFunction<std::string(DNSCryptCert::*)()>("getMagic", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.magic), sizeof(cert.magic)); });
+    g_lua.registerFunction<std::string(DNSCryptCert::*)()>("getEsVersion", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.esVersion), sizeof(cert.esVersion)); });
+    g_lua.registerFunction<std::string(DNSCryptCert::*)()>("getProtocolMinorVersion", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.protocolMinorVersion), sizeof(cert.protocolMinorVersion)); });
+    g_lua.registerFunction<std::string(DNSCryptCert::*)()>("getSignature", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.signature), sizeof(cert.signature)); });
+    g_lua.registerFunction<std::string(DNSCryptCert::*)()>("getResolverPublicKey", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.signedData.resolverPK), sizeof(cert.signedData.resolverPK)); });
+    g_lua.registerFunction<std::string(DNSCryptCert::*)()>("getClientMagic", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast<const char*>(cert.signedData.clientMagic), sizeof(cert.signedData.clientMagic)); });
+    g_lua.registerFunction<uint32_t(DNSCryptCert::*)()>("getSerial", [](const DNSCryptCert& cert) { return cert.getSerial(); });
+    g_lua.registerFunction<uint32_t(DNSCryptCert::*)()>("getTSStart", [](const DNSCryptCert& cert) { return ntohl(cert.getTSStart()); });
+    g_lua.registerFunction<uint32_t(DNSCryptCert::*)()>("getTSEnd", [](const DNSCryptCert& cert) { return ntohl(cert.getTSEnd()); });
 #endif
 
   /* BPF Filter */
index d8a8c916f2e1e16637922fcba8aa2960452ef3f6..470d8827ecb2cc48a06facc66182ef20a2284ab8 100644 (file)
@@ -912,7 +912,7 @@ void setupLuaConfig(bool client)
       parseLocalBindVars(vars, doTCP, reusePort, tcpFastOpenQueueSize, interface, cpus);
 
       try {
-        DnsCryptContext ctx(providerName, certFile, keyFile);
+        auto ctx = std::make_shared<DNSCryptContext>(providerName, certFile, keyFile);
         g_dnsCryptLocals.push_back(std::make_tuple(ComboAddress(addr, 443), ctx, reusePort, tcpFastOpenQueueSize, interface, cpus));
       }
       catch(std::exception& e) {
@@ -929,17 +929,13 @@ void setupLuaConfig(bool client)
       setLuaNoSideEffect();
 #ifdef HAVE_DNSCRYPT
       ostringstream ret;
-      boost::format fmt("%1$-3d %2% %|25t|%3$-20.20s %|26t|%4$-8d %|35t|%5$-21.21s %|56t|%6$-9d %|66t|%7$-21.21s" );
-      ret << (fmt % "#" % "Address" % "Provider Name" % "Serial" % "Validity" % "P. Serial" % "P. Validity") << endl;
+      boost::format fmt("%1$-3d %2% %|25t|%3$-20.20s");
+      ret << (fmt % "#" % "Address" % "Provider Name") << endl;
       size_t idx = 0;
 
       for (const auto& local : g_dnsCryptLocals) {
-        const DnsCryptContext& ctx = std::get<1>(local);
-        bool const hasOldCert = ctx.hasOldCertificate();
-        const DnsCryptCert& cert = ctx.getCurrentCertificate();
-        const DnsCryptCert& oldCert = ctx.getOldCertificate();
-
-        ret<< (fmt % idx % std::get<0>(local).toStringWithPort() % ctx.getProviderName() % cert.signedData.serial % DnsCryptContext::certificateDateToStr(cert.signedData.tsEnd) % (hasOldCert ? oldCert.signedData.serial : 0) % (hasOldCert ? DnsCryptContext::certificateDateToStr(oldCert.signedData.tsEnd) : "-")) << endl;
+        const std::shared_ptr<DNSCryptContext> ctx = std::get<1>(local);
+        ret<< (fmt % idx % std::get<0>(local).toStringWithPort() % ctx->getProviderName()) << endl;
         idx++;
       }
 
@@ -949,12 +945,12 @@ void setupLuaConfig(bool client)
 #endif
     });
 
-  g_lua.writeFunction("getDNSCryptBind", [client](size_t idx) {
+  g_lua.writeFunction("getDNSCryptBind", [](size_t idx) {
       setLuaNoSideEffect();
 #ifdef HAVE_DNSCRYPT
-      DnsCryptContext* ret = nullptr;
+      std::shared_ptr<DNSCryptContext> ret = nullptr;
       if (idx < g_dnsCryptLocals.size()) {
-        ret = &(std::get<1>(g_dnsCryptLocals.at(idx)));
+        ret = std::get<1>(g_dnsCryptLocals.at(idx));
       }
       return ret;
 #else
@@ -970,7 +966,7 @@ void setupLuaConfig(bool client)
       sodium_mlock(privateKey, sizeof(privateKey));
 
       try {
-        DnsCryptContext::generateProviderKeys(publicKey, privateKey);
+        DNSCryptContext::generateProviderKeys(publicKey, privateKey);
 
         ofstream pubKStream(publicKeyFile);
         pubKStream.write((char*) publicKey, sizeof(publicKey));
@@ -980,7 +976,7 @@ void setupLuaConfig(bool client)
         privKStream.write((char*) privateKey, sizeof(privateKey));
         privKStream.close();
 
-        g_outputBuffer="Provider fingerprint is: " + DnsCryptContext::getProviderFingerprint(publicKey) + "\n";
+        g_outputBuffer="Provider fingerprint is: " + DNSCryptContext::getProviderFingerprint(publicKey) + "\n";
       }
       catch(std::exception& e) {
         errlog(e.what());
@@ -1007,7 +1003,7 @@ void setupLuaConfig(bool client)
           throw std::runtime_error("Invalid dnscrypt provider public key file " + publicKeyFile);
 
         file.close();
-        g_outputBuffer="Provider fingerprint is: " + DnsCryptContext::getProviderFingerprint(publicKey) + "\n";
+        g_outputBuffer="Provider fingerprint is: " + DNSCryptContext::getProviderFingerprint(publicKey) + "\n";
       }
       catch(std::exception& e) {
         errlog(e.what());
@@ -1018,16 +1014,16 @@ void setupLuaConfig(bool client)
 #endif
     });
 
-  g_lua.writeFunction("generateDNSCryptCertificate", [](const std::string& providerPrivateKeyFile, const std::string& certificateFile, const std::string privateKeyFile, uint32_t serial, time_t begin, time_t end) {
+  g_lua.writeFunction("generateDNSCryptCertificate", [](const std::string& providerPrivateKeyFile, const std::string& certificateFile, const std::string privateKeyFile, uint32_t serial, time_t begin, time_t end, boost::optional<DNSCryptExchangeVersion> version) {
       setLuaNoSideEffect();
 #ifdef HAVE_DNSCRYPT
-      DnsCryptPrivateKey privateKey;
-      DnsCryptCert cert;
+      DNSCryptPrivateKey privateKey;
+      DNSCryptCert cert;
 
       try {
-        if (generateDNSCryptCertificate(providerPrivateKeyFile, serial, begin, end, cert, privateKey)) {
+        if (generateDNSCryptCertificate(providerPrivateKeyFile, serial, begin, end, version ? *version : DNSCryptExchangeVersion::VERSION1, cert, privateKey)) {
           privateKey.saveToFile(privateKeyFile);
-          DnsCryptContext::saveCertFromFile(cert, certificateFile);
+          DNSCryptContext::saveCertFromFile(cert, certificateFile);
         }
       }
       catch(const std::exception& e) {
index 3516c25e52886f81ae7f15be3fda321d2b88099d..b3b74864ebf338d83d14568213efd9b848705b63 100644 (file)
@@ -316,14 +316,21 @@ void* tcpClientThread(int pipefd)
 
         char* query = &queryBuffer[0];
         handler.read(query, qlen, g_tcpRecvTimeout, remainingTime);
+
+        /* we need this one to be accurate ("real") for the protobuf message */
+       struct timespec queryRealTime;
+       struct timespec now;
+       gettime(&now);
+       gettime(&queryRealTime, true);
+
 #ifdef HAVE_DNSCRYPT
-        std::shared_ptr<DnsCryptQuery> dnsCryptQuery = nullptr;
+        std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
 
         if (ci.cs->dnscryptCtx) {
-          dnsCryptQuery = std::make_shared<DnsCryptQuery>();
+          dnsCryptQuery = std::make_shared<DNSCryptQuery>(ci.cs->dnscryptCtx);
           uint16_t decryptedQueryLen = 0;
           vector<uint8_t> response;
-          bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, query, qlen, dnsCryptQuery, &decryptedQueryLen, true, response);
+          bool decrypted = handleDNSCryptQuery(query, qlen, dnsCryptQuery, &decryptedQueryLen, true, queryRealTime.tv_sec, response);
 
           if (!decrypted) {
             if (response.size() > 0) {
@@ -342,11 +349,6 @@ void* tcpClientThread(int pipefd)
 
        string poolname;
        int delayMsec=0;
-       /* we need this one to be accurate ("real") for the protobuf message */
-       struct timespec queryRealTime;
-       struct timespec now;
-       gettime(&now);
-       gettime(&queryRealTime, true);
 
        const uint16_t* flags = getFlagsFromDNSHeader(dh);
        uint16_t origFlags = *flags;
index e525ba9c054c47dc1a0d7ef8ed3fd7c31f506e28..6ecac15c57f66af9ba1a4c209859a5c3c0f5b55d 100644 (file)
@@ -90,7 +90,7 @@ string g_outputBuffer;
 vector<std::tuple<ComboAddress, bool, bool, int, string, std::set<int>>> g_locals;
 std::vector<std::shared_ptr<TLSFrontend>> g_tlslocals;
 #ifdef HAVE_DNSCRYPT
-std::vector<std::tuple<ComboAddress,DnsCryptContext,bool, int, string, std::set<int>>> g_dnsCryptLocals;
+std::vector<std::tuple<ComboAddress,std::shared_ptr<DNSCryptContext>,bool, int, string, std::set<int> >> g_dnsCryptLocals;
 #endif
 #ifdef HAVE_EBPF
 shared_ptr<BPFFilter> g_defaultBPFFilter;
@@ -336,7 +336,7 @@ bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize,
 }
 
 #ifdef HAVE_DNSCRYPT
-bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DnsCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy)
+bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy)
 {
   if (dnsCryptQuery) {
     uint16_t encryptedResponseLen = 0;
@@ -347,7 +347,7 @@ bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize,
       *dh = dhCopy;
     }
 
-    int res = dnsCryptQuery->ctx->encryptResponse(response, *responseLen, responseSize, dnsCryptQuery, tcp, &encryptedResponseLen);
+    int res = dnsCryptQuery->encryptResponse(response, *responseLen, responseSize, tcp, &encryptedResponseLen);
     if (res == 0) {
       *responseLen = encryptedResponseLen;
     } else {
@@ -1154,15 +1154,15 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s
 }
 
 #ifdef HAVE_DNSCRYPT
-static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DnsCryptQuery>& dnsCryptQuery, const ComboAddress& dest, const ComboAddress& remote)
+static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr<DNSCryptQuery>& dnsCryptQuery, const ComboAddress& dest, const ComboAddress& remote, time_t now)
 {
   if (cs.dnscryptCtx) {
     vector<uint8_t> response;
     uint16_t decryptedQueryLen = 0;
 
-    dnsCryptQuery = std::make_shared<DnsCryptQuery>();
+    dnsCryptQuery = std::make_shared<DNSCryptQuery>(cs.dnscryptCtx);
 
-    bool decrypted = handleDnsCryptQuery(cs.dnscryptCtx, const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, false, response);
+    bool decrypted = handleDNSCryptQuery(const_cast<char*>(query), len, dnsCryptQuery, &decryptedQueryLen, false, now, response);
 
     if (!decrypted) {
       if (response.size() > 0) {
@@ -1221,10 +1221,18 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
       return;
     }
 
+    /* we need an accurate ("real") value for the response and
+       to store into the IDS, but not for insertion into the
+       rings for example */
+    struct timespec queryRealTime;
+    struct timespec now;
+    gettime(&now);
+    gettime(&queryRealTime, true);
+
 #ifdef HAVE_DNSCRYPT
-    std::shared_ptr<DnsCryptQuery> dnsCryptQuery = nullptr;
+    std::shared_ptr<DNSCryptQuery> dnsCryptQuery = nullptr;
 
-    if (!checkDNSCryptQuery(cs, query, len, dnsCryptQuery, dest, remote)) {
+    if (!checkDNSCryptQuery(cs, query, len, dnsCryptQuery, dest, remote, queryRealTime.tv_sec)) {
       return;
     }
 #endif
@@ -1238,14 +1246,6 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
 
     string poolname;
     int delayMsec = 0;
-    /* we need an accurate ("real") value for the response and
-       to store into the IDS, but not for insertion into the
-       rings for example */
-    struct timespec queryRealTime;
-    struct timespec now;
-    gettime(&now);
-    gettime(&queryRealTime, true);
-
     const uint16_t * flags = getFlagsFromDNSHeader(dh);
     const uint16_t origFlags = *flags;
     uint16_t qtype, qclass;
@@ -2362,7 +2362,7 @@ try
   for(auto& dcLocal : g_dnsCryptLocals) {
     ClientState* cs = new ClientState;
     cs->local = std::get<0>(dcLocal);
-    cs->dnscryptCtx = &(std::get<1>(dcLocal));
+    cs->dnscryptCtx = std::get<1>(dcLocal);
     cs->udpFD = SSocket(cs->local.sin4.sin_family, SOCK_DGRAM, 0);
     if(cs->local.sin4.sin_family == AF_INET6) {
       SSetsockopt(cs->udpFD, IPPROTO_IPV6, IPV6_V6ONLY, 1);
@@ -2408,7 +2408,7 @@ try
 
     cs = new ClientState;
     cs->local = std::get<0>(dcLocal);
-    cs->dnscryptCtx = &(std::get<1>(dcLocal));
+    cs->dnscryptCtx = std::get<1>(dcLocal);
     cs->tcpFD = SSocket(cs->local.sin4.sin_family, SOCK_STREAM, 0);
     SSetsockopt(cs->tcpFD, SOL_SOCKET, SO_REUSEADDR, 1);
 #ifdef TCP_DEFER_ACCEPT
index cfc92475ff9018312e73a59fcb8f41b6953adf24..90d9041d8db22cf1052f821d91a992bac0529a72 100644 (file)
@@ -341,7 +341,7 @@ struct IDState
   StopWatch sentTime;                                         // 16
   DNSName qname;                                              // 80
 #ifdef HAVE_DNSCRYPT
-  std::shared_ptr<DnsCryptQuery> dnsCryptQuery{0};
+  std::shared_ptr<DNSCryptQuery> dnsCryptQuery{nullptr};
 #endif
 #ifdef HAVE_PROTOBUF
   boost::optional<boost::uuids::uuid> uniqueId;
@@ -368,7 +368,7 @@ struct Rings {
   {
     queryRing.set_capacity(capacity);
     respRing.set_capacity(capacity);
-    pthread_rwlock_init(&queryLock, 0);
+    pthread_rwlock_init(&queryLock, nullptr);
   }
   struct Query
   {
@@ -417,7 +417,7 @@ typedef std::function<std::tuple<bool, string>(DNSQuestion dq)> QueryCountFilter
 struct QueryCount {
   QueryCount()
   {
-    pthread_rwlock_init(&queryLock, 0);
+    pthread_rwlock_init(&queryLock, nullptr);
   }
   QueryCountRecords records;
   QueryCountFilter filter;
@@ -432,7 +432,7 @@ struct ClientState
   std::set<int> cpus;
   ComboAddress local;
 #ifdef HAVE_DNSCRYPT
-  DnsCryptContext* dnscryptCtx{0};
+  std::shared_ptr<DNSCryptContext> dnscryptCtx{nullptr};
 #endif
   shared_ptr<TLSFrontend> tlsFrontend;
   std::atomic<uint64_t> queries{0};
@@ -802,10 +802,10 @@ void restoreFlags(struct dnsheader* dh, uint16_t origFlags);
 bool checkQueryHeaders(const struct dnsheader* dh);
 
 #ifdef HAVE_DNSCRYPT
-extern std::vector<std::tuple<ComboAddress,DnsCryptContext,bool,int, std::string, std::set<int>>> g_dnsCryptLocals;
+extern std::vector<std::tuple<ComboAddress, std::shared_ptr<DNSCryptContext>, bool, int, std::string, std::set<int> > > g_dnsCryptLocals;
 
-int handleDnsCryptQuery(DnsCryptContext* ctx, char* packet, uint16_t len, std::shared_ptr<DnsCryptQuery>& query, uint16_t* decryptedQueryLen, bool tcp, std::vector<uint8_t>& response);
-bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DnsCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy);
+bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy);
+int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr<DNSCryptQuery> query, uint16_t* decryptedQueryLen, bool tcp, time_t now, std::vector<uint8_t>& response);
 #endif
 
 bool addXPF(DNSQuestion& dq, uint16_t optionCode);
index 5bf67a773a3c22a3dcb95db86bd86811a74f8b45..c049ff03cccbc00953fc812d306eefa114493c77 100644 (file)
@@ -40,14 +40,14 @@ BOOST_AUTO_TEST_SUITE(dnscrypt_cc)
 
 // plaintext query for cert
 BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQuery) {
-  DnsCryptPrivateKey resolverPrivateKey;
-  DnsCryptCert resolverCert;
+  DNSCryptPrivateKey resolverPrivateKey;
+  DNSCryptCert resolverCert;
   unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE];
   unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE];
-  time_t now = time(NULL);
-  DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
-  DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert);
-  DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey);
+  time_t now = time(nullptr);
+  DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
+  DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert);
+  auto ctx = std::make_shared<DNSCryptContext>("2.name", resolverCert, resolverPrivateKey);
 
   DNSName name("2.name.");
   vector<uint8_t> plainQuery;
@@ -55,17 +55,17 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQuery) {
   pw.getHeader()->rd = 0;
   uint16_t len = plainQuery.size();
 
-  std::shared_ptr<DnsCryptQuery> query = std::make_shared<DnsCryptQuery>();
+  std::shared_ptr<DNSCryptQuery> query = std::make_shared<DNSCryptQuery>(ctx);
   uint16_t decryptedLen = 0;
 
-  ctx.parsePacket((char*) plainQuery.data(), len, query, false, &decryptedLen);
+  query->parsePacket((char*) plainQuery.data(), len, false, &decryptedLen, now);
 
-  BOOST_CHECK_EQUAL(query->valid, true);
-  BOOST_CHECK_EQUAL(query->encrypted, false);
+  BOOST_CHECK_EQUAL(query->isValid(), true);
+  BOOST_CHECK_EQUAL(query->isEncrypted(), false);
 
   std::vector<uint8_t> response;
 
-  ctx.getCertificateResponse(query, response);
+  query->getCertificateResponse(now, response);
 
   MOADNSParser mdp(false, (char*) response.data(), response.size());
 
@@ -81,14 +81,14 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQuery) {
 
 // invalid plaintext query (A)
 BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQueryInvalidA) {
-    DnsCryptPrivateKey resolverPrivateKey;
-  DnsCryptCert resolverCert;
+  DNSCryptPrivateKey resolverPrivateKey;
+  DNSCryptCert resolverCert;
   unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE];
   unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE];
-  time_t now = time(NULL);
-  DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
-  DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert);
-  DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey);
+  time_t now = time(nullptr);
+  DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
+  DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert);
+  auto ctx = std::make_shared<DNSCryptContext>("2.name", resolverCert, resolverPrivateKey);
 
   DNSName name("2.name.");
 
@@ -97,24 +97,24 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQueryInvalidA) {
   pw.getHeader()->rd = 0;
   uint16_t len = plainQuery.size();
 
-  std::shared_ptr<DnsCryptQuery> query = std::make_shared<DnsCryptQuery>();
+  std::shared_ptr<DNSCryptQuery> query = std::make_shared<DNSCryptQuery>(ctx);
   uint16_t decryptedLen = 0;
 
-  ctx.parsePacket((char*) plainQuery.data(), len, query, false, &decryptedLen);
+  query->parsePacket((char*) plainQuery.data(), len, false, &decryptedLen, now);
 
-  BOOST_CHECK_EQUAL(query->valid, false);
+  BOOST_CHECK_EQUAL(query->isValid(), false);
 }
 
 // invalid plaintext query (wrong provider name)
 BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQueryInvalidProviderName) {
-  DnsCryptPrivateKey resolverPrivateKey;
-  DnsCryptCert resolverCert;
+  DNSCryptPrivateKey resolverPrivateKey;
+  DNSCryptCert resolverCert;
   unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE];
   unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE];
-  time_t now = time(NULL);
-  DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
-  DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert);
-  DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey);
+  time_t now = time(nullptr);
+  DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
+  DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert);
+  auto ctx = std::make_shared<DNSCryptContext>("2.name", resolverCert, resolverPrivateKey);
 
   DNSName name("2.WRONG.name.");
 
@@ -123,29 +123,29 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQueryInvalidProviderName) {
   pw.getHeader()->rd = 0;
   uint16_t len = plainQuery.size();
 
-  std::shared_ptr<DnsCryptQuery> query = std::make_shared<DnsCryptQuery>();
+  std::shared_ptr<DNSCryptQuery> query = std::make_shared<DNSCryptQuery>(ctx);
   uint16_t decryptedLen = 0;
 
-  ctx.parsePacket((char*) plainQuery.data(), len, query, false, &decryptedLen);
+  query->parsePacket((char*) plainQuery.data(), len, false, &decryptedLen, now);
 
-  BOOST_CHECK_EQUAL(query->valid, false);
+  BOOST_CHECK_EQUAL(query->isValid(), false);
 }
 
 // valid encrypted query
 BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValid) {
-  DnsCryptPrivateKey resolverPrivateKey;
-  DnsCryptCert resolverCert;
+  DNSCryptPrivateKey resolverPrivateKey;
+  DNSCryptCert resolverCert;
   unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE];
   unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE];
-  time_t now = time(NULL);
-  DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
-  DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert);
-  DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey);
+  time_t now = time(nullptr);
+  DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
+  DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert);
+  auto ctx = std::make_shared<DNSCryptContext>("2.name", resolverCert, resolverPrivateKey);
 
-  DnsCryptPrivateKey clientPrivateKey;
+  DNSCryptPrivateKey clientPrivateKey;
   unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE];
 
-  DnsCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey);
+  DNSCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey);
 
   unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x08, 0x09, 0x0A, 0x0B };
 
@@ -153,27 +153,27 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValid) {
   vector<uint8_t> plainQuery;
   DNSPacketWriter pw(plainQuery, name, QType::AAAA, QClass::IN, 0);
   pw.getHeader()->rd = 1;
-  size_t requiredSize = plainQuery.size() + sizeof(DnsCryptQueryHeader) + DNSCRYPT_MAC_SIZE;
-  if (requiredSize < DnsCryptQuery::minUDPLength) {
-    requiredSize = DnsCryptQuery::minUDPLength;
+  size_t requiredSize = plainQuery.size() + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE;
+  if (requiredSize < DNSCryptQuery::s_minUDPLength) {
+    requiredSize = DNSCryptQuery::s_minUDPLength;
   }
 
   plainQuery.reserve(requiredSize);
   uint16_t len = plainQuery.size();
   uint16_t encryptedResponseLen = 0;
 
-  int res = ctx.encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen);
+  int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared<DNSCryptCert>(resolverCert));
 
   BOOST_CHECK_EQUAL(res, 0);
   BOOST_CHECK(encryptedResponseLen > len);
 
-  std::shared_ptr<DnsCryptQuery> query = std::make_shared<DnsCryptQuery>();
+  std::shared_ptr<DNSCryptQuery> query = std::make_shared<DNSCryptQuery>(ctx);
   uint16_t decryptedLen = 0;
 
-  ctx.parsePacket((char*) plainQuery.data(), encryptedResponseLen, query, false, &decryptedLen);
+  query->parsePacket((char*) plainQuery.data(), encryptedResponseLen, false, &decryptedLen, now);
 
-  BOOST_CHECK_EQUAL(query->valid, true);
-  BOOST_CHECK_EQUAL(query->encrypted, true);
+  BOOST_CHECK_EQUAL(query->isValid(), true);
+  BOOST_CHECK_EQUAL(query->isEncrypted(), true);
 
   MOADNSParser mdp(true, (char*) plainQuery.data(), decryptedLen);
 
@@ -189,19 +189,19 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValid) {
 
 // valid encrypted query with not enough room
 BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidButShort) {
-  DnsCryptPrivateKey resolverPrivateKey;
-  DnsCryptCert resolverCert;
+  DNSCryptPrivateKey resolverPrivateKey;
+  DNSCryptCert resolverCert;
   unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE];
   unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE];
-  time_t now = time(NULL);
-  DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
-  DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert);
-  DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey);
+  time_t now = time(nullptr);
+  DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
+  DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert);
+  auto ctx = std::make_shared<DNSCryptContext>("2.name", resolverCert, resolverPrivateKey);
 
-  DnsCryptPrivateKey clientPrivateKey;
+  DNSCryptPrivateKey clientPrivateKey;
   unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE];
 
-  DnsCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey);
+  DNSCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey);
 
   unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x08, 0x09, 0x0A, 0x0B };
 
@@ -213,26 +213,26 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidButShort) {
   uint16_t len = plainQuery.size();
   uint16_t encryptedResponseLen = 0;
 
-  int res = ctx.encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen);
+  int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared<DNSCryptCert>(resolverCert));
 
   BOOST_CHECK_EQUAL(res, ENOBUFS);
 }
 
 // valid encrypted query with old key
 BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidWithOldKey) {
-  DnsCryptPrivateKey resolverPrivateKey;
-  DnsCryptCert resolverCert;
+  DNSCryptPrivateKey resolverPrivateKey;
+  DNSCryptCert resolverCert;
   unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE];
   unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE];
-  time_t now = time(NULL);
-  DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
-  DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert);
-  DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey);
+  time_t now = time(nullptr);
+  DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
+  DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert);
+  auto ctx = std::make_shared<DNSCryptContext>("2.name", resolverCert, resolverPrivateKey);
 
-  DnsCryptPrivateKey clientPrivateKey;
+  DNSCryptPrivateKey clientPrivateKey;
   unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE];
 
-  DnsCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey);
+  DNSCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey);
 
   unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x08, 0x09, 0x0A, 0x0B };
 
@@ -241,9 +241,9 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidWithOldKey) {
   DNSPacketWriter pw(plainQuery, name, QType::AAAA, QClass::IN, 0);
   pw.getHeader()->rd = 1;
 
-  size_t requiredSize = plainQuery.size() + sizeof(DnsCryptQueryHeader) + DNSCRYPT_MAC_SIZE;
-  if (requiredSize < DnsCryptQuery::minUDPLength) {
-    requiredSize = DnsCryptQuery::minUDPLength;
+  size_t requiredSize = plainQuery.size() + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE;
+  if (requiredSize < DNSCryptQuery::s_minUDPLength) {
+    requiredSize = DNSCryptQuery::s_minUDPLength;
   }
 
   plainQuery.reserve(requiredSize);
@@ -251,21 +251,23 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidWithOldKey) {
   uint16_t len = plainQuery.size();
   uint16_t encryptedResponseLen = 0;
 
-  int res = ctx.encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen);
+  int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared<DNSCryptCert>(resolverCert));
 
   BOOST_CHECK_EQUAL(res, 0);
   BOOST_CHECK(encryptedResponseLen > len);
 
-  DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert);
-  ctx.setNewCertificate(resolverCert, resolverPrivateKey);
+  DNSCryptCert newResolverCert;
+  DNSCryptContext::generateCertificate(2, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, newResolverCert);
+  ctx->addNewCertificate(newResolverCert, resolverPrivateKey);
+  ctx->markInactive(resolverCert.getSerial());
 
-  std::shared_ptr<DnsCryptQuery> query = std::make_shared<DnsCryptQuery>();
+  std::shared_ptr<DNSCryptQuery> query = std::make_shared<DNSCryptQuery>(ctx);
   uint16_t decryptedLen = 0;
 
-  ctx.parsePacket((char*) plainQuery.data(), encryptedResponseLen, query, false, &decryptedLen);
+  query->parsePacket((char*) plainQuery.data(), encryptedResponseLen, false, &decryptedLen, now);
 
-  BOOST_CHECK_EQUAL(query->valid, true);
-  BOOST_CHECK_EQUAL(query->encrypted, true);
+  BOOST_CHECK_EQUAL(query->isValid(), true);
+  BOOST_CHECK_EQUAL(query->isEncrypted(), true);
 
   MOADNSParser mdp(true, (char*) plainQuery.data(), decryptedLen);
 
@@ -281,19 +283,19 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidWithOldKey) {
 
 // valid encrypted query with wrong key
 BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryInvalidWithWrongKey) {
-  DnsCryptPrivateKey resolverPrivateKey;
-  DnsCryptCert resolverCert;
+  DNSCryptPrivateKey resolverPrivateKey;
+  DNSCryptCert resolverCert;
   unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE];
   unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE];
-  time_t now = time(NULL);
-  DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
-  DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert);
-  DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey);
+  time_t now = time(nullptr);
+  DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey);
+  DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert);
+  auto ctx = std::make_shared<DNSCryptContext>("2.name", resolverCert, resolverPrivateKey);
 
-  DnsCryptPrivateKey clientPrivateKey;
+  DNSCryptPrivateKey clientPrivateKey;
   unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE];
 
-  DnsCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey);
+  DNSCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey);
 
   unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x08, 0x09, 0x0A, 0x0B };
 
@@ -302,9 +304,9 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryInvalidWithWrongKey) {
   DNSPacketWriter pw(plainQuery, name, QType::AAAA, QClass::IN, 0);
   pw.getHeader()->rd = 1;
 
-  size_t requiredSize = plainQuery.size() + sizeof(DnsCryptQueryHeader) + DNSCRYPT_MAC_SIZE;
-  if (requiredSize < DnsCryptQuery::minUDPLength) {
-    requiredSize = DnsCryptQuery::minUDPLength;
+  size_t requiredSize = plainQuery.size() + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE;
+  if (requiredSize < DNSCryptQuery::s_minUDPLength) {
+    requiredSize = DNSCryptQuery::s_minUDPLength;
   }
 
   plainQuery.reserve(requiredSize);
@@ -312,25 +314,25 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryInvalidWithWrongKey) {
   uint16_t len = plainQuery.size();
   uint16_t encryptedResponseLen = 0;
 
-  int res = ctx.encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen);
+  int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared<DNSCryptCert>(resolverCert));
 
   BOOST_CHECK_EQUAL(res, 0);
   BOOST_CHECK(encryptedResponseLen > len);
 
-  DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert);
-  ctx.setNewCertificate(resolverCert, resolverPrivateKey);
+  DNSCryptCert newResolverCert;
+  DNSCryptContext::generateCertificate(2, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, newResolverCert);
+  ctx->addNewCertificate(newResolverCert, resolverPrivateKey);
+  ctx->markInactive(resolverCert.getSerial());
+  ctx->removeInactiveCertificate(resolverCert.getSerial());
 
-  DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert);
-  ctx.setNewCertificate(resolverCert, resolverPrivateKey);
+  /* we have removed the old certificate, we can't decrypt this query */
 
-  /* we have changed the key two times, we don't have the one used to encrypt this query */
-
-  std::shared_ptr<DnsCryptQuery> query = std::make_shared<DnsCryptQuery>();
+  std::shared_ptr<DNSCryptQuery> query = std::make_shared<DNSCryptQuery>(ctx);
   uint16_t decryptedLen = 0;
 
-  ctx.parsePacket((char*) plainQuery.data(), encryptedResponseLen, query, false, &decryptedLen);
+  query->parsePacket((char*) plainQuery.data(), encryptedResponseLen, false, &decryptedLen, now);
 
-  BOOST_CHECK_EQUAL(query->valid, false);
+  BOOST_CHECK_EQUAL(query->isValid(), false);
 }
 
 #endif
index 238d973e47b90e21a2a8d5defd0df0c51ba6047f..a93aeaea0a98efcc401eceae205fd66a67245ce5 100644 (file)
@@ -94,7 +94,6 @@ class DNSCryptClient(object):
         data = None
         if tcp:
             got = sock.recv(2)
-            print(len(got))
             if got:
                 (rlen,) = struct.unpack("!H", got)
                 data = sock.recv(rlen)
@@ -134,6 +133,8 @@ class DNSCryptClient(object):
         if an.rdclass != dns.rdataclass.IN or an.rdtype != dns.rdatatype.TXT or len(an.items) == 0:
             raise Exception("Invalid response to public key request")
 
+        self._resolverCertificates = []
+
         for item in an.items:
             if len(item.strings) != 1:
                 continue
@@ -152,6 +153,15 @@ class DNSCryptClient(object):
 
         return result
 
+    def getAllResolverCertificates(self, onlyValid=False):
+        certs = self._resolverCertificates
+        result = []
+        for cert in certs:
+            if not onlyValid or cert.isValid():
+                result.append(cert)
+
+        return result
+
     @staticmethod
     def _generateNonce():
         nonce = libnacl.utils.rand_nonce()
index f5f45d41c7940dd2720e8d760f4b56d8c4ef50d2..376140d804fc40dffe3ace9aea73de030a6a0a34 100644 (file)
@@ -1,5 +1,6 @@
 #!/usr/bin/env python
 import base64
+import socket
 import time
 import dns
 import dns.message
@@ -136,16 +137,16 @@ class TestDNSCrypt(DNSCryptTest):
 
         # generate a new certificate
         self.sendConsoleCommand("generateDNSCryptCertificate('DNSCryptProviderPrivate.key', 'DNSCryptResolver.cert.2', 'DNSCryptResolver.key.2', {!s}, {:.0f}, {:.0f})".format(self._resolverCertificateSerial + 1, self._resolverCertificateValidFrom, self._resolverCertificateValidUntil))
-        # switch to that new certificate
+        # add that new certificate
         self.sendConsoleCommand("getDNSCryptBind(0):loadNewCertificate('DNSCryptResolver.cert.2', 'DNSCryptResolver.key.2')")
 
-        oldSerial = self.sendConsoleCommand("getDNSCryptBind(0):getOldCertificate():getSerial()")
+        oldSerial = self.sendConsoleCommand("getDNSCryptBind(0):getCertificatePair(0):getCertificate():getSerial()")
         self.assertEquals(int(oldSerial), self._resolverCertificateSerial)
-        effectiveSerial = self.sendConsoleCommand("getDNSCryptBind(0):getCurrentCertificate():getSerial()")
+        effectiveSerial = self.sendConsoleCommand("getDNSCryptBind(0):getCertificatePair(1):getCertificate():getSerial()")
         self.assertEquals(int(effectiveSerial), self._resolverCertificateSerial + 1)
-        tsStart = self.sendConsoleCommand("getDNSCryptBind(0):getCurrentCertificate():getTSStart()")
+        tsStart = self.sendConsoleCommand("getDNSCryptBind(0):getCertificatePair(1):getCertificate():getTSStart()")
         self.assertEquals(int(tsStart), self._resolverCertificateValidFrom)
-        tsEnd = self.sendConsoleCommand("getDNSCryptBind(0):getCurrentCertificate():getTSEnd()")
+        tsEnd = self.sendConsoleCommand("getDNSCryptBind(0):getCertificatePair(1):getCertificate():getTSEnd()")
         self.assertEquals(int(tsEnd), self._resolverCertificateValidUntil)
 
         # we should still be able to send queries with the previous certificate
@@ -160,6 +161,11 @@ class TestDNSCrypt(DNSCryptTest):
         cert = client.getResolverCertificate()
         self.assertTrue(cert)
         self.assertEquals(cert.serial, self._resolverCertificateSerial + 1)
+        # we should still get the old ones
+        certs = client.getAllResolverCertificates(True)
+        self.assertEquals(len(certs), 2)
+        self.assertEquals(certs[0].serial, self._resolverCertificateSerial)
+        self.assertEquals(certs[1].serial, self._resolverCertificateSerial + 1)
 
         # generate a third certificate, this time in memory
         self.sendConsoleCommand("getDNSCryptBind(0):generateAndLoadInMemoryCertificate('DNSCryptProviderPrivate.key', {!s}, {:.0f}, {:.0f})".format(self._resolverCertificateSerial + 2, self._resolverCertificateValidFrom, self._resolverCertificateValidUntil))
@@ -176,6 +182,52 @@ class TestDNSCrypt(DNSCryptTest):
         cert = client.getResolverCertificate()
         self.assertTrue(cert)
         self.assertEquals(cert.serial, self._resolverCertificateSerial + 2)
+        # we should still get the old ones
+        certs = client.getAllResolverCertificates(True)
+        self.assertEquals(len(certs), 3)
+        self.assertEquals(certs[0].serial, self._resolverCertificateSerial)
+        self.assertEquals(certs[1].serial, self._resolverCertificateSerial + 1)
+        self.assertEquals(certs[2].serial, self._resolverCertificateSerial + 2)
+
+        # generate a fourth certificate, still in memory
+        self.sendConsoleCommand("getDNSCryptBind(0):generateAndLoadInMemoryCertificate('DNSCryptProviderPrivate.key', {!s}, {:.0f}, {:.0f})".format(self._resolverCertificateSerial + 3, self._resolverCertificateValidFrom, self._resolverCertificateValidUntil))
+
+        # mark the old ones as inactive
+        self.sendConsoleCommand("getDNSCryptBind(0):markInactive({!s})".format(self._resolverCertificateSerial))
+        self.sendConsoleCommand("getDNSCryptBind(0):markInactive({!s})".format(self._resolverCertificateSerial + 1))
+        self.sendConsoleCommand("getDNSCryptBind(0):markInactive({!s})".format(self._resolverCertificateSerial + 2))
+        # we should still be able to send queries with the third one
+        self.doDNSCryptQuery(client, query, response, False)
+        self.doDNSCryptQuery(client, query, response, True)
+        cert = client.getResolverCertificate()
+        self.assertTrue(cert)
+        self.assertEquals(cert.serial, self._resolverCertificateSerial + 2)
+        # now remove them
+        self.sendConsoleCommand("getDNSCryptBind(0):removeInactiveCertificate({!s})".format(self._resolverCertificateSerial))
+        self.sendConsoleCommand("getDNSCryptBind(0):removeInactiveCertificate({!s})".format(self._resolverCertificateSerial + 1))
+        self.sendConsoleCommand("getDNSCryptBind(0):removeInactiveCertificate({!s})".format(self._resolverCertificateSerial + 2))
+
+        # we should not be able to send with the old ones anymore
+        try:
+            data = client.query(query.to_wire())
+        except socket.timeout:
+            data = None
+        self.assertEquals(data, None)
+
+        # refreshing should get us the fourth one
+        client.refreshResolverCertificates()
+        cert = client.getResolverCertificate()
+        self.assertTrue(cert)
+        self.assertEquals(cert.serial, self._resolverCertificateSerial + 3)
+        # and only that one
+        certs = client.getAllResolverCertificates(True)
+        self.assertEquals(len(certs), 1)
+        # and we should be able to query with it
+        self.doDNSCryptQuery(client, query, response, False)
+        self.doDNSCryptQuery(client, query, response, True)
+        cert = client.getResolverCertificate()
+        self.assertTrue(cert)
+        self.assertEquals(cert.serial, self._resolverCertificateSerial + 3)
 
 class TestDNSCryptWithCache(DNSCryptTest):