]> granicus.if.org Git - pdns/commitdiff
Add implementation for NetmaskTree<T>
authorAki Tuomi <cmouse@desteem.org>
Sun, 15 Nov 2015 10:32:12 +0000 (12:32 +0200)
committerAki Tuomi <cmouse@desteem.org>
Tue, 17 Nov 2015 18:28:49 +0000 (20:28 +0200)
A tree based map implementation to keep collection
of network prefixes and return best match for
given ComboAddress.

pdns/iputils.hh

index db7670a67b71e48803484bd16949b055a10382f1..51da7a301a72b21c68ebfaeeeb10dcc14b918cd4 100644 (file)
@@ -29,6 +29,7 @@
 #include <iostream>
 #include <stdio.h>
 #include <functional>
+#include <bitset>
 #include "pdnsexception.hh"
 #include "misc.hh"
 #include <sys/socket.h>
@@ -369,6 +370,325 @@ private:
   uint8_t d_bits;
 };
 
+/** Per-bit binary tree map implementation with <Netmask,T> pair.
+ *
+ * This is an binary tree implementation for storing attributes for IPv4 and IPv6 prefixes.
+ * The most simple use case is simple NetmaskTree<bool> used by NetmaskGroup, which only
+ * wants to know if given IP address is matched in the prefixes stored.
+ *
+ * This element is useful for anything that needs to *STORE* prefixes, and *MATCH* IP addresses
+ * to a *LIST* of *PREFIXES*. Not the other way round.
+ *
+ * You can store IPv4 and IPv6 addresses to same tree, separate payload storage is kept per AFI.
+ *
+ * To erase something copy values to new tree sans the value you want to erase.
+ *
+ * Use swap if you need to move the tree to another NetmaskTree instance, it is WAY faster
+ * than using copy ctor or assigment operator, since it moves the nodes and tree root to
+ * new home instead of actually recreating the tree.
+ *
+ * Please see NetmaskGroup for example of simple use case. Other usecases can be found
+ * from GeoIPBackend and Sortlist, and from dnsdist.
+ */
+template <typename T>
+class NetmaskTree {
+public:
+  typedef Netmask key_type;
+  typedef T value_type;
+  typedef std::pair<key_type,value_type> node_type;
+  typedef size_t size_type;
+
+private:
+  /** Single node in tree, internal use only.
+    */
+  class TreeNode : boost::noncopyable {
+  public:
+     explicit TreeNode(int bits) noexcept : parent(NULL),d_bits(bits) {
+     }
+
+     //<! Makes a left node with one more bit than parent
+     TreeNode* make_left() {
+       if (!left) {
+         left = unique_ptr<TreeNode>(new TreeNode(d_bits+1));
+         left->parent = this;
+       }
+       return left.get();
+     }
+
+     //<! Makes a right node with one more bit than parent
+     TreeNode* make_right() {
+       if (!right) {
+         right = unique_ptr<TreeNode>(new TreeNode(d_bits+1));
+         right->parent = this;
+       }
+       return right.get();
+     }
+
+     unique_ptr<TreeNode> left;
+     unique_ptr<TreeNode> right;
+     TreeNode* parent;
+
+     unique_ptr<node_type> node4; //<! IPv4 value-pair
+     unique_ptr<node_type> node6; //<! IPv6 value-pair
+
+     int d_bits; //<! How many bits have been used so far
+  };
+
+public:
+  NetmaskTree() noexcept {
+  }
+
+  NetmaskTree(const NetmaskTree& rhs) {
+    // it is easier to copy the nodes than tree.
+    // also acts as handy compactor
+    for(auto const& node: rhs._nodes)
+      insert(node->first).second = node->second;
+  }
+
+  NetmaskTree& operator=(const NetmaskTree& rhs) {
+    clear();
+    // see above.
+    for(auto const& node: rhs._nodes)
+      insert(node->first).second = node->second;
+    return *this;
+  }
+
+  //<! Index operator for value-pair, throws if not found
+  const value_type& operator[](const key_type& rhs) const {
+    const node_type *value = lookup(rhs.getNetwork(), rhs.getBits());
+    if (value == nullptr) throw std::range_error(rhs.toString() + string(" not found"));
+    return value->second;
+  }
+
+  //<! Index operator for value-pair, creates new if not found
+  T& operator[](const key_type& rhs) {
+    return insert(rhs).second;
+  }
+
+  const typename std::vector<node_type*>::const_iterator begin() const { return _nodes.begin(); }
+  const typename std::vector<node_type*>::const_iterator end() const { return _nodes.end(); }
+
+  typename std::vector<node_type*>::iterator begin() { return _nodes.begin(); }
+  typename std::vector<node_type*>::iterator end() { return _nodes.end(); }
+
+  node_type& insert(const string &mask) {
+    return insert(key_type(mask));
+  }
+
+  //<! Creates new value-pair in tree and returns it.
+  node_type& insert(const key_type& key) {
+    // lazily initialize tree on first insert.
+    if (!root) root = unique_ptr<TreeNode>(new TreeNode(0));
+    TreeNode* node = root.get();
+    node_type* value = nullptr;
+
+    if (key.getNetwork().sin4.sin_family == AF_INET) {
+      std::bitset<32> addr(be32toh(key.getNetwork().sin4.sin_addr.s_addr));
+      int bits = 0;
+      // we turn left on 0 and right on 1
+      while(bits < key.getBits()) {
+        uint8_t val = addr[31-bits];
+        if (val)
+          node = node->make_right();
+        else
+          node = node->make_left();
+        bits++;
+      }
+      // only create node if not yet assigned
+      if (!node->node4) {
+        node->node4 = unique_ptr<node_type>(new node_type());
+        _nodes.push_back(node->node4.get());
+      }
+      value = node->node4.get();
+    } else {
+      uint64_t* addr = (uint64_t*)key.getNetwork().sin6.sin6_addr.s6_addr;
+      std::bitset<64> addr_low(be64toh(addr[1]));
+      std::bitset<64> addr_high(be64toh(addr[0]));
+      int bits = 0;
+      while(bits < key.getBits()) {
+        uint8_t val;
+        // we use high address until we are
+        if (bits < 64) val = addr_high[63-bits];
+        // past 64 bits, and start using low address
+        else val = addr_low[127-bits];
+
+        // we turn left on 0 and right on 1
+        if (val)
+          node = node->make_right();
+        else
+          node = node->make_left();
+        bits++;
+      }
+      // only create node if not yet assigned
+      if (!node->node6) {
+        node->node6 = unique_ptr<node_type>(new node_type());
+        _nodes.push_back(node->node6.get());
+      }
+      value = node->node6.get();
+    }
+    // assign key
+    value->first = key;
+    return *value;
+  }
+
+  void insert_or_assign(const key_type& mask, const value_type& value) {
+    insert(mask).second = value;
+  }
+
+  const node_type& at(const key_type& value) const {
+    const node_type* node = lookup(value);
+    if (node == nullptr) throw std::range_error(value.toString() + string(" not found"));
+    return *node;
+  }
+
+  const node_type* lookup(const key_type& value) const {
+    return lookup(value.getNetwork(), value.getBits());
+  }
+
+  const node_type* lookup(const ComboAddress& value, int max_bits = 128) const {
+    if (!root) return nullptr;
+
+    TreeNode *node = root.get();
+    node_type *ret = nullptr;
+
+    // exact same thing as above, except
+    if (value.sin4.sin_family == AF_INET) {
+      max_bits = std::max(0,std::min(max_bits,32));
+      std::bitset<32> addr(be32toh(value.sin4.sin_addr.s_addr));
+      int bits = 0;
+
+      while(bits < max_bits) {
+        // ...we keep track of last non-empty node
+        if (node->node4) ret = node->node4.get();
+        uint8_t val = addr[31-bits];
+        // ...and we don't create left/right hand
+        if (val) {
+          if (node->right) node = node->right.get();
+          // ..and we break when road ends
+          else break;
+        } else {
+          if (node->left) node = node->left.get();
+          else break;
+        }
+        bits++;
+      }
+      // needed if we did not find one in loop
+      if (node->node4) ret = node->node4.get();
+    } else {
+      uint64_t* addr = (uint64_t*)value.sin6.sin6_addr.s6_addr;
+      max_bits = std::max(0,std::min(max_bits,128));
+      std::bitset<64> addr_low(be64toh(addr[1]));
+      std::bitset<64> addr_high(be64toh(addr[0]));
+      int bits = 0;
+      while(bits < max_bits) {
+        if (node->node6) ret = node->node6.get();
+        uint8_t val;
+        if (bits < 64) val = addr_high[63-bits];
+        else val = addr_low[127-bits];
+        if (val) {
+          if (node->right) node = node->right.get();
+          else break;
+        } else {
+          if (node->left) node = node->left.get();
+          else break;
+        }
+        bits++;
+      }
+      if (node->node6) ret = node->node6.get();
+    }
+
+    // this can be nullptr.
+    return ret;
+  }
+
+  void erase(const key_type& key) {
+    TreeNode *node = root.get();
+
+    // no tree, no value
+    if ( node == nullptr ) return;
+
+    // exact same thing as above, except
+    if (key.getNetwork().sin4.sin_family == AF_INET) {
+      std::bitset<32> addr(be32toh(key.getNetwork().sin4.sin_addr.s_addr));
+      int bits = 0;
+      while(node && bits < key.getBits()) {
+        uint8_t val = addr[31-bits];
+        if (val) {
+          node = node->right.get();
+        } else {
+          node = node->left.get();
+        }
+        bits++;
+      }
+      if (node) {
+        for(auto it = _nodes.begin(); it != _nodes.end(); it++)
+           if (node->node4.get() == *it) _nodes.erase(it);
+        node->node4.reset();
+      }
+    } else {
+      uint64_t* addr = (uint64_t*)key.getNetwork().sin6.sin6_addr.s6_addr;
+      std::bitset<64> addr_low(be64toh(addr[1]));
+      std::bitset<64> addr_high(be64toh(addr[0]));
+      int bits = 0;
+      while(node && bits < key.getBits()) {
+        uint8_t val;
+        if (bits < 64) val = addr_high[63-bits];
+        else val = addr_low[127-bits];
+        if (val) {
+          node = node->right.get();
+        } else {
+          node = node->left.get();
+        }
+        bits++;
+      }
+      if (node) {
+        for(auto it = _nodes.begin(); it != _nodes.end(); it++)
+           if (node->node4.get() == *it) _nodes.erase(it);
+        node->node6.reset();
+      }
+    }
+  }
+
+  void erase(const string& key) {
+    erase(key_type(key));
+  }
+
+  //<! checks whether the container is empty.
+  bool empty() const {
+    return _nodes.empty();
+  }
+
+  //<! returns the number of elements
+  size_type size() const {
+    return _nodes.size();
+  }
+
+  //<! See if given ComboAddress matches any prefix
+  bool match(const ComboAddress& value) const {
+    return (lookup(value) != nullptr);
+  }
+
+  bool match(const std::string& value) const {
+    return match(ComboAddress(value));
+  }
+
+  //<! Clean out the tree
+  void clear() {
+    _nodes.clear();
+    root.reset(nullptr);
+  }
+
+  //<! swaps the contents, rhs is left with nullptr.
+  void swap(NetmaskTree& rhs) {
+    root.swap(rhs.root);
+    _nodes.swap(rhs._nodes);
+  }
+
+private:
+  unique_ptr<TreeNode> root; //<! Root of our tree
+  std::vector<node_type*> _nodes; //<! Container for actual values
+};
+
 /** This class represents a group of supplemental Netmask classes. An IP address matchs
     if it is matched by zero or more of the Netmask classes within.
 */