]> granicus.if.org Git - ipset/blob - kernel/net/netfilter/ipset/ip_set_core.c
b424c0410d708c91ceb9d18050744d908903f0b9
[ipset] / kernel / net / netfilter / ipset / ip_set_core.c
1 /* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
2  *                         Patrick Schaaf <bof@bof.de>
3  * Copyright (C) 2003-2013 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
4  *
5  * This program is free software; you can redistribute it and/or modify
6  * it under the terms of the GNU General Public License version 2 as
7  * published by the Free Software Foundation.
8  */
9
10 /* Kernel module for IP set management */
11
12 #include <linux/init.h>
13 #include <linux/module.h>
14 #include <linux/moduleparam.h>
15 #include <linux/ip.h>
16 #include <linux/skbuff.h>
17 #include <linux/spinlock.h>
18 #include <linux/rculist.h>
19 #include <net/netlink.h>
20 #include <net/net_namespace.h>
21 #include <net/netns/generic.h>
22
23 #include <linux/netfilter.h>
24 #include <linux/netfilter/x_tables.h>
25 #include <linux/netfilter/nfnetlink.h>
26 #include <linux/netfilter/ipset/ip_set.h>
27
28 static LIST_HEAD(ip_set_type_list);             /* all registered set types */
29 static DEFINE_MUTEX(ip_set_type_mutex);         /* protects ip_set_type_list */
30 static DEFINE_RWLOCK(ip_set_ref_lock);          /* protects the set refs */
31
32 struct ip_set_net {
33         struct ip_set * __rcu *ip_set_list;     /* all individual sets */
34         ip_set_id_t     ip_set_max;     /* max number of sets */
35         bool            is_deleted;     /* deleted by ip_set_net_exit */
36         bool            is_destroyed;   /* all sets are destroyed */
37 };
38
39 static unsigned int ip_set_net_id __read_mostly;
40
41 static inline struct ip_set_net *ip_set_pernet(struct net *net)
42 {
43         return net_generic(net, ip_set_net_id);
44 }
45
46 #define IP_SET_INC      64
47 #define STRNCMP(a, b)   (strncmp(a, b, IPSET_MAXNAMELEN) == 0)
48
49 static unsigned int max_sets;
50
51 module_param(max_sets, int, 0600);
52 MODULE_PARM_DESC(max_sets, "maximal number of sets");
53 MODULE_LICENSE("GPL");
54 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
55 MODULE_DESCRIPTION("ip_set: protocol " __stringify(IPSET_PROTOCOL));
56 MODULE_ALIAS_NFNL_SUBSYS(NFNL_SUBSYS_IPSET);
57
58 /* When the nfnl mutex or ip_set_ref_lock is held: */
59 #define ip_set_dereference(p)           \
60         rcu_dereference_protected(p,    \
61                 lockdep_nfnl_is_held(NFNL_SUBSYS_IPSET) || \
62                 lockdep_is_held(&ip_set_ref_lock))
63 #define ip_set(inst, id)                \
64         ip_set_dereference((inst)->ip_set_list)[id]
65 #define ip_set_ref_netlink(inst,id)     \
66         rcu_dereference_raw((inst)->ip_set_list)[id]
67
68 /* The set types are implemented in modules and registered set types
69  * can be found in ip_set_type_list. Adding/deleting types is
70  * serialized by ip_set_type_mutex.
71  */
72
73 static inline void
74 ip_set_type_lock(void)
75 {
76         mutex_lock(&ip_set_type_mutex);
77 }
78
79 static inline void
80 ip_set_type_unlock(void)
81 {
82         mutex_unlock(&ip_set_type_mutex);
83 }
84
85 /* Register and deregister settype */
86
87 static struct ip_set_type *
88 find_set_type(const char *name, u8 family, u8 revision)
89 {
90         struct ip_set_type *type;
91
92         list_for_each_entry_rcu(type, &ip_set_type_list, list)
93                 if (STRNCMP(type->name, name) &&
94                     (type->family == family ||
95                      type->family == NFPROTO_UNSPEC) &&
96                     revision >= type->revision_min &&
97                     revision <= type->revision_max)
98                         return type;
99         return NULL;
100 }
101
102 /* Unlock, try to load a set type module and lock again */
103 static bool
104 load_settype(const char *name)
105 {
106         nfnl_unlock(NFNL_SUBSYS_IPSET);
107         pr_debug("try to load ip_set_%s\n", name);
108         if (request_module("ip_set_%s", name) < 0) {
109                 pr_warn("Can't find ip_set type %s\n", name);
110                 nfnl_lock(NFNL_SUBSYS_IPSET);
111                 return false;
112         }
113         nfnl_lock(NFNL_SUBSYS_IPSET);
114         return true;
115 }
116
117 /* Find a set type and reference it */
118 #define find_set_type_get(name, family, revision, found)        \
119         __find_set_type_get(name, family, revision, found, false)
120
121 static int
122 __find_set_type_get(const char *name, u8 family, u8 revision,
123                     struct ip_set_type **found, bool retry)
124 {
125         struct ip_set_type *type;
126         int err;
127
128         if (retry && !load_settype(name))
129                 return -IPSET_ERR_FIND_TYPE;
130
131         rcu_read_lock();
132         *found = find_set_type(name, family, revision);
133         if (*found) {
134                 err = !try_module_get((*found)->me) ? -EFAULT : 0;
135                 goto unlock;
136         }
137         /* Make sure the type is already loaded
138          * but we don't support the revision
139          */
140         list_for_each_entry_rcu(type, &ip_set_type_list, list)
141                 if (STRNCMP(type->name, name)) {
142                         err = -IPSET_ERR_FIND_TYPE;
143                         goto unlock;
144                 }
145         rcu_read_unlock();
146
147         return retry ? -IPSET_ERR_FIND_TYPE :
148                 __find_set_type_get(name, family, revision, found, true);
149
150 unlock:
151         rcu_read_unlock();
152         return err;
153 }
154
155 /* Find a given set type by name and family.
156  * If we succeeded, the supported minimal and maximum revisions are
157  * filled out.
158  */
159 #define find_set_type_minmax(name, family, min, max) \
160         __find_set_type_minmax(name, family, min, max, false)
161
162 static int
163 __find_set_type_minmax(const char *name, u8 family, u8 *min, u8 *max,
164                        bool retry)
165 {
166         struct ip_set_type *type;
167         bool found = false;
168
169         if (retry && !load_settype(name))
170                 return -IPSET_ERR_FIND_TYPE;
171
172         *min = 255; *max = 0;
173         rcu_read_lock();
174         list_for_each_entry_rcu(type, &ip_set_type_list, list)
175                 if (STRNCMP(type->name, name) &&
176                     (type->family == family ||
177                      type->family == NFPROTO_UNSPEC)) {
178                         found = true;
179                         if (type->revision_min < *min)
180                                 *min = type->revision_min;
181                         if (type->revision_max > *max)
182                                 *max = type->revision_max;
183                 }
184         rcu_read_unlock();
185         if (found)
186                 return 0;
187
188         return retry ? -IPSET_ERR_FIND_TYPE :
189                 __find_set_type_minmax(name, family, min, max, true);
190 }
191
192 #define family_name(f)  ((f) == NFPROTO_IPV4 ? "inet" : \
193                          (f) == NFPROTO_IPV6 ? "inet6" : "any")
194
195 /* Register a set type structure. The type is identified by
196  * the unique triple of name, family and revision.
197  */
198 int
199 ip_set_type_register(struct ip_set_type *type)
200 {
201         int ret = 0;
202
203         if (type->protocol != IPSET_PROTOCOL) {
204                 pr_warn("ip_set type %s, family %s, revision %u:%u uses wrong protocol version %u (want %u)\n",
205                         type->name, family_name(type->family),
206                         type->revision_min, type->revision_max,
207                         type->protocol, IPSET_PROTOCOL);
208                 return -EINVAL;
209         }
210
211         ip_set_type_lock();
212         if (find_set_type(type->name, type->family, type->revision_min)) {
213                 /* Duplicate! */
214                 pr_warn("ip_set type %s, family %s with revision min %u already registered!\n",
215                         type->name, family_name(type->family),
216                         type->revision_min);
217                 ip_set_type_unlock();
218                 return -EINVAL;
219         }
220         list_add_rcu(&type->list, &ip_set_type_list);
221         pr_debug("type %s, family %s, revision %u:%u registered.\n",
222                  type->name, family_name(type->family),
223                  type->revision_min, type->revision_max);
224         ip_set_type_unlock();
225
226         return ret;
227 }
228 EXPORT_SYMBOL_GPL(ip_set_type_register);
229
230 /* Unregister a set type. There's a small race with ip_set_create */
231 void
232 ip_set_type_unregister(struct ip_set_type *type)
233 {
234         ip_set_type_lock();
235         if (!find_set_type(type->name, type->family, type->revision_min)) {
236                 pr_warn("ip_set type %s, family %s with revision min %u not registered\n",
237                         type->name, family_name(type->family),
238                         type->revision_min);
239                 ip_set_type_unlock();
240                 return;
241         }
242         list_del_rcu(&type->list);
243         pr_debug("type %s, family %s with revision min %u unregistered.\n",
244                  type->name, family_name(type->family), type->revision_min);
245         ip_set_type_unlock();
246
247         synchronize_rcu();
248 }
249 EXPORT_SYMBOL_GPL(ip_set_type_unregister);
250
251 /* Utility functions */
252 void *
253 ip_set_alloc(size_t size)
254 {
255         void *members = NULL;
256
257         if (size < KMALLOC_MAX_SIZE)
258                 members = kzalloc(size, GFP_KERNEL | __GFP_NOWARN);
259
260         if (members) {
261                 pr_debug("%p: allocated with kmalloc\n", members);
262                 return members;
263         }
264
265         members = vzalloc(size);
266         if (!members)
267                 return NULL;
268         pr_debug("%p: allocated with vmalloc\n", members);
269
270         return members;
271 }
272 EXPORT_SYMBOL_GPL(ip_set_alloc);
273
274 void
275 ip_set_free(void *members)
276 {
277         pr_debug("%p: free with %s\n", members,
278                  is_vmalloc_addr(members) ? "vfree" : "kfree");
279         kvfree(members);
280 }
281 EXPORT_SYMBOL_GPL(ip_set_free);
282
283 static inline bool
284 flag_nested(const struct nlattr *nla)
285 {
286         return nla->nla_type & NLA_F_NESTED;
287 }
288
289 static const struct nla_policy ipaddr_policy[IPSET_ATTR_IPADDR_MAX + 1] = {
290         [IPSET_ATTR_IPADDR_IPV4]        = { .type = NLA_U32 },
291         [IPSET_ATTR_IPADDR_IPV6]        = { .type = NLA_BINARY,
292                                             .len = sizeof(struct in6_addr) },
293 };
294
295 int
296 ip_set_get_ipaddr4(struct nlattr *nla,  __be32 *ipaddr)
297 {
298         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
299
300         if (unlikely(!flag_nested(nla)))
301                 return -IPSET_ERR_PROTOCOL;
302         if (NLA_PARSE_NESTED(tb, IPSET_ATTR_IPADDR_MAX, nla,
303                              ipaddr_policy, NULL))
304                 return -IPSET_ERR_PROTOCOL;
305         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV4)))
306                 return -IPSET_ERR_PROTOCOL;
307
308         *ipaddr = nla_get_be32(tb[IPSET_ATTR_IPADDR_IPV4]);
309         return 0;
310 }
311 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr4);
312
313 int
314 ip_set_get_ipaddr6(struct nlattr *nla, union nf_inet_addr *ipaddr)
315 {
316         struct nlattr *tb[IPSET_ATTR_IPADDR_MAX + 1];
317
318         if (unlikely(!flag_nested(nla)))
319                 return -IPSET_ERR_PROTOCOL;
320
321         if (NLA_PARSE_NESTED(tb, IPSET_ATTR_IPADDR_MAX, nla,
322                              ipaddr_policy, NULL))
323                 return -IPSET_ERR_PROTOCOL;
324         if (unlikely(!ip_set_attr_netorder(tb, IPSET_ATTR_IPADDR_IPV6)))
325                 return -IPSET_ERR_PROTOCOL;
326
327         memcpy(ipaddr, nla_data(tb[IPSET_ATTR_IPADDR_IPV6]),
328                sizeof(struct in6_addr));
329         return 0;
330 }
331 EXPORT_SYMBOL_GPL(ip_set_get_ipaddr6);
332
333 typedef void (*destroyer)(struct ip_set *, void *);
334 /* ipset data extension types, in size order */
335
336 const struct ip_set_ext_type ip_set_extensions[] = {
337         [IPSET_EXT_ID_COUNTER] = {
338                 .type   = IPSET_EXT_COUNTER,
339                 .flag   = IPSET_FLAG_WITH_COUNTERS,
340                 .len    = sizeof(struct ip_set_counter),
341                 .align  = __alignof__(struct ip_set_counter),
342         },
343         [IPSET_EXT_ID_TIMEOUT] = {
344                 .type   = IPSET_EXT_TIMEOUT,
345                 .len    = sizeof(unsigned long),
346                 .align  = __alignof__(unsigned long),
347         },
348         [IPSET_EXT_ID_SKBINFO] = {
349                 .type   = IPSET_EXT_SKBINFO,
350                 .flag   = IPSET_FLAG_WITH_SKBINFO,
351                 .len    = sizeof(struct ip_set_skbinfo),
352                 .align  = __alignof__(struct ip_set_skbinfo),
353         },
354         [IPSET_EXT_ID_COMMENT] = {
355                 .type    = IPSET_EXT_COMMENT | IPSET_EXT_DESTROY,
356                 .flag    = IPSET_FLAG_WITH_COMMENT,
357                 .len     = sizeof(struct ip_set_comment),
358                 .align   = __alignof__(struct ip_set_comment),
359                 .destroy = (destroyer) ip_set_comment_free,
360         },
361 };
362 EXPORT_SYMBOL_GPL(ip_set_extensions);
363
364 static inline bool
365 add_extension(enum ip_set_ext_id id, u32 flags, struct nlattr *tb[])
366 {
367         return ip_set_extensions[id].flag ?
368                 (flags & ip_set_extensions[id].flag) :
369                 !!tb[IPSET_ATTR_TIMEOUT];
370 }
371
372 size_t
373 ip_set_elem_len(struct ip_set *set, struct nlattr *tb[], size_t len,
374                 size_t align)
375 {
376         enum ip_set_ext_id id;
377         u32 cadt_flags = 0;
378
379         if (tb[IPSET_ATTR_CADT_FLAGS])
380                 cadt_flags = ip_set_get_h32(tb[IPSET_ATTR_CADT_FLAGS]);
381         if (cadt_flags & IPSET_FLAG_WITH_FORCEADD)
382                 set->flags |= IPSET_CREATE_FLAG_FORCEADD;
383         if (!align)
384                 align = 1;
385         for (id = 0; id < IPSET_EXT_ID_MAX; id++) {
386                 if (!add_extension(id, cadt_flags, tb))
387                         continue;
388                 len = ALIGN(len, ip_set_extensions[id].align);
389                 set->offset[id] = len;
390                 set->extensions |= ip_set_extensions[id].type;
391                 len += ip_set_extensions[id].len;
392         }
393         return ALIGN(len, align);
394 }
395 EXPORT_SYMBOL_GPL(ip_set_elem_len);
396
397 int
398 ip_set_get_extensions(struct ip_set *set, struct nlattr *tb[],
399                       struct ip_set_ext *ext)
400 {
401         u64 fullmark;
402
403         if (unlikely(!ip_set_optattr_netorder(tb, IPSET_ATTR_TIMEOUT) ||
404                      !ip_set_optattr_netorder(tb, IPSET_ATTR_PACKETS) ||
405                      !ip_set_optattr_netorder(tb, IPSET_ATTR_BYTES) ||
406                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBMARK) ||
407                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBPRIO) ||
408                      !ip_set_optattr_netorder(tb, IPSET_ATTR_SKBQUEUE)))
409                 return -IPSET_ERR_PROTOCOL;
410
411         if (tb[IPSET_ATTR_TIMEOUT]) {
412                 if (!SET_WITH_TIMEOUT(set))
413                         return -IPSET_ERR_TIMEOUT;
414                 ext->timeout = ip_set_timeout_uget(tb[IPSET_ATTR_TIMEOUT]);
415         }
416         if (tb[IPSET_ATTR_BYTES] || tb[IPSET_ATTR_PACKETS]) {
417                 if (!SET_WITH_COUNTER(set))
418                         return -IPSET_ERR_COUNTER;
419                 if (tb[IPSET_ATTR_BYTES])
420                         ext->bytes = be64_to_cpu(nla_get_be64(
421                                                  tb[IPSET_ATTR_BYTES]));
422                 if (tb[IPSET_ATTR_PACKETS])
423                         ext->packets = be64_to_cpu(nla_get_be64(
424                                                    tb[IPSET_ATTR_PACKETS]));
425         }
426         if (tb[IPSET_ATTR_COMMENT]) {
427                 if (!SET_WITH_COMMENT(set))
428                         return -IPSET_ERR_COMMENT;
429                 ext->comment = ip_set_comment_uget(tb[IPSET_ATTR_COMMENT]);
430         }
431         if (tb[IPSET_ATTR_SKBMARK]) {
432                 if (!SET_WITH_SKBINFO(set))
433                         return -IPSET_ERR_SKBINFO;
434                 fullmark = be64_to_cpu(nla_get_be64(tb[IPSET_ATTR_SKBMARK]));
435                 ext->skbinfo.skbmark = fullmark >> 32;
436                 ext->skbinfo.skbmarkmask = fullmark & 0xffffffff;
437         }
438         if (tb[IPSET_ATTR_SKBPRIO]) {
439                 if (!SET_WITH_SKBINFO(set))
440                         return -IPSET_ERR_SKBINFO;
441                 ext->skbinfo.skbprio =
442                         be32_to_cpu(nla_get_be32(tb[IPSET_ATTR_SKBPRIO]));
443         }
444         if (tb[IPSET_ATTR_SKBQUEUE]) {
445                 if (!SET_WITH_SKBINFO(set))
446                         return -IPSET_ERR_SKBINFO;
447                 ext->skbinfo.skbqueue =
448                         be16_to_cpu(nla_get_be16(tb[IPSET_ATTR_SKBQUEUE]));
449         }
450         return 0;
451 }
452 EXPORT_SYMBOL_GPL(ip_set_get_extensions);
453
454 int
455 ip_set_put_extensions(struct sk_buff *skb, const struct ip_set *set,
456                       const void *e, bool active)
457 {
458         if (SET_WITH_TIMEOUT(set)) {
459                 unsigned long *timeout = ext_timeout(e, set);
460
461                 if (nla_put_net32(skb, IPSET_ATTR_TIMEOUT,
462                         htonl(active ? ip_set_timeout_get(timeout)
463                               : *timeout)))
464                         return -EMSGSIZE;
465         }
466         if (SET_WITH_COUNTER(set) &&
467             ip_set_put_counter(skb, ext_counter(e, set)))
468                 return -EMSGSIZE;
469         if (SET_WITH_COMMENT(set) &&
470             ip_set_put_comment(skb, ext_comment(e, set)))
471                 return -EMSGSIZE;
472         if (SET_WITH_SKBINFO(set) &&
473             ip_set_put_skbinfo(skb, ext_skbinfo(e, set)))
474                 return -EMSGSIZE;
475         return 0;
476 }
477 EXPORT_SYMBOL_GPL(ip_set_put_extensions);
478
479 bool
480 ip_set_match_extensions(struct ip_set *set, const struct ip_set_ext *ext,
481                         struct ip_set_ext *mext, u32 flags, void *data)
482 {
483         if (SET_WITH_TIMEOUT(set) &&
484             ip_set_timeout_expired(ext_timeout(data, set)))
485                 return false;
486         if (SET_WITH_COUNTER(set)) {
487                 struct ip_set_counter *counter = ext_counter(data, set);
488
489                 if (flags & IPSET_FLAG_MATCH_COUNTERS &&
490                     !(ip_set_match_counter(ip_set_get_packets(counter),
491                                 mext->packets, mext->packets_op) &&
492                       ip_set_match_counter(ip_set_get_bytes(counter),
493                                 mext->bytes, mext->bytes_op)))
494                         return false;
495                 ip_set_update_counter(counter, ext, flags);
496         }
497         if (SET_WITH_SKBINFO(set))
498                 ip_set_get_skbinfo(ext_skbinfo(data, set),
499                                    ext, mext, flags);
500         return true;
501 }
502 EXPORT_SYMBOL_GPL(ip_set_match_extensions);
503
504 /* Creating/destroying/renaming/swapping affect the existence and
505  * the properties of a set. All of these can be executed from userspace
506  * only and serialized by the nfnl mutex indirectly from nfnetlink.
507  *
508  * Sets are identified by their index in ip_set_list and the index
509  * is used by the external references (set/SET netfilter modules).
510  *
511  * The set behind an index may change by swapping only, from userspace.
512  */
513
514 static inline void
515 __ip_set_get(struct ip_set *set)
516 {
517         write_lock_bh(&ip_set_ref_lock);
518         set->ref++;
519         write_unlock_bh(&ip_set_ref_lock);
520 }
521
522 static inline void
523 __ip_set_put(struct ip_set *set)
524 {
525         write_lock_bh(&ip_set_ref_lock);
526         BUG_ON(set->ref == 0);
527         set->ref--;
528         write_unlock_bh(&ip_set_ref_lock);
529 }
530
531 /* set->ref can be swapped out by ip_set_swap, netlink events (like dump) need
532  * a separate reference counter
533  */
534 static inline void
535 __ip_set_put_netlink(struct ip_set *set)
536 {
537         write_lock_bh(&ip_set_ref_lock);
538         BUG_ON(set->ref_netlink == 0);
539         set->ref_netlink--;
540         write_unlock_bh(&ip_set_ref_lock);
541 }
542
543 /* Add, del and test set entries from kernel.
544  *
545  * The set behind the index must exist and must be referenced
546  * so it can't be destroyed (or changed) under our foot.
547  */
548
549 static inline struct ip_set *
550 ip_set_rcu_get(struct net *net, ip_set_id_t index)
551 {
552         struct ip_set *set;
553         struct ip_set_net *inst = ip_set_pernet(net);
554
555         rcu_read_lock();
556         /* ip_set_list itself needs to be protected */
557         set = rcu_dereference(inst->ip_set_list)[index];
558         rcu_read_unlock();
559
560         return set;
561 }
562
563 int
564 ip_set_test(ip_set_id_t index, const struct sk_buff *skb,
565             const struct xt_action_param *par, struct ip_set_adt_opt *opt)
566 {
567         struct ip_set *set = ip_set_rcu_get(IPSET_DEV_NET(par), index);
568         int ret = 0;
569
570         BUG_ON(!set);
571         pr_debug("set %s, index %u\n", set->name, index);
572
573         if (opt->dim < set->type->dimension ||
574             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
575                 return 0;
576
577         rcu_read_lock_bh();
578         ret = set->variant->kadt(set, skb, par, IPSET_TEST, opt);
579         rcu_read_unlock_bh();
580
581         if (ret == -EAGAIN) {
582                 /* Type requests element to be completed */
583                 pr_debug("element must be completed, ADD is triggered\n");
584                 spin_lock_bh(&set->lock);
585                 set->variant->kadt(set, skb, par, IPSET_ADD, opt);
586                 spin_unlock_bh(&set->lock);
587                 ret = 1;
588         } else {
589                 /* --return-nomatch: invert matched element */
590                 if ((opt->cmdflags & IPSET_FLAG_RETURN_NOMATCH) &&
591                     (set->type->features & IPSET_TYPE_NOMATCH) &&
592                     (ret > 0 || ret == -ENOTEMPTY))
593                         ret = -ret;
594         }
595
596         /* Convert error codes to nomatch */
597         return (ret < 0 ? 0 : ret);
598 }
599 EXPORT_SYMBOL_GPL(ip_set_test);
600
601 int
602 ip_set_add(ip_set_id_t index, const struct sk_buff *skb,
603            const struct xt_action_param *par, struct ip_set_adt_opt *opt)
604 {
605         struct ip_set *set = ip_set_rcu_get(IPSET_DEV_NET(par), index);
606         int ret;
607
608         BUG_ON(!set);
609         pr_debug("set %s, index %u\n", set->name, index);
610
611         if (opt->dim < set->type->dimension ||
612             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
613                 return -IPSET_ERR_TYPE_MISMATCH;
614
615         spin_lock_bh(&set->lock);
616         ret = set->variant->kadt(set, skb, par, IPSET_ADD, opt);
617         spin_unlock_bh(&set->lock);
618
619         return ret;
620 }
621 EXPORT_SYMBOL_GPL(ip_set_add);
622
623 int
624 ip_set_del(ip_set_id_t index, const struct sk_buff *skb,
625            const struct xt_action_param *par, struct ip_set_adt_opt *opt)
626 {
627         struct ip_set *set = ip_set_rcu_get(IPSET_DEV_NET(par), index);
628         int ret = 0;
629
630         BUG_ON(!set);
631         pr_debug("set %s, index %u\n", set->name, index);
632
633         if (opt->dim < set->type->dimension ||
634             !(opt->family == set->family || set->family == NFPROTO_UNSPEC))
635                 return -IPSET_ERR_TYPE_MISMATCH;
636
637         spin_lock_bh(&set->lock);
638         ret = set->variant->kadt(set, skb, par, IPSET_DEL, opt);
639         spin_unlock_bh(&set->lock);
640
641         return ret;
642 }
643 EXPORT_SYMBOL_GPL(ip_set_del);
644
645 /* Find set by name, reference it once. The reference makes sure the
646  * thing pointed to, does not go away under our feet.
647  *
648  */
649 ip_set_id_t
650 ip_set_get_byname(struct net *net, const char *name, struct ip_set **set)
651 {
652         ip_set_id_t i, index = IPSET_INVALID_ID;
653         struct ip_set *s;
654         struct ip_set_net *inst = ip_set_pernet(net);
655
656         rcu_read_lock();
657         for (i = 0; i < inst->ip_set_max; i++) {
658                 s = rcu_dereference(inst->ip_set_list)[i];
659                 if (s && STRNCMP(s->name, name)) {
660                         __ip_set_get(s);
661                         index = i;
662                         *set = s;
663                         break;
664                 }
665         }
666         rcu_read_unlock();
667
668         return index;
669 }
670 EXPORT_SYMBOL_GPL(ip_set_get_byname);
671
672 /* If the given set pointer points to a valid set, decrement
673  * reference count by 1. The caller shall not assume the index
674  * to be valid, after calling this function.
675  *
676  */
677
678 static inline void
679 __ip_set_put_byindex(struct ip_set_net *inst, ip_set_id_t index)
680 {
681         struct ip_set *set;
682
683         rcu_read_lock();
684         set = rcu_dereference(inst->ip_set_list)[index];
685         if (set)
686                 __ip_set_put(set);
687         rcu_read_unlock();
688 }
689
690 void
691 ip_set_put_byindex(struct net *net, ip_set_id_t index)
692 {
693         struct ip_set_net *inst = ip_set_pernet(net);
694
695         __ip_set_put_byindex(inst, index);
696 }
697 EXPORT_SYMBOL_GPL(ip_set_put_byindex);
698
699 /* Get the name of a set behind a set index.
700  * Set itself is protected by RCU, but its name isn't: to protect against
701  * renaming, grab ip_set_ref_lock as reader (see ip_set_rename()) and copy the
702  * name.
703  */
704 void
705 ip_set_name_byindex(struct net *net, ip_set_id_t index, char *name)
706 {
707         struct ip_set *set = ip_set_rcu_get(net, index);
708
709         BUG_ON(!set);
710
711         read_lock_bh(&ip_set_ref_lock);
712         strncpy(name, set->name, IPSET_MAXNAMELEN);
713         read_unlock_bh(&ip_set_ref_lock);
714 }
715 EXPORT_SYMBOL_GPL(ip_set_name_byindex);
716
717 /* Routines to call by external subsystems, which do not
718  * call nfnl_lock for us.
719  */
720
721 /* Find set by index, reference it once. The reference makes sure the
722  * thing pointed to, does not go away under our feet.
723  *
724  * The nfnl mutex is used in the function.
725  */
726 ip_set_id_t
727 ip_set_nfnl_get_byindex(struct net *net, ip_set_id_t index)
728 {
729         struct ip_set *set;
730         struct ip_set_net *inst = ip_set_pernet(net);
731
732         if (index >= inst->ip_set_max)
733                 return IPSET_INVALID_ID;
734
735         nfnl_lock(NFNL_SUBSYS_IPSET);
736         set = ip_set(inst, index);
737         if (set)
738                 __ip_set_get(set);
739         else
740                 index = IPSET_INVALID_ID;
741         nfnl_unlock(NFNL_SUBSYS_IPSET);
742
743         return index;
744 }
745 EXPORT_SYMBOL_GPL(ip_set_nfnl_get_byindex);
746
747 /* If the given set pointer points to a valid set, decrement
748  * reference count by 1. The caller shall not assume the index
749  * to be valid, after calling this function.
750  *
751  * The nfnl mutex is used in the function.
752  */
753 void
754 ip_set_nfnl_put(struct net *net, ip_set_id_t index)
755 {
756         struct ip_set *set;
757         struct ip_set_net *inst = ip_set_pernet(net);
758
759         nfnl_lock(NFNL_SUBSYS_IPSET);
760         if (!inst->is_deleted) { /* already deleted from ip_set_net_exit() */
761                 set = ip_set(inst, index);
762                 if (set)
763                         __ip_set_put(set);
764         }
765         nfnl_unlock(NFNL_SUBSYS_IPSET);
766 }
767 EXPORT_SYMBOL_GPL(ip_set_nfnl_put);
768
769 /* Communication protocol with userspace over netlink.
770  *
771  * The commands are serialized by the nfnl mutex.
772  */
773
774 static inline u8 protocol(const struct nlattr * const tb[])
775 {
776         return nla_get_u8(tb[IPSET_ATTR_PROTOCOL]);
777 }
778
779 static inline bool
780 protocol_failed(const struct nlattr * const tb[])
781 {
782         return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) != IPSET_PROTOCOL;
783 }
784
785 static inline bool
786 protocol_min_failed(const struct nlattr * const tb[])
787 {
788         return !tb[IPSET_ATTR_PROTOCOL] || protocol(tb) < IPSET_PROTOCOL_MIN;
789 }
790
791 static inline u32
792 flag_exist(const struct nlmsghdr *nlh)
793 {
794         return nlh->nlmsg_flags & NLM_F_EXCL ? 0 : IPSET_FLAG_EXIST;
795 }
796
797 static struct nlmsghdr *
798 start_msg(struct sk_buff *skb, u32 portid, u32 seq, unsigned int flags,
799           enum ipset_cmd cmd)
800 {
801         struct nlmsghdr *nlh;
802         struct nfgenmsg *nfmsg;
803
804         nlh = nlmsg_put(skb, portid, seq, nfnl_msg_type(NFNL_SUBSYS_IPSET, cmd),
805                         sizeof(*nfmsg), flags);
806         if (!nlh)
807                 return NULL;
808
809         nfmsg = nlmsg_data(nlh);
810         nfmsg->nfgen_family = NFPROTO_IPV4;
811         nfmsg->version = NFNETLINK_V0;
812         nfmsg->res_id = 0;
813
814         return nlh;
815 }
816
817 /* Create a set */
818
819 static const struct nla_policy ip_set_create_policy[IPSET_ATTR_CMD_MAX + 1] = {
820         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
821         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
822                                     .len = IPSET_MAXNAMELEN - 1 },
823         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
824                                     .len = IPSET_MAXNAMELEN - 1},
825         [IPSET_ATTR_REVISION]   = { .type = NLA_U8 },
826         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
827         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
828 };
829
830 static struct ip_set *
831 find_set_and_id(struct ip_set_net *inst, const char *name, ip_set_id_t *id)
832 {
833         struct ip_set *set = NULL;
834         ip_set_id_t i;
835
836         *id = IPSET_INVALID_ID;
837         for (i = 0; i < inst->ip_set_max; i++) {
838                 set = ip_set(inst, i);
839                 if (set && STRNCMP(set->name, name)) {
840                         *id = i;
841                         break;
842                 }
843         }
844         return (*id == IPSET_INVALID_ID ? NULL : set);
845 }
846
847 static inline struct ip_set *
848 find_set(struct ip_set_net *inst, const char *name)
849 {
850         ip_set_id_t id;
851
852         return find_set_and_id(inst, name, &id);
853 }
854
855 static int
856 find_free_id(struct ip_set_net *inst, const char *name, ip_set_id_t *index,
857              struct ip_set **set)
858 {
859         struct ip_set *s;
860         ip_set_id_t i;
861
862         *index = IPSET_INVALID_ID;
863         for (i = 0;  i < inst->ip_set_max; i++) {
864                 s = ip_set(inst, i);
865                 if (!s) {
866                         if (*index == IPSET_INVALID_ID)
867                                 *index = i;
868                 } else if (STRNCMP(name, s->name)) {
869                         /* Name clash */
870                         *set = s;
871                         return -EEXIST;
872                 }
873         }
874         if (*index == IPSET_INVALID_ID)
875                 /* No free slot remained */
876                 return -IPSET_ERR_MAX_SETS;
877         return 0;
878 }
879
880 static int
881 IPSET_CBFN(ip_set_none, struct net *net, struct sock *ctnl,
882            struct sk_buff *skb, const struct nlmsghdr *nlh,
883            const struct nlattr * const attr[],
884            struct netlink_ext_ack *extack)
885 {
886         return -EOPNOTSUPP;
887 }
888
889 static int
890 IPSET_CBFN(ip_set_create, struct net *n, struct sock *ctnl,
891            struct sk_buff *skb, const struct nlmsghdr *nlh,
892            const struct nlattr * const attr[],
893            struct netlink_ext_ack *extack)
894 {
895         struct net *net = IPSET_SOCK_NET(n, ctnl);
896         struct ip_set_net *inst = ip_set_pernet(net);
897         struct ip_set *set, *clash = NULL;
898         ip_set_id_t index = IPSET_INVALID_ID;
899         struct nlattr *tb[IPSET_ATTR_CREATE_MAX + 1] = {};
900         const char *name, *typename;
901         u8 family, revision;
902         u32 flags = flag_exist(nlh);
903         int ret = 0;
904
905         if (unlikely(protocol_min_failed(attr) ||
906                      !attr[IPSET_ATTR_SETNAME] ||
907                      !attr[IPSET_ATTR_TYPENAME] ||
908                      !attr[IPSET_ATTR_REVISION] ||
909                      !attr[IPSET_ATTR_FAMILY] ||
910                      (attr[IPSET_ATTR_DATA] &&
911                       !flag_nested(attr[IPSET_ATTR_DATA]))))
912                 return -IPSET_ERR_PROTOCOL;
913
914         name = nla_data(attr[IPSET_ATTR_SETNAME]);
915         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
916         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
917         revision = nla_get_u8(attr[IPSET_ATTR_REVISION]);
918         pr_debug("setname: %s, typename: %s, family: %s, revision: %u\n",
919                  name, typename, family_name(family), revision);
920
921         /* First, and without any locks, allocate and initialize
922          * a normal base set structure.
923          */
924         set = kzalloc(sizeof(*set), GFP_KERNEL);
925         if (!set)
926                 return -ENOMEM;
927         spin_lock_init(&set->lock);
928         strlcpy(set->name, name, IPSET_MAXNAMELEN);
929         set->family = family;
930         set->revision = revision;
931
932         /* Next, check that we know the type, and take
933          * a reference on the type, to make sure it stays available
934          * while constructing our new set.
935          *
936          * After referencing the type, we try to create the type
937          * specific part of the set without holding any locks.
938          */
939         ret = find_set_type_get(typename, family, revision, &set->type);
940         if (ret)
941                 goto out;
942
943         /* Without holding any locks, create private part. */
944         if (attr[IPSET_ATTR_DATA] &&
945             NLA_PARSE_NESTED(tb, IPSET_ATTR_CREATE_MAX, attr[IPSET_ATTR_DATA],
946                              set->type->create_policy, NULL)) {
947                 ret = -IPSET_ERR_PROTOCOL;
948                 goto put_out;
949         }
950
951         ret = set->type->create(net, set, tb, flags);
952         if (ret != 0)
953                 goto put_out;
954
955         /* BTW, ret==0 here. */
956
957         /* Here, we have a valid, constructed set and we are protected
958          * by the nfnl mutex. Find the first free index in ip_set_list
959          * and check clashing.
960          */
961         ret = find_free_id(inst, set->name, &index, &clash);
962         if (ret == -EEXIST) {
963                 /* If this is the same set and requested, ignore error */
964                 if ((flags & IPSET_FLAG_EXIST) &&
965                     STRNCMP(set->type->name, clash->type->name) &&
966                     set->type->family == clash->type->family &&
967                     set->type->revision_min == clash->type->revision_min &&
968                     set->type->revision_max == clash->type->revision_max &&
969                     set->variant->same_set(set, clash))
970                         ret = 0;
971                 goto cleanup;
972         } else if (ret == -IPSET_ERR_MAX_SETS) {
973                 struct ip_set **list, **tmp;
974                 ip_set_id_t i = inst->ip_set_max + IP_SET_INC;
975
976                 if (i < inst->ip_set_max || i == IPSET_INVALID_ID)
977                         /* Wraparound */
978                         goto cleanup;
979
980                 list = kvcalloc(i, sizeof(struct ip_set *), GFP_KERNEL);
981                 if (!list)
982                         goto cleanup;
983                 /* nfnl mutex is held, both lists are valid */
984                 tmp = ip_set_dereference(inst->ip_set_list);
985                 memcpy(list, tmp, sizeof(struct ip_set *) * inst->ip_set_max);
986                 rcu_assign_pointer(inst->ip_set_list, list);
987                 /* Make sure all current packets have passed through */
988                 synchronize_net();
989                 /* Use new list */
990                 index = inst->ip_set_max;
991                 inst->ip_set_max = i;
992                 kvfree(tmp);
993                 ret = 0;
994         } else if (ret) {
995                 goto cleanup;
996         }
997
998         /* Finally! Add our shiny new set to the list, and be done. */
999         pr_debug("create: '%s' created with index %u!\n", set->name, index);
1000         ip_set(inst, index) = set;
1001
1002         return ret;
1003
1004 cleanup:
1005         set->variant->destroy(set);
1006 put_out:
1007         module_put(set->type->me);
1008 out:
1009         kfree(set);
1010         return ret;
1011 }
1012
1013 /* Destroy sets */
1014
1015 static const struct nla_policy
1016 ip_set_setname_policy[IPSET_ATTR_CMD_MAX + 1] = {
1017         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1018         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1019                                     .len = IPSET_MAXNAMELEN - 1 },
1020 };
1021
1022 static void
1023 ip_set_destroy_set(struct ip_set *set)
1024 {
1025         pr_debug("set: %s\n",  set->name);
1026
1027         /* Must call it without holding any lock */
1028         set->variant->destroy(set);
1029         module_put(set->type->me);
1030         kfree(set);
1031 }
1032
1033 static int
1034 IPSET_CBFN(ip_set_destroy, struct net *net, struct sock *ctnl,
1035            struct sk_buff *skb, const struct nlmsghdr *nlh,
1036            const struct nlattr * const attr[],
1037            struct netlink_ext_ack *extack)
1038 {
1039         struct ip_set_net *inst = ip_set_pernet(IPSET_SOCK_NET(net, ctnl));
1040         struct ip_set *s;
1041         ip_set_id_t i;
1042         int ret = 0;
1043
1044         if (unlikely(protocol_min_failed(attr)))
1045                 return -IPSET_ERR_PROTOCOL;
1046
1047         /* Must wait for flush to be really finished in list:set */
1048         rcu_barrier();
1049
1050         /* Commands are serialized and references are
1051          * protected by the ip_set_ref_lock.
1052          * External systems (i.e. xt_set) must call
1053          * ip_set_put|get_nfnl_* functions, that way we
1054          * can safely check references here.
1055          *
1056          * list:set timer can only decrement the reference
1057          * counter, so if it's already zero, we can proceed
1058          * without holding the lock.
1059          */
1060         read_lock_bh(&ip_set_ref_lock);
1061         if (!attr[IPSET_ATTR_SETNAME]) {
1062                 for (i = 0; i < inst->ip_set_max; i++) {
1063                         s = ip_set(inst, i);
1064                         if (s && (s->ref || s->ref_netlink)) {
1065                                 ret = -IPSET_ERR_BUSY;
1066                                 goto out;
1067                         }
1068                 }
1069                 inst->is_destroyed = true;
1070                 read_unlock_bh(&ip_set_ref_lock);
1071                 for (i = 0; i < inst->ip_set_max; i++) {
1072                         s = ip_set(inst, i);
1073                         if (s) {
1074                                 ip_set(inst, i) = NULL;
1075                                 ip_set_destroy_set(s);
1076                         }
1077                 }
1078                 /* Modified by ip_set_destroy() only, which is serialized */
1079                 inst->is_destroyed = false;
1080         } else {
1081                 s = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1082                                     &i);
1083                 if (!s) {
1084                         ret = -ENOENT;
1085                         goto out;
1086                 } else if (s->ref || s->ref_netlink) {
1087                         ret = -IPSET_ERR_BUSY;
1088                         goto out;
1089                 }
1090                 ip_set(inst, i) = NULL;
1091                 read_unlock_bh(&ip_set_ref_lock);
1092
1093                 ip_set_destroy_set(s);
1094         }
1095         return 0;
1096 out:
1097         read_unlock_bh(&ip_set_ref_lock);
1098         return ret;
1099 }
1100
1101 /* Flush sets */
1102
1103 static void
1104 ip_set_flush_set(struct ip_set *set)
1105 {
1106         pr_debug("set: %s\n",  set->name);
1107
1108         spin_lock_bh(&set->lock);
1109         set->variant->flush(set);
1110         spin_unlock_bh(&set->lock);
1111 }
1112
1113 static int
1114 IPSET_CBFN(ip_set_flush, struct net *net, struct sock *ctnl,
1115            struct sk_buff *skb, const struct nlmsghdr *nlh,
1116            const struct nlattr * const attr[],
1117            struct netlink_ext_ack *extack)
1118 {
1119         struct ip_set_net *inst = ip_set_pernet(IPSET_SOCK_NET(net, ctnl));
1120         struct ip_set *s;
1121         ip_set_id_t i;
1122
1123         if (unlikely(protocol_min_failed(attr)))
1124                 return -IPSET_ERR_PROTOCOL;
1125
1126         if (!attr[IPSET_ATTR_SETNAME]) {
1127                 for (i = 0; i < inst->ip_set_max; i++) {
1128                         s = ip_set(inst, i);
1129                         if (s)
1130                                 ip_set_flush_set(s);
1131                 }
1132         } else {
1133                 s = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1134                 if (!s)
1135                         return -ENOENT;
1136
1137                 ip_set_flush_set(s);
1138         }
1139
1140         return 0;
1141 }
1142
1143 /* Rename a set */
1144
1145 static const struct nla_policy
1146 ip_set_setname2_policy[IPSET_ATTR_CMD_MAX + 1] = {
1147         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1148         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1149                                     .len = IPSET_MAXNAMELEN - 1 },
1150         [IPSET_ATTR_SETNAME2]   = { .type = NLA_NUL_STRING,
1151                                     .len = IPSET_MAXNAMELEN - 1 },
1152 };
1153
1154 static int
1155 IPSET_CBFN(ip_set_rename, struct net *net, struct sock *ctnl,
1156            struct sk_buff *skb, const struct nlmsghdr *nlh,
1157            const struct nlattr * const attr[],
1158            struct netlink_ext_ack *extack)
1159 {
1160         struct ip_set_net *inst = ip_set_pernet(IPSET_SOCK_NET(net, ctnl));
1161         struct ip_set *set, *s;
1162         const char *name2;
1163         ip_set_id_t i;
1164         int ret = 0;
1165
1166         if (unlikely(protocol_min_failed(attr) ||
1167                      !attr[IPSET_ATTR_SETNAME] ||
1168                      !attr[IPSET_ATTR_SETNAME2]))
1169                 return -IPSET_ERR_PROTOCOL;
1170
1171         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1172         if (!set)
1173                 return -ENOENT;
1174
1175         write_lock_bh(&ip_set_ref_lock);
1176         if (set->ref != 0) {
1177                 ret = -IPSET_ERR_REFERENCED;
1178                 goto out;
1179         }
1180
1181         name2 = nla_data(attr[IPSET_ATTR_SETNAME2]);
1182         for (i = 0; i < inst->ip_set_max; i++) {
1183                 s = ip_set(inst, i);
1184                 if (s && STRNCMP(s->name, name2)) {
1185                         ret = -IPSET_ERR_EXIST_SETNAME2;
1186                         goto out;
1187                 }
1188         }
1189         strncpy(set->name, name2, IPSET_MAXNAMELEN);
1190
1191 out:
1192         write_unlock_bh(&ip_set_ref_lock);
1193         return ret;
1194 }
1195
1196 /* Swap two sets so that name/index points to the other.
1197  * References and set names are also swapped.
1198  *
1199  * The commands are serialized by the nfnl mutex and references are
1200  * protected by the ip_set_ref_lock. The kernel interfaces
1201  * do not hold the mutex but the pointer settings are atomic
1202  * so the ip_set_list always contains valid pointers to the sets.
1203  */
1204
1205 static int
1206 IPSET_CBFN(ip_set_swap, struct net *net, struct sock *ctnl,
1207            struct sk_buff *skb, const struct nlmsghdr *nlh,
1208            const struct nlattr * const attr[],
1209            struct netlink_ext_ack *extack)
1210 {
1211         struct ip_set_net *inst = ip_set_pernet(IPSET_SOCK_NET(net, ctnl));
1212         struct ip_set *from, *to;
1213         ip_set_id_t from_id, to_id;
1214         char from_name[IPSET_MAXNAMELEN];
1215
1216         if (unlikely(protocol_min_failed(attr) ||
1217                      !attr[IPSET_ATTR_SETNAME] ||
1218                      !attr[IPSET_ATTR_SETNAME2]))
1219                 return -IPSET_ERR_PROTOCOL;
1220
1221         from = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]),
1222                                &from_id);
1223         if (!from)
1224                 return -ENOENT;
1225
1226         to = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME2]),
1227                              &to_id);
1228         if (!to)
1229                 return -IPSET_ERR_EXIST_SETNAME2;
1230
1231         /* Features must not change.
1232          * Not an artifical restriction anymore, as we must prevent
1233          * possible loops created by swapping in setlist type of sets.
1234          */
1235         if (!(from->type->features == to->type->features &&
1236               from->family == to->family))
1237                 return -IPSET_ERR_TYPE_MISMATCH;
1238
1239         write_lock_bh(&ip_set_ref_lock);
1240
1241         if (from->ref_netlink || to->ref_netlink) {
1242                 write_unlock_bh(&ip_set_ref_lock);
1243                 return -EBUSY;
1244         }
1245
1246         strncpy(from_name, from->name, IPSET_MAXNAMELEN);
1247         strncpy(from->name, to->name, IPSET_MAXNAMELEN);
1248         strncpy(to->name, from_name, IPSET_MAXNAMELEN);
1249
1250         swap(from->ref, to->ref);
1251         ip_set(inst, from_id) = to;
1252         ip_set(inst, to_id) = from;
1253         write_unlock_bh(&ip_set_ref_lock);
1254
1255         return 0;
1256 }
1257
1258 /* List/save set data */
1259
1260 #define DUMP_INIT       0
1261 #define DUMP_ALL        1
1262 #define DUMP_ONE        2
1263 #define DUMP_LAST       3
1264
1265 #define DUMP_TYPE(arg)          (((u32)(arg)) & 0x0000FFFF)
1266 #define DUMP_FLAGS(arg)         (((u32)(arg)) >> 16)
1267
1268 static int
1269 ip_set_dump_done(struct netlink_callback *cb)
1270 {
1271         if (cb->args[IPSET_CB_ARG0]) {
1272                 struct ip_set_net *inst =
1273                         (struct ip_set_net *)cb->args[IPSET_CB_NET];
1274                 ip_set_id_t index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1275                 struct ip_set *set = ip_set_ref_netlink(inst, index);
1276
1277                 if (set->variant->uref)
1278                         set->variant->uref(set, cb, false);
1279                 pr_debug("release set %s\n", set->name);
1280                 __ip_set_put_netlink(set);
1281         }
1282         return 0;
1283 }
1284
1285 static inline void
1286 dump_attrs(struct nlmsghdr *nlh)
1287 {
1288         const struct nlattr *attr;
1289         int rem;
1290
1291         pr_debug("dump nlmsg\n");
1292         nlmsg_for_each_attr(attr, nlh, sizeof(struct nfgenmsg), rem) {
1293                 pr_debug("type: %u, len %u\n", nla_type(attr), attr->nla_len);
1294         }
1295 }
1296
1297 static int
1298 dump_init(struct netlink_callback *cb, struct ip_set_net *inst)
1299 {
1300         struct nlmsghdr *nlh = nlmsg_hdr(cb->skb);
1301         int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1302         struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1303         struct nlattr *attr = (void *)nlh + min_len;
1304         u32 dump_type;
1305         ip_set_id_t index;
1306
1307         /* Second pass, so parser can't fail */
1308         NLA_PARSE(cda, IPSET_ATTR_CMD_MAX, attr, nlh->nlmsg_len - min_len,
1309                   ip_set_setname_policy, NULL);
1310
1311         cb->args[IPSET_CB_PROTO] = nla_get_u8(cda[IPSET_ATTR_PROTOCOL]);
1312         if (cda[IPSET_ATTR_SETNAME]) {
1313                 struct ip_set *set;
1314
1315                 set = find_set_and_id(inst, nla_data(cda[IPSET_ATTR_SETNAME]),
1316                                       &index);
1317                 if (!set)
1318                         return -ENOENT;
1319
1320                 dump_type = DUMP_ONE;
1321                 cb->args[IPSET_CB_INDEX] = index;
1322         } else {
1323                 dump_type = DUMP_ALL;
1324         }
1325
1326         if (cda[IPSET_ATTR_FLAGS]) {
1327                 u32 f = ip_set_get_h32(cda[IPSET_ATTR_FLAGS]);
1328
1329                 dump_type |= (f << 16);
1330         }
1331         cb->args[IPSET_CB_NET] = (unsigned long)inst;
1332         cb->args[IPSET_CB_DUMP] = dump_type;
1333
1334         return 0;
1335 }
1336
1337 static int
1338 ip_set_dump_start(struct sk_buff *skb, struct netlink_callback *cb)
1339 {
1340         ip_set_id_t index = IPSET_INVALID_ID, max;
1341         struct ip_set *set = NULL;
1342         struct nlmsghdr *nlh = NULL;
1343         unsigned int flags = NETLINK_PORTID(cb->skb) ? NLM_F_MULTI : 0;
1344         struct ip_set_net *inst = ip_set_pernet(sock_net(skb->sk));
1345         u32 dump_type, dump_flags;
1346         bool is_destroyed;
1347         int ret = 0;
1348
1349         if (!cb->args[IPSET_CB_DUMP]) {
1350                 ret = dump_init(cb, inst);
1351                 if (ret < 0) {
1352                         nlh = nlmsg_hdr(cb->skb);
1353                         /* We have to create and send the error message
1354                          * manually :-(
1355                          */
1356                         if (nlh->nlmsg_flags & NLM_F_ACK)
1357                                 NETLINK_ACK(cb->skb, nlh, ret, NULL);
1358                         return ret;
1359                 }
1360         }
1361
1362         if (cb->args[IPSET_CB_INDEX] >= inst->ip_set_max)
1363                 goto out;
1364
1365         dump_type = DUMP_TYPE(cb->args[IPSET_CB_DUMP]);
1366         dump_flags = DUMP_FLAGS(cb->args[IPSET_CB_DUMP]);
1367         max = dump_type == DUMP_ONE ? cb->args[IPSET_CB_INDEX] + 1
1368                                     : inst->ip_set_max;
1369 dump_last:
1370         pr_debug("dump type, flag: %u %u index: %ld\n",
1371                  dump_type, dump_flags, cb->args[IPSET_CB_INDEX]);
1372         for (; cb->args[IPSET_CB_INDEX] < max; cb->args[IPSET_CB_INDEX]++) {
1373                 index = (ip_set_id_t)cb->args[IPSET_CB_INDEX];
1374                 write_lock_bh(&ip_set_ref_lock);
1375                 set = ip_set(inst, index);
1376                 is_destroyed = inst->is_destroyed;
1377                 if (!set || is_destroyed) {
1378                         write_unlock_bh(&ip_set_ref_lock);
1379                         if (dump_type == DUMP_ONE) {
1380                                 ret = -ENOENT;
1381                                 goto out;
1382                         }
1383                         if (is_destroyed) {
1384                                 /* All sets are just being destroyed */
1385                                 ret = 0;
1386                                 goto out;
1387                         }
1388                         continue;
1389                 }
1390                 /* When dumping all sets, we must dump "sorted"
1391                  * so that lists (unions of sets) are dumped last.
1392                  */
1393                 if (dump_type != DUMP_ONE &&
1394                     ((dump_type == DUMP_ALL) ==
1395                      !!(set->type->features & IPSET_DUMP_LAST))) {
1396                         write_unlock_bh(&ip_set_ref_lock);
1397                         continue;
1398                 }
1399                 pr_debug("List set: %s\n", set->name);
1400                 if (!cb->args[IPSET_CB_ARG0]) {
1401                         /* Start listing: make sure set won't be destroyed */
1402                         pr_debug("reference set\n");
1403                         set->ref_netlink++;
1404                 }
1405                 write_unlock_bh(&ip_set_ref_lock);
1406                 nlh = start_msg(skb, NETLINK_PORTID(cb->skb),
1407                                 cb->nlh->nlmsg_seq, flags,
1408                                 IPSET_CMD_LIST);
1409                 if (!nlh) {
1410                         ret = -EMSGSIZE;
1411                         goto release_refcount;
1412                 }
1413                 if (nla_put_u8(skb, IPSET_ATTR_PROTOCOL,
1414                                cb->args[IPSET_CB_PROTO]) ||
1415                     nla_put_string(skb, IPSET_ATTR_SETNAME, set->name))
1416                         goto nla_put_failure;
1417                 if (dump_flags & IPSET_FLAG_LIST_SETNAME)
1418                         goto next_set;
1419                 switch (cb->args[IPSET_CB_ARG0]) {
1420                 case 0:
1421                         /* Core header data */
1422                         if (nla_put_string(skb, IPSET_ATTR_TYPENAME,
1423                                            set->type->name) ||
1424                             nla_put_u8(skb, IPSET_ATTR_FAMILY,
1425                                        set->family) ||
1426                             nla_put_u8(skb, IPSET_ATTR_REVISION,
1427                                        set->revision))
1428                                 goto nla_put_failure;
1429                         if (cb->args[IPSET_CB_PROTO] > IPSET_PROTOCOL_MIN &&
1430                             nla_put_net16(skb, IPSET_ATTR_INDEX, htons(index)))
1431                                 goto nla_put_failure;
1432                         ret = set->variant->head(set, skb);
1433                         if (ret < 0)
1434                                 goto release_refcount;
1435                         if (dump_flags & IPSET_FLAG_LIST_HEADER)
1436                                 goto next_set;
1437                         if (set->variant->uref)
1438                                 set->variant->uref(set, cb, true);
1439                         /* fall through */
1440                 default:
1441                         ret = set->variant->list(set, skb, cb);
1442                         if (!cb->args[IPSET_CB_ARG0])
1443                                 /* Set is done, proceed with next one */
1444                                 goto next_set;
1445                         goto release_refcount;
1446                 }
1447         }
1448         /* If we dump all sets, continue with dumping last ones */
1449         if (dump_type == DUMP_ALL) {
1450                 dump_type = DUMP_LAST;
1451                 cb->args[IPSET_CB_DUMP] = dump_type | (dump_flags << 16);
1452                 cb->args[IPSET_CB_INDEX] = 0;
1453                 if (set && set->variant->uref)
1454                         set->variant->uref(set, cb, false);
1455                 goto dump_last;
1456         }
1457         goto out;
1458
1459 nla_put_failure:
1460         ret = -EFAULT;
1461 next_set:
1462         if (dump_type == DUMP_ONE)
1463                 cb->args[IPSET_CB_INDEX] = IPSET_INVALID_ID;
1464         else
1465                 cb->args[IPSET_CB_INDEX]++;
1466 release_refcount:
1467         /* If there was an error or set is done, release set */
1468         if (ret || !cb->args[IPSET_CB_ARG0]) {
1469                 set = ip_set_ref_netlink(inst, index);
1470                 if (set->variant->uref)
1471                         set->variant->uref(set, cb, false);
1472                 pr_debug("release set %s\n", set->name);
1473                 __ip_set_put_netlink(set);
1474                 cb->args[IPSET_CB_ARG0] = 0;
1475         }
1476 out:
1477         if (nlh) {
1478                 nlmsg_end(skb, nlh);
1479                 pr_debug("nlmsg_len: %u\n", nlh->nlmsg_len);
1480                 dump_attrs(nlh);
1481         }
1482
1483         return ret < 0 ? ret : skb->len;
1484 }
1485
1486 static int
1487 IPSET_CBFN(ip_set_dump, struct net *net, struct sock *ctnl,
1488            struct sk_buff *skb, const struct nlmsghdr *nlh,
1489            const struct nlattr * const attr[],
1490            struct netlink_ext_ack *extack)
1491 {
1492         if (unlikely(protocol_min_failed(attr)))
1493                 return -IPSET_ERR_PROTOCOL;
1494
1495 #if HAVE_NETLINK_DUMP_START_ARGS == 5
1496         return netlink_dump_start(ctnl, skb, nlh,
1497                                   ip_set_dump_start,
1498                                   ip_set_dump_done);
1499 #elif HAVE_NETLINK_DUMP_START_ARGS == 6
1500         return netlink_dump_start(ctnl, skb, nlh,
1501                                   ip_set_dump_start,
1502                                   ip_set_dump_done, 0);
1503 #else
1504         {
1505                 struct netlink_dump_control c = {
1506                         .dump = ip_set_dump_start,
1507                         .done = ip_set_dump_done,
1508                 };
1509                 return netlink_dump_start(ctnl, skb, nlh, &c);
1510         }
1511 #endif
1512 }
1513
1514 /* Add, del and test */
1515
1516 static const struct nla_policy ip_set_adt_policy[IPSET_ATTR_CMD_MAX + 1] = {
1517         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1518         [IPSET_ATTR_SETNAME]    = { .type = NLA_NUL_STRING,
1519                                     .len = IPSET_MAXNAMELEN - 1 },
1520         [IPSET_ATTR_LINENO]     = { .type = NLA_U32 },
1521         [IPSET_ATTR_DATA]       = { .type = NLA_NESTED },
1522         [IPSET_ATTR_ADT]        = { .type = NLA_NESTED },
1523 };
1524
1525 static int
1526 call_ad(struct sock *ctnl, struct sk_buff *skb, struct ip_set *set,
1527         struct nlattr *tb[], enum ipset_adt adt,
1528         u32 flags, bool use_lineno)
1529 {
1530         int ret;
1531         u32 lineno = 0;
1532         bool eexist = flags & IPSET_FLAG_EXIST, retried = false;
1533
1534         do {
1535                 spin_lock_bh(&set->lock);
1536                 ret = set->variant->uadt(set, tb, adt, &lineno, flags, retried);
1537                 spin_unlock_bh(&set->lock);
1538                 retried = true;
1539         } while (ret == -EAGAIN &&
1540                  set->variant->resize &&
1541                  (ret = set->variant->resize(set, retried)) == 0);
1542
1543         if (!ret || (ret == -IPSET_ERR_EXIST && eexist))
1544                 return 0;
1545         if (lineno && use_lineno) {
1546                 /* Error in restore/batch mode: send back lineno */
1547                 struct nlmsghdr *rep, *nlh = nlmsg_hdr(skb);
1548                 struct sk_buff *skb2;
1549                 struct nlmsgerr *errmsg;
1550                 size_t payload = min(SIZE_MAX,
1551                                      sizeof(*errmsg) + nlmsg_len(nlh));
1552                 int min_len = nlmsg_total_size(sizeof(struct nfgenmsg));
1553                 struct nlattr *cda[IPSET_ATTR_CMD_MAX + 1];
1554                 struct nlattr *cmdattr;
1555                 u32 *errline;
1556
1557                 skb2 = nlmsg_new(payload, GFP_KERNEL);
1558                 if (!skb2)
1559                         return -ENOMEM;
1560                 rep = __nlmsg_put(skb2, NETLINK_PORTID(skb),
1561                                   nlh->nlmsg_seq, NLMSG_ERROR, payload, 0);
1562                 errmsg = nlmsg_data(rep);
1563                 errmsg->error = ret;
1564                 memcpy(&errmsg->msg, nlh, nlh->nlmsg_len);
1565                 cmdattr = (void *)&errmsg->msg + min_len;
1566
1567                 NLA_PARSE(cda, IPSET_ATTR_CMD_MAX, cmdattr,
1568                           nlh->nlmsg_len - min_len, ip_set_adt_policy, NULL);
1569
1570                 errline = nla_data(cda[IPSET_ATTR_LINENO]);
1571
1572                 *errline = lineno;
1573
1574                 netlink_unicast(ctnl, skb2, NETLINK_PORTID(skb),
1575                                 MSG_DONTWAIT);
1576                 /* Signal netlink not to send its ACK/errmsg.  */
1577                 return -EINTR;
1578         }
1579
1580         return ret;
1581 }
1582
1583 static int
1584 IPSET_CBFN(ip_set_uadd, struct net *net, struct sock *ctnl,
1585            struct sk_buff *skb, const struct nlmsghdr *nlh,
1586            const struct nlattr * const attr[],
1587            struct netlink_ext_ack *extack)
1588 {
1589         struct ip_set_net *inst = ip_set_pernet(IPSET_SOCK_NET(net, ctnl));
1590         struct ip_set *set;
1591         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1592         const struct nlattr *nla;
1593         u32 flags = flag_exist(nlh);
1594         bool use_lineno;
1595         int ret = 0;
1596
1597         if (unlikely(protocol_min_failed(attr) ||
1598                      !attr[IPSET_ATTR_SETNAME] ||
1599                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1600                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1601                      (attr[IPSET_ATTR_DATA] &&
1602                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1603                      (attr[IPSET_ATTR_ADT] &&
1604                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1605                        !attr[IPSET_ATTR_LINENO]))))
1606                 return -IPSET_ERR_PROTOCOL;
1607
1608         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1609         if (!set)
1610                 return -ENOENT;
1611
1612         use_lineno = !!attr[IPSET_ATTR_LINENO];
1613         if (attr[IPSET_ATTR_DATA]) {
1614                 if (NLA_PARSE_NESTED(tb, IPSET_ATTR_ADT_MAX,
1615                                      attr[IPSET_ATTR_DATA],
1616                                      set->type->adt_policy, NULL))
1617                         return -IPSET_ERR_PROTOCOL;
1618                 ret = call_ad(ctnl, skb, set, tb, IPSET_ADD, flags,
1619                               use_lineno);
1620         } else {
1621                 int nla_rem;
1622
1623                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1624                         memset(tb, 0, sizeof(tb));
1625                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1626                             !flag_nested(nla) ||
1627                             NLA_PARSE_NESTED(tb, IPSET_ATTR_ADT_MAX, nla,
1628                                              set->type->adt_policy, NULL))
1629                                 return -IPSET_ERR_PROTOCOL;
1630                         ret = call_ad(ctnl, skb, set, tb, IPSET_ADD,
1631                                       flags, use_lineno);
1632                         if (ret < 0)
1633                                 return ret;
1634                 }
1635         }
1636         return ret;
1637 }
1638
1639 static int
1640 IPSET_CBFN(ip_set_udel, struct net *net, struct sock *ctnl,
1641            struct sk_buff *skb, const struct nlmsghdr *nlh,
1642            const struct nlattr * const attr[],
1643            struct netlink_ext_ack *extack)
1644 {
1645         struct ip_set_net *inst = ip_set_pernet(IPSET_SOCK_NET(net, ctnl));
1646         struct ip_set *set;
1647         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1648         const struct nlattr *nla;
1649         u32 flags = flag_exist(nlh);
1650         bool use_lineno;
1651         int ret = 0;
1652
1653         if (unlikely(protocol_min_failed(attr) ||
1654                      !attr[IPSET_ATTR_SETNAME] ||
1655                      !((attr[IPSET_ATTR_DATA] != NULL) ^
1656                        (attr[IPSET_ATTR_ADT] != NULL)) ||
1657                      (attr[IPSET_ATTR_DATA] &&
1658                       !flag_nested(attr[IPSET_ATTR_DATA])) ||
1659                      (attr[IPSET_ATTR_ADT] &&
1660                       (!flag_nested(attr[IPSET_ATTR_ADT]) ||
1661                        !attr[IPSET_ATTR_LINENO]))))
1662                 return -IPSET_ERR_PROTOCOL;
1663
1664         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1665         if (!set)
1666                 return -ENOENT;
1667
1668         use_lineno = !!attr[IPSET_ATTR_LINENO];
1669         if (attr[IPSET_ATTR_DATA]) {
1670                 if (NLA_PARSE_NESTED(tb, IPSET_ATTR_ADT_MAX,
1671                                      attr[IPSET_ATTR_DATA],
1672                                      set->type->adt_policy, NULL))
1673                         return -IPSET_ERR_PROTOCOL;
1674                 ret = call_ad(ctnl, skb, set, tb, IPSET_DEL, flags,
1675                               use_lineno);
1676         } else {
1677                 int nla_rem;
1678
1679                 nla_for_each_nested(nla, attr[IPSET_ATTR_ADT], nla_rem) {
1680                         memset(tb, 0, sizeof(*tb));
1681                         if (nla_type(nla) != IPSET_ATTR_DATA ||
1682                             !flag_nested(nla) ||
1683                             NLA_PARSE_NESTED(tb, IPSET_ATTR_ADT_MAX, nla,
1684                                              set->type->adt_policy, NULL))
1685                                 return -IPSET_ERR_PROTOCOL;
1686                         ret = call_ad(ctnl, skb, set, tb, IPSET_DEL,
1687                                       flags, use_lineno);
1688                         if (ret < 0)
1689                                 return ret;
1690                 }
1691         }
1692         return ret;
1693 }
1694
1695 static int
1696 IPSET_CBFN(ip_set_utest, struct net *net, struct sock *ctnl,
1697            struct sk_buff *skb,
1698            const struct nlmsghdr *nlh,
1699            const struct nlattr * const attr[],
1700            struct netlink_ext_ack *extack)
1701 {
1702         struct ip_set_net *inst = ip_set_pernet(IPSET_SOCK_NET(net, ctnl));
1703         struct ip_set *set;
1704         struct nlattr *tb[IPSET_ATTR_ADT_MAX + 1] = {};
1705         int ret = 0;
1706
1707         if (unlikely(protocol_min_failed(attr) ||
1708                      !attr[IPSET_ATTR_SETNAME] ||
1709                      !attr[IPSET_ATTR_DATA] ||
1710                      !flag_nested(attr[IPSET_ATTR_DATA])))
1711                 return -IPSET_ERR_PROTOCOL;
1712
1713         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1714         if (!set)
1715                 return -ENOENT;
1716
1717         if (NLA_PARSE_NESTED(tb, IPSET_ATTR_ADT_MAX, attr[IPSET_ATTR_DATA],
1718                              set->type->adt_policy, NULL))
1719                 return -IPSET_ERR_PROTOCOL;
1720
1721         rcu_read_lock_bh();
1722         ret = set->variant->uadt(set, tb, IPSET_TEST, NULL, 0, 0);
1723         rcu_read_unlock_bh();
1724         /* Userspace can't trigger element to be re-added */
1725         if (ret == -EAGAIN)
1726                 ret = 1;
1727
1728         return ret > 0 ? 0 : -IPSET_ERR_EXIST;
1729 }
1730
1731 /* Get headed data of a set */
1732
1733 static int
1734 IPSET_CBFN(ip_set_header, struct net *net, struct sock *ctnl,
1735            struct sk_buff *skb, const struct nlmsghdr *nlh,
1736            const struct nlattr * const attr[],
1737            struct netlink_ext_ack *extack)
1738 {
1739         struct ip_set_net *inst = ip_set_pernet(IPSET_SOCK_NET(net, ctnl));
1740         const struct ip_set *set;
1741         struct sk_buff *skb2;
1742         struct nlmsghdr *nlh2;
1743         int ret = 0;
1744
1745         if (unlikely(protocol_min_failed(attr) ||
1746                      !attr[IPSET_ATTR_SETNAME]))
1747                 return -IPSET_ERR_PROTOCOL;
1748
1749         set = find_set(inst, nla_data(attr[IPSET_ATTR_SETNAME]));
1750         if (!set)
1751                 return -ENOENT;
1752
1753         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1754         if (!skb2)
1755                 return -ENOMEM;
1756
1757         nlh2 = start_msg(skb2, NETLINK_PORTID(skb), nlh->nlmsg_seq, 0,
1758                          IPSET_CMD_HEADER);
1759         if (!nlh2)
1760                 goto nlmsg_failure;
1761         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1762             nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name) ||
1763             nla_put_string(skb2, IPSET_ATTR_TYPENAME, set->type->name) ||
1764             nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1765             nla_put_u8(skb2, IPSET_ATTR_REVISION, set->revision))
1766                 goto nla_put_failure;
1767         nlmsg_end(skb2, nlh2);
1768
1769         ret = netlink_unicast(ctnl, skb2, NETLINK_PORTID(skb), MSG_DONTWAIT);
1770         if (ret < 0)
1771                 return ret;
1772
1773         return 0;
1774
1775 nla_put_failure:
1776         nlmsg_cancel(skb2, nlh2);
1777 nlmsg_failure:
1778         kfree_skb(skb2);
1779         return -EMSGSIZE;
1780 }
1781
1782 /* Get type data */
1783
1784 static const struct nla_policy ip_set_type_policy[IPSET_ATTR_CMD_MAX + 1] = {
1785         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1786         [IPSET_ATTR_TYPENAME]   = { .type = NLA_NUL_STRING,
1787                                     .len = IPSET_MAXNAMELEN - 1 },
1788         [IPSET_ATTR_FAMILY]     = { .type = NLA_U8 },
1789 };
1790
1791 static int
1792 IPSET_CBFN(ip_set_type, struct net *net, struct sock *ctnl,
1793            struct sk_buff *skb, const struct nlmsghdr *nlh,
1794            const struct nlattr * const attr[],
1795            struct netlink_ext_ack *extack)
1796 {
1797         struct sk_buff *skb2;
1798         struct nlmsghdr *nlh2;
1799         u8 family, min, max;
1800         const char *typename;
1801         int ret = 0;
1802
1803         if (unlikely(protocol_min_failed(attr) ||
1804                      !attr[IPSET_ATTR_TYPENAME] ||
1805                      !attr[IPSET_ATTR_FAMILY]))
1806                 return -IPSET_ERR_PROTOCOL;
1807
1808         family = nla_get_u8(attr[IPSET_ATTR_FAMILY]);
1809         typename = nla_data(attr[IPSET_ATTR_TYPENAME]);
1810         ret = find_set_type_minmax(typename, family, &min, &max);
1811         if (ret)
1812                 return ret;
1813
1814         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1815         if (!skb2)
1816                 return -ENOMEM;
1817
1818         nlh2 = start_msg(skb2, NETLINK_PORTID(skb), nlh->nlmsg_seq, 0,
1819                          IPSET_CMD_TYPE);
1820         if (!nlh2)
1821                 goto nlmsg_failure;
1822         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1823             nla_put_string(skb2, IPSET_ATTR_TYPENAME, typename) ||
1824             nla_put_u8(skb2, IPSET_ATTR_FAMILY, family) ||
1825             nla_put_u8(skb2, IPSET_ATTR_REVISION, max) ||
1826             nla_put_u8(skb2, IPSET_ATTR_REVISION_MIN, min))
1827                 goto nla_put_failure;
1828         nlmsg_end(skb2, nlh2);
1829
1830         pr_debug("Send TYPE, nlmsg_len: %u\n", nlh2->nlmsg_len);
1831         ret = netlink_unicast(ctnl, skb2, NETLINK_PORTID(skb), MSG_DONTWAIT);
1832         if (ret < 0)
1833                 return ret;
1834
1835         return 0;
1836
1837 nla_put_failure:
1838         nlmsg_cancel(skb2, nlh2);
1839 nlmsg_failure:
1840         kfree_skb(skb2);
1841         return -EMSGSIZE;
1842 }
1843
1844 /* Get protocol version */
1845
1846 static const struct nla_policy
1847 ip_set_protocol_policy[IPSET_ATTR_CMD_MAX + 1] = {
1848         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1849 };
1850
1851 static int
1852 IPSET_CBFN(ip_set_protocol, struct net *net, struct sock *ctnl,
1853            struct sk_buff *skb, const struct nlmsghdr *nlh,
1854            const struct nlattr * const attr[],
1855            struct netlink_ext_ack *extack)
1856 {
1857         struct sk_buff *skb2;
1858         struct nlmsghdr *nlh2;
1859         int ret = 0;
1860
1861         if (unlikely(!attr[IPSET_ATTR_PROTOCOL]))
1862                 return -IPSET_ERR_PROTOCOL;
1863
1864         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1865         if (!skb2)
1866                 return -ENOMEM;
1867
1868         nlh2 = start_msg(skb2, NETLINK_PORTID(skb), nlh->nlmsg_seq, 0,
1869                          IPSET_CMD_PROTOCOL);
1870         if (!nlh2)
1871                 goto nlmsg_failure;
1872         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL))
1873                 goto nla_put_failure;
1874         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL_MIN, IPSET_PROTOCOL_MIN))
1875                 goto nla_put_failure;
1876         nlmsg_end(skb2, nlh2);
1877
1878         ret = netlink_unicast(ctnl, skb2, NETLINK_PORTID(skb), MSG_DONTWAIT);
1879         if (ret < 0)
1880                 return ret;
1881
1882         return 0;
1883
1884 nla_put_failure:
1885         nlmsg_cancel(skb2, nlh2);
1886 nlmsg_failure:
1887         kfree_skb(skb2);
1888         return -EMSGSIZE;
1889 }
1890
1891 /* Get set by name or index, from userspace */
1892
1893 static int
1894 IPSET_CBFN(ip_set_byname, struct net *net, struct sock *ctnl,
1895            struct sk_buff *skb, const struct nlmsghdr *nlh,
1896            const struct nlattr * const attr[],
1897            struct netlink_ext_ack *extack)
1898 {
1899         struct ip_set_net *inst = ip_set_pernet(IPSET_SOCK_NET(net, ctnl));
1900         struct sk_buff *skb2;
1901         struct nlmsghdr *nlh2;
1902         ip_set_id_t id = IPSET_INVALID_ID;
1903         const struct ip_set *set;
1904         int ret = 0;
1905
1906         if (unlikely(protocol_failed(attr) ||
1907                      !attr[IPSET_ATTR_SETNAME]))
1908                 return -IPSET_ERR_PROTOCOL;
1909
1910         set = find_set_and_id(inst, nla_data(attr[IPSET_ATTR_SETNAME]), &id);
1911         if (id == IPSET_INVALID_ID)
1912                 return -ENOENT;
1913
1914         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1915         if (!skb2)
1916                 return -ENOMEM;
1917
1918         nlh2 = start_msg(skb2, NETLINK_PORTID(skb), nlh->nlmsg_seq, 0,
1919                          IPSET_CMD_GET_BYNAME);
1920         if (!nlh2)
1921                 goto nlmsg_failure;
1922         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1923             nla_put_u8(skb2, IPSET_ATTR_FAMILY, set->family) ||
1924             nla_put_net16(skb2, IPSET_ATTR_INDEX, htons(id)))
1925                 goto nla_put_failure;
1926         nlmsg_end(skb2, nlh2);
1927
1928         ret = netlink_unicast(ctnl, skb2, NETLINK_PORTID(skb), MSG_DONTWAIT);
1929         if (ret < 0)
1930                 return ret;
1931
1932         return 0;
1933
1934 nla_put_failure:
1935         nlmsg_cancel(skb2, nlh2);
1936 nlmsg_failure:
1937         kfree_skb(skb2);
1938         return -EMSGSIZE;
1939 }
1940
1941 static const struct nla_policy ip_set_index_policy[IPSET_ATTR_CMD_MAX + 1] = {
1942         [IPSET_ATTR_PROTOCOL]   = { .type = NLA_U8 },
1943         [IPSET_ATTR_INDEX]      = { .type = NLA_U16 },
1944 };
1945
1946 static int
1947 IPSET_CBFN(ip_set_byindex, struct net *net, struct sock *ctnl,
1948            struct sk_buff *skb, const struct nlmsghdr *nlh,
1949            const struct nlattr * const attr[],
1950            struct netlink_ext_ack *extack)
1951 {
1952         struct ip_set_net *inst = ip_set_pernet(IPSET_SOCK_NET(net, ctnl));
1953         struct sk_buff *skb2;
1954         struct nlmsghdr *nlh2;
1955         ip_set_id_t id = IPSET_INVALID_ID;
1956         const struct ip_set *set;
1957         int ret = 0;
1958
1959         if (unlikely(protocol_failed(attr) ||
1960                      !attr[IPSET_ATTR_INDEX]))
1961                 return -IPSET_ERR_PROTOCOL;
1962
1963         id = ip_set_get_h16(attr[IPSET_ATTR_INDEX]);
1964         if (id >= inst->ip_set_max)
1965                 return -ENOENT;
1966         set = ip_set(inst, id);
1967         if (set == NULL)
1968                 return -ENOENT;
1969
1970         skb2 = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1971         if (!skb2)
1972                 return -ENOMEM;
1973
1974         nlh2 = start_msg(skb2, NETLINK_PORTID(skb), nlh->nlmsg_seq, 0,
1975                          IPSET_CMD_GET_BYINDEX);
1976         if (!nlh2)
1977                 goto nlmsg_failure;
1978         if (nla_put_u8(skb2, IPSET_ATTR_PROTOCOL, protocol(attr)) ||
1979             nla_put_string(skb2, IPSET_ATTR_SETNAME, set->name))
1980                 goto nla_put_failure;
1981         nlmsg_end(skb2, nlh2);
1982
1983         ret = netlink_unicast(ctnl, skb2, NETLINK_PORTID(skb), MSG_DONTWAIT);
1984         if (ret < 0)
1985                 return ret;
1986
1987         return 0;
1988
1989 nla_put_failure:
1990         nlmsg_cancel(skb2, nlh2);
1991 nlmsg_failure:
1992         kfree_skb(skb2);
1993         return -EMSGSIZE;
1994 }
1995
1996 static const struct nfnl_callback ip_set_netlink_subsys_cb[IPSET_MSG_MAX] = {
1997         [IPSET_CMD_NONE]        = {
1998                 .call           = ip_set_none,
1999                 .attr_count     = IPSET_ATTR_CMD_MAX,
2000         },
2001         [IPSET_CMD_CREATE]      = {
2002                 .call           = ip_set_create,
2003                 .attr_count     = IPSET_ATTR_CMD_MAX,
2004                 .policy         = ip_set_create_policy,
2005         },
2006         [IPSET_CMD_DESTROY]     = {
2007                 .call           = ip_set_destroy,
2008                 .attr_count     = IPSET_ATTR_CMD_MAX,
2009                 .policy         = ip_set_setname_policy,
2010         },
2011         [IPSET_CMD_FLUSH]       = {
2012                 .call           = ip_set_flush,
2013                 .attr_count     = IPSET_ATTR_CMD_MAX,
2014                 .policy         = ip_set_setname_policy,
2015         },
2016         [IPSET_CMD_RENAME]      = {
2017                 .call           = ip_set_rename,
2018                 .attr_count     = IPSET_ATTR_CMD_MAX,
2019                 .policy         = ip_set_setname2_policy,
2020         },
2021         [IPSET_CMD_SWAP]        = {
2022                 .call           = ip_set_swap,
2023                 .attr_count     = IPSET_ATTR_CMD_MAX,
2024                 .policy         = ip_set_setname2_policy,
2025         },
2026         [IPSET_CMD_LIST]        = {
2027                 .call           = ip_set_dump,
2028                 .attr_count     = IPSET_ATTR_CMD_MAX,
2029                 .policy         = ip_set_setname_policy,
2030         },
2031         [IPSET_CMD_SAVE]        = {
2032                 .call           = ip_set_dump,
2033                 .attr_count     = IPSET_ATTR_CMD_MAX,
2034                 .policy         = ip_set_setname_policy,
2035         },
2036         [IPSET_CMD_ADD] = {
2037                 .call           = ip_set_uadd,
2038                 .attr_count     = IPSET_ATTR_CMD_MAX,
2039                 .policy         = ip_set_adt_policy,
2040         },
2041         [IPSET_CMD_DEL] = {
2042                 .call           = ip_set_udel,
2043                 .attr_count     = IPSET_ATTR_CMD_MAX,
2044                 .policy         = ip_set_adt_policy,
2045         },
2046         [IPSET_CMD_TEST]        = {
2047                 .call           = ip_set_utest,
2048                 .attr_count     = IPSET_ATTR_CMD_MAX,
2049                 .policy         = ip_set_adt_policy,
2050         },
2051         [IPSET_CMD_HEADER]      = {
2052                 .call           = ip_set_header,
2053                 .attr_count     = IPSET_ATTR_CMD_MAX,
2054                 .policy         = ip_set_setname_policy,
2055         },
2056         [IPSET_CMD_TYPE]        = {
2057                 .call           = ip_set_type,
2058                 .attr_count     = IPSET_ATTR_CMD_MAX,
2059                 .policy         = ip_set_type_policy,
2060         },
2061         [IPSET_CMD_PROTOCOL]    = {
2062                 .call           = ip_set_protocol,
2063                 .attr_count     = IPSET_ATTR_CMD_MAX,
2064                 .policy         = ip_set_protocol_policy,
2065         },
2066         [IPSET_CMD_GET_BYNAME]  = {
2067                 .call           = ip_set_byname,
2068                 .attr_count     = IPSET_ATTR_CMD_MAX,
2069                 .policy         = ip_set_setname_policy,
2070         },
2071         [IPSET_CMD_GET_BYINDEX] = {
2072                 .call           = ip_set_byindex,
2073                 .attr_count     = IPSET_ATTR_CMD_MAX,
2074                 .policy         = ip_set_index_policy,
2075         },
2076 };
2077
2078 static struct nfnetlink_subsystem ip_set_netlink_subsys __read_mostly = {
2079         .name           = "ip_set",
2080         .subsys_id      = NFNL_SUBSYS_IPSET,
2081         .cb_count       = IPSET_MSG_MAX,
2082         .cb             = ip_set_netlink_subsys_cb,
2083 };
2084
2085 /* Interface to iptables/ip6tables */
2086
2087 static int
2088 ip_set_sockfn_get(struct sock *sk, int optval, void __user *user, int *len)
2089 {
2090         unsigned int *op;
2091         void *data;
2092         int copylen = *len, ret = 0;
2093         struct net *net = sock_net(sk);
2094         struct ip_set_net *inst = ip_set_pernet(net);
2095
2096         if (!ns_capable(net->user_ns, CAP_NET_ADMIN))
2097                 return -EPERM;
2098         if (optval != SO_IP_SET)
2099                 return -EBADF;
2100         if (*len < sizeof(unsigned int))
2101                 return -EINVAL;
2102
2103         data = vmalloc(*len);
2104         if (!data)
2105                 return -ENOMEM;
2106         if (copy_from_user(data, user, *len) != 0) {
2107                 ret = -EFAULT;
2108                 goto done;
2109         }
2110         op = data;
2111
2112         if (*op < IP_SET_OP_VERSION) {
2113                 /* Check the version at the beginning of operations */
2114                 struct ip_set_req_version *req_version = data;
2115
2116                 if (*len < sizeof(struct ip_set_req_version)) {
2117                         ret = -EINVAL;
2118                         goto done;
2119                 }
2120
2121                 if (req_version->version < IPSET_PROTOCOL_MIN) {
2122                         ret = -EPROTO;
2123                         goto done;
2124                 }
2125         }
2126
2127         switch (*op) {
2128         case IP_SET_OP_VERSION: {
2129                 struct ip_set_req_version *req_version = data;
2130
2131                 if (*len != sizeof(struct ip_set_req_version)) {
2132                         ret = -EINVAL;
2133                         goto done;
2134                 }
2135
2136                 req_version->version = IPSET_PROTOCOL;
2137                 ret = copy_to_user(user, req_version,
2138                                    sizeof(struct ip_set_req_version));
2139                 goto done;
2140         }
2141         case IP_SET_OP_GET_BYNAME: {
2142                 struct ip_set_req_get_set *req_get = data;
2143                 ip_set_id_t id;
2144
2145                 if (*len != sizeof(struct ip_set_req_get_set)) {
2146                         ret = -EINVAL;
2147                         goto done;
2148                 }
2149                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2150                 nfnl_lock(NFNL_SUBSYS_IPSET);
2151                 find_set_and_id(inst, req_get->set.name, &id);
2152                 req_get->set.index = id;
2153                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2154                 goto copy;
2155         }
2156         case IP_SET_OP_GET_FNAME: {
2157                 struct ip_set_req_get_set_family *req_get = data;
2158                 ip_set_id_t id;
2159
2160                 if (*len != sizeof(struct ip_set_req_get_set_family)) {
2161                         ret = -EINVAL;
2162                         goto done;
2163                 }
2164                 req_get->set.name[IPSET_MAXNAMELEN - 1] = '\0';
2165                 nfnl_lock(NFNL_SUBSYS_IPSET);
2166                 find_set_and_id(inst, req_get->set.name, &id);
2167                 req_get->set.index = id;
2168                 if (id != IPSET_INVALID_ID)
2169                         req_get->family = ip_set(inst, id)->family;
2170                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2171                 goto copy;
2172         }
2173         case IP_SET_OP_GET_BYINDEX: {
2174                 struct ip_set_req_get_set *req_get = data;
2175                 struct ip_set *set;
2176
2177                 if (*len != sizeof(struct ip_set_req_get_set) ||
2178                     req_get->set.index >= inst->ip_set_max) {
2179                         ret = -EINVAL;
2180                         goto done;
2181                 }
2182                 nfnl_lock(NFNL_SUBSYS_IPSET);
2183                 set = ip_set(inst, req_get->set.index);
2184                 ret = strscpy(req_get->set.name, set ? set->name : "",
2185                               IPSET_MAXNAMELEN);
2186                 nfnl_unlock(NFNL_SUBSYS_IPSET);
2187                 if (ret < 0)
2188                         goto done;
2189                 goto copy;
2190         }
2191         default:
2192                 ret = -EBADMSG;
2193                 goto done;
2194         }       /* end of switch(op) */
2195
2196 copy:
2197         ret = copy_to_user(user, data, copylen);
2198
2199 done:
2200         vfree(data);
2201         if (ret > 0)
2202                 ret = 0;
2203         return ret;
2204 }
2205
2206 static struct nf_sockopt_ops so_set __read_mostly = {
2207         .pf             = PF_INET,
2208         .get_optmin     = SO_IP_SET,
2209         .get_optmax     = SO_IP_SET + 1,
2210         .get            = ip_set_sockfn_get,
2211         .owner          = THIS_MODULE,
2212 };
2213
2214 static int __net_init
2215 ip_set_net_init(struct net *net)
2216 {
2217         struct ip_set_net *inst;
2218         struct ip_set **list;
2219
2220 #ifdef HAVE_NET_OPS_ID
2221         inst = ip_set_pernet(net);
2222 #else
2223         int err;
2224
2225         inst = kzalloc(sizeof(struct ip_set_net), GFP_KERNEL);
2226         if (!inst)
2227                 return -ENOMEM;
2228         err = net_assign_generic(net, ip_set_net_id, inst);
2229         if (err < 0)
2230                 goto err_alloc;
2231 #endif
2232         inst->ip_set_max = max_sets ? max_sets : CONFIG_IP_SET_MAX;
2233         if (inst->ip_set_max >= IPSET_INVALID_ID)
2234                 inst->ip_set_max = IPSET_INVALID_ID - 1;
2235
2236         list = kvcalloc(inst->ip_set_max, sizeof(struct ip_set *), GFP_KERNEL);
2237         if (!list)
2238 #ifdef HAVE_NET_OPS_ID
2239                 return -ENOMEM;
2240 #else
2241                 goto err_alloc;
2242 #endif
2243         inst->is_deleted = false;
2244         inst->is_destroyed = false;
2245         rcu_assign_pointer(inst->ip_set_list, list);
2246         return 0;
2247
2248 #ifndef HAVE_NET_OPS_ID
2249 err_alloc:
2250         kfree(inst);
2251         return err;
2252 #endif
2253 }
2254
2255 static void __net_exit
2256 ip_set_net_exit(struct net *net)
2257 {
2258         struct ip_set_net *inst = ip_set_pernet(net);
2259
2260         struct ip_set *set = NULL;
2261         ip_set_id_t i;
2262
2263         inst->is_deleted = true; /* flag for ip_set_nfnl_put */
2264
2265         nfnl_lock(NFNL_SUBSYS_IPSET);
2266         for (i = 0; i < inst->ip_set_max; i++) {
2267                 set = ip_set(inst, i);
2268                 if (set) {
2269                         ip_set(inst, i) = NULL;
2270                         ip_set_destroy_set(set);
2271                 }
2272         }
2273         nfnl_unlock(NFNL_SUBSYS_IPSET);
2274         kvfree(rcu_dereference_protected(inst->ip_set_list, 1));
2275 #ifndef HAVE_NET_OPS_ID
2276         kvfree(inst);
2277 #endif
2278 }
2279
2280 static struct pernet_operations ip_set_net_ops = {
2281         .init   = ip_set_net_init,
2282         .exit   = ip_set_net_exit,
2283 #ifdef HAVE_NET_OPS_ID
2284         .id     = &ip_set_net_id,
2285         .size   = sizeof(struct ip_set_net),
2286 #ifdef HAVE_NET_OPS_ASYNC
2287         .async  = true,
2288 #endif
2289 #endif
2290 };
2291
2292 #ifdef HAVE_NET_OPS_ID
2293 #define REGISTER_PERNET_SUBSYS(s) \
2294         register_pernet_subsys(s)
2295 #define UNREGISTER_PERNET_SUBSYS(s) \
2296         unregister_pernet_subsys(s);
2297 #else
2298 #define REGISTER_PERNET_SUBSYS(s) \
2299         register_pernet_gen_device(&ip_set_net_id, s)
2300 #define UNREGISTER_PERNET_SUBSYS(s) \
2301         unregister_pernet_gen_device(ip_set_net_id, s);
2302 #endif
2303
2304
2305 static int __init
2306 ip_set_init(void)
2307 {
2308         int ret = REGISTER_PERNET_SUBSYS(&ip_set_net_ops);
2309
2310         if (ret) {
2311                 pr_err("ip_set: cannot register pernet_subsys.\n");
2312                 return ret;
2313         }
2314
2315         ret = nfnetlink_subsys_register(&ip_set_netlink_subsys);
2316         if (ret != 0) {
2317                 pr_err("ip_set: cannot register with nfnetlink.\n");
2318                 UNREGISTER_PERNET_SUBSYS(&ip_set_net_ops);
2319                 return ret;
2320         }
2321
2322         ret = nf_register_sockopt(&so_set);
2323         if (ret != 0) {
2324                 pr_err("SO_SET registry failed: %d\n", ret);
2325                 nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2326                 UNREGISTER_PERNET_SUBSYS(&ip_set_net_ops);
2327                 return ret;
2328         }
2329
2330         return 0;
2331 }
2332
2333 static void __exit
2334 ip_set_fini(void)
2335 {
2336         nf_unregister_sockopt(&so_set);
2337         nfnetlink_subsys_unregister(&ip_set_netlink_subsys);
2338
2339         UNREGISTER_PERNET_SUBSYS(&ip_set_net_ops);
2340         pr_debug("these are the famous last words\n");
2341 }
2342
2343 module_init(ip_set_init);
2344 module_exit(ip_set_fini);