From: Remi Gacogne Date: Sun, 15 Oct 2017 20:28:32 +0000 (+0200) Subject: dnsdist: Add support for DNSCrypt's xchacha20, n active certs X-Git-Tag: dnsdist-1.3.0~37^2~5 X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=43234e7697a41c89fd93e738583bfb383de9000e;p=pdns dnsdist: Add support for DNSCrypt's xchacha20, n active certs --- diff --git a/m4/pdns_check_libsodium.m4 b/m4/pdns_check_libsodium.m4 index 30e961e89..03daf0058 100644 --- a/m4/pdns_check_libsodium.m4 +++ b/m4/pdns_check_libsodium.m4 @@ -15,7 +15,7 @@ AC_DEFUN([PDNS_CHECK_LIBSODIUM], [ save_LIBS=$LIBS CFLAGS="$LIBSODIUM_CFLAGS $CFLAGS" LIBS="$LIBSODIUM_LIBS $LIBS" - AC_CHECK_FUNCS([crypto_box_easy_afternm]) + AC_CHECK_FUNCS([crypto_box_easy_afternm crypto_box_curve25519xchacha20poly1305_easy]) CFLAGS=$save_CFLAGS LIBS=$save_LIBS ], [ : ]) diff --git a/pdns/dnscrypt.cc b/pdns/dnscrypt.cc index 62ec89706..c03284585 100644 --- a/pdns/dnscrypt.cc +++ b/pdns/dnscrypt.cc @@ -25,14 +25,15 @@ #include "dolog.hh" #include "dnscrypt.hh" #include "dnswriter.hh" +#include "lock.hh" -DnsCryptPrivateKey::DnsCryptPrivateKey() +DNSCryptPrivateKey::DNSCryptPrivateKey() { sodium_memzero(key, sizeof(key)); sodium_mlock(key, sizeof(key)); } -void DnsCryptPrivateKey::loadFromFile(const std::string& keyFile) +void DNSCryptPrivateKey::loadFromFile(const std::string& keyFile) { ifstream file(keyFile); sodium_memzero(key, sizeof(key)); @@ -47,54 +48,96 @@ void DnsCryptPrivateKey::loadFromFile(const std::string& keyFile) file.close(); } -void DnsCryptPrivateKey::saveToFile(const std::string& keyFile) const +void DNSCryptPrivateKey::saveToFile(const std::string& keyFile) const { ofstream file(keyFile); file.write((char*) key, sizeof(key)); file.close(); } -DnsCryptPrivateKey::~DnsCryptPrivateKey() +DNSCryptPrivateKey::~DNSCryptPrivateKey() { sodium_munlock(key, sizeof(key)); } +DNSCryptExchangeVersion DNSCryptQuery::getVersion() const +{ + if (d_pair == nullptr) { + throw std::runtime_error("Unable to determine the version of a DNSCrypt query if there is not associated cert"); + } + + return DNSCryptContext::getExchangeVersion(d_pair->cert); +} + #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM -DnsCryptQuery::~DnsCryptQuery() +DNSCryptQuery::~DNSCryptQuery() { - if (sharedKeyComputed) { - sodium_munlock(sharedKey, sizeof(sharedKey)); + if (d_sharedKeyComputed) { + sodium_munlock(d_sharedKey, sizeof(d_sharedKey)); } } -int DnsCryptQuery::computeSharedKey(const DnsCryptPrivateKey& privateKey) +int DNSCryptQuery::computeSharedKey() { + assert(d_pair != nullptr); + int res = 0; - if (sharedKeyComputed) { + if (d_sharedKeyComputed) { return res; } - sodium_mlock(sharedKey, sizeof(sharedKey)); - res = crypto_box_beforenm(sharedKey, - header.clientPK, - privateKey.key); + const DNSCryptExchangeVersion version = DNSCryptContext::getExchangeVersion(d_pair->cert); + + sodium_mlock(d_sharedKey, sizeof(d_sharedKey)); + + if (version == DNSCryptExchangeVersion::VERSION1) { + res = crypto_box_beforenm(d_sharedKey, + d_header.clientPK, + d_pair->privateKey.key); + } + else if (version == DNSCryptExchangeVersion::VERSION2) { +#ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY + res = crypto_box_curve25519xchacha20poly1305_beforenm(d_sharedKey, + d_header.clientPK, + d_pair->privateKey.key); +#else /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */ + res = -1; +#endif /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */ + } + else { + res = -1; + } if (res != 0) { - sodium_munlock(sharedKey, sizeof(sharedKey)); + sodium_munlock(d_sharedKey, sizeof(d_sharedKey)); return res; } - sharedKeyComputed = true; + d_sharedKeyComputed = true; return res; } #else -DnsCryptQuery::~DnsCryptQuery() +DNSCryptQuery::~DNSCryptQuery() { } #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */ -void DnsCryptContext::generateProviderKeys(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE], unsigned char privateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]) +DNSCryptContext::DNSCryptContext(const std::string& pName, const std::string& certFile, const std::string& keyFile): providerName(pName) +{ + pthread_rwlock_init(&d_lock, 0); + + loadNewCertificate(certFile, keyFile); +} + +DNSCryptContext::DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey): providerName(pName) +{ + pthread_rwlock_init(&d_lock, 0); + + addNewCertificate(certificate, pKey); +} + +void DNSCryptContext::generateProviderKeys(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE], unsigned char privateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]) { int res = crypto_sign_ed25519_keypair(publicKey, privateKey); @@ -103,7 +146,7 @@ void DnsCryptContext::generateProviderKeys(unsigned char publicKey[DNSCRYPT_PROV } } -std::string DnsCryptContext::getProviderFingerprint(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]) +std::string DNSCryptContext::getProviderFingerprint(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]) { boost::format fmt("%02X%02X"); ostringstream ret; @@ -119,12 +162,51 @@ std::string DnsCryptContext::getProviderFingerprint(unsigned char publicKey[DNSC return ret.str(); } -void DnsCryptContext::generateCertificate(uint32_t serial, time_t begin, time_t end, const unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE], DnsCryptPrivateKey& privateKey, DnsCryptCert& cert) +void DNSCryptContext::setExchangeVersion(const DNSCryptExchangeVersion& version, unsigned char esVersion[sizeof(DNSCryptCert::esVersion)]) +{ + esVersion[0] = 0x00; + + if (version == DNSCryptExchangeVersion::VERSION1) { + esVersion[1] = { 0x01 }; + } + else if (version == DNSCryptExchangeVersion::VERSION2) { + esVersion[1] = { 0x02 }; + } + else { + throw std::runtime_error("Unknown DNSCrypt exchange version"); + } +} + +DNSCryptExchangeVersion DNSCryptContext::getExchangeVersion(const unsigned char esVersion[sizeof(DNSCryptCert::esVersion)]) +{ + if (esVersion[0] != 0x00) { + throw std::runtime_error("Unknown DNSCrypt exchange version"); + } + + if (esVersion[1] == 0x01) { + return DNSCryptExchangeVersion::VERSION1; + } + else if (esVersion[1] == 0x02) { + return DNSCryptExchangeVersion::VERSION2; + } + + throw std::runtime_error("Unknown DNSCrypt exchange version"); +} + +DNSCryptExchangeVersion DNSCryptContext::getExchangeVersion(const DNSCryptCert& cert) +{ + return getExchangeVersion(cert.esVersion); +} + + +void DNSCryptContext::generateCertificate(uint32_t serial, time_t begin, time_t end, const DNSCryptExchangeVersion& version, const unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE], DNSCryptPrivateKey& privateKey, DNSCryptCert& cert) { unsigned char magic[DNSCRYPT_CERT_MAGIC_SIZE] = DNSCRYPT_CERT_MAGIC_VALUE; - unsigned char esVersion[] = DNSCRYPT_CERT_ES_VERSION_VALUE; unsigned char protocolMinorVersion[] = DNSCRYPT_CERT_PROTOCOL_MINOR_VERSION_VALUE; unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]; + unsigned char esVersion[sizeof(DNSCryptCert::esVersion)]; + setExchangeVersion(version, esVersion); + generateResolverKeyPair(privateKey, pubK); memcpy(cert.magic, magic, sizeof(magic)); @@ -145,14 +227,14 @@ void DnsCryptContext::generateCertificate(uint32_t serial, time_t begin, time_t providerPrivateKey); if (res == 0) { - assert(signatureSize == sizeof(DnsCryptCertSignedData) + DNSCRYPT_SIGNATURE_SIZE); + assert(signatureSize == sizeof(DNSCryptCertSignedData) + DNSCRYPT_SIGNATURE_SIZE); } else { throw std::runtime_error("Error generating DNSCrypt certificate"); } } -void DnsCryptContext::loadCertFromFile(const std::string&filename, DnsCryptCert& dest) +void DNSCryptContext::loadCertFromFile(const std::string&filename, DNSCryptCert& dest) { ifstream file(filename); file.read((char *) &dest, sizeof(dest)); @@ -163,14 +245,14 @@ void DnsCryptContext::loadCertFromFile(const std::string&filename, DnsCryptCert& file.close(); } -void DnsCryptContext::saveCertFromFile(const DnsCryptCert& cert, const std::string&filename) +void DNSCryptContext::saveCertFromFile(const DNSCryptCert& cert, const std::string&filename) { ofstream file(filename); file.write((char *) &cert, sizeof(cert)); file.close(); } -void DnsCryptContext::generateResolverKeyPair(DnsCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]) +void DNSCryptContext::generateResolverKeyPair(DNSCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]) { int res = crypto_box_keypair(pubK, privK.key); @@ -179,7 +261,7 @@ void DnsCryptContext::generateResolverKeyPair(DnsCryptPrivateKey& privK, unsigne } } -void DnsCryptContext::computePublicKeyFromPrivate(const DnsCryptPrivateKey& privK, unsigned char* pubK) +void DNSCryptContext::computePublicKeyFromPrivate(const DNSCryptPrivateKey& privK, unsigned char* pubK) { int res = crypto_scalarmult_base(pubK, privK.key); @@ -189,10 +271,10 @@ void DnsCryptContext::computePublicKeyFromPrivate(const DnsCryptPrivateKey& priv } } -std::string DnsCryptContext::certificateDateToStr(uint32_t date) +std::string DNSCryptContext::certificateDateToStr(uint32_t date) { char buf[20]; - time_t tdate = (time_t) ntohl(date); + time_t tdate = static_cast(ntohl(date)); struct tm date_tm; localtime_r(&tdate, &date_tm); @@ -201,150 +283,233 @@ std::string DnsCryptContext::certificateDateToStr(uint32_t date) return string(buf); } -void DnsCryptContext::setNewCertificate(const DnsCryptCert& newCert, const DnsCryptPrivateKey& newKey) +void DNSCryptContext::addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active) { - // XXX TODO: this could use a lock - oldPrivateKey = privateKey; - oldCert = cert; - hasOldCert = true; - privateKey = newKey; - cert = newCert; + WriteLock w(&d_lock); + + for (auto pair : certs) { + if (pair->cert.getSerial() == newCert.getSerial()) { + throw std::runtime_error("Error adding a new certificate: we already have a certificate with the same serial"); + } + } + + auto pair = std::make_shared(); + pair->cert = newCert; + pair->privateKey = newKey; + computePublicKeyFromPrivate(pair->privateKey, pair->publicKey); + pair->active = active; + certs.push_back(pair); } -void DnsCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile) +void DNSCryptContext::loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active) { - DnsCryptCert newCert; - DnsCryptPrivateKey newPrivateKey; + DNSCryptCert newCert; + DNSCryptPrivateKey newPrivateKey; loadCertFromFile(certFile, newCert); newPrivateKey.loadFromFile(keyFile); - setNewCertificate(newCert, newPrivateKey); + + addNewCertificate(newCert, newPrivateKey, active); +} + +void DNSCryptContext::markActive(uint32_t serial) +{ + WriteLock w(&d_lock); + + for (auto pair : certs) { + if (pair->active == false && pair->cert.getSerial() == serial) { + pair->active = true; + return; + } + } + throw std::runtime_error("No inactive certificate found with this serial"); +} + +void DNSCryptContext::markInactive(uint32_t serial) +{ + WriteLock w(&d_lock); + + for (auto pair : certs) { + if (pair->active == true && pair->cert.getSerial() == serial) { + pair->active = false; + return; + } + } + throw std::runtime_error("No active certificate found with this serial"); } -void DnsCryptContext::parsePlaintextQuery(const char * packet, uint16_t packetSize, std::shared_ptr query) const +void DNSCryptContext::removeInactiveCertificate(uint32_t serial) { + WriteLock w(&d_lock); + + for (auto it = certs.begin(); it != certs.end(); ) { + if ((*it)->active == false && (*it)->cert.getSerial() == serial) { + it = certs.erase(it); + return; + } else { + it++; + } + } + throw std::runtime_error("No inactive certificate found with this serial"); +} + +bool DNSCryptQuery::parsePlaintextQuery(const char * packet, uint16_t packetSize) +{ + assert(d_ctx != nullptr); + if (packetSize < sizeof(dnsheader)) { - return; + return false; } - struct dnsheader * dh = (struct dnsheader *) packet; + const struct dnsheader * dh = reinterpret_cast(packet); if (dh->qr || ntohs(dh->qdcount) != 1 || dh->ancount != 0 || dh->nscount != 0 || dh->opcode != Opcode::Query) - return; + return false; unsigned int consumed; uint16_t qtype, qclass; DNSName qname(packet, packetSize, sizeof(dnsheader), false, &qtype, &qclass, &consumed); if ((packetSize - sizeof(dnsheader)) < (consumed + sizeof(qtype) + sizeof(qclass))) - return; + return false; if (qtype != QType::TXT || qclass != QClass::IN) - return; + return false; - if (qname != DNSName(providerName)) - return; + if (qname != d_ctx->getProviderName()) + return false; + + d_qname = qname; + d_id = dh->id; + d_valid = true; - query->qname = qname; - query->id = dh->id; - query->valid = true; + return true; } -void DnsCryptContext::getCertificateResponse(const std::shared_ptr query, vector& response) const +void DNSCryptContext::getCertificateResponse(time_t now, const DNSName& qname, uint16_t qid, std::vector& response) { - DNSPacketWriter pw(response, query->qname, QType::TXT, QClass::IN, Opcode::Query); + DNSPacketWriter pw(response, qname, QType::TXT, QClass::IN, Opcode::Query); struct dnsheader * dh = pw.getHeader(); - dh->id = query->id; + dh->id = qid; dh->qr = true; dh->rcode = RCode::NoError; - pw.startRecord(query->qname, QType::TXT, (DNSCRYPT_CERTIFICATE_RESPONSE_TTL), QClass::IN, DNSResourceRecord::ANSWER, true); - std::string scert; - uint8_t certSize = sizeof(cert); - scert.assign((const char*) &certSize, sizeof(certSize)); - scert.append((const char*) &cert, certSize); - pw.xfrBlob(scert); - pw.commit(); -} + ReadLock r(&d_lock); + for (const auto pair : certs) { + if (!pair->active || !pair->cert.isValid(now)) { + continue; + } -bool DnsCryptContext::magicMatchesPublicKey(std::shared_ptr query) const -{ - const unsigned char* magic = query->header.clientMagic; + pw.startRecord(qname, QType::TXT, (DNSCRYPT_CERTIFICATE_RESPONSE_TTL), QClass::IN, DNSResourceRecord::ANSWER, true); + std::string scert; + uint8_t certSize = sizeof(pair->cert); + scert.assign((const char*) &certSize, sizeof(certSize)); + scert.append((const char*) &pair->cert, certSize); - if (memcmp(magic, cert.signedData.clientMagic, DNSCRYPT_CLIENT_MAGIC_SIZE) == 0) { - return true; + pw.xfrBlob(scert); + pw.commit(); } +} - if (hasOldCert == true && - memcmp(magic, oldCert.signedData.clientMagic, DNSCRYPT_CLIENT_MAGIC_SIZE) == 0) { - query->useOldCert = true; - return true; +bool DNSCryptContext::magicMatchesAPublicKey(DNSCryptQuery& query, time_t now) +{ + const unsigned char* magic = query.getClientMagic(); + + ReadLock r(&d_lock); + for (const auto& pair : certs) { + if (pair->cert.isValid(now) && memcmp(magic, pair->cert.signedData.clientMagic, DNSCRYPT_CLIENT_MAGIC_SIZE) == 0) { + query.setCertificatePair(pair); + return true; + } } return false; } -void DnsCryptContext::isQueryEncrypted(const char * packet, uint16_t packetSize, std::shared_ptr query, bool tcp) const +bool DNSCryptQuery::isEncryptedQuery(const char * packet, uint16_t packetSize, bool tcp, time_t now) { - query->encrypted = false; + assert(d_ctx != nullptr); - if (packetSize < sizeof(DnsCryptQueryHeader)) { - return; + d_encrypted = false; + + if (packetSize < sizeof(DNSCryptQueryHeader)) { + return false; } - if (!tcp && packetSize < DnsCryptQuery::minUDPLength) { - return; + if (!tcp && packetSize < DNSCryptQuery::s_minUDPLength) { + return false; } - struct DnsCryptQueryHeader* header = (struct DnsCryptQueryHeader*) packet; + const struct DNSCryptQueryHeader* header = reinterpret_cast(packet); - query->header = *(header); + d_header = *header; - if (!magicMatchesPublicKey(query)) { - return; + if (!d_ctx->magicMatchesAPublicKey(*this, now)) { + return false; } - query->encrypted = true; + d_encrypted = true; + + return true; } -void DnsCryptContext::getDecryptedQuery(std::shared_ptr query, bool tcp, char* packet, uint16_t packetSize, uint16_t* decryptedQueryLen) const +void DNSCryptQuery::getDecrypted(bool tcp, char* packet, uint16_t packetSize, uint16_t* decryptedQueryLen) { - assert(decryptedQueryLen != NULL); - assert(query->encrypted); - assert(query->valid == false); + assert(decryptedQueryLen != nullptr); + assert(d_encrypted); + assert(d_pair != nullptr); + assert(d_valid == false); #ifdef DNSCRYPT_STRICT_PADDING_LENGTH - if (tcp && ((packetSize - sizeof(DnsCryptQueryHeader)) % DNSCRYPT_PADDED_BLOCK_SIZE) != 0) { - vinfolog("Dropping encrypted query with invalid size of %d (should be a multiple of %d)", (packetSize - sizeof(DnsCryptQueryHeader)), DNSCRYPT_PADDED_BLOCK_SIZE); + if (tcp && ((packetSize - sizeof(DNSCryptQueryHeader)) % DNSCRYPT_PADDED_BLOCK_SIZE) != 0) { + vinfolog("Dropping encrypted query with invalid size of %d (should be a multiple of %d)", (packetSize - sizeof(DNSCryptQueryHeader)), DNSCRYPT_PADDED_BLOCK_SIZE); return; } #endif unsigned char nonce[DNSCRYPT_NONCE_SIZE]; - static_assert(sizeof(nonce) == (2* sizeof(query->header.clientNonce)), "Nonce should be larger than clientNonce (half)"); - static_assert(sizeof(query->header.clientPK) == DNSCRYPT_PUBLIC_KEY_SIZE, "Client Publick key size is not right"); - static_assert(sizeof(privateKey.key) == DNSCRYPT_PRIVATE_KEY_SIZE, "Private key size is not right"); + static_assert(sizeof(nonce) == (2* sizeof(d_header.clientNonce)), "Nonce should be larger than clientNonce (half)"); + static_assert(sizeof(d_header.clientPK) == DNSCRYPT_PUBLIC_KEY_SIZE, "Client Publick key size is not right"); + static_assert(sizeof(d_pair->privateKey.key) == DNSCRYPT_PRIVATE_KEY_SIZE, "Private key size is not right"); - memcpy(nonce, &query->header.clientNonce, sizeof(query->header.clientNonce)); - memset(nonce + sizeof(query->header.clientNonce), 0, sizeof(nonce) - sizeof(query->header.clientNonce)); + memcpy(nonce, &d_header.clientNonce, sizeof(d_header.clientNonce)); + memset(nonce + sizeof(d_header.clientNonce), 0, sizeof(nonce) - sizeof(d_header.clientNonce)); #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM - int res = query->computeSharedKey(query->useOldCert ? oldPrivateKey : privateKey); + int res = computeSharedKey(); if (res != 0) { vinfolog("Dropping encrypted query we can't compute the shared key for"); return; } - res = crypto_box_open_easy_afternm((unsigned char*) packet, - (unsigned char*) packet + sizeof(DnsCryptQueryHeader), - packetSize - sizeof(DnsCryptQueryHeader), - nonce, - query->sharedKey); -#else - int res = crypto_box_open_easy((unsigned char*) packet, - (unsigned char*) packet + sizeof(DnsCryptQueryHeader), - packetSize - sizeof(DnsCryptQueryHeader), + const DNSCryptExchangeVersion version = getVersion(); + + if (version == DNSCryptExchangeVersion::VERSION1) { + res = crypto_box_open_easy_afternm(reinterpret_cast(packet), + reinterpret_cast(packet + sizeof(DNSCryptQueryHeader)), + packetSize - sizeof(DNSCryptQueryHeader), + nonce, + d_sharedKey); + } + else if (version == DNSCryptExchangeVersion::VERSION2) { +#ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY + res = crypto_box_curve25519xchacha20poly1305_open_easy_afternm(reinterpret_cast(packet), + reinterpret_cast(packet + sizeof(DNSCryptQueryHeader)), + packetSize - sizeof(DNSCryptQueryHeader), + nonce, + d_sharedKey); +#else /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */ + res = -1; +#endif /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */ + } else { + res = -1; + } + +#else /* HAVE_CRYPTO_BOX_EASY_AFTERNM */ + int res = crypto_box_open_easy(reinterpret_cast(packet), + reinterpret_cast(packet + sizeof(DNSCryptQueryHeader)), + packetSize - sizeof(DNSCryptQueryHeader), nonce, - query->header.clientPK, - query->useOldCert ? oldPrivateKey.key : privateKey.key); + d_header.clientPK, + d_pair->privateKey.key); #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */ if (res != 0) { @@ -352,14 +517,14 @@ void DnsCryptContext::getDecryptedQuery(std::shared_ptr query, bo return; } - *decryptedQueryLen = packetSize - sizeof(DnsCryptQueryHeader) - DNSCRYPT_MAC_SIZE; + *decryptedQueryLen = packetSize - sizeof(DNSCryptQueryHeader) - DNSCRYPT_MAC_SIZE; uint16_t pos = *decryptedQueryLen; assert(pos < packetSize); - query->paddedLen = *decryptedQueryLen; + d_paddedLen = *decryptedQueryLen; while(pos > 0 && packet[pos - 1] == 0) pos--; - if (pos == 0 || ((uint8_t) packet[pos - 1]) != 0x80) { + if (pos == 0 || static_cast(packet[pos - 1]) != 0x80) { vinfolog("Dropping encrypted query with invalid padding value"); return; } @@ -374,32 +539,35 @@ void DnsCryptContext::getDecryptedQuery(std::shared_ptr query, bo return; } - query->len = pos; + d_len = pos; + d_valid = true; +} - query->valid = true; +void DNSCryptQuery::getCertificateResponse(time_t now, std::vector& response) const +{ + assert(d_ctx != nullptr); + d_ctx->getCertificateResponse(now, d_qname, d_id, response); } -void DnsCryptContext::parsePacket(char* packet, uint16_t packetSize, std::shared_ptr query, bool tcp, uint16_t* decryptedQueryLen) const +void DNSCryptQuery::parsePacket(char* packet, uint16_t packetSize, bool tcp, uint16_t* decryptedQueryLen, time_t now) { - assert(packet != NULL); - assert(decryptedQueryLen != NULL); + assert(packet != nullptr); + assert(decryptedQueryLen != nullptr); - query->valid = false; + d_valid = false; /* might be a plaintext certificate request or an authenticated request */ - isQueryEncrypted(packet, packetSize, query, tcp); - - if (query->encrypted) { - getDecryptedQuery(query, tcp, packet, packetSize, decryptedQueryLen); + if (isEncryptedQuery(packet, packetSize, tcp, now)) { + getDecrypted(tcp, packet, packetSize, decryptedQueryLen); } else { - parsePlaintextQuery(packet, packetSize, query); + parsePlaintextQuery(packet, packetSize); } } -void DnsCryptContext::fillServerNonce(unsigned char* nonce) const +void DNSCryptQuery::fillServerNonce(unsigned char* nonce) const { - uint32_t* dest = (uint32_t*) nonce; + uint32_t* dest = reinterpret_cast(nonce); static const size_t nonceSize = DNSCRYPT_NONCE_SIZE / 2; for (size_t pos = 0; pos < (nonceSize / sizeof(*dest)); pos++) @@ -413,44 +581,47 @@ void DnsCryptContext::fillServerNonce(unsigned char* nonce) const "The length of must be between 0 and 256 bytes, and must be constant for a given (, ) tuple." */ -uint16_t DnsCryptContext::computePaddingSize(uint16_t unpaddedLen, size_t maxLen, const unsigned char* clientNonce) const +uint16_t DNSCryptQuery::computePaddingSize(uint16_t unpaddedLen, size_t maxLen) const { - size_t paddedLen = 0; + size_t paddedSize = 0; uint16_t result = 0; uint32_t rnd = 0; - assert(clientNonce != NULL); + assert(d_header.clientNonce); + assert(d_pair != nullptr); + unsigned char nonce[DNSCRYPT_NONCE_SIZE]; - memcpy(nonce, clientNonce, (DNSCRYPT_NONCE_SIZE / 2)); - memcpy(&(nonce[DNSCRYPT_NONCE_SIZE / 2]), clientNonce, (DNSCRYPT_NONCE_SIZE / 2)); - crypto_stream((unsigned char*) &rnd, sizeof(rnd), nonce, privateKey.key); + memcpy(nonce, d_header.clientNonce, (DNSCRYPT_NONCE_SIZE / 2)); + memcpy(&(nonce[DNSCRYPT_NONCE_SIZE / 2]), d_header.clientNonce, (DNSCRYPT_NONCE_SIZE / 2)); + crypto_stream((unsigned char*) &rnd, sizeof(rnd), nonce, d_pair->privateKey.key); - paddedLen = unpaddedLen + rnd % (maxLen - unpaddedLen + 1); - paddedLen += DNSCRYPT_PADDED_BLOCK_SIZE - (paddedLen % DNSCRYPT_PADDED_BLOCK_SIZE); + paddedSize = unpaddedLen + rnd % (maxLen - unpaddedLen + 1); + paddedSize += DNSCRYPT_PADDED_BLOCK_SIZE - (paddedSize % DNSCRYPT_PADDED_BLOCK_SIZE); - if (paddedLen > maxLen) - paddedLen = maxLen; + if (paddedSize > maxLen) + paddedSize = maxLen; - result = paddedLen - unpaddedLen; + result = paddedSize - unpaddedLen; return result; } -int DnsCryptContext::encryptResponse(char* response, uint16_t responseLen, uint16_t responseSize, const std::shared_ptr query, bool tcp, uint16_t* encryptedResponseLen) const +int DNSCryptQuery::encryptResponse(char* response, uint16_t responseLen, uint16_t responseSize, bool tcp, uint16_t* encryptedResponseLen) { - struct DnsCryptResponseHeader header; - assert(response != NULL); + struct DNSCryptResponseHeader responseHeader; + assert(response != nullptr); assert(responseLen > 0); assert(responseSize >= responseLen); - assert(encryptedResponseLen != NULL); - assert(query->encrypted == true); + assert(encryptedResponseLen != nullptr); + assert(d_encrypted == true); + assert(d_pair != nullptr); - if (!tcp && query->paddedLen < responseLen) { - struct dnsheader* dh = (struct dnsheader*) response; + if (!tcp && d_paddedLen < responseLen) { + struct dnsheader* dh = reinterpret_cast(response); size_t questionSize = 0; if (responseLen > sizeof(dnsheader)) { unsigned int consumed = 0; - DNSName qname(response, responseLen, sizeof(dnsheader), false, 0, 0, &consumed); + DNSName tempQName(response, responseLen, sizeof(dnsheader), false, 0, 0, &consumed); if (consumed > 0) { questionSize = consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE; } @@ -458,31 +629,31 @@ int DnsCryptContext::encryptResponse(char* response, uint16_t responseLen, uint1 responseLen = sizeof(dnsheader) + questionSize; - if (responseLen > query->paddedLen) { - responseLen = query->paddedLen; + if (responseLen > d_paddedLen) { + responseLen = d_paddedLen; } dh->ancount = dh->arcount = dh->nscount = 0; dh->tc = 1; } - size_t requiredSize = sizeof(header) + DNSCRYPT_MAC_SIZE + responseLen; + size_t requiredSize = sizeof(responseHeader) + DNSCRYPT_MAC_SIZE + responseLen; size_t maxSize = (responseSize > (requiredSize + DNSCRYPT_MAX_RESPONSE_PADDING_SIZE)) ? (requiredSize + DNSCRYPT_MAX_RESPONSE_PADDING_SIZE) : responseSize; - uint16_t paddingSize = computePaddingSize(requiredSize, maxSize, query->header.clientNonce); + uint16_t paddingSize = computePaddingSize(requiredSize, maxSize); requiredSize += paddingSize; if (requiredSize > responseSize) return ENOBUFS; - memcpy(&header.nonce, &query->header.clientNonce, sizeof query->header.clientNonce); - fillServerNonce(&(header.nonce[sizeof(query->header.clientNonce)])); + memcpy(&responseHeader.nonce, &d_header.clientNonce, sizeof d_header.clientNonce); + fillServerNonce(&(responseHeader.nonce[sizeof(d_header.clientNonce)])); /* moving the existing response after the header + MAC */ - memmove(response + sizeof(header) + DNSCRYPT_MAC_SIZE, response, responseLen); + memmove(response + sizeof(responseHeader) + DNSCRYPT_MAC_SIZE, response, responseLen); uint16_t pos = 0; /* copying header */ - memcpy(response + pos, &header, sizeof(header)); - pos += sizeof(header); + memcpy(response + pos, &responseHeader, sizeof(responseHeader)); + pos += sizeof(responseHeader); /* setting MAC bytes to 0 */ memset(response + pos, 0, DNSCRYPT_MAC_SIZE); pos += DNSCRYPT_MAC_SIZE; @@ -490,30 +661,48 @@ int DnsCryptContext::encryptResponse(char* response, uint16_t responseLen, uint1 /* skipping response */ pos += responseLen; /* padding */ - response[pos] = (uint8_t) 0x80; + response[pos] = static_cast(0x80); pos++; memset(response + pos, 0, paddingSize - 1); pos += (paddingSize - 1); /* encrypting */ #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM - int res = query->computeSharedKey(query->useOldCert ? oldPrivateKey : privateKey); + int res = computeSharedKey(); if (res != 0) { return res; } - res = crypto_box_easy_afternm((unsigned char*) (response + sizeof(header)), - (unsigned char*) (response + toEncryptPos), - responseLen + paddingSize, - header.nonce, - query->sharedKey); + const DNSCryptExchangeVersion version = getVersion(); + + if (version == DNSCryptExchangeVersion::VERSION1) { + res = crypto_box_easy_afternm(reinterpret_cast(response + sizeof(responseHeader)), + reinterpret_cast(response + toEncryptPos), + responseLen + paddingSize, + responseHeader.nonce, + d_sharedKey); + } + else if (version == DNSCryptExchangeVersion::VERSION2) { +#ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY + res = crypto_box_curve25519xchacha20poly1305_easy_afternm(reinterpret_cast(response + sizeof(responseHeader)), + reinterpret_cast(response + toEncryptPos), + responseLen + paddingSize, + responseHeader.nonce, + d_sharedKey); +#else /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */ + res = -1; +#endif /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */ + } + else { + res = -1; + } #else - int res = crypto_box_easy((unsigned char*) (response + sizeof(header)), - (unsigned char*) (response + toEncryptPos), + int res = crypto_box_easy(reinterpret_cast(response + sizeof(responseHeader)), + reinterpret_cast(response + toEncryptPos), responseLen + paddingSize, - header.nonce, - query->header.clientPK, - query->useOldCert ? oldPrivateKey.key : privateKey.key); + responseHeader.nonce, + d_header.clientPK, + d_pair->privateKey.key); #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */ if (res == 0) { @@ -524,34 +713,36 @@ int DnsCryptContext::encryptResponse(char* response, uint16_t responseLen, uint1 return res; } -int DnsCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DnsCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen) const +int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen, const std::shared_ptr cert) const { - assert(query != NULL); + assert(query != nullptr); assert(queryLen > 0); assert(querySize >= queryLen); - assert(encryptedResponseLen != NULL); + assert(encryptedResponseLen != nullptr); + assert(cert != nullptr); + unsigned char nonce[DNSCRYPT_NONCE_SIZE]; - size_t requiredSize = sizeof(DnsCryptQueryHeader) + DNSCRYPT_MAC_SIZE + queryLen; + size_t requiredSize = sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE + queryLen; /* this is not optimal, we should compute a random padding size, multiple of DNSCRYPT_PADDED_BLOCK_SIZE, DNSCRYPT_PADDED_BLOCK_SIZE <= padding size <= 4096? */ uint16_t paddingSize = DNSCRYPT_PADDED_BLOCK_SIZE - (queryLen % DNSCRYPT_PADDED_BLOCK_SIZE); requiredSize += paddingSize; - if (!tcp && requiredSize < DnsCryptQuery::minUDPLength) { - paddingSize += (DnsCryptQuery::minUDPLength - requiredSize); - requiredSize = DnsCryptQuery::minUDPLength; + if (!tcp && requiredSize < DNSCryptQuery::s_minUDPLength) { + paddingSize += (DNSCryptQuery::s_minUDPLength - requiredSize); + requiredSize = DNSCryptQuery::s_minUDPLength; } if (requiredSize > querySize) return ENOBUFS; /* moving the existing query after the header + MAC */ - memmove(query + sizeof(DnsCryptQueryHeader) + DNSCRYPT_MAC_SIZE, query, queryLen); + memmove(query + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE, query, queryLen); size_t pos = 0; /* client magic */ - memcpy(query + pos, cert.signedData.clientMagic, sizeof(cert.signedData.clientMagic)); - pos += sizeof(cert.signedData.clientMagic); + memcpy(query + pos, cert->signedData.clientMagic, sizeof(cert->signedData.clientMagic)); + pos += sizeof(cert->signedData.clientMagic); /* client PK */ memcpy(query + pos, clientPublicKey, DNSCRYPT_PUBLIC_KEY_SIZE); @@ -570,7 +761,7 @@ int DnsCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t query pos += queryLen; /* padding */ - query[pos] = (uint8_t) 0x80; + query[pos] = static_cast(0x80); pos++; memset(query + pos, 0, paddingSize - 1); pos += paddingSize - 1; @@ -578,12 +769,30 @@ int DnsCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t query memcpy(nonce, clientNonce, DNSCRYPT_NONCE_SIZE / 2); memset(nonce + (DNSCRYPT_NONCE_SIZE / 2), 0, DNSCRYPT_NONCE_SIZE / 2); - int res = crypto_box_easy((unsigned char*) query + encryptedPos, - (unsigned char*) query + encryptedPos + DNSCRYPT_MAC_SIZE, - queryLen + paddingSize, - nonce, - cert.signedData.resolverPK, - clientPrivateKey.key); + const DNSCryptExchangeVersion version = getExchangeVersion(*cert); + int res = -1; + + if (version == DNSCryptExchangeVersion::VERSION1) { + res = crypto_box_easy(reinterpret_cast(query + encryptedPos), + reinterpret_cast(query + encryptedPos + DNSCRYPT_MAC_SIZE), + queryLen + paddingSize, + nonce, + cert->signedData.resolverPK, + clientPrivateKey.key); + } + else if (version == DNSCryptExchangeVersion::VERSION2) { +#ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY + res = crypto_box_curve25519xchacha20poly1305_easy(reinterpret_cast(query + encryptedPos), + reinterpret_cast(query + encryptedPos + DNSCRYPT_MAC_SIZE), + queryLen + paddingSize, + nonce, + cert->signedData.resolverPK, + clientPrivateKey.key); +#endif /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */ + } + else { + throw std::runtime_error("Unknown DNSCrypt exchange version"); + } if (res == 0) { assert(pos == requiredSize); @@ -593,7 +802,7 @@ int DnsCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t query return res; } -bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, DnsCryptCert& certOut, DnsCryptPrivateKey& keyOut) +bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, DNSCryptExchangeVersion version, DNSCryptCert& certOut, DNSCryptPrivateKey& keyOut) { bool success = false; unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]; @@ -608,7 +817,7 @@ bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint throw std::runtime_error("Invalid DNSCrypt provider key file " + providerPrivateKeyFile); } - DnsCryptContext::generateCertificate(serial, begin, end, providerPrivateKey, keyOut, certOut); + DNSCryptContext::generateCertificate(serial, begin, end, version, providerPrivateKey, keyOut, certOut); success = true; } catch(const std::exception& e) { diff --git a/pdns/dnscrypt.hh b/pdns/dnscrypt.hh index 49f1186c3..aad89cd8c 100644 --- a/pdns/dnscrypt.hh +++ b/pdns/dnscrypt.hh @@ -27,21 +27,32 @@ #include #include #include +#include + #include #include "dnsname.hh" #define DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE (crypto_sign_ed25519_PUBLICKEYBYTES) #define DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE (crypto_sign_ed25519_SECRETKEYBYTES) +#define DNSCRYPT_SIGNATURE_SIZE (crypto_sign_ed25519_BYTES) + #define DNSCRYPT_PUBLIC_KEY_SIZE (crypto_box_curve25519xsalsa20poly1305_PUBLICKEYBYTES) #define DNSCRYPT_PRIVATE_KEY_SIZE (crypto_box_curve25519xsalsa20poly1305_SECRETKEYBYTES) #define DNSCRYPT_NONCE_SIZE (crypto_box_curve25519xsalsa20poly1305_NONCEBYTES) #define DNSCRYPT_BEFORENM_SIZE (crypto_box_curve25519xsalsa20poly1305_BEFORENMBYTES) -#define DNSCRYPT_SIGNATURE_SIZE (crypto_sign_ed25519_BYTES) #define DNSCRYPT_MAC_SIZE (crypto_box_curve25519xsalsa20poly1305_MACBYTES) + +#ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY +static_assert(crypto_box_curve25519xsalsa20poly1305_PUBLICKEYBYTES == crypto_box_curve25519xchacha20poly1305_PUBLICKEYBYTES, "DNSCrypt public key size should be the same for all exchange versions"); +static_assert(crypto_box_curve25519xsalsa20poly1305_SECRETKEYBYTES == crypto_box_curve25519xchacha20poly1305_SECRETKEYBYTES, "DNSCrypt private key size should be the same for all exchange versions"); +static_assert(crypto_box_curve25519xchacha20poly1305_NONCEBYTES == crypto_box_curve25519xsalsa20poly1305_NONCEBYTES, "DNSCrypt nonce size should be the same for all exchange versions"); +static_assert(crypto_box_curve25519xsalsa20poly1305_MACBYTES == crypto_box_curve25519xchacha20poly1305_MACBYTES, "DNSCrypt MAC size should be the same for all exchange versions"); +static_assert(crypto_box_curve25519xchacha20poly1305_BEFORENMBYTES == crypto_box_curve25519xsalsa20poly1305_BEFORENMBYTES, "DNSCrypt BEFORENM size should be the same for all exchange versions"); +#endif /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */ + #define DNSCRYPT_CERT_MAGIC_SIZE (4) #define DNSCRYPT_CERT_MAGIC_VALUE { 0x44, 0x4e, 0x53, 0x43 } -#define DNSCRYPT_CERT_ES_VERSION_VALUE { 0x00, 0x01 } #define DNSCRYPT_CERT_PROTOCOL_MINOR_VERSION_VALUE { 0x00, 0x00 } #define DNSCRYPT_CLIENT_MAGIC_SIZE (8) #define DNSCRYPT_RESOLVER_MAGIC { 0x72, 0x36, 0x66, 0x6e, 0x76, 0x57, 0x6a, 0x38 } @@ -54,11 +65,14 @@ /* "The client must check for new certificates every hour", so let's use one hour TTL */ #define DNSCRYPT_CERTIFICATE_RESPONSE_TTL (3600) -static_assert(DNSCRYPT_CLIENT_MAGIC_SIZE <= DNSCRYPT_PUBLIC_KEY_SIZE, "Dnscrypt Client Nonce size should be smaller or equal to public key size."); +static_assert(DNSCRYPT_CLIENT_MAGIC_SIZE <= DNSCRYPT_PUBLIC_KEY_SIZE, "DNSCrypt Client Nonce size should be smaller or equal to public key size."); + +#define DNSCRYPT_CERT_ES_VERSION1_VALUE { 0x00, 0x01 } +#define DNSCRYPT_CERT_ES_VERSION2_VALUE { 0x00, 0x02 } -class DnsCryptContext; +class DNSCryptContext; -struct DnsCryptCertSignedData +struct DNSCryptCertSignedData { unsigned char resolverPK[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]; unsigned char clientMagic[DNSCRYPT_CLIENT_MAGIC_SIZE]; @@ -67,128 +81,184 @@ struct DnsCryptCertSignedData uint32_t tsEnd; }; -struct DnsCryptCert +class DNSCryptCert { +public: + uint32_t getSerial() const + { + return signedData.serial; + } + uint32_t getTSStart() const + { + return signedData.tsStart; + } + uint32_t getTSEnd() const + { + return signedData.tsEnd; + } + bool isValid(time_t now) const + { + return ntohl(getTSStart()) <= now && now <= ntohl(getTSEnd()); + } unsigned char magic[DNSCRYPT_CERT_MAGIC_SIZE]; unsigned char esVersion[2]; unsigned char protocolMinorVersion[2]; unsigned char signature[DNSCRYPT_SIGNATURE_SIZE]; - struct DnsCryptCertSignedData signedData; + struct DNSCryptCertSignedData signedData; }; -static_assert((sizeof(DnsCryptCertSignedData) + DNSCRYPT_SIGNATURE_SIZE) == 116, "Dnscrypt cert signed data size + signature size should be 116!"); -static_assert(sizeof(DnsCryptCert) == 124, "Dnscrypt cert size should be 124!"); +static_assert((sizeof(DNSCryptCertSignedData) + DNSCRYPT_SIGNATURE_SIZE) == 116, "Dnscrypt cert signed data size + signature size should be 116!"); +static_assert(sizeof(DNSCryptCert) == 124, "Dnscrypt cert size should be 124!"); -struct DnsCryptQueryHeader +struct DNSCryptQueryHeader { unsigned char clientMagic[DNSCRYPT_CLIENT_MAGIC_SIZE]; unsigned char clientPK[DNSCRYPT_PUBLIC_KEY_SIZE]; unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2]; }; -static_assert(sizeof(DnsCryptQueryHeader) == 52, "Dnscrypt query header size should be 52!"); +static_assert(sizeof(DNSCryptQueryHeader) == 52, "Dnscrypt query header size should be 52!"); -struct DnsCryptResponseHeader +struct DNSCryptResponseHeader { const unsigned char resolverMagic[DNSCRYPT_RESOLVER_MAGIC_SIZE] = DNSCRYPT_RESOLVER_MAGIC; unsigned char nonce[DNSCRYPT_NONCE_SIZE]; }; -class DnsCryptPrivateKey +typedef enum { + VERSION1, + VERSION2 +} DNSCryptExchangeVersion; + +class DNSCryptPrivateKey { public: - DnsCryptPrivateKey(); - ~DnsCryptPrivateKey(); + DNSCryptPrivateKey(); + ~DNSCryptPrivateKey(); void loadFromFile(const std::string& keyFile); void saveToFile(const std::string& keyFile) const; unsigned char key[DNSCRYPT_PRIVATE_KEY_SIZE]; }; -class DnsCryptQuery +struct DNSCryptCertificatePair +{ + unsigned char publicKey[DNSCRYPT_PUBLIC_KEY_SIZE]; + DNSCryptCert cert; + DNSCryptPrivateKey privateKey; + bool active; +}; + +class DNSCryptQuery { public: - DnsCryptQuery() + DNSCryptQuery(std::shared_ptr ctx): d_ctx(ctx) + { + } + ~DNSCryptQuery(); + + bool isValid() const + { + return d_valid; + } + + const DNSName& getQName() const { + return d_qname; } - ~DnsCryptQuery(); + + uint16_t getID() const + { + return d_id; + } + + const unsigned char* getClientMagic() const + { + return d_header.clientMagic; + } + + bool isEncrypted() const + { + return d_encrypted; + } + + void setCertificatePair(std::shared_ptr pair) + { + d_pair = pair; + } + + void parsePacket(char* packet, uint16_t packetSize, bool tcp, uint16_t* decryptedQueryLen, time_t now); + void getDecrypted(bool tcp, char* packet, uint16_t packetSize, uint16_t* decryptedQueryLen); + void getCertificateResponse(time_t now, std::vector& response) const; + int encryptResponse(char* response, uint16_t responseLen, uint16_t responseSize, bool tcp, uint16_t* encryptedResponseLen); + + static const size_t s_minUDPLength = 256; + +private: + DNSCryptExchangeVersion getVersion() const; #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM - int computeSharedKey(const DnsCryptPrivateKey& privateKey); + int computeSharedKey(); #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */ + void fillServerNonce(unsigned char* dest) const; + uint16_t computePaddingSize(uint16_t unpaddedLen, size_t maxLen) const; + bool parsePlaintextQuery(const char * packet, uint16_t packetSize); + bool isEncryptedQuery(const char * packet, uint16_t packetSize, bool tcp, time_t now); - static const size_t minUDPLength = 256; - - DnsCryptQueryHeader header; + DNSCryptQueryHeader d_header; #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM - unsigned char sharedKey[crypto_box_BEFORENMBYTES]; + unsigned char d_sharedKey[crypto_box_BEFORENMBYTES]; #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */ - DNSName qname; - DnsCryptContext* ctx; - uint16_t id{0}; - uint16_t len{0}; - uint16_t paddedLen; - bool useOldCert{false}; - bool encrypted{false}; - bool valid{false}; + DNSName d_qname; + std::shared_ptr d_ctx{nullptr}; + std::shared_ptr d_pair{nullptr}; + uint16_t d_id{0}; + uint16_t d_len{0}; + uint16_t d_paddedLen; + bool d_encrypted{false}; + bool d_valid{false}; + #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM - bool sharedKeyComputed{false}; + bool d_sharedKeyComputed{false}; #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */ }; -class DnsCryptContext +class DNSCryptContext { public: static void generateProviderKeys(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE], unsigned char privateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]); static std::string getProviderFingerprint(unsigned char publicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]); - static void generateCertificate(uint32_t serial, time_t begin, time_t end, const unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE], DnsCryptPrivateKey& privateKey, DnsCryptCert& cert); - static void saveCertFromFile(const DnsCryptCert& cert, const std::string&filename); + static void generateCertificate(uint32_t serial, time_t begin, time_t end, const DNSCryptExchangeVersion& version, const unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE], DNSCryptPrivateKey& privateKey, DNSCryptCert& cert); + static void saveCertFromFile(const DNSCryptCert& cert, const std::string&filename); static std::string certificateDateToStr(uint32_t date); - static void generateResolverKeyPair(DnsCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]); + static void generateResolverKeyPair(DNSCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]); + static void setExchangeVersion(const DNSCryptExchangeVersion& version, unsigned char esVersion[sizeof(DNSCryptCert::esVersion)]); + static DNSCryptExchangeVersion getExchangeVersion(const unsigned char esVersion[sizeof(DNSCryptCert::esVersion)]); + static DNSCryptExchangeVersion getExchangeVersion(const DNSCryptCert& cert); - DnsCryptContext(const std::string& pName, const std::string& certFile, const std::string& keyFile): providerName(pName) - { - loadCertFromFile(certFile, cert); - privateKey.loadFromFile(keyFile); - computePublicKeyFromPrivate(privateKey, publicKey); - } + DNSCryptContext(const std::string& pName, const std::string& certFile, const std::string& keyFile); + DNSCryptContext(const std::string& pName, const DNSCryptCert& certificate, const DNSCryptPrivateKey& pKey); - DnsCryptContext(const std::string& pName, const DnsCryptCert& certificate, const DnsCryptPrivateKey& pKey): providerName(pName), cert(certificate), privateKey(pKey) - { - computePublicKeyFromPrivate(privateKey, publicKey); - } - - void parsePacket(char* packet, uint16_t packetSize, std::shared_ptr query, bool tcp, uint16_t* decryptedQueryLen) const; - int encryptResponse(char* response, uint16_t responseLen, uint16_t responseSize, const std::shared_ptr query, bool tcp, uint16_t* encryptedResponseLen) const; - void getCertificateResponse(const std::shared_ptr query, std::vector& response) const; - void loadNewCertificate(const std::string& certFile, const std::string& keyFile); - void setNewCertificate(const DnsCryptCert& newCert, const DnsCryptPrivateKey& newKey); - const DnsCryptCert& getCurrentCertificate() const { return cert; }; - const DnsCryptCert& getOldCertificate() const { return oldCert; }; - bool hasOldCertificate() const { return hasOldCert; }; - const std::string& getProviderName() const { return providerName; } - int encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DnsCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen) const; + void loadNewCertificate(const std::string& certFile, const std::string& keyFile, bool active=true); + void addNewCertificate(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, bool active=true); + void markActive(uint32_t serial); + void markInactive(uint32_t serial); + void removeInactiveCertificate(uint32_t serial); + std::vector> getCertificates() { return certs; }; + const DNSName& getProviderName() const { return providerName; } + int encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen, const std::shared_ptr cert) const; + bool magicMatchesAPublicKey(DNSCryptQuery& query, time_t now); + void getCertificateResponse(time_t now, const DNSName& qname, uint16_t qid, std::vector& response); private: - static void computePublicKeyFromPrivate(const DnsCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]); - static void loadCertFromFile(const std::string&filename, DnsCryptCert& dest); + static void computePublicKeyFromPrivate(const DNSCryptPrivateKey& privK, unsigned char pubK[DNSCRYPT_PUBLIC_KEY_SIZE]); + static void loadCertFromFile(const std::string&filename, DNSCryptCert& dest); - void parsePlaintextQuery(const char * packet, uint16_t packetSize, std::shared_ptr query) const; - bool magicMatchesPublicKey(std::shared_ptr query) const; - void isQueryEncrypted(const char * packet, uint16_t packetSize, std::shared_ptr query, bool tcp) const; - void getDecryptedQuery(std::shared_ptr query, bool tcp, char* packet, uint16_t packetSize, uint16_t* decryptedQueryLen) const; - void fillServerNonce(unsigned char* dest) const; - uint16_t computePaddingSize(uint16_t unpaddedLen, size_t maxLen, const unsigned char* clientNonce) const; - - std::string providerName; - DnsCryptCert cert; - DnsCryptCert oldCert; - DnsCryptPrivateKey privateKey; - unsigned char publicKey[DNSCRYPT_PUBLIC_KEY_SIZE]; - DnsCryptPrivateKey oldPrivateKey; - bool hasOldCert{false}; + pthread_rwlock_t d_lock; + std::vector> certs; + DNSName providerName; }; -bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, DnsCryptCert& certOut, DnsCryptPrivateKey& keyOut); +bool generateDNSCryptCertificate(const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, DNSCryptExchangeVersion version, DNSCryptCert& certOut, DNSCryptPrivateKey& keyOut); #endif diff --git a/pdns/dnsdist-dnscrypt.cc b/pdns/dnsdist-dnscrypt.cc index 5a4c4885d..bc8c75a23 100644 --- a/pdns/dnsdist-dnscrypt.cc +++ b/pdns/dnsdist-dnscrypt.cc @@ -24,24 +24,22 @@ #include "dnscrypt.hh" #ifdef HAVE_DNSCRYPT -int handleDnsCryptQuery(DnsCryptContext* ctx, char* packet, uint16_t len, std::shared_ptr& query, uint16_t* decryptedQueryLen, bool tcp, std::vector& response) +int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr query, uint16_t* decryptedQueryLen, bool tcp, time_t now, std::vector& response) { - query->ctx = ctx; + query->parsePacket(packet, len, tcp, decryptedQueryLen, now); - ctx->parsePacket(packet, len, query, tcp, decryptedQueryLen); - - if (query->valid == false) { + if (query->isValid() == false) { vinfolog("Dropping DNSCrypt invalid query"); return false; } - if (query->encrypted == false) { - ctx->getCertificateResponse(query, response); + if (query->isEncrypted() == false) { + query->getCertificateResponse(now, response); return false; } - if(*decryptedQueryLen < (int)sizeof(struct dnsheader)) { + if(*decryptedQueryLen < static_cast(sizeof(struct dnsheader))) { g_stats.nonCompliantQueries++; return false; } diff --git a/pdns/dnsdist-lua-bindings.cc b/pdns/dnsdist-lua-bindings.cc index bb1415c7a..292bc48c2 100644 --- a/pdns/dnsdist-lua-bindings.cc +++ b/pdns/dnsdist-lua-bindings.cc @@ -286,37 +286,113 @@ void setupLuaBindings(bool client) }); #ifdef HAVE_DNSCRYPT - /* DnsCryptContext bindings */ - g_lua.registerFunction("getProviderName", [](const DnsCryptContext& ctx) { return ctx.getProviderName(); }); - g_lua.registerFunction("getCurrentCertificate", [](const DnsCryptContext& ctx) { return ctx.getCurrentCertificate(); }); - g_lua.registerFunction("getOldCertificate", [](const DnsCryptContext& ctx) { return ctx.getOldCertificate(); }); - g_lua.registerFunction("hasOldCertificate", &DnsCryptContext::hasOldCertificate); - g_lua.registerFunction("loadNewCertificate", &DnsCryptContext::loadNewCertificate); - g_lua.registerFunction("generateAndLoadInMemoryCertificate", [](DnsCryptContext& ctx, const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end) { - DnsCryptPrivateKey privateKey; - DnsCryptCert cert; - - try { - if (generateDNSCryptCertificate(providerPrivateKeyFile, serial, begin, end, cert, privateKey)) { - ctx.setNewCertificate(cert, privateKey); + /* DNSCryptContext bindings */ + g_lua.registerFunction("getProviderName", [](const DNSCryptContext& ctx) { return ctx.getProviderName().toStringNoDot(); }); + g_lua.registerFunction("addNewCertificate", &DNSCryptContext::addNewCertificate); + g_lua.registerFunction("markActive", &DNSCryptContext::markActive); + g_lua.registerFunction("markInactive", &DNSCryptContext::markInactive); + g_lua.registerFunction("removeInactiveCertificate", &DNSCryptContext::removeInactiveCertificate); + g_lua.registerFunction::*)(const std::string& certFile, const std::string& keyFile, boost::optional active)>("loadNewCertificate", [](std::shared_ptr ctx, const std::string& certFile, const std::string& keyFile, boost::optional active) { + + if (ctx == nullptr) { + throw std::runtime_error("DNSCryptContext::loadNewCertificate() called on a nil value"); + } + + ctx->loadNewCertificate(certFile, keyFile, active ? *active : true); + }); + g_lua.registerFunction::*)(const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, boost::optional active)>("addNewCertificate", [](std::shared_ptr ctx, const DNSCryptCert& newCert, const DNSCryptPrivateKey& newKey, boost::optional active) { + + if (ctx == nullptr) { + throw std::runtime_error("DNSCryptContext::addNewCertificate() called on a nil value"); + } + + ctx->addNewCertificate(newCert, newKey, active ? *active : true); + }); + g_lua.registerFunction>(std::shared_ptr::*)()>("getCertificatePairs", [](std::shared_ptr ctx) { + std::map> result; + + if (ctx != nullptr) { + size_t idx = 1; + for (auto pair : ctx->getCertificates()) { + result[idx++] = pair; } } - catch(const std::exception& e) { - errlog(e.what()); - g_outputBuffer="Error: "+string(e.what())+"\n"; + + return result; + }); + g_lua.registerFunction(std::shared_ptr::*)(size_t idx)>("getCertificatePair", [](std::shared_ptr ctx, size_t idx) { + + if (ctx == nullptr) { + throw std::runtime_error("DNSCryptContext::getCertificatePair() called on a nil value"); + } + + std::shared_ptr result = nullptr; + auto certs = ctx->getCertificates(); + if (idx < certs.size()) { + result = certs.at(idx); } + + return result; }); - /* DnsCryptCert */ - g_lua.registerFunction("getMagic", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast(cert.magic), sizeof(cert.magic)); }); - g_lua.registerFunction("getEsVersion", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast(cert.esVersion), sizeof(cert.esVersion)); }); - g_lua.registerFunction("getProtocolMinorVersion", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast(cert.protocolMinorVersion), sizeof(cert.protocolMinorVersion)); }); - g_lua.registerFunction("getSignature", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast(cert.signature), sizeof(cert.signature)); }); - g_lua.registerFunction("getResolverPublicKey", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast(cert.signedData.resolverPK), sizeof(cert.signedData.resolverPK)); }); - g_lua.registerFunction("getClientMagic", [](const DnsCryptCert& cert) { return std::string(reinterpret_cast(cert.signedData.clientMagic), sizeof(cert.signedData.clientMagic)); }); - g_lua.registerFunction("getSerial", [](const DnsCryptCert& cert) { return cert.signedData.serial; }); - g_lua.registerFunction("getTSStart", [](const DnsCryptCert& cert) { return ntohl(cert.signedData.tsStart); }); - g_lua.registerFunction("getTSEnd", [](const DnsCryptCert& cert) { return ntohl(cert.signedData.tsEnd); }); + g_lua.registerFunction::*)()>("printCertificates", [](const std::shared_ptr ctx) { + ostringstream ret; + + if (ctx != nullptr) { + size_t idx = 1; + boost::format fmt("%1$-3d %|5t|%2$-8d %|10t|%3$-2d %|20t|%4$-21.21s %|41t|%5$-21.21s"); + ret << (fmt % "#" % "Serial" % "Version" % "From" % "To" ) << endl; + + for (auto pair : ctx->getCertificates()) { + const auto cert = pair->cert; + const DNSCryptExchangeVersion version = DNSCryptContext::getExchangeVersion(cert); + + ret << (fmt % idx % cert.getSerial() % (version == DNSCryptExchangeVersion::VERSION1 ? 1 : 2) % DNSCryptContext::certificateDateToStr(cert.getTSStart()) % DNSCryptContext::certificateDateToStr(cert.getTSEnd())) << endl; + } + } + + return ret.str(); + }); + + g_lua.registerFunction version)>("generateAndLoadInMemoryCertificate", [](DNSCryptContext& ctx, const std::string& providerPrivateKeyFile, uint32_t serial, time_t begin, time_t end, boost::optional version) { + DNSCryptPrivateKey privateKey; + DNSCryptCert cert; + + try { + if (generateDNSCryptCertificate(providerPrivateKeyFile, serial, begin, end, version ? *version : DNSCryptExchangeVersion::VERSION1, cert, privateKey)) { + ctx.addNewCertificate(cert, privateKey); + } + } + catch(const std::exception& e) { + errlog(e.what()); + g_outputBuffer="Error: "+string(e.what())+"\n"; + } + }); + + /* DNSCryptCertificatePair */ + g_lua.registerFunction::*)()>("getCertificate", [](const std::shared_ptr pair) { + if (pair == nullptr) { + throw std::runtime_error("DNSCryptCertificatePair::getCertificate() called on a nil value"); + } + return pair->cert; + }); + g_lua.registerFunction::*)()>("isActive", [](const std::shared_ptr pair) { + if (pair == nullptr) { + throw std::runtime_error("DNSCryptCertificatePair::isActive() called on a nil value"); + } + return pair->active; + }); + + /* DNSCryptCert */ + g_lua.registerFunction("getMagic", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast(cert.magic), sizeof(cert.magic)); }); + g_lua.registerFunction("getEsVersion", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast(cert.esVersion), sizeof(cert.esVersion)); }); + g_lua.registerFunction("getProtocolMinorVersion", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast(cert.protocolMinorVersion), sizeof(cert.protocolMinorVersion)); }); + g_lua.registerFunction("getSignature", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast(cert.signature), sizeof(cert.signature)); }); + g_lua.registerFunction("getResolverPublicKey", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast(cert.signedData.resolverPK), sizeof(cert.signedData.resolverPK)); }); + g_lua.registerFunction("getClientMagic", [](const DNSCryptCert& cert) { return std::string(reinterpret_cast(cert.signedData.clientMagic), sizeof(cert.signedData.clientMagic)); }); + g_lua.registerFunction("getSerial", [](const DNSCryptCert& cert) { return cert.getSerial(); }); + g_lua.registerFunction("getTSStart", [](const DNSCryptCert& cert) { return ntohl(cert.getTSStart()); }); + g_lua.registerFunction("getTSEnd", [](const DNSCryptCert& cert) { return ntohl(cert.getTSEnd()); }); #endif /* BPF Filter */ diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index d8a8c916f..470d8827e 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -912,7 +912,7 @@ void setupLuaConfig(bool client) parseLocalBindVars(vars, doTCP, reusePort, tcpFastOpenQueueSize, interface, cpus); try { - DnsCryptContext ctx(providerName, certFile, keyFile); + auto ctx = std::make_shared(providerName, certFile, keyFile); g_dnsCryptLocals.push_back(std::make_tuple(ComboAddress(addr, 443), ctx, reusePort, tcpFastOpenQueueSize, interface, cpus)); } catch(std::exception& e) { @@ -929,17 +929,13 @@ void setupLuaConfig(bool client) setLuaNoSideEffect(); #ifdef HAVE_DNSCRYPT ostringstream ret; - boost::format fmt("%1$-3d %2% %|25t|%3$-20.20s %|26t|%4$-8d %|35t|%5$-21.21s %|56t|%6$-9d %|66t|%7$-21.21s" ); - ret << (fmt % "#" % "Address" % "Provider Name" % "Serial" % "Validity" % "P. Serial" % "P. Validity") << endl; + boost::format fmt("%1$-3d %2% %|25t|%3$-20.20s"); + ret << (fmt % "#" % "Address" % "Provider Name") << endl; size_t idx = 0; for (const auto& local : g_dnsCryptLocals) { - const DnsCryptContext& ctx = std::get<1>(local); - bool const hasOldCert = ctx.hasOldCertificate(); - const DnsCryptCert& cert = ctx.getCurrentCertificate(); - const DnsCryptCert& oldCert = ctx.getOldCertificate(); - - ret<< (fmt % idx % std::get<0>(local).toStringWithPort() % ctx.getProviderName() % cert.signedData.serial % DnsCryptContext::certificateDateToStr(cert.signedData.tsEnd) % (hasOldCert ? oldCert.signedData.serial : 0) % (hasOldCert ? DnsCryptContext::certificateDateToStr(oldCert.signedData.tsEnd) : "-")) << endl; + const std::shared_ptr ctx = std::get<1>(local); + ret<< (fmt % idx % std::get<0>(local).toStringWithPort() % ctx->getProviderName()) << endl; idx++; } @@ -949,12 +945,12 @@ void setupLuaConfig(bool client) #endif }); - g_lua.writeFunction("getDNSCryptBind", [client](size_t idx) { + g_lua.writeFunction("getDNSCryptBind", [](size_t idx) { setLuaNoSideEffect(); #ifdef HAVE_DNSCRYPT - DnsCryptContext* ret = nullptr; + std::shared_ptr ret = nullptr; if (idx < g_dnsCryptLocals.size()) { - ret = &(std::get<1>(g_dnsCryptLocals.at(idx))); + ret = std::get<1>(g_dnsCryptLocals.at(idx)); } return ret; #else @@ -970,7 +966,7 @@ void setupLuaConfig(bool client) sodium_mlock(privateKey, sizeof(privateKey)); try { - DnsCryptContext::generateProviderKeys(publicKey, privateKey); + DNSCryptContext::generateProviderKeys(publicKey, privateKey); ofstream pubKStream(publicKeyFile); pubKStream.write((char*) publicKey, sizeof(publicKey)); @@ -980,7 +976,7 @@ void setupLuaConfig(bool client) privKStream.write((char*) privateKey, sizeof(privateKey)); privKStream.close(); - g_outputBuffer="Provider fingerprint is: " + DnsCryptContext::getProviderFingerprint(publicKey) + "\n"; + g_outputBuffer="Provider fingerprint is: " + DNSCryptContext::getProviderFingerprint(publicKey) + "\n"; } catch(std::exception& e) { errlog(e.what()); @@ -1007,7 +1003,7 @@ void setupLuaConfig(bool client) throw std::runtime_error("Invalid dnscrypt provider public key file " + publicKeyFile); file.close(); - g_outputBuffer="Provider fingerprint is: " + DnsCryptContext::getProviderFingerprint(publicKey) + "\n"; + g_outputBuffer="Provider fingerprint is: " + DNSCryptContext::getProviderFingerprint(publicKey) + "\n"; } catch(std::exception& e) { errlog(e.what()); @@ -1018,16 +1014,16 @@ void setupLuaConfig(bool client) #endif }); - g_lua.writeFunction("generateDNSCryptCertificate", [](const std::string& providerPrivateKeyFile, const std::string& certificateFile, const std::string privateKeyFile, uint32_t serial, time_t begin, time_t end) { + g_lua.writeFunction("generateDNSCryptCertificate", [](const std::string& providerPrivateKeyFile, const std::string& certificateFile, const std::string privateKeyFile, uint32_t serial, time_t begin, time_t end, boost::optional version) { setLuaNoSideEffect(); #ifdef HAVE_DNSCRYPT - DnsCryptPrivateKey privateKey; - DnsCryptCert cert; + DNSCryptPrivateKey privateKey; + DNSCryptCert cert; try { - if (generateDNSCryptCertificate(providerPrivateKeyFile, serial, begin, end, cert, privateKey)) { + if (generateDNSCryptCertificate(providerPrivateKeyFile, serial, begin, end, version ? *version : DNSCryptExchangeVersion::VERSION1, cert, privateKey)) { privateKey.saveToFile(privateKeyFile); - DnsCryptContext::saveCertFromFile(cert, certificateFile); + DNSCryptContext::saveCertFromFile(cert, certificateFile); } } catch(const std::exception& e) { diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 3516c25e5..b3b74864e 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -316,14 +316,21 @@ void* tcpClientThread(int pipefd) char* query = &queryBuffer[0]; handler.read(query, qlen, g_tcpRecvTimeout, remainingTime); + + /* we need this one to be accurate ("real") for the protobuf message */ + struct timespec queryRealTime; + struct timespec now; + gettime(&now); + gettime(&queryRealTime, true); + #ifdef HAVE_DNSCRYPT - std::shared_ptr dnsCryptQuery = nullptr; + std::shared_ptr dnsCryptQuery = nullptr; if (ci.cs->dnscryptCtx) { - dnsCryptQuery = std::make_shared(); + dnsCryptQuery = std::make_shared(ci.cs->dnscryptCtx); uint16_t decryptedQueryLen = 0; vector response; - bool decrypted = handleDnsCryptQuery(ci.cs->dnscryptCtx, query, qlen, dnsCryptQuery, &decryptedQueryLen, true, response); + bool decrypted = handleDNSCryptQuery(query, qlen, dnsCryptQuery, &decryptedQueryLen, true, queryRealTime.tv_sec, response); if (!decrypted) { if (response.size() > 0) { @@ -342,11 +349,6 @@ void* tcpClientThread(int pipefd) string poolname; int delayMsec=0; - /* we need this one to be accurate ("real") for the protobuf message */ - struct timespec queryRealTime; - struct timespec now; - gettime(&now); - gettime(&queryRealTime, true); const uint16_t* flags = getFlagsFromDNSHeader(dh); uint16_t origFlags = *flags; diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index e525ba9c0..6ecac15c5 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -90,7 +90,7 @@ string g_outputBuffer; vector>> g_locals; std::vector> g_tlslocals; #ifdef HAVE_DNSCRYPT -std::vector>> g_dnsCryptLocals; +std::vector,bool, int, string, std::set >> g_dnsCryptLocals; #endif #ifdef HAVE_EBPF shared_ptr g_defaultBPFFilter; @@ -336,7 +336,7 @@ bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, } #ifdef HAVE_DNSCRYPT -bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy) +bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy) { if (dnsCryptQuery) { uint16_t encryptedResponseLen = 0; @@ -347,7 +347,7 @@ bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, *dh = dhCopy; } - int res = dnsCryptQuery->ctx->encryptResponse(response, *responseLen, responseSize, dnsCryptQuery, tcp, &encryptedResponseLen); + int res = dnsCryptQuery->encryptResponse(response, *responseLen, responseSize, tcp, &encryptedResponseLen); if (res == 0) { *responseLen = encryptedResponseLen; } else { @@ -1154,15 +1154,15 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s } #ifdef HAVE_DNSCRYPT -static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr& dnsCryptQuery, const ComboAddress& dest, const ComboAddress& remote) +static bool checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr& dnsCryptQuery, const ComboAddress& dest, const ComboAddress& remote, time_t now) { if (cs.dnscryptCtx) { vector response; uint16_t decryptedQueryLen = 0; - dnsCryptQuery = std::make_shared(); + dnsCryptQuery = std::make_shared(cs.dnscryptCtx); - bool decrypted = handleDnsCryptQuery(cs.dnscryptCtx, const_cast(query), len, dnsCryptQuery, &decryptedQueryLen, false, response); + bool decrypted = handleDNSCryptQuery(const_cast(query), len, dnsCryptQuery, &decryptedQueryLen, false, now, response); if (!decrypted) { if (response.size() > 0) { @@ -1221,10 +1221,18 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct return; } + /* we need an accurate ("real") value for the response and + to store into the IDS, but not for insertion into the + rings for example */ + struct timespec queryRealTime; + struct timespec now; + gettime(&now); + gettime(&queryRealTime, true); + #ifdef HAVE_DNSCRYPT - std::shared_ptr dnsCryptQuery = nullptr; + std::shared_ptr dnsCryptQuery = nullptr; - if (!checkDNSCryptQuery(cs, query, len, dnsCryptQuery, dest, remote)) { + if (!checkDNSCryptQuery(cs, query, len, dnsCryptQuery, dest, remote, queryRealTime.tv_sec)) { return; } #endif @@ -1238,14 +1246,6 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct string poolname; int delayMsec = 0; - /* we need an accurate ("real") value for the response and - to store into the IDS, but not for insertion into the - rings for example */ - struct timespec queryRealTime; - struct timespec now; - gettime(&now); - gettime(&queryRealTime, true); - const uint16_t * flags = getFlagsFromDNSHeader(dh); const uint16_t origFlags = *flags; uint16_t qtype, qclass; @@ -2362,7 +2362,7 @@ try for(auto& dcLocal : g_dnsCryptLocals) { ClientState* cs = new ClientState; cs->local = std::get<0>(dcLocal); - cs->dnscryptCtx = &(std::get<1>(dcLocal)); + cs->dnscryptCtx = std::get<1>(dcLocal); cs->udpFD = SSocket(cs->local.sin4.sin_family, SOCK_DGRAM, 0); if(cs->local.sin4.sin_family == AF_INET6) { SSetsockopt(cs->udpFD, IPPROTO_IPV6, IPV6_V6ONLY, 1); @@ -2408,7 +2408,7 @@ try cs = new ClientState; cs->local = std::get<0>(dcLocal); - cs->dnscryptCtx = &(std::get<1>(dcLocal)); + cs->dnscryptCtx = std::get<1>(dcLocal); cs->tcpFD = SSocket(cs->local.sin4.sin_family, SOCK_STREAM, 0); SSetsockopt(cs->tcpFD, SOL_SOCKET, SO_REUSEADDR, 1); #ifdef TCP_DEFER_ACCEPT diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index cfc92475f..90d9041d8 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -341,7 +341,7 @@ struct IDState StopWatch sentTime; // 16 DNSName qname; // 80 #ifdef HAVE_DNSCRYPT - std::shared_ptr dnsCryptQuery{0}; + std::shared_ptr dnsCryptQuery{nullptr}; #endif #ifdef HAVE_PROTOBUF boost::optional uniqueId; @@ -368,7 +368,7 @@ struct Rings { { queryRing.set_capacity(capacity); respRing.set_capacity(capacity); - pthread_rwlock_init(&queryLock, 0); + pthread_rwlock_init(&queryLock, nullptr); } struct Query { @@ -417,7 +417,7 @@ typedef std::function(DNSQuestion dq)> QueryCountFilter struct QueryCount { QueryCount() { - pthread_rwlock_init(&queryLock, 0); + pthread_rwlock_init(&queryLock, nullptr); } QueryCountRecords records; QueryCountFilter filter; @@ -432,7 +432,7 @@ struct ClientState std::set cpus; ComboAddress local; #ifdef HAVE_DNSCRYPT - DnsCryptContext* dnscryptCtx{0}; + std::shared_ptr dnscryptCtx{nullptr}; #endif shared_ptr tlsFrontend; std::atomic queries{0}; @@ -802,10 +802,10 @@ void restoreFlags(struct dnsheader* dh, uint16_t origFlags); bool checkQueryHeaders(const struct dnsheader* dh); #ifdef HAVE_DNSCRYPT -extern std::vector>> g_dnsCryptLocals; +extern std::vector, bool, int, std::string, std::set > > g_dnsCryptLocals; -int handleDnsCryptQuery(DnsCryptContext* ctx, char* packet, uint16_t len, std::shared_ptr& query, uint16_t* decryptedQueryLen, bool tcp, std::vector& response); -bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy); +bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy); +int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr query, uint16_t* decryptedQueryLen, bool tcp, time_t now, std::vector& response); #endif bool addXPF(DNSQuestion& dq, uint16_t optionCode); diff --git a/pdns/test-dnscrypt_cc.cc b/pdns/test-dnscrypt_cc.cc index 5bf67a773..c049ff03c 100644 --- a/pdns/test-dnscrypt_cc.cc +++ b/pdns/test-dnscrypt_cc.cc @@ -40,14 +40,14 @@ BOOST_AUTO_TEST_SUITE(dnscrypt_cc) // plaintext query for cert BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQuery) { - DnsCryptPrivateKey resolverPrivateKey; - DnsCryptCert resolverCert; + DNSCryptPrivateKey resolverPrivateKey; + DNSCryptCert resolverCert; unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]; unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]; - time_t now = time(NULL); - DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); - DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert); - DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey); + time_t now = time(nullptr); + DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); + DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert); + auto ctx = std::make_shared("2.name", resolverCert, resolverPrivateKey); DNSName name("2.name."); vector plainQuery; @@ -55,17 +55,17 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQuery) { pw.getHeader()->rd = 0; uint16_t len = plainQuery.size(); - std::shared_ptr query = std::make_shared(); + std::shared_ptr query = std::make_shared(ctx); uint16_t decryptedLen = 0; - ctx.parsePacket((char*) plainQuery.data(), len, query, false, &decryptedLen); + query->parsePacket((char*) plainQuery.data(), len, false, &decryptedLen, now); - BOOST_CHECK_EQUAL(query->valid, true); - BOOST_CHECK_EQUAL(query->encrypted, false); + BOOST_CHECK_EQUAL(query->isValid(), true); + BOOST_CHECK_EQUAL(query->isEncrypted(), false); std::vector response; - ctx.getCertificateResponse(query, response); + query->getCertificateResponse(now, response); MOADNSParser mdp(false, (char*) response.data(), response.size()); @@ -81,14 +81,14 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQuery) { // invalid plaintext query (A) BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQueryInvalidA) { - DnsCryptPrivateKey resolverPrivateKey; - DnsCryptCert resolverCert; + DNSCryptPrivateKey resolverPrivateKey; + DNSCryptCert resolverCert; unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]; unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]; - time_t now = time(NULL); - DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); - DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert); - DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey); + time_t now = time(nullptr); + DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); + DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert); + auto ctx = std::make_shared("2.name", resolverCert, resolverPrivateKey); DNSName name("2.name."); @@ -97,24 +97,24 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQueryInvalidA) { pw.getHeader()->rd = 0; uint16_t len = plainQuery.size(); - std::shared_ptr query = std::make_shared(); + std::shared_ptr query = std::make_shared(ctx); uint16_t decryptedLen = 0; - ctx.parsePacket((char*) plainQuery.data(), len, query, false, &decryptedLen); + query->parsePacket((char*) plainQuery.data(), len, false, &decryptedLen, now); - BOOST_CHECK_EQUAL(query->valid, false); + BOOST_CHECK_EQUAL(query->isValid(), false); } // invalid plaintext query (wrong provider name) BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQueryInvalidProviderName) { - DnsCryptPrivateKey resolverPrivateKey; - DnsCryptCert resolverCert; + DNSCryptPrivateKey resolverPrivateKey; + DNSCryptCert resolverCert; unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]; unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]; - time_t now = time(NULL); - DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); - DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert); - DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey); + time_t now = time(nullptr); + DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); + DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert); + auto ctx = std::make_shared("2.name", resolverCert, resolverPrivateKey); DNSName name("2.WRONG.name."); @@ -123,29 +123,29 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQueryInvalidProviderName) { pw.getHeader()->rd = 0; uint16_t len = plainQuery.size(); - std::shared_ptr query = std::make_shared(); + std::shared_ptr query = std::make_shared(ctx); uint16_t decryptedLen = 0; - ctx.parsePacket((char*) plainQuery.data(), len, query, false, &decryptedLen); + query->parsePacket((char*) plainQuery.data(), len, false, &decryptedLen, now); - BOOST_CHECK_EQUAL(query->valid, false); + BOOST_CHECK_EQUAL(query->isValid(), false); } // valid encrypted query BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValid) { - DnsCryptPrivateKey resolverPrivateKey; - DnsCryptCert resolverCert; + DNSCryptPrivateKey resolverPrivateKey; + DNSCryptCert resolverCert; unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]; unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]; - time_t now = time(NULL); - DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); - DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert); - DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey); + time_t now = time(nullptr); + DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); + DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert); + auto ctx = std::make_shared("2.name", resolverCert, resolverPrivateKey); - DnsCryptPrivateKey clientPrivateKey; + DNSCryptPrivateKey clientPrivateKey; unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE]; - DnsCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey); + DNSCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey); unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x08, 0x09, 0x0A, 0x0B }; @@ -153,27 +153,27 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValid) { vector plainQuery; DNSPacketWriter pw(plainQuery, name, QType::AAAA, QClass::IN, 0); pw.getHeader()->rd = 1; - size_t requiredSize = plainQuery.size() + sizeof(DnsCryptQueryHeader) + DNSCRYPT_MAC_SIZE; - if (requiredSize < DnsCryptQuery::minUDPLength) { - requiredSize = DnsCryptQuery::minUDPLength; + size_t requiredSize = plainQuery.size() + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE; + if (requiredSize < DNSCryptQuery::s_minUDPLength) { + requiredSize = DNSCryptQuery::s_minUDPLength; } plainQuery.reserve(requiredSize); uint16_t len = plainQuery.size(); uint16_t encryptedResponseLen = 0; - int res = ctx.encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen); + int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared(resolverCert)); BOOST_CHECK_EQUAL(res, 0); BOOST_CHECK(encryptedResponseLen > len); - std::shared_ptr query = std::make_shared(); + std::shared_ptr query = std::make_shared(ctx); uint16_t decryptedLen = 0; - ctx.parsePacket((char*) plainQuery.data(), encryptedResponseLen, query, false, &decryptedLen); + query->parsePacket((char*) plainQuery.data(), encryptedResponseLen, false, &decryptedLen, now); - BOOST_CHECK_EQUAL(query->valid, true); - BOOST_CHECK_EQUAL(query->encrypted, true); + BOOST_CHECK_EQUAL(query->isValid(), true); + BOOST_CHECK_EQUAL(query->isEncrypted(), true); MOADNSParser mdp(true, (char*) plainQuery.data(), decryptedLen); @@ -189,19 +189,19 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValid) { // valid encrypted query with not enough room BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidButShort) { - DnsCryptPrivateKey resolverPrivateKey; - DnsCryptCert resolverCert; + DNSCryptPrivateKey resolverPrivateKey; + DNSCryptCert resolverCert; unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]; unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]; - time_t now = time(NULL); - DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); - DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert); - DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey); + time_t now = time(nullptr); + DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); + DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert); + auto ctx = std::make_shared("2.name", resolverCert, resolverPrivateKey); - DnsCryptPrivateKey clientPrivateKey; + DNSCryptPrivateKey clientPrivateKey; unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE]; - DnsCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey); + DNSCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey); unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x08, 0x09, 0x0A, 0x0B }; @@ -213,26 +213,26 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidButShort) { uint16_t len = plainQuery.size(); uint16_t encryptedResponseLen = 0; - int res = ctx.encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen); + int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared(resolverCert)); BOOST_CHECK_EQUAL(res, ENOBUFS); } // valid encrypted query with old key BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidWithOldKey) { - DnsCryptPrivateKey resolverPrivateKey; - DnsCryptCert resolverCert; + DNSCryptPrivateKey resolverPrivateKey; + DNSCryptCert resolverCert; unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]; unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]; - time_t now = time(NULL); - DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); - DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert); - DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey); + time_t now = time(nullptr); + DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); + DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert); + auto ctx = std::make_shared("2.name", resolverCert, resolverPrivateKey); - DnsCryptPrivateKey clientPrivateKey; + DNSCryptPrivateKey clientPrivateKey; unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE]; - DnsCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey); + DNSCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey); unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x08, 0x09, 0x0A, 0x0B }; @@ -241,9 +241,9 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidWithOldKey) { DNSPacketWriter pw(plainQuery, name, QType::AAAA, QClass::IN, 0); pw.getHeader()->rd = 1; - size_t requiredSize = plainQuery.size() + sizeof(DnsCryptQueryHeader) + DNSCRYPT_MAC_SIZE; - if (requiredSize < DnsCryptQuery::minUDPLength) { - requiredSize = DnsCryptQuery::minUDPLength; + size_t requiredSize = plainQuery.size() + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE; + if (requiredSize < DNSCryptQuery::s_minUDPLength) { + requiredSize = DNSCryptQuery::s_minUDPLength; } plainQuery.reserve(requiredSize); @@ -251,21 +251,23 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidWithOldKey) { uint16_t len = plainQuery.size(); uint16_t encryptedResponseLen = 0; - int res = ctx.encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen); + int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared(resolverCert)); BOOST_CHECK_EQUAL(res, 0); BOOST_CHECK(encryptedResponseLen > len); - DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert); - ctx.setNewCertificate(resolverCert, resolverPrivateKey); + DNSCryptCert newResolverCert; + DNSCryptContext::generateCertificate(2, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, newResolverCert); + ctx->addNewCertificate(newResolverCert, resolverPrivateKey); + ctx->markInactive(resolverCert.getSerial()); - std::shared_ptr query = std::make_shared(); + std::shared_ptr query = std::make_shared(ctx); uint16_t decryptedLen = 0; - ctx.parsePacket((char*) plainQuery.data(), encryptedResponseLen, query, false, &decryptedLen); + query->parsePacket((char*) plainQuery.data(), encryptedResponseLen, false, &decryptedLen, now); - BOOST_CHECK_EQUAL(query->valid, true); - BOOST_CHECK_EQUAL(query->encrypted, true); + BOOST_CHECK_EQUAL(query->isValid(), true); + BOOST_CHECK_EQUAL(query->isEncrypted(), true); MOADNSParser mdp(true, (char*) plainQuery.data(), decryptedLen); @@ -281,19 +283,19 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidWithOldKey) { // valid encrypted query with wrong key BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryInvalidWithWrongKey) { - DnsCryptPrivateKey resolverPrivateKey; - DnsCryptCert resolverCert; + DNSCryptPrivateKey resolverPrivateKey; + DNSCryptCert resolverCert; unsigned char providerPublicKey[DNSCRYPT_PROVIDER_PUBLIC_KEY_SIZE]; unsigned char providerPrivateKey[DNSCRYPT_PROVIDER_PRIVATE_KEY_SIZE]; - time_t now = time(NULL); - DnsCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); - DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert); - DnsCryptContext ctx("2.name", resolverCert, resolverPrivateKey); + time_t now = time(nullptr); + DNSCryptContext::generateProviderKeys(providerPublicKey, providerPrivateKey); + DNSCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, resolverCert); + auto ctx = std::make_shared("2.name", resolverCert, resolverPrivateKey); - DnsCryptPrivateKey clientPrivateKey; + DNSCryptPrivateKey clientPrivateKey; unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE]; - DnsCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey); + DNSCryptContext::generateResolverKeyPair(clientPrivateKey, clientPublicKey); unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x08, 0x09, 0x0A, 0x0B }; @@ -302,9 +304,9 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryInvalidWithWrongKey) { DNSPacketWriter pw(plainQuery, name, QType::AAAA, QClass::IN, 0); pw.getHeader()->rd = 1; - size_t requiredSize = plainQuery.size() + sizeof(DnsCryptQueryHeader) + DNSCRYPT_MAC_SIZE; - if (requiredSize < DnsCryptQuery::minUDPLength) { - requiredSize = DnsCryptQuery::minUDPLength; + size_t requiredSize = plainQuery.size() + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE; + if (requiredSize < DNSCryptQuery::s_minUDPLength) { + requiredSize = DNSCryptQuery::s_minUDPLength; } plainQuery.reserve(requiredSize); @@ -312,25 +314,25 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryInvalidWithWrongKey) { uint16_t len = plainQuery.size(); uint16_t encryptedResponseLen = 0; - int res = ctx.encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen); + int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared(resolverCert)); BOOST_CHECK_EQUAL(res, 0); BOOST_CHECK(encryptedResponseLen > len); - DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert); - ctx.setNewCertificate(resolverCert, resolverPrivateKey); + DNSCryptCert newResolverCert; + DNSCryptContext::generateCertificate(2, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, newResolverCert); + ctx->addNewCertificate(newResolverCert, resolverPrivateKey); + ctx->markInactive(resolverCert.getSerial()); + ctx->removeInactiveCertificate(resolverCert.getSerial()); - DnsCryptContext::generateCertificate(1, now, now + (24 * 60 * 3600), providerPrivateKey, resolverPrivateKey, resolverCert); - ctx.setNewCertificate(resolverCert, resolverPrivateKey); + /* we have removed the old certificate, we can't decrypt this query */ - /* we have changed the key two times, we don't have the one used to encrypt this query */ - - std::shared_ptr query = std::make_shared(); + std::shared_ptr query = std::make_shared(ctx); uint16_t decryptedLen = 0; - ctx.parsePacket((char*) plainQuery.data(), encryptedResponseLen, query, false, &decryptedLen); + query->parsePacket((char*) plainQuery.data(), encryptedResponseLen, false, &decryptedLen, now); - BOOST_CHECK_EQUAL(query->valid, false); + BOOST_CHECK_EQUAL(query->isValid(), false); } #endif diff --git a/regression-tests.dnsdist/dnscrypt.py b/regression-tests.dnsdist/dnscrypt.py index 238d973e4..a93aeaea0 100644 --- a/regression-tests.dnsdist/dnscrypt.py +++ b/regression-tests.dnsdist/dnscrypt.py @@ -94,7 +94,6 @@ class DNSCryptClient(object): data = None if tcp: got = sock.recv(2) - print(len(got)) if got: (rlen,) = struct.unpack("!H", got) data = sock.recv(rlen) @@ -134,6 +133,8 @@ class DNSCryptClient(object): if an.rdclass != dns.rdataclass.IN or an.rdtype != dns.rdatatype.TXT or len(an.items) == 0: raise Exception("Invalid response to public key request") + self._resolverCertificates = [] + for item in an.items: if len(item.strings) != 1: continue @@ -152,6 +153,15 @@ class DNSCryptClient(object): return result + def getAllResolverCertificates(self, onlyValid=False): + certs = self._resolverCertificates + result = [] + for cert in certs: + if not onlyValid or cert.isValid(): + result.append(cert) + + return result + @staticmethod def _generateNonce(): nonce = libnacl.utils.rand_nonce() diff --git a/regression-tests.dnsdist/test_DNSCrypt.py b/regression-tests.dnsdist/test_DNSCrypt.py index f5f45d41c..376140d80 100644 --- a/regression-tests.dnsdist/test_DNSCrypt.py +++ b/regression-tests.dnsdist/test_DNSCrypt.py @@ -1,5 +1,6 @@ #!/usr/bin/env python import base64 +import socket import time import dns import dns.message @@ -136,16 +137,16 @@ class TestDNSCrypt(DNSCryptTest): # generate a new certificate self.sendConsoleCommand("generateDNSCryptCertificate('DNSCryptProviderPrivate.key', 'DNSCryptResolver.cert.2', 'DNSCryptResolver.key.2', {!s}, {:.0f}, {:.0f})".format(self._resolverCertificateSerial + 1, self._resolverCertificateValidFrom, self._resolverCertificateValidUntil)) - # switch to that new certificate + # add that new certificate self.sendConsoleCommand("getDNSCryptBind(0):loadNewCertificate('DNSCryptResolver.cert.2', 'DNSCryptResolver.key.2')") - oldSerial = self.sendConsoleCommand("getDNSCryptBind(0):getOldCertificate():getSerial()") + oldSerial = self.sendConsoleCommand("getDNSCryptBind(0):getCertificatePair(0):getCertificate():getSerial()") self.assertEquals(int(oldSerial), self._resolverCertificateSerial) - effectiveSerial = self.sendConsoleCommand("getDNSCryptBind(0):getCurrentCertificate():getSerial()") + effectiveSerial = self.sendConsoleCommand("getDNSCryptBind(0):getCertificatePair(1):getCertificate():getSerial()") self.assertEquals(int(effectiveSerial), self._resolverCertificateSerial + 1) - tsStart = self.sendConsoleCommand("getDNSCryptBind(0):getCurrentCertificate():getTSStart()") + tsStart = self.sendConsoleCommand("getDNSCryptBind(0):getCertificatePair(1):getCertificate():getTSStart()") self.assertEquals(int(tsStart), self._resolverCertificateValidFrom) - tsEnd = self.sendConsoleCommand("getDNSCryptBind(0):getCurrentCertificate():getTSEnd()") + tsEnd = self.sendConsoleCommand("getDNSCryptBind(0):getCertificatePair(1):getCertificate():getTSEnd()") self.assertEquals(int(tsEnd), self._resolverCertificateValidUntil) # we should still be able to send queries with the previous certificate @@ -160,6 +161,11 @@ class TestDNSCrypt(DNSCryptTest): cert = client.getResolverCertificate() self.assertTrue(cert) self.assertEquals(cert.serial, self._resolverCertificateSerial + 1) + # we should still get the old ones + certs = client.getAllResolverCertificates(True) + self.assertEquals(len(certs), 2) + self.assertEquals(certs[0].serial, self._resolverCertificateSerial) + self.assertEquals(certs[1].serial, self._resolverCertificateSerial + 1) # generate a third certificate, this time in memory self.sendConsoleCommand("getDNSCryptBind(0):generateAndLoadInMemoryCertificate('DNSCryptProviderPrivate.key', {!s}, {:.0f}, {:.0f})".format(self._resolverCertificateSerial + 2, self._resolverCertificateValidFrom, self._resolverCertificateValidUntil)) @@ -176,6 +182,52 @@ class TestDNSCrypt(DNSCryptTest): cert = client.getResolverCertificate() self.assertTrue(cert) self.assertEquals(cert.serial, self._resolverCertificateSerial + 2) + # we should still get the old ones + certs = client.getAllResolverCertificates(True) + self.assertEquals(len(certs), 3) + self.assertEquals(certs[0].serial, self._resolverCertificateSerial) + self.assertEquals(certs[1].serial, self._resolverCertificateSerial + 1) + self.assertEquals(certs[2].serial, self._resolverCertificateSerial + 2) + + # generate a fourth certificate, still in memory + self.sendConsoleCommand("getDNSCryptBind(0):generateAndLoadInMemoryCertificate('DNSCryptProviderPrivate.key', {!s}, {:.0f}, {:.0f})".format(self._resolverCertificateSerial + 3, self._resolverCertificateValidFrom, self._resolverCertificateValidUntil)) + + # mark the old ones as inactive + self.sendConsoleCommand("getDNSCryptBind(0):markInactive({!s})".format(self._resolverCertificateSerial)) + self.sendConsoleCommand("getDNSCryptBind(0):markInactive({!s})".format(self._resolverCertificateSerial + 1)) + self.sendConsoleCommand("getDNSCryptBind(0):markInactive({!s})".format(self._resolverCertificateSerial + 2)) + # we should still be able to send queries with the third one + self.doDNSCryptQuery(client, query, response, False) + self.doDNSCryptQuery(client, query, response, True) + cert = client.getResolverCertificate() + self.assertTrue(cert) + self.assertEquals(cert.serial, self._resolverCertificateSerial + 2) + # now remove them + self.sendConsoleCommand("getDNSCryptBind(0):removeInactiveCertificate({!s})".format(self._resolverCertificateSerial)) + self.sendConsoleCommand("getDNSCryptBind(0):removeInactiveCertificate({!s})".format(self._resolverCertificateSerial + 1)) + self.sendConsoleCommand("getDNSCryptBind(0):removeInactiveCertificate({!s})".format(self._resolverCertificateSerial + 2)) + + # we should not be able to send with the old ones anymore + try: + data = client.query(query.to_wire()) + except socket.timeout: + data = None + self.assertEquals(data, None) + + # refreshing should get us the fourth one + client.refreshResolverCertificates() + cert = client.getResolverCertificate() + self.assertTrue(cert) + self.assertEquals(cert.serial, self._resolverCertificateSerial + 3) + # and only that one + certs = client.getAllResolverCertificates(True) + self.assertEquals(len(certs), 1) + # and we should be able to query with it + self.doDNSCryptQuery(client, query, response, False) + self.doDNSCryptQuery(client, query, response, True) + cert = client.getResolverCertificate() + self.assertTrue(cert) + self.assertEquals(cert.serial, self._resolverCertificateSerial + 3) class TestDNSCryptWithCache(DNSCryptTest):