]> granicus.if.org Git - pdns/commitdiff
dnsdist: Move rules definitions to a separate header
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 11 Jun 2018 10:53:49 +0000 (12:53 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 22 Jun 2018 07:58:14 +0000 (09:58 +0200)
It makes it easier to test them.

pdns/dnsdist-lua-rules.cc
pdns/dnsdistdist/Makefile.am
pdns/dnsdistdist/dnsdist-rules.hh [new file with mode: 0644]

index 437b47b4ec717c1bd3331e02b8dc4938c15fe3bd..6f1509af7e14476b3aac5538f72259ae0c8fc760 100644 (file)
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  */
 #include "dnsdist.hh"
-#include "dnsdist-ecs.hh"
+#include "dnsdist-rules.hh"
 #include "dnsdist-lua.hh"
 
-#include "dnsparser.hh"
-
-class MaxQPSIPRule : public DNSRule
-{
-public:
-  MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64, unsigned int expiration=300, unsigned int cleanupDelay=60):
-    d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration)
-  {
-    pthread_rwlock_init(&d_lock, 0);
-    gettime(&d_lastCleanup, true);
-  }
-
-  void cleanupIfNeeded(const struct timespec& now) const
-  {
-    if (d_cleanupDelay > 0) {
-      struct timespec cutOff = d_lastCleanup;
-      cutOff.tv_sec += d_cleanupDelay;
-
-      if (cutOff < now) {
-        WriteLock w(&d_lock);
-
-        /* the QPS Limiter doesn't use realtime, be careful! */
-        gettime(&cutOff, false);
-        cutOff.tv_sec -= d_expiration;
-
-        for (auto entry = d_limits.begin(); entry != d_limits.end(); ) {
-          if (!entry->second.seenSince(cutOff)) {
-            entry = d_limits.erase(entry);
-          }
-          else {
-            ++entry;
-          }
-        }
-
-        d_lastCleanup = now;
-      }
-    }
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    cleanupIfNeeded(*dq->queryTime);
-
-    ComboAddress zeroport(*dq->remote);
-    zeroport.sin4.sin_port=0;
-    zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
-    {
-      ReadLock r(&d_lock);
-      const auto iter = d_limits.find(zeroport);
-      if (iter != d_limits.end()) {
-        return !iter->second.check(d_qps, d_burst);
-      }
-    }
-    {
-      WriteLock w(&d_lock);
-
-      auto iter = d_limits.find(zeroport);
-      if(iter == d_limits.end()) {
-        iter=d_limits.insert({zeroport,QPSLimiter(d_qps, d_burst)}).first;
-      }
-      return !iter->second.check(d_qps, d_burst);
-    }
-  }
-
-  string toString() const override
-  {
-    return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst);
-  }
-
-
-private:
-  mutable pthread_rwlock_t d_lock;
-  mutable std::map<ComboAddress, BasicQPSLimiter> d_limits;
-  mutable struct timespec d_lastCleanup;
-  unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc, d_cleanupDelay, d_expiration;
-};
-
-class MaxQPSRule : public DNSRule
-{
-public:
-  MaxQPSRule(unsigned int qps)
-   : d_qps(qps, qps)
-  {}
-
-  MaxQPSRule(unsigned int qps, unsigned int burst)
-   : d_qps(qps, burst)
-  {}
-
-
-  bool matches(const DNSQuestion* qd) const override
-  {
-    return d_qps.check();
-  }
-
-  string toString() const override
-  {
-    return "Max " + std::to_string(d_qps.getRate()) + " qps";
-  }
-
-
-private:
-  mutable QPSLimiter d_qps;
-};
-
-class NMGRule : public DNSRule
-{
-public:
-  NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {}
-protected:
-  NetmaskGroup d_nmg;
-};
-
-class NetmaskGroupRule : public NMGRule
-{
-public:
-  NetmaskGroupRule(const NetmaskGroup& nmg, bool src) : NMGRule(nmg)
-  {
-      d_src = src;
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    if(!d_src) {
-        return d_nmg.match(*dq->local);
-    }
-    return d_nmg.match(*dq->remote);
-  }
-
-  string toString() const override
-  {
-    if(!d_src) {
-        return "Dst: "+d_nmg.toString();
-    }
-    return "Src: "+d_nmg.toString();
-  }
-private:
-  bool d_src;
-};
-
-class TimedIPSetRule : public DNSRule, boost::noncopyable
-{
-private:
-  struct IPv6 {
-    IPv6(const ComboAddress& ca)
-    {
-      static_assert(sizeof(*this)==16, "IPv6 struct has wrong size");
-      memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16);
-    }
-    bool operator==(const IPv6& rhs) const
-    {
-      return a==rhs.a && b==rhs.b;
-    }
-    uint64_t a, b;
-  };
-
-public:
-  TimedIPSetRule()
-  {
-    pthread_rwlock_init(&d_lock4, 0);
-    pthread_rwlock_init(&d_lock6, 0);
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    if(dq->remote->sin4.sin_family == AF_INET) {
-      ReadLock rl(&d_lock4);
-      auto fnd = d_ip4s.find(dq->remote->sin4.sin_addr.s_addr);
-      if(fnd == d_ip4s.end()) {
-        return false;
-      }
-      return time(0) < fnd->second;
-    } else {
-      ReadLock rl(&d_lock6);
-      auto fnd = d_ip6s.find({*dq->remote});
-      if(fnd == d_ip6s.end()) {
-        return false;
-      }
-      return time(0) < fnd->second;
-    }
-  }
-
-  void add(const ComboAddress& ca, time_t ttd)
-  {
-    // think twice before adding templates here
-    if(ca.sin4.sin_family == AF_INET) {
-      WriteLock rl(&d_lock4);
-      auto res=d_ip4s.insert({ca.sin4.sin_addr.s_addr, ttd});
-      if(!res.second && (time_t)res.first->second < ttd)
-        res.first->second = (uint32_t)ttd;
-    }
-    else {
-      WriteLock rl(&d_lock6);
-      auto res=d_ip6s.insert({{ca}, ttd});
-      if(!res.second && (time_t)res.first->second < ttd)
-        res.first->second = (uint32_t)ttd;
-    }
-  }
-
-  void remove(const ComboAddress& ca)
-  {
-    if(ca.sin4.sin_family == AF_INET) {
-      WriteLock rl(&d_lock4);
-      d_ip4s.erase(ca.sin4.sin_addr.s_addr);
-    }
-    else {
-      WriteLock rl(&d_lock6);
-      d_ip6s.erase({ca});
-    }
-  }
-
-  void clear()
-  {
-    {
-      WriteLock rl(&d_lock4);
-      d_ip4s.clear();
-    }
-    WriteLock rl(&d_lock6);
-    d_ip6s.clear();
-  }
-
-  void cleanup()
-  {
-    time_t now=time(0);
-    {
-      WriteLock rl(&d_lock4);
-
-      for(auto iter = d_ip4s.begin(); iter != d_ip4s.end(); ) {
-       if(iter->second < now)
-         iter=d_ip4s.erase(iter);
-       else
-         ++iter;
-      }
-
-    }
-
-    {
-      WriteLock rl(&d_lock6);
-
-      for(auto iter = d_ip6s.begin(); iter != d_ip6s.end(); ) {
-       if(iter->second < now)
-         iter=d_ip6s.erase(iter);
-       else
-         ++iter;
-      }
-
-    }
-
-  }
-
-  string toString() const override
-  {
-    time_t now=time(0);
-    uint64_t count = 0;
-    {
-      ReadLock rl(&d_lock4);
-      for(const auto& ip : d_ip4s)
-        if(now < ip.second)
-          ++count;
-    }
-    {
-      ReadLock rl(&d_lock6);
-      for(const auto& ip : d_ip6s)
-        if(now < ip.second)
-          ++count;
-    }
-
-    return "Src: "+std::to_string(count)+" ips";
-  }
-private:
-  struct IPv6Hash
-  {
-    std::size_t operator()(const IPv6& ip) const
-    {
-      auto ah=std::hash<uint64_t>{}(ip.a);
-      auto bh=std::hash<uint64_t>{}(ip.b);
-      return ah & (bh<<1);
-    }
-  };
-  std::unordered_map<IPv6, time_t, IPv6Hash> d_ip6s;
-  std::unordered_map<uint32_t, time_t> d_ip4s;
-  mutable pthread_rwlock_t d_lock4;
-  mutable pthread_rwlock_t d_lock6;
-};
-
-
-class AllRule : public DNSRule
-{
-public:
-  AllRule() {}
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return true;
-  }
-
-  string toString() const override
-  {
-    return "All";
-  }
-
-};
-
-
-class DNSSECRule : public DNSRule
-{
-public:
-  DNSSECRule()
-  {
-
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return dq->dh->cd || (getEDNSZ((const char*)dq->dh, dq->len) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
-  }
-
-  string toString() const override
-  {
-    return "DNSSEC";
-  }
-};
-
-class AndRule : public DNSRule
-{
-public:
-  AndRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
-  {
-    for(const auto& r : rules)
-      d_rules.push_back(r.second);
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    auto iter = d_rules.begin();
-    for(; iter != d_rules.end(); ++iter)
-      if(!(*iter)->matches(dq))
-        break;
-    return iter == d_rules.end();
-  }
-
-  string toString() const override
-  {
-    string ret;
-    for(const auto& rule : d_rules) {
-      if(!ret.empty())
-        ret+= " && ";
-      ret += "("+ rule->toString()+")";
-    }
-    return ret;
-  }
-private:
-
-  vector<std::shared_ptr<DNSRule> > d_rules;
-
-};
-
-
-class OrRule : public DNSRule
-{
-public:
-  OrRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
-  {
-    for(const auto& r : rules)
-      d_rules.push_back(r.second);
-  }
-
-  bool matches(const DNSQuestion* dq) const override
-  {
-    auto iter = d_rules.begin();
-    for(; iter != d_rules.end(); ++iter)
-      if((*iter)->matches(dq))
-        return true;
-    return false;
-  }
-
-  string toString() const override
-  {
-    string ret;
-    for(const auto& rule : d_rules) {
-      if(!ret.empty())
-        ret+= " || ";
-      ret += "("+ rule->toString()+")";
-    }
-    return ret;
-  }
-private:
-
-  vector<std::shared_ptr<DNSRule> > d_rules;
-
-};
-
-
-class RegexRule : public DNSRule
-{
-public:
-  RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex)
-  {
-
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_regex.match(dq->qname->toStringNoDot());
-  }
-
-  string toString() const override
-  {
-    return "Regex: "+d_visual;
-  }
-private:
-  Regex d_regex;
-  string d_visual;
-};
-
-#ifdef HAVE_RE2
-#include <re2/re2.h>
-class RE2Rule : public DNSRule
-{
-public:
-  RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2)
-  {
-
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return RE2::FullMatch(dq->qname->toStringNoDot(), d_re2);
-  }
-
-  string toString() const override
-  {
-    return "RE2 match: "+d_visual;
-  }
-private:
-  RE2 d_re2;
-  string d_visual;
-};
-#endif
-
-
-class SuffixMatchNodeRule : public DNSRule
-{
-public:
-  SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_smn.check(*dq->qname);
-  }
-  string toString() const override
-  {
-    if(d_quiet)
-      return "qname==in-set";
-    else
-      return "qname in "+d_smn.toString();
-  }
-private:
-  SuffixMatchNode d_smn;
-  bool d_quiet;
-};
-
-class QNameRule : public DNSRule
-{
-public:
-  QNameRule(const DNSName& qname) : d_qname(qname)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_qname==*dq->qname;
-  }
-  string toString() const override
-  {
-    return "qname=="+d_qname.toString();
-  }
-private:
-  DNSName d_qname;
-};
-
-
-class QTypeRule : public DNSRule
-{
-public:
-  QTypeRule(uint16_t qtype) : d_qtype(qtype)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_qtype == dq->qtype;
-  }
-  string toString() const override
-  {
-    QType qt(d_qtype);
-    return "qtype=="+qt.getName();
-  }
-private:
-  uint16_t d_qtype;
-};
-
-class QClassRule : public DNSRule
-{
-public:
-  QClassRule(uint16_t qclass) : d_qclass(qclass)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_qclass == dq->qclass;
-  }
-  string toString() const override
-  {
-    return "qclass=="+std::to_string(d_qclass);
-  }
-private:
-  uint16_t d_qclass;
-};
-
-class OpcodeRule : public DNSRule
-{
-public:
-  OpcodeRule(uint8_t opcode) : d_opcode(opcode)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_opcode == dq->dh->opcode;
-  }
-  string toString() const override
-  {
-    return "opcode=="+std::to_string(d_opcode);
-  }
-private:
-  uint8_t d_opcode;
-};
-
-class TCPRule : public DNSRule
-{
-public:
-  TCPRule(bool tcp): d_tcp(tcp)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return dq->tcp == d_tcp;
-  }
-  string toString() const override
-  {
-    return (d_tcp ? "TCP" : "UDP");
-  }
-private:
-  bool d_tcp;
-};
-
-
-class NotRule : public DNSRule
-{
-public:
-  NotRule(shared_ptr<DNSRule>& rule): d_rule(rule)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return !d_rule->matches(dq);
-  }
-  string toString() const override
-  {
-    return "!("+ d_rule->toString()+")";
-  }
-private:
-  shared_ptr<DNSRule> d_rule;
-};
-
-class RecordsCountRule : public DNSRule
-{
-public:
-  RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    uint16_t count = 0;
-    switch(d_section) {
-    case 0:
-      count = ntohs(dq->dh->qdcount);
-      break;
-    case 1:
-      count = ntohs(dq->dh->ancount);
-      break;
-    case 2:
-      count = ntohs(dq->dh->nscount);
-      break;
-    case 3:
-      count = ntohs(dq->dh->arcount);
-      break;
-    }
-    return count >= d_minCount && count <= d_maxCount;
-  }
-  string toString() const override
-  {
-    string section;
-    switch(d_section) {
-    case 0:
-      section = "QD";
-      break;
-    case 1:
-      section = "AN";
-      break;
-    case 2:
-      section = "NS";
-      break;
-    case 3:
-      section = "AR";
-      break;
-    }
-    return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount);
-  }
-private:
-  uint16_t d_minCount;
-  uint16_t d_maxCount;
-  uint8_t d_section;
-};
-
-class RecordsTypeCountRule : public DNSRule
-{
-public:
-  RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    uint16_t count = 0;
-    switch(d_section) {
-    case 0:
-      count = ntohs(dq->dh->qdcount);
-      break;
-    case 1:
-      count = ntohs(dq->dh->ancount);
-      break;
-    case 2:
-      count = ntohs(dq->dh->nscount);
-      break;
-    case 3:
-      count = ntohs(dq->dh->arcount);
-      break;
-    }
-    if (count < d_minCount) {
-      return false;
-    }
-    count = getRecordsOfTypeCount(reinterpret_cast<const char*>(dq->dh), dq->len, d_section, d_type);
-    return count >= d_minCount && count <= d_maxCount;
-  }
-  string toString() const override
-  {
-    string section;
-    switch(d_section) {
-    case 0:
-      section = "QD";
-      break;
-    case 1:
-      section = "AN";
-      break;
-    case 2:
-      section = "NS";
-      break;
-    case 3:
-      section = "AR";
-      break;
-    }
-    return std::to_string(d_minCount) + " <= " + QType(d_type).getName() + " records in " + section + " <= "+ std::to_string(d_maxCount);
-  }
-private:
-  uint16_t d_type;
-  uint16_t d_minCount;
-  uint16_t d_maxCount;
-  uint8_t d_section;
-};
-
-class TrailingDataRule : public DNSRule
-{
-public:
-  TrailingDataRule()
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(dq->dh), dq->len);
-    return length < dq->len;
-  }
-  string toString() const override
-  {
-    return "trailing data";
-  }
-};
-
-class QNameLabelsCountRule : public DNSRule
-{
-public:
-  QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    unsigned int count = dq->qname->countLabels();
-    return count < d_min || count > d_max;
-  }
-  string toString() const override
-  {
-    return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
-  }
-private:
-  unsigned int d_min;
-  unsigned int d_max;
-};
-
-class QNameWireLengthRule : public DNSRule
-{
-public:
-  QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    size_t const wirelength = dq->qname->wirelength();
-    return wirelength < d_min || wirelength > d_max;
-  }
-  string toString() const override
-  {
-    return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
-  }
-private:
-  size_t d_min;
-  size_t d_max;
-};
-
-class RCodeRule : public DNSRule
-{
-public:
-  RCodeRule(uint8_t rcode) : d_rcode(rcode)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return d_rcode == dq->dh->rcode;
-  }
-  string toString() const override
-  {
-    return "rcode=="+RCode::to_s(d_rcode);
-  }
-private:
-  uint8_t d_rcode;
-};
-
-class ERCodeRule : public DNSRule
-{
-public:
-  ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    // avoid parsing EDNS OPT RR when not needed.
-    if (d_rcode != dq->dh->rcode) {
-      return false;
-    }
-
-    char * optStart = NULL;
-    size_t optLen = 0;
-    bool last = false;
-    int res = locateEDNSOptRR(const_cast<char*>(reinterpret_cast<const char*>(dq->dh)), dq->len, &optStart, &optLen, &last);
-    if (res != 0) {
-      // no EDNS OPT RR
-      return d_extrcode == 0;
-    }
-
-    // root label (1), type (2), class (2), ttl (4) + rdlen (2)
-    if (optLen < 11) {
-      return false;
-    }
-
-    if (*optStart != 0) {
-      // OPT RR Name != '.'
-      return false;
-    }
-    EDNS0Record edns0;
-    static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
-    // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
-    memcpy(&edns0, optStart + 5, sizeof edns0);
-
-    return d_extrcode == edns0.extRCode;
-  }
-  string toString() const override
-  {
-    return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4));
-  }
-private:
-  uint8_t d_rcode;     // plain DNS Rcode
-  uint8_t d_extrcode;  // upper bits in EDNS0 record
-};
-
-class RDRule : public DNSRule
-{
-public:
-  RDRule()
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    return dq->dh->rd == 1;
-  }
-  string toString() const override
-  {
-    return "rd==1";
-  }
-};
-
-class ProbaRule : public DNSRule
-{
-public:
-  ProbaRule(double proba) : d_proba(proba)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    if(d_proba == 1.0)
-      return true;
-    double rnd = 1.0*random() / RAND_MAX;
-    return rnd > (1.0 - d_proba);
-  }
-  string toString() const override
-  {
-    return "match with prob. " + (boost::format("%0.2f") % d_proba).str();
-  }
-private:
-  double d_proba;
-};
-
-class TagRule : public DNSRule
-{
-public:
-  TagRule(const std::string& tag, boost::optional<std::string> value) : d_value(value), d_tag(tag)
-  {
-  }
-  bool matches(const DNSQuestion* dq) const override
-  {
-    if (!dq->qTag) {
-      return false;
-    }
-
-    const auto it = dq->qTag->find(d_tag);
-    if (it == dq->qTag->cend()) {
-      return false;
-    }
-
-    if (!d_value) {
-      return true;
-    }
-
-    return it->second == *d_value;
-  }
-
-  string toString() const override
-  {
-    return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : "");
-  }
-
-private:
-  boost::optional<std::string> d_value;
-  std::string d_tag;
-};
-
 std::shared_ptr<DNSRule> makeRule(const luadnsrule_t& var)
 {
   if (var.type() == typeid(std::shared_ptr<DNSRule>))
index 3081f33cf6ffe1a2b24de81a302702be295dbe9c..982901414e0ffad2025cf8d04bb5d38f90f4bf1a 100644 (file)
@@ -105,6 +105,7 @@ dnsdist_SOURCES = \
        dnsdist-lua-vars.cc \
        dnsdist-protobuf.cc dnsdist-protobuf.hh \
        dnsdist-rings.cc dnsdist-rings.hh \
+       dnsdist-rules.hh \
        dnsdist-snmp.cc dnsdist-snmp.hh \
        dnsdist-tcp.cc \
        dnsdist-web.cc \
diff --git a/pdns/dnsdistdist/dnsdist-rules.hh b/pdns/dnsdistdist/dnsdist-rules.hh
new file mode 100644 (file)
index 0000000..ee91eda
--- /dev/null
@@ -0,0 +1,890 @@
+/*
+ * This file is part of PowerDNS or dnsdist.
+ * Copyright -- PowerDNS.COM B.V. and its contributors
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of version 2 of the GNU General Public License as
+ * published by the Free Software Foundation.
+ *
+ * In addition, for the avoidance of any doubt, permission is granted to
+ * link this program with OpenSSL and to (re)distribute the binaries
+ * produced as the result of such linking.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+#pragma once
+
+#include "cachecleaner.hh"
+#include "dnsdist.hh"
+#include "dnsdist-ecs.hh"
+#include "dnsparser.hh"
+
+class MaxQPSIPRule : public DNSRule
+{
+public:
+  MaxQPSIPRule(unsigned int qps, unsigned int burst, unsigned int ipv4trunc=32, unsigned int ipv6trunc=64, unsigned int expiration=300, unsigned int cleanupDelay=60):
+    d_qps(qps), d_burst(burst), d_ipv4trunc(ipv4trunc), d_ipv6trunc(ipv6trunc), d_cleanupDelay(cleanupDelay), d_expiration(expiration)
+  {
+    pthread_rwlock_init(&d_lock, 0);
+    gettime(&d_lastCleanup, true);
+  }
+
+  void cleanupIfNeeded(const struct timespec& now) const
+  {
+    if (d_cleanupDelay > 0) {
+      struct timespec cutOff = d_lastCleanup;
+      cutOff.tv_sec += d_cleanupDelay;
+
+      if (cutOff < now) {
+        WriteLock w(&d_lock);
+
+        /* the QPS Limiter doesn't use realtime, be careful! */
+        gettime(&cutOff, false);
+        cutOff.tv_sec -= d_expiration;
+
+        for (auto entry = d_limits.begin(); entry != d_limits.end(); ) {
+          if (!entry->second.seenSince(cutOff)) {
+            entry = d_limits.erase(entry);
+          }
+          else {
+            ++entry;
+          }
+        }
+
+        d_lastCleanup = now;
+      }
+    }
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    cleanupIfNeeded(*dq->queryTime);
+
+    ComboAddress zeroport(*dq->remote);
+    zeroport.sin4.sin_port=0;
+    zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc);
+    {
+      ReadLock r(&d_lock);
+      const auto iter = d_limits.find(zeroport);
+      if (iter != d_limits.end()) {
+        return !iter->second.check(d_qps, d_burst);
+      }
+    }
+    {
+      WriteLock w(&d_lock);
+
+      auto iter = d_limits.find(zeroport);
+      if(iter == d_limits.end()) {
+        iter=d_limits.insert({zeroport,QPSLimiter(d_qps, d_burst)}).first;
+      }
+      return !iter->second.check(d_qps, d_burst);
+    }
+  }
+
+  string toString() const override
+  {
+    return "IP (/"+std::to_string(d_ipv4trunc)+", /"+std::to_string(d_ipv6trunc)+") match for QPS over " + std::to_string(d_qps) + " burst "+ std::to_string(d_burst);
+  }
+
+
+private:
+  mutable pthread_rwlock_t d_lock;
+  mutable std::map<ComboAddress, BasicQPSLimiter> d_limits;
+  mutable struct timespec d_lastCleanup;
+  unsigned int d_qps, d_burst, d_ipv4trunc, d_ipv6trunc, d_cleanupDelay, d_expiration;
+};
+
+class MaxQPSRule : public DNSRule
+{
+public:
+  MaxQPSRule(unsigned int qps)
+   : d_qps(qps, qps)
+  {}
+
+  MaxQPSRule(unsigned int qps, unsigned int burst)
+   : d_qps(qps, burst)
+  {}
+
+
+  bool matches(const DNSQuestion* qd) const override
+  {
+    return d_qps.check();
+  }
+
+  string toString() const override
+  {
+    return "Max " + std::to_string(d_qps.getRate()) + " qps";
+  }
+
+
+private:
+  mutable QPSLimiter d_qps;
+};
+
+class NMGRule : public DNSRule
+{
+public:
+  NMGRule(const NetmaskGroup& nmg) : d_nmg(nmg) {}
+protected:
+  NetmaskGroup d_nmg;
+};
+
+class NetmaskGroupRule : public NMGRule
+{
+public:
+  NetmaskGroupRule(const NetmaskGroup& nmg, bool src) : NMGRule(nmg)
+  {
+      d_src = src;
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if(!d_src) {
+        return d_nmg.match(*dq->local);
+    }
+    return d_nmg.match(*dq->remote);
+  }
+
+  string toString() const override
+  {
+    if(!d_src) {
+        return "Dst: "+d_nmg.toString();
+    }
+    return "Src: "+d_nmg.toString();
+  }
+private:
+  bool d_src;
+};
+
+class TimedIPSetRule : public DNSRule, boost::noncopyable
+{
+private:
+  struct IPv6 {
+    IPv6(const ComboAddress& ca)
+    {
+      static_assert(sizeof(*this)==16, "IPv6 struct has wrong size");
+      memcpy((char*)this, ca.sin6.sin6_addr.s6_addr, 16);
+    }
+    bool operator==(const IPv6& rhs) const
+    {
+      return a==rhs.a && b==rhs.b;
+    }
+    uint64_t a, b;
+  };
+
+public:
+  TimedIPSetRule()
+  {
+    pthread_rwlock_init(&d_lock4, 0);
+    pthread_rwlock_init(&d_lock6, 0);
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if(dq->remote->sin4.sin_family == AF_INET) {
+      ReadLock rl(&d_lock4);
+      auto fnd = d_ip4s.find(dq->remote->sin4.sin_addr.s_addr);
+      if(fnd == d_ip4s.end()) {
+        return false;
+      }
+      return time(0) < fnd->second;
+    } else {
+      ReadLock rl(&d_lock6);
+      auto fnd = d_ip6s.find({*dq->remote});
+      if(fnd == d_ip6s.end()) {
+        return false;
+      }
+      return time(0) < fnd->second;
+    }
+  }
+
+  void add(const ComboAddress& ca, time_t ttd)
+  {
+    // think twice before adding templates here
+    if(ca.sin4.sin_family == AF_INET) {
+      WriteLock rl(&d_lock4);
+      auto res=d_ip4s.insert({ca.sin4.sin_addr.s_addr, ttd});
+      if(!res.second && (time_t)res.first->second < ttd)
+        res.first->second = (uint32_t)ttd;
+    }
+    else {
+      WriteLock rl(&d_lock6);
+      auto res=d_ip6s.insert({{ca}, ttd});
+      if(!res.second && (time_t)res.first->second < ttd)
+        res.first->second = (uint32_t)ttd;
+    }
+  }
+
+  void remove(const ComboAddress& ca)
+  {
+    if(ca.sin4.sin_family == AF_INET) {
+      WriteLock rl(&d_lock4);
+      d_ip4s.erase(ca.sin4.sin_addr.s_addr);
+    }
+    else {
+      WriteLock rl(&d_lock6);
+      d_ip6s.erase({ca});
+    }
+  }
+
+  void clear()
+  {
+    {
+      WriteLock rl(&d_lock4);
+      d_ip4s.clear();
+    }
+    WriteLock rl(&d_lock6);
+    d_ip6s.clear();
+  }
+
+  void cleanup()
+  {
+    time_t now=time(0);
+    {
+      WriteLock rl(&d_lock4);
+
+      for(auto iter = d_ip4s.begin(); iter != d_ip4s.end(); ) {
+       if(iter->second < now)
+         iter=d_ip4s.erase(iter);
+       else
+         ++iter;
+      }
+
+    }
+
+    {
+      WriteLock rl(&d_lock6);
+
+      for(auto iter = d_ip6s.begin(); iter != d_ip6s.end(); ) {
+       if(iter->second < now)
+         iter=d_ip6s.erase(iter);
+       else
+         ++iter;
+      }
+
+    }
+
+  }
+
+  string toString() const override
+  {
+    time_t now=time(0);
+    uint64_t count = 0;
+    {
+      ReadLock rl(&d_lock4);
+      for(const auto& ip : d_ip4s)
+        if(now < ip.second)
+          ++count;
+    }
+    {
+      ReadLock rl(&d_lock6);
+      for(const auto& ip : d_ip6s)
+        if(now < ip.second)
+          ++count;
+    }
+
+    return "Src: "+std::to_string(count)+" ips";
+  }
+private:
+  struct IPv6Hash
+  {
+    std::size_t operator()(const IPv6& ip) const
+    {
+      auto ah=std::hash<uint64_t>{}(ip.a);
+      auto bh=std::hash<uint64_t>{}(ip.b);
+      return ah & (bh<<1);
+    }
+  };
+  std::unordered_map<IPv6, time_t, IPv6Hash> d_ip6s;
+  std::unordered_map<uint32_t, time_t> d_ip4s;
+  mutable pthread_rwlock_t d_lock4;
+  mutable pthread_rwlock_t d_lock6;
+};
+
+
+class AllRule : public DNSRule
+{
+public:
+  AllRule() {}
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return true;
+  }
+
+  string toString() const override
+  {
+    return "All";
+  }
+
+};
+
+
+class DNSSECRule : public DNSRule
+{
+public:
+  DNSSECRule()
+  {
+
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return dq->dh->cd || (getEDNSZ((const char*)dq->dh, dq->len) & EDNS_HEADER_FLAG_DO);    // turns out dig sets ad by default..
+  }
+
+  string toString() const override
+  {
+    return "DNSSEC";
+  }
+};
+
+class AndRule : public DNSRule
+{
+public:
+  AndRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
+  {
+    for(const auto& r : rules)
+      d_rules.push_back(r.second);
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    auto iter = d_rules.begin();
+    for(; iter != d_rules.end(); ++iter)
+      if(!(*iter)->matches(dq))
+        break;
+    return iter == d_rules.end();
+  }
+
+  string toString() const override
+  {
+    string ret;
+    for(const auto& rule : d_rules) {
+      if(!ret.empty())
+        ret+= " && ";
+      ret += "("+ rule->toString()+")";
+    }
+    return ret;
+  }
+private:
+
+  vector<std::shared_ptr<DNSRule> > d_rules;
+
+};
+
+
+class OrRule : public DNSRule
+{
+public:
+  OrRule(const vector<pair<int, shared_ptr<DNSRule> > >& rules)
+  {
+    for(const auto& r : rules)
+      d_rules.push_back(r.second);
+  }
+
+  bool matches(const DNSQuestion* dq) const override
+  {
+    auto iter = d_rules.begin();
+    for(; iter != d_rules.end(); ++iter)
+      if((*iter)->matches(dq))
+        return true;
+    return false;
+  }
+
+  string toString() const override
+  {
+    string ret;
+    for(const auto& rule : d_rules) {
+      if(!ret.empty())
+        ret+= " || ";
+      ret += "("+ rule->toString()+")";
+    }
+    return ret;
+  }
+private:
+
+  vector<std::shared_ptr<DNSRule> > d_rules;
+
+};
+
+
+class RegexRule : public DNSRule
+{
+public:
+  RegexRule(const std::string& regex) : d_regex(regex), d_visual(regex)
+  {
+
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_regex.match(dq->qname->toStringNoDot());
+  }
+
+  string toString() const override
+  {
+    return "Regex: "+d_visual;
+  }
+private:
+  Regex d_regex;
+  string d_visual;
+};
+
+#ifdef HAVE_RE2
+#include <re2/re2.h>
+class RE2Rule : public DNSRule
+{
+public:
+  RE2Rule(const std::string& re2) : d_re2(re2, RE2::Latin1), d_visual(re2)
+  {
+
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return RE2::FullMatch(dq->qname->toStringNoDot(), d_re2);
+  }
+
+  string toString() const override
+  {
+    return "RE2 match: "+d_visual;
+  }
+private:
+  RE2 d_re2;
+  string d_visual;
+};
+#endif
+
+
+class SuffixMatchNodeRule : public DNSRule
+{
+public:
+  SuffixMatchNodeRule(const SuffixMatchNode& smn, bool quiet=false) : d_smn(smn), d_quiet(quiet)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_smn.check(*dq->qname);
+  }
+  string toString() const override
+  {
+    if(d_quiet)
+      return "qname==in-set";
+    else
+      return "qname in "+d_smn.toString();
+  }
+private:
+  SuffixMatchNode d_smn;
+  bool d_quiet;
+};
+
+class QNameRule : public DNSRule
+{
+public:
+  QNameRule(const DNSName& qname) : d_qname(qname)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_qname==*dq->qname;
+  }
+  string toString() const override
+  {
+    return "qname=="+d_qname.toString();
+  }
+private:
+  DNSName d_qname;
+};
+
+
+class QTypeRule : public DNSRule
+{
+public:
+  QTypeRule(uint16_t qtype) : d_qtype(qtype)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_qtype == dq->qtype;
+  }
+  string toString() const override
+  {
+    QType qt(d_qtype);
+    return "qtype=="+qt.getName();
+  }
+private:
+  uint16_t d_qtype;
+};
+
+class QClassRule : public DNSRule
+{
+public:
+  QClassRule(uint16_t qclass) : d_qclass(qclass)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_qclass == dq->qclass;
+  }
+  string toString() const override
+  {
+    return "qclass=="+std::to_string(d_qclass);
+  }
+private:
+  uint16_t d_qclass;
+};
+
+class OpcodeRule : public DNSRule
+{
+public:
+  OpcodeRule(uint8_t opcode) : d_opcode(opcode)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_opcode == dq->dh->opcode;
+  }
+  string toString() const override
+  {
+    return "opcode=="+std::to_string(d_opcode);
+  }
+private:
+  uint8_t d_opcode;
+};
+
+class TCPRule : public DNSRule
+{
+public:
+  TCPRule(bool tcp): d_tcp(tcp)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return dq->tcp == d_tcp;
+  }
+  string toString() const override
+  {
+    return (d_tcp ? "TCP" : "UDP");
+  }
+private:
+  bool d_tcp;
+};
+
+
+class NotRule : public DNSRule
+{
+public:
+  NotRule(shared_ptr<DNSRule>& rule): d_rule(rule)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return !d_rule->matches(dq);
+  }
+  string toString() const override
+  {
+    return "!("+ d_rule->toString()+")";
+  }
+private:
+  shared_ptr<DNSRule> d_rule;
+};
+
+class RecordsCountRule : public DNSRule
+{
+public:
+  RecordsCountRule(uint8_t section, uint16_t minCount, uint16_t maxCount): d_minCount(minCount), d_maxCount(maxCount), d_section(section)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    uint16_t count = 0;
+    switch(d_section) {
+    case 0:
+      count = ntohs(dq->dh->qdcount);
+      break;
+    case 1:
+      count = ntohs(dq->dh->ancount);
+      break;
+    case 2:
+      count = ntohs(dq->dh->nscount);
+      break;
+    case 3:
+      count = ntohs(dq->dh->arcount);
+      break;
+    }
+    return count >= d_minCount && count <= d_maxCount;
+  }
+  string toString() const override
+  {
+    string section;
+    switch(d_section) {
+    case 0:
+      section = "QD";
+      break;
+    case 1:
+      section = "AN";
+      break;
+    case 2:
+      section = "NS";
+      break;
+    case 3:
+      section = "AR";
+      break;
+    }
+    return std::to_string(d_minCount) + " <= records in " + section + " <= "+ std::to_string(d_maxCount);
+  }
+private:
+  uint16_t d_minCount;
+  uint16_t d_maxCount;
+  uint8_t d_section;
+};
+
+class RecordsTypeCountRule : public DNSRule
+{
+public:
+  RecordsTypeCountRule(uint8_t section, uint16_t type, uint16_t minCount, uint16_t maxCount): d_type(type), d_minCount(minCount), d_maxCount(maxCount), d_section(section)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    uint16_t count = 0;
+    switch(d_section) {
+    case 0:
+      count = ntohs(dq->dh->qdcount);
+      break;
+    case 1:
+      count = ntohs(dq->dh->ancount);
+      break;
+    case 2:
+      count = ntohs(dq->dh->nscount);
+      break;
+    case 3:
+      count = ntohs(dq->dh->arcount);
+      break;
+    }
+    if (count < d_minCount) {
+      return false;
+    }
+    count = getRecordsOfTypeCount(reinterpret_cast<const char*>(dq->dh), dq->len, d_section, d_type);
+    return count >= d_minCount && count <= d_maxCount;
+  }
+  string toString() const override
+  {
+    string section;
+    switch(d_section) {
+    case 0:
+      section = "QD";
+      break;
+    case 1:
+      section = "AN";
+      break;
+    case 2:
+      section = "NS";
+      break;
+    case 3:
+      section = "AR";
+      break;
+    }
+    return std::to_string(d_minCount) + " <= " + QType(d_type).getName() + " records in " + section + " <= "+ std::to_string(d_maxCount);
+  }
+private:
+  uint16_t d_type;
+  uint16_t d_minCount;
+  uint16_t d_maxCount;
+  uint8_t d_section;
+};
+
+class TrailingDataRule : public DNSRule
+{
+public:
+  TrailingDataRule()
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    uint16_t length = getDNSPacketLength(reinterpret_cast<const char*>(dq->dh), dq->len);
+    return length < dq->len;
+  }
+  string toString() const override
+  {
+    return "trailing data";
+  }
+};
+
+class QNameLabelsCountRule : public DNSRule
+{
+public:
+  QNameLabelsCountRule(unsigned int minLabelsCount, unsigned int maxLabelsCount): d_min(minLabelsCount), d_max(maxLabelsCount)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    unsigned int count = dq->qname->countLabels();
+    return count < d_min || count > d_max;
+  }
+  string toString() const override
+  {
+    return "labels count < " + std::to_string(d_min) + " || labels count > " + std::to_string(d_max);
+  }
+private:
+  unsigned int d_min;
+  unsigned int d_max;
+};
+
+class QNameWireLengthRule : public DNSRule
+{
+public:
+  QNameWireLengthRule(size_t min, size_t max): d_min(min), d_max(max)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    size_t const wirelength = dq->qname->wirelength();
+    return wirelength < d_min || wirelength > d_max;
+  }
+  string toString() const override
+  {
+    return "wire length < " + std::to_string(d_min) + " || wire length > " + std::to_string(d_max);
+  }
+private:
+  size_t d_min;
+  size_t d_max;
+};
+
+class RCodeRule : public DNSRule
+{
+public:
+  RCodeRule(uint8_t rcode) : d_rcode(rcode)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return d_rcode == dq->dh->rcode;
+  }
+  string toString() const override
+  {
+    return "rcode=="+RCode::to_s(d_rcode);
+  }
+private:
+  uint8_t d_rcode;
+};
+
+class ERCodeRule : public DNSRule
+{
+public:
+  ERCodeRule(uint8_t rcode) : d_rcode(rcode & 0xF), d_extrcode(rcode >> 4)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    // avoid parsing EDNS OPT RR when not needed.
+    if (d_rcode != dq->dh->rcode) {
+      return false;
+    }
+
+    char * optStart = NULL;
+    size_t optLen = 0;
+    bool last = false;
+    int res = locateEDNSOptRR(const_cast<char*>(reinterpret_cast<const char*>(dq->dh)), dq->len, &optStart, &optLen, &last);
+    if (res != 0) {
+      // no EDNS OPT RR
+      return d_extrcode == 0;
+    }
+
+    // root label (1), type (2), class (2), ttl (4) + rdlen (2)
+    if (optLen < 11) {
+      return false;
+    }
+
+    if (*optStart != 0) {
+      // OPT RR Name != '.'
+      return false;
+    }
+    EDNS0Record edns0;
+    static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size");
+    // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2).
+    memcpy(&edns0, optStart + 5, sizeof edns0);
+
+    return d_extrcode == edns0.extRCode;
+  }
+  string toString() const override
+  {
+    return "ercode=="+ERCode::to_s(d_rcode | (d_extrcode << 4));
+  }
+private:
+  uint8_t d_rcode;     // plain DNS Rcode
+  uint8_t d_extrcode;  // upper bits in EDNS0 record
+};
+
+class RDRule : public DNSRule
+{
+public:
+  RDRule()
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    return dq->dh->rd == 1;
+  }
+  string toString() const override
+  {
+    return "rd==1";
+  }
+};
+
+class ProbaRule : public DNSRule
+{
+public:
+  ProbaRule(double proba) : d_proba(proba)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if(d_proba == 1.0)
+      return true;
+    double rnd = 1.0*random() / RAND_MAX;
+    return rnd > (1.0 - d_proba);
+  }
+  string toString() const override
+  {
+    return "match with prob. " + (boost::format("%0.2f") % d_proba).str();
+  }
+private:
+  double d_proba;
+};
+
+class TagRule : public DNSRule
+{
+public:
+  TagRule(const std::string& tag, boost::optional<std::string> value) : d_value(value), d_tag(tag)
+  {
+  }
+  bool matches(const DNSQuestion* dq) const override
+  {
+    if (!dq->qTag) {
+      return false;
+    }
+
+    const auto it = dq->qTag->find(d_tag);
+    if (it == dq->qTag->cend()) {
+      return false;
+    }
+
+    if (!d_value) {
+      return true;
+    }
+
+    return it->second == *d_value;
+  }
+
+  string toString() const override
+  {
+    return "tag '" + d_tag + "' is set" + (d_value ? (" to '" + *d_value + "'") : "");
+  }
+
+private:
+  boost::optional<std::string> d_value;
+  std::string d_tag;
+};