]> granicus.if.org Git - ipset/commitdiff
Fix cidr book keeping for hash:*net* types
authorJozsef Kadlecsik <kadlec@blackhole.kfki.hu>
Mon, 10 Sep 2012 19:19:09 +0000 (21:19 +0200)
committerJozsef Kadlecsik <kadlec@blackhole.kfki.hu>
Mon, 10 Sep 2012 19:19:09 +0000 (21:19 +0200)
The book-keeping of the different sized networks were bogus, fix it.
The broken code could lead invalid matching in such sets when the number
of different sized networks were greater than the smallest CIDR value of
the networks.

kernel/include/linux/netfilter/ipset/ip_set_ahash.h

index b114d35aea5e652c90b13ef94e1ef00f6d864c90..8708c343f656b488978030b32ed347167808513f 100644 (file)
@@ -137,50 +137,59 @@ htable_bits(u32 hashsize)
 #endif
 
 #define SET_HOST_MASK(family)  (family == AF_INET ? 32 : 128)
+#ifdef IP_SET_HASH_WITH_MULTI
+#define NETS_LENGTH(family)    (SET_HOST_MASK(family) + 1)
+#else
+#define NETS_LENGTH(family)    SET_HOST_MASK(family)
+#endif
 
 /* Network cidr size book keeping when the hash stores different
  * sized networks */
 static void
-add_cidr(struct ip_set_hash *h, u8 cidr, u8 host_mask)
+add_cidr(struct ip_set_hash *h, u8 cidr, u8 nets_length)
 {
-       u8 i;
-
-       ++h->nets[cidr-1].nets;
-
-       pr_debug("add_cidr added %u: %u\n", cidr, h->nets[cidr-1].nets);
+       int i, j;
 
-       if (h->nets[cidr-1].nets > 1)
-               return;
-
-       /* New cidr size */
-       for (i = 0; i < host_mask && h->nets[i].cidr; i++) {
-               /* Add in increasing prefix order, so larger cidr first */
-               if (h->nets[i].cidr < cidr)
-                       swap(h->nets[i].cidr, cidr);
+       /* Add in increasing prefix order, so larger cidr first */
+       for (i = 0, j = -1; i < nets_length && h->nets[i].nets; i++) {
+               if (j != -1)
+                       continue;
+               else if (h->nets[i].cidr < cidr)
+                       j = i;
+               else if (h->nets[i].cidr == cidr) {
+                       h->nets[i].nets++;
+                       return;
+               }
+       }
+       if (j != -1) {
+               for (; j != -1 && i > j; i--) {
+                       h->nets[i].cidr = h->nets[i - 1].cidr;
+                       h->nets[i].nets = h->nets[i - 1].nets;
+               }
        }
-       if (i < host_mask)
-               h->nets[i].cidr = cidr;
+       h->nets[i].cidr = cidr;
+       h->nets[i].nets = 1;
 }
 
 static void
-del_cidr(struct ip_set_hash *h, u8 cidr, u8 host_mask)
+del_cidr(struct ip_set_hash *h, u8 cidr, u8 nets_length)
 {
-       u8 i;
-
-       --h->nets[cidr-1].nets;
+       u8 i, j;
 
-       pr_debug("del_cidr deleted %u: %u\n", cidr, h->nets[cidr-1].nets);
+       for (i = 0; i < nets_length - 1 && h->nets[i].cidr != cidr; i++)
+               ;
+       h->nets[i].nets--;
 
-       if (h->nets[cidr-1].nets != 0)
+       if (h->nets[i].nets != 0)
                return;
 
-       /* All entries with this cidr size deleted, so cleanup h->cidr[] */
-       for (i = 0; i < host_mask - 1 && h->nets[i].cidr; i++) {
-               if (h->nets[i].cidr == cidr)
-                       h->nets[i].cidr = cidr = h->nets[i+1].cidr;
+       for (j = i; j < nets_length - 1 && h->nets[j].nets; j++) {
+               h->nets[j].cidr = h->nets[j + 1].cidr;
+               h->nets[j].nets = h->nets[j + 1].nets;
        }
-       h->nets[i - 1].cidr = 0;
 }
+#else
+#define NETS_LENGTH(family)            0
 #endif
 
 /* Destroy the hashtable part of the set */
@@ -202,14 +211,14 @@ ahash_destroy(struct htable *t)
 
 /* Calculate the actual memory size of the set data */
 static size_t
-ahash_memsize(const struct ip_set_hash *h, size_t dsize, u8 host_mask)
+ahash_memsize(const struct ip_set_hash *h, size_t dsize, u8 nets_length)
 {
        u32 i;
        struct htable *t = h->table;
        size_t memsize = sizeof(*h)
                         + sizeof(*t)
 #ifdef IP_SET_HASH_WITH_NETS
-                        + sizeof(struct ip_set_hash_nets) * host_mask
+                        + sizeof(struct ip_set_hash_nets) * nets_length
 #endif
                         + jhash_size(t->htable_bits) * sizeof(struct hbucket);
 
@@ -238,7 +247,7 @@ ip_set_hash_flush(struct ip_set *set)
        }
 #ifdef IP_SET_HASH_WITH_NETS
        memset(h->nets, 0, sizeof(struct ip_set_hash_nets)
-                          * SET_HOST_MASK(set->family));
+                          * NETS_LENGTH(set->family));
 #endif
        h->elements = 0;
 }
@@ -271,9 +280,6 @@ ip_set_hash_destroy(struct ip_set *set)
 (jhash2((u32 *)(data), HKEY_DATALEN/sizeof(u32), initval)      \
        & jhash_mask(htable_bits))
 
-#define CONCAT(a, b, c)                a##b##c
-#define TOKEN(a, b, c)         CONCAT(a, b, c)
-
 /* Type/family dependent function prototypes */
 
 #define type_pf_data_equal     TOKEN(TYPE, PF, _data_equal)
@@ -478,7 +484,7 @@ type_pf_add(struct ip_set *set, void *value, u32 timeout, u32 flags)
        }
 
 #ifdef IP_SET_HASH_WITH_NETS
-       add_cidr(h, CIDR(d->cidr), HOST_MASK);
+       add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
 #endif
        h->elements++;
 out:
@@ -513,7 +519,7 @@ type_pf_del(struct ip_set *set, void *value, u32 timeout, u32 flags)
                n->pos--;
                h->elements--;
 #ifdef IP_SET_HASH_WITH_NETS
-               del_cidr(h, CIDR(d->cidr), HOST_MASK);
+               del_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
 #endif
                if (n->pos + AHASH_INIT_SIZE < n->size) {
                        void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
@@ -546,10 +552,10 @@ type_pf_test_cidrs(struct ip_set *set, struct type_pf_elem *d, u32 timeout)
        const struct type_pf_elem *data;
        int i, j = 0;
        u32 key, multi = 0;
-       u8 host_mask = SET_HOST_MASK(set->family);
+       u8 nets_length = NETS_LENGTH(set->family);
 
        pr_debug("test by nets\n");
-       for (; j < host_mask && h->nets[j].cidr && !multi; j++) {
+       for (; j < nets_length && h->nets[j].nets && !multi; j++) {
                type_pf_data_netmask(d, h->nets[j].cidr);
                key = HKEY(d, h->initval, t->htable_bits);
                n = hbucket(t, key);
@@ -604,7 +610,7 @@ type_pf_head(struct ip_set *set, struct sk_buff *skb)
        memsize = ahash_memsize(h, with_timeout(h->timeout)
                                        ? sizeof(struct type_pf_telem)
                                        : sizeof(struct type_pf_elem),
-                               set->family == AF_INET ? 32 : 128);
+                               NETS_LENGTH(set->family));
        read_unlock_bh(&set->lock);
 
        nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
@@ -783,7 +789,7 @@ type_pf_elem_tadd(struct hbucket *n, const struct type_pf_elem *value,
 
 /* Delete expired elements from the hashtable */
 static void
-type_pf_expire(struct ip_set_hash *h)
+type_pf_expire(struct ip_set_hash *h, u8 nets_length)
 {
        struct htable *t = h->table;
        struct hbucket *n;
@@ -798,7 +804,7 @@ type_pf_expire(struct ip_set_hash *h)
                        if (type_pf_data_expired(data)) {
                                pr_debug("expired %u/%u\n", i, j);
 #ifdef IP_SET_HASH_WITH_NETS
-                               del_cidr(h, CIDR(data->cidr), HOST_MASK);
+                               del_cidr(h, CIDR(data->cidr), nets_length);
 #endif
                                if (j != n->pos - 1)
                                        /* Not last one */
@@ -839,7 +845,7 @@ type_pf_tresize(struct ip_set *set, bool retried)
        if (!retried) {
                i = h->elements;
                write_lock_bh(&set->lock);
-               type_pf_expire(set->data);
+               type_pf_expire(set->data, NETS_LENGTH(set->family));
                write_unlock_bh(&set->lock);
                if (h->elements <  i)
                        return 0;
@@ -904,7 +910,7 @@ type_pf_tadd(struct ip_set *set, void *value, u32 timeout, u32 flags)
 
        if (h->elements >= h->maxelem)
                /* FIXME: when set is full, we slow down here */
-               type_pf_expire(h);
+               type_pf_expire(h, NETS_LENGTH(set->family));
        if (h->elements >= h->maxelem) {
                if (net_ratelimit())
                        pr_warning("Set %s is full, maxelem %u reached\n",
@@ -933,8 +939,8 @@ type_pf_tadd(struct ip_set *set, void *value, u32 timeout, u32 flags)
        if (j != AHASH_MAX(h) + 1) {
                data = ahash_tdata(n, j);
 #ifdef IP_SET_HASH_WITH_NETS
-               del_cidr(h, CIDR(data->cidr), HOST_MASK);
-               add_cidr(h, CIDR(d->cidr), HOST_MASK);
+               del_cidr(h, CIDR(data->cidr), NETS_LENGTH(set->family));
+               add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
 #endif
                type_pf_data_copy(data, d);
                type_pf_data_timeout_set(data, timeout);
@@ -952,7 +958,7 @@ type_pf_tadd(struct ip_set *set, void *value, u32 timeout, u32 flags)
        }
 
 #ifdef IP_SET_HASH_WITH_NETS
-       add_cidr(h, CIDR(d->cidr), HOST_MASK);
+       add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
 #endif
        h->elements++;
 out:
@@ -986,7 +992,7 @@ type_pf_tdel(struct ip_set *set, void *value, u32 timeout, u32 flags)
                n->pos--;
                h->elements--;
 #ifdef IP_SET_HASH_WITH_NETS
-               del_cidr(h, CIDR(d->cidr), HOST_MASK);
+               del_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
 #endif
                if (n->pos + AHASH_INIT_SIZE < n->size) {
                        void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
@@ -1016,9 +1022,9 @@ type_pf_ttest_cidrs(struct ip_set *set, struct type_pf_elem *d, u32 timeout)
        struct hbucket *n;
        int i, j = 0;
        u32 key, multi = 0;
-       u8 host_mask = SET_HOST_MASK(set->family);
+       u8 nets_length = NETS_LENGTH(set->family);
 
-       for (; j < host_mask && h->nets[j].cidr && !multi; j++) {
+       for (; j < nets_length && h->nets[j].nets && !multi; j++) {
                type_pf_data_netmask(d, h->nets[j].cidr);
                key = HKEY(d, h->initval, t->htable_bits);
                n = hbucket(t, key);
@@ -1147,7 +1153,7 @@ type_pf_gc(unsigned long ul_set)
 
        pr_debug("called\n");
        write_lock_bh(&set->lock);
-       type_pf_expire(h);
+       type_pf_expire(h, NETS_LENGTH(set->family));
        write_unlock_bh(&set->lock);
 
        h->gc.expires = jiffies + IPSET_GC_PERIOD(h->timeout) * HZ;