#include <iostream>
#include <stdio.h>
#include <functional>
+#include <bitset>
#include "pdnsexception.hh"
#include "misc.hh"
#include <sys/socket.h>
uint8_t d_bits;
};
+/** Per-bit binary tree map implementation with <Netmask,T> pair.
+ *
+ * This is an binary tree implementation for storing attributes for IPv4 and IPv6 prefixes.
+ * The most simple use case is simple NetmaskTree<bool> used by NetmaskGroup, which only
+ * wants to know if given IP address is matched in the prefixes stored.
+ *
+ * This element is useful for anything that needs to *STORE* prefixes, and *MATCH* IP addresses
+ * to a *LIST* of *PREFIXES*. Not the other way round.
+ *
+ * You can store IPv4 and IPv6 addresses to same tree, separate payload storage is kept per AFI.
+ *
+ * To erase something copy values to new tree sans the value you want to erase.
+ *
+ * Use swap if you need to move the tree to another NetmaskTree instance, it is WAY faster
+ * than using copy ctor or assigment operator, since it moves the nodes and tree root to
+ * new home instead of actually recreating the tree.
+ *
+ * Please see NetmaskGroup for example of simple use case. Other usecases can be found
+ * from GeoIPBackend and Sortlist, and from dnsdist.
+ */
+template <typename T>
+class NetmaskTree {
+public:
+ typedef Netmask key_type;
+ typedef T value_type;
+ typedef std::pair<key_type,value_type> node_type;
+ typedef size_t size_type;
+
+private:
+ /** Single node in tree, internal use only.
+ */
+ class TreeNode : boost::noncopyable {
+ public:
+ explicit TreeNode(int bits) noexcept : parent(NULL),d_bits(bits) {
+ }
+
+ //<! Makes a left node with one more bit than parent
+ TreeNode* make_left() {
+ if (!left) {
+ left = unique_ptr<TreeNode>(new TreeNode(d_bits+1));
+ left->parent = this;
+ }
+ return left.get();
+ }
+
+ //<! Makes a right node with one more bit than parent
+ TreeNode* make_right() {
+ if (!right) {
+ right = unique_ptr<TreeNode>(new TreeNode(d_bits+1));
+ right->parent = this;
+ }
+ return right.get();
+ }
+
+ unique_ptr<TreeNode> left;
+ unique_ptr<TreeNode> right;
+ TreeNode* parent;
+
+ unique_ptr<node_type> node4; //<! IPv4 value-pair
+ unique_ptr<node_type> node6; //<! IPv6 value-pair
+
+ int d_bits; //<! How many bits have been used so far
+ };
+
+public:
+ NetmaskTree() noexcept {
+ }
+
+ NetmaskTree(const NetmaskTree& rhs) {
+ // it is easier to copy the nodes than tree.
+ // also acts as handy compactor
+ for(auto const& node: rhs._nodes)
+ insert(node->first).second = node->second;
+ }
+
+ NetmaskTree& operator=(const NetmaskTree& rhs) {
+ clear();
+ // see above.
+ for(auto const& node: rhs._nodes)
+ insert(node->first).second = node->second;
+ return *this;
+ }
+
+ //<! Index operator for value-pair, throws if not found
+ const value_type& operator[](const key_type& rhs) const {
+ const node_type *value = lookup(rhs.getNetwork(), rhs.getBits());
+ if (value == nullptr) throw std::range_error(rhs.toString() + string(" not found"));
+ return value->second;
+ }
+
+ //<! Index operator for value-pair, creates new if not found
+ T& operator[](const key_type& rhs) {
+ return insert(rhs).second;
+ }
+
+ const typename std::vector<node_type*>::const_iterator begin() const { return _nodes.begin(); }
+ const typename std::vector<node_type*>::const_iterator end() const { return _nodes.end(); }
+
+ typename std::vector<node_type*>::iterator begin() { return _nodes.begin(); }
+ typename std::vector<node_type*>::iterator end() { return _nodes.end(); }
+
+ node_type& insert(const string &mask) {
+ return insert(key_type(mask));
+ }
+
+ //<! Creates new value-pair in tree and returns it.
+ node_type& insert(const key_type& key) {
+ // lazily initialize tree on first insert.
+ if (!root) root = unique_ptr<TreeNode>(new TreeNode(0));
+ TreeNode* node = root.get();
+ node_type* value = nullptr;
+
+ if (key.getNetwork().sin4.sin_family == AF_INET) {
+ std::bitset<32> addr(be32toh(key.getNetwork().sin4.sin_addr.s_addr));
+ int bits = 0;
+ // we turn left on 0 and right on 1
+ while(bits < key.getBits()) {
+ uint8_t val = addr[31-bits];
+ if (val)
+ node = node->make_right();
+ else
+ node = node->make_left();
+ bits++;
+ }
+ // only create node if not yet assigned
+ if (!node->node4) {
+ node->node4 = unique_ptr<node_type>(new node_type());
+ _nodes.push_back(node->node4.get());
+ }
+ value = node->node4.get();
+ } else {
+ uint64_t* addr = (uint64_t*)key.getNetwork().sin6.sin6_addr.s6_addr;
+ std::bitset<64> addr_low(be64toh(addr[1]));
+ std::bitset<64> addr_high(be64toh(addr[0]));
+ int bits = 0;
+ while(bits < key.getBits()) {
+ uint8_t val;
+ // we use high address until we are
+ if (bits < 64) val = addr_high[63-bits];
+ // past 64 bits, and start using low address
+ else val = addr_low[127-bits];
+
+ // we turn left on 0 and right on 1
+ if (val)
+ node = node->make_right();
+ else
+ node = node->make_left();
+ bits++;
+ }
+ // only create node if not yet assigned
+ if (!node->node6) {
+ node->node6 = unique_ptr<node_type>(new node_type());
+ _nodes.push_back(node->node6.get());
+ }
+ value = node->node6.get();
+ }
+ // assign key
+ value->first = key;
+ return *value;
+ }
+
+ void insert_or_assign(const key_type& mask, const value_type& value) {
+ insert(mask).second = value;
+ }
+
+ const node_type& at(const key_type& value) const {
+ const node_type* node = lookup(value);
+ if (node == nullptr) throw std::range_error(value.toString() + string(" not found"));
+ return *node;
+ }
+
+ const node_type* lookup(const key_type& value) const {
+ return lookup(value.getNetwork(), value.getBits());
+ }
+
+ const node_type* lookup(const ComboAddress& value, int max_bits = 128) const {
+ if (!root) return nullptr;
+
+ TreeNode *node = root.get();
+ node_type *ret = nullptr;
+
+ // exact same thing as above, except
+ if (value.sin4.sin_family == AF_INET) {
+ max_bits = std::max(0,std::min(max_bits,32));
+ std::bitset<32> addr(be32toh(value.sin4.sin_addr.s_addr));
+ int bits = 0;
+
+ while(bits < max_bits) {
+ // ...we keep track of last non-empty node
+ if (node->node4) ret = node->node4.get();
+ uint8_t val = addr[31-bits];
+ // ...and we don't create left/right hand
+ if (val) {
+ if (node->right) node = node->right.get();
+ // ..and we break when road ends
+ else break;
+ } else {
+ if (node->left) node = node->left.get();
+ else break;
+ }
+ bits++;
+ }
+ // needed if we did not find one in loop
+ if (node->node4) ret = node->node4.get();
+ } else {
+ uint64_t* addr = (uint64_t*)value.sin6.sin6_addr.s6_addr;
+ max_bits = std::max(0,std::min(max_bits,128));
+ std::bitset<64> addr_low(be64toh(addr[1]));
+ std::bitset<64> addr_high(be64toh(addr[0]));
+ int bits = 0;
+ while(bits < max_bits) {
+ if (node->node6) ret = node->node6.get();
+ uint8_t val;
+ if (bits < 64) val = addr_high[63-bits];
+ else val = addr_low[127-bits];
+ if (val) {
+ if (node->right) node = node->right.get();
+ else break;
+ } else {
+ if (node->left) node = node->left.get();
+ else break;
+ }
+ bits++;
+ }
+ if (node->node6) ret = node->node6.get();
+ }
+
+ // this can be nullptr.
+ return ret;
+ }
+
+ void erase(const key_type& key) {
+ TreeNode *node = root.get();
+
+ // no tree, no value
+ if ( node == nullptr ) return;
+
+ // exact same thing as above, except
+ if (key.getNetwork().sin4.sin_family == AF_INET) {
+ std::bitset<32> addr(be32toh(key.getNetwork().sin4.sin_addr.s_addr));
+ int bits = 0;
+ while(node && bits < key.getBits()) {
+ uint8_t val = addr[31-bits];
+ if (val) {
+ node = node->right.get();
+ } else {
+ node = node->left.get();
+ }
+ bits++;
+ }
+ if (node) {
+ for(auto it = _nodes.begin(); it != _nodes.end(); it++)
+ if (node->node4.get() == *it) _nodes.erase(it);
+ node->node4.reset();
+ }
+ } else {
+ uint64_t* addr = (uint64_t*)key.getNetwork().sin6.sin6_addr.s6_addr;
+ std::bitset<64> addr_low(be64toh(addr[1]));
+ std::bitset<64> addr_high(be64toh(addr[0]));
+ int bits = 0;
+ while(node && bits < key.getBits()) {
+ uint8_t val;
+ if (bits < 64) val = addr_high[63-bits];
+ else val = addr_low[127-bits];
+ if (val) {
+ node = node->right.get();
+ } else {
+ node = node->left.get();
+ }
+ bits++;
+ }
+ if (node) {
+ for(auto it = _nodes.begin(); it != _nodes.end(); it++)
+ if (node->node4.get() == *it) _nodes.erase(it);
+ node->node6.reset();
+ }
+ }
+ }
+
+ void erase(const string& key) {
+ erase(key_type(key));
+ }
+
+ //<! checks whether the container is empty.
+ bool empty() const {
+ return _nodes.empty();
+ }
+
+ //<! returns the number of elements
+ size_type size() const {
+ return _nodes.size();
+ }
+
+ //<! See if given ComboAddress matches any prefix
+ bool match(const ComboAddress& value) const {
+ return (lookup(value) != nullptr);
+ }
+
+ bool match(const std::string& value) const {
+ return match(ComboAddress(value));
+ }
+
+ //<! Clean out the tree
+ void clear() {
+ _nodes.clear();
+ root.reset(nullptr);
+ }
+
+ //<! swaps the contents, rhs is left with nullptr.
+ void swap(NetmaskTree& rhs) {
+ root.swap(rhs.root);
+ _nodes.swap(rhs._nodes);
+ }
+
+private:
+ unique_ptr<TreeNode> root; //<! Root of our tree
+ std::vector<node_type*> _nodes; //<! Container for actual values
+};
+
/** This class represents a group of supplemental Netmask classes. An IP address matchs
if it is matched by zero or more of the Netmask classes within.
*/