]> granicus.if.org Git - pdns/commitdiff
dnsdist: reduce resprulactions/cachehitresprulactions code deuplication
authorChris Hofstaedtler <chris.hofstaedtler@deduktiva.com>
Mon, 15 Jan 2018 20:11:29 +0000 (21:11 +0100)
committerChris Hofstaedtler <chris.hofstaedtler@deduktiva.com>
Mon, 22 Jan 2018 14:43:30 +0000 (15:43 +0100)
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-rules.cc
pdns/dnsdist-web.cc
regression-tests.dnsdist/test_API.py

index abc618a8aae4811381a62be58aef71941753d650..9b86a89974abd27b3b3de2058345a711202250f0 100644 (file)
@@ -808,6 +808,19 @@ private:
   std::string d_value;
 };
 
+template<typename T, typename ActionT>
+static void addAction(GlobalStateHolder<vector<T> > *someRulActions, luadnsrule_t var, std::shared_ptr<ActionT> action, boost::optional<luaruleparams_t> params) {
+  setLuaSideEffect();
+
+  boost::uuids::uuid uuid;
+  parseRuleParams(params, uuid);
+
+  auto rule=makeRule(var);
+  someRulActions->modify([rule, action, uuid](vector<T>& rulactions){
+      rulactions.push_back({rule, action, uuid});
+    });
+}
+
 void setupLuaActions()
 {
   g_lua.writeFunction("newRuleAction", [](luadnsrule_t dnsrule, std::shared_ptr<DNSAction> action, boost::optional<luaruleparams_t> params) {
@@ -820,76 +833,35 @@ void setupLuaActions()
     });
 
   g_lua.writeFunction("addAction", [](luadnsrule_t var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction> > era, boost::optional<luaruleparams_t> params) {
-      if (era.type() == typeid(std::shared_ptr<DNSResponseAction>)) {
+      if (era.type() != typeid(std::shared_ptr<DNSAction>)) {
         throw std::runtime_error("addAction() can only be called with query-related actions, not response-related ones. Are you looking for addResponseAction()?");
       }
 
-      boost::uuids::uuid uuid;
-      parseRuleParams(params, uuid);
-
-      auto ea = *boost::get<std::shared_ptr<DNSAction>>(&era);
-      setLuaSideEffect();
-      auto rule=makeRule(var);
-      g_rulactions.modify([rule, ea, uuid](decltype(g_rulactions)::value_type& rulactions){
-          rulactions.push_back({rule, ea, uuid});
-        });
+      addAction(&g_rulactions, var, boost::get<std::shared_ptr<DNSAction> >(era), params);
     });
 
   g_lua.writeFunction("addLuaAction", [](luadnsrule_t var, LuaAction::func_t func, boost::optional<luaruleparams_t> params) {
-      setLuaSideEffect();
-
-      boost::uuids::uuid uuid;
-      parseRuleParams(params, uuid);
-
-      auto rule=makeRule(var);
-      g_rulactions.modify([rule, func, uuid](decltype(g_rulactions)::value_type& rulactions){
-          rulactions.push_back({rule, std::make_shared<LuaAction>(func), uuid});
-        });
+      addAction(&g_rulactions, var, std::make_shared<LuaAction>(func), params);
     });
 
   g_lua.writeFunction("addLuaResponseAction", [](luadnsrule_t var, LuaResponseAction::func_t func, boost::optional<luaruleparams_t> params) {
-      setLuaSideEffect();
-
-      boost::uuids::uuid uuid;
-      parseRuleParams(params, uuid);
-
-      auto rule=makeRule(var);
-      g_resprulactions.modify([rule, func, uuid](decltype(g_resprulactions)::value_type& rulactions){
-          rulactions.push_back({rule, std::make_shared<LuaResponseAction>(func), uuid});
-        });
+      addAction(&g_resprulactions, var, std::make_shared<LuaResponseAction>(func), params);
     });
 
   g_lua.writeFunction("addResponseAction", [](luadnsrule_t var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction> > era, boost::optional<luaruleparams_t> params) {
-      if (era.type() == typeid(std::shared_ptr<DNSAction>)) {
+      if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) {
         throw std::runtime_error("addResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?");
       }
 
-      auto ea = *boost::get<std::shared_ptr<DNSResponseAction>>(&era);
-      boost::uuids::uuid uuid;
-      parseRuleParams(params, uuid);
-
-      setLuaSideEffect();
-      auto rule=makeRule(var);
-      g_resprulactions.modify([rule, ea, uuid](decltype(g_resprulactions)::value_type& rulactions){
-          rulactions.push_back({rule, ea, uuid});
-        });
+      addAction(&g_resprulactions, var, boost::get<std::shared_ptr<DNSResponseAction> >(era), params);
     });
 
   g_lua.writeFunction("addCacheHitResponseAction", [](luadnsrule_t var, boost::variant<std::shared_ptr<DNSAction>, std::shared_ptr<DNSResponseAction>> era, boost::optional<luaruleparams_t> params) {
-      if (era.type() == typeid(std::shared_ptr<DNSAction>)) {
+      if (era.type() != typeid(std::shared_ptr<DNSResponseAction>)) {
         throw std::runtime_error("addCacheHitResponseAction() can only be called with response-related actions, not query-related ones. Are you looking for addAction()?");
       }
 
-      setLuaSideEffect();
-      auto rule=makeRule(var);
-
-      boost::uuids::uuid uuid;
-      parseRuleParams(params, uuid);
-
-      auto ea = *boost::get<std::shared_ptr<DNSResponseAction>>(&era);
-      g_cachehitresprulactions.modify([rule, ea, uuid](decltype(g_cachehitresprulactions)::value_type& rulactions){
-          rulactions.push_back({rule, ea, uuid});
-        });
+      addAction(&g_cachehitresprulactions, var, boost::get<std::shared_ptr<DNSResponseAction> >(era), params);
     });
 
   g_lua.registerFunction<void(DNSAction::*)()>("printStats", [](const DNSAction& ta) {
index f2b83662035336d8f2bd279582703babd2d78775..4e2e53c8cf58c2a664bab9d2ec1c37c97549fe45 100644 (file)
@@ -915,6 +915,87 @@ void parseRuleParams(boost::optional<luaruleparams_t> params, boost::uuids::uuid
   uuid = makeRuleID(uuidStr);
 }
 
+template<typename T>
+static void showRules(GlobalStateHolder<vector<T> > *someRulActions, boost::optional<bool> showUUIDs) {
+  setLuaNoSideEffect();
+  int num=0;
+  if (showUUIDs.get_value_or(false)) {
+    boost::format fmt("%-3d %-38s %9d %-56s %s\n");
+    g_outputBuffer += (fmt % "#" % "UUID" % "Matches" % "Rule" % "Action").str();
+    for(const auto& lim : someRulActions->getCopy()) {
+      string name = lim.d_rule->toString();
+      g_outputBuffer += (fmt % num % boost::uuids::to_string(lim.d_id) % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
+      ++num;
+    }
+  }
+  else {
+    boost::format fmt("%-3d %9d %-56s %s\n");
+    g_outputBuffer += (fmt % "#" % "Matches" % "Rule" % "Action").str();
+    for(const auto& lim : someRulActions->getCopy()) {
+      string name = lim.d_rule->toString();
+      g_outputBuffer += (fmt % num % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
+      ++num;
+    }
+  }
+}
+
+template<typename T>
+static void rmRule(GlobalStateHolder<vector<T> > *someRulActions, boost::variant<unsigned int, std::string> id) {
+  setLuaSideEffect();
+  auto rules = someRulActions->getCopy();
+  if (auto str = boost::get<std::string>(&id)) {
+    boost::uuids::string_generator gen;
+    const auto uuid = gen(*str);
+    if (rules.erase(std::remove_if(rules.begin(),
+                                    rules.end(),
+                                    [uuid](const T& a) { return a.d_id == uuid; }),
+                    rules.end()) == rules.end()) {
+      g_outputBuffer = "Error: no rule matched\n";
+      return;
+    }
+  }
+  else if (auto pos = boost::get<unsigned int>(&id)) {
+    if (*pos >= rules.size()) {
+      g_outputBuffer = "Error: attempt to delete non-existing rule\n";
+      return;
+    }
+    rules.erase(rules.begin()+*pos);
+  }
+  someRulActions->setState(rules);
+}
+
+template<typename T>
+static void topRule(GlobalStateHolder<vector<T> > *someRulActions) {
+  setLuaSideEffect();
+  auto rules = someRulActions->getCopy();
+  if(rules.empty())
+    return;
+  auto subject = *rules.rbegin();
+  rules.erase(std::prev(rules.end()));
+  rules.insert(rules.begin(), subject);
+  someRulActions->setState(rules);
+}
+
+template<typename T>
+static void mvRule(GlobalStateHolder<vector<T> > *someRespRulActions, unsigned int from, unsigned int to) {
+  setLuaSideEffect();
+  auto rules = someRespRulActions->getCopy();
+  if(from >= rules.size() || to > rules.size()) {
+    g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
+    return;
+  }
+  auto subject = rules[from];
+  rules.erase(rules.begin()+from);
+  if(to == rules.size())
+    rules.push_back(subject);
+  else {
+    if(from < to)
+      --to;
+    rules.insert(rules.begin()+to, subject);
+  }
+  someRespRulActions->setState(rules);
+}
+
 void setupLuaRules()
 {
   g_lua.writeFunction("makeRule", makeRule);
@@ -922,209 +1003,47 @@ void setupLuaRules()
   g_lua.registerFunction<string(std::shared_ptr<DNSRule>::*)()>("toString", [](const std::shared_ptr<DNSRule>& rule) { return rule->toString(); });
 
   g_lua.writeFunction("showResponseRules", [](boost::optional<bool> showUUIDs) {
-      setLuaNoSideEffect();
-      int num=0;
-      if (showUUIDs.get_value_or(false)) {
-        boost::format fmt("%-3d %-38s %9d %-50s %s\n");
-        g_outputBuffer += (fmt % "#" % "UUID" % "Matches" % "Rule" % "Action").str();
-        for(const auto& lim : g_resprulactions.getCopy()) {
-          string name = lim.d_rule->toString();
-          g_outputBuffer += (fmt % num % boost::uuids::to_string(lim.d_id) % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
-          ++num;
-        }
-      }
-      else {
-        boost::format fmt("%-3d %9d %-50s %s\n");
-        g_outputBuffer += (fmt % "#" % "Matches" % "Rule" % "Action").str();
-        for(const auto& lim : g_resprulactions.getCopy()) {
-          string name = lim.d_rule->toString();
-          g_outputBuffer += (fmt % num % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
-          ++num;
-        }
-      }
+      showRules(&g_resprulactions, showUUIDs);
     });
 
   g_lua.writeFunction("rmResponseRule", [](boost::variant<unsigned int, std::string> id) {
-      setLuaSideEffect();
-      auto rules = g_resprulactions.getCopy();
-      if (auto str = boost::get<std::string>(&id)) {
-        boost::uuids::string_generator gen;
-        const auto uuid = gen(*str);
-        if (rules.erase(std::remove_if(rules.begin(),
-                                       rules.end(),
-                                       [uuid](const DNSDistResponseRuleAction& a) { return a.d_id == uuid; }),
-                        rules.end()) == rules.end()) {
-          g_outputBuffer = "Error: no rule matched\n";
-        }
-      }
-      else if (auto pos = boost::get<unsigned int>(&id)) {
-        if (*pos >= rules.size()) {
-          g_outputBuffer = "Error: attempt to delete non-existing rule\n";
-          return;
-        }
-        rules.erase(rules.begin()+*pos);
-      }
-      g_resprulactions.setState(rules);
+      rmRule(&g_resprulactions, id);
     });
 
   g_lua.writeFunction("topResponseRule", []() {
-      setLuaSideEffect();
-      auto rules = g_resprulactions.getCopy();
-      if(rules.empty())
-          return;
-      auto subject = *rules.rbegin();
-      rules.erase(std::prev(rules.end()));
-      rules.insert(rules.begin(), subject);
-      g_resprulactions.setState(rules);
+      topRule(&g_resprulactions);
     });
 
   g_lua.writeFunction("mvResponseRule", [](unsigned int from, unsigned int to) {
-      setLuaSideEffect();
-      auto rules = g_resprulactions.getCopy();
-      if(from >= rules.size() || to > rules.size()) {
-        g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
-        return;
-      }
-      auto subject = rules[from];
-      rules.erase(rules.begin()+from);
-      if(to == rules.size())
-        rules.push_back(subject);
-      else {
-        if(from < to)
-          --to;
-        rules.insert(rules.begin()+to, subject);
-      }
-      g_resprulactions.setState(rules);
+      mvRule(&g_resprulactions, from, to);
     });
 
   g_lua.writeFunction("showCacheHitResponseRules", [](boost::optional<bool> showUUIDs) {
-      setLuaNoSideEffect();
-      int num=0;
-      if (showUUIDs.get_value_or(false)) {
-        boost::format fmt("%-3d %-38s %9d %-50s %s\n");
-        g_outputBuffer += (fmt % "#" % "UUID" % "Matches" % "Rule" % "Action").str();
-        for(const auto& lim : g_cachehitresprulactions.getCopy()) {
-          string name = lim.d_rule->toString();
-          g_outputBuffer += (fmt % num % boost::uuids::to_string(lim.d_id) % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
-          ++num;
-        }
-      }
-      else {
-        boost::format fmt("%-3d %9d %-50s %s\n");
-        g_outputBuffer += (fmt % "#" % "Matches" % "Rule" % "Action").str();
-        for(const auto& lim : g_cachehitresprulactions.getCopy()) {
-          string name = lim.d_rule->toString();
-          g_outputBuffer += (fmt % num % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
-          ++num;
-        }
-      }
+      showRules(&g_cachehitresprulactions, showUUIDs);
     });
 
   g_lua.writeFunction("rmCacheHitResponseRule", [](boost::variant<unsigned int, std::string> id) {
-      setLuaSideEffect();
-      auto rules = g_cachehitresprulactions.getCopy();
-      if (auto str = boost::get<std::string>(&id)) {
-        boost::uuids::string_generator gen;
-        const auto uuid = gen(*str);
-        if (rules.erase(std::remove_if(rules.begin(),
-                                       rules.end(),
-                                       [uuid](const DNSDistResponseRuleAction& a) { return a.d_id == uuid; }),
-                        rules.end()) == rules.end()) {
-          g_outputBuffer = "Error: no rule matched\n";
-        }
-      }
-      else if (auto pos = boost::get<unsigned int>(&id)) {
-        if (*pos >= rules.size()) {
-          g_outputBuffer = "Error: attempt to delete non-existing rule\n";
-          return;
-        }
-        rules.erase(rules.begin()+*pos);
-      }
-      g_cachehitresprulactions.setState(rules);
+      rmRule(&g_cachehitresprulactions, id);
     });
 
   g_lua.writeFunction("topCacheHitResponseRule", []() {
-      setLuaSideEffect();
-      auto rules = g_cachehitresprulactions.getCopy();
-      if(rules.empty())
-        return;
-      auto subject = *rules.rbegin();
-      rules.erase(std::prev(rules.end()));
-      rules.insert(rules.begin(), subject);
-      g_cachehitresprulactions.setState(rules);
+      topRule(&g_cachehitresprulactions);
     });
 
   g_lua.writeFunction("mvCacheHitResponseRule", [](unsigned int from, unsigned int to) {
-      setLuaSideEffect();
-      auto rules = g_cachehitresprulactions.getCopy();
-      if(from >= rules.size() || to > rules.size()) {
-        g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
-        return;
-      }
-      auto subject = rules[from];
-      rules.erase(rules.begin()+from);
-      if(to == rules.size())
-        rules.push_back(subject);
-      else {
-        if(from < to)
-          --to;
-        rules.insert(rules.begin()+to, subject);
-      }
-      g_cachehitresprulactions.setState(rules);
+      mvRule(&g_cachehitresprulactions, from, to);
     });
 
   g_lua.writeFunction("rmRule", [](boost::variant<unsigned int, std::string> id) {
-      setLuaSideEffect();
-      auto rules = g_rulactions.getCopy();
-      if (auto str = boost::get<std::string>(&id)) {
-        boost::uuids::string_generator gen;
-        const auto uuid = gen(*str);
-        if (rules.erase(std::remove_if(rules.begin(),
-                                       rules.end(),
-                                       [uuid](const DNSDistRuleAction& a) { return a.d_id == uuid; }),
-                        rules.end()) == rules.end()) {
-          g_outputBuffer = "Error: no rule matched\n";
-        }
-      }
-      else if (auto pos = boost::get<unsigned int>(&id)) {
-        if (*pos >= rules.size()) {
-          g_outputBuffer = "Error: attempt to delete non-existing rule\n";
-          return;
-        }
-        rules.erase(rules.begin()+*pos);
-      }
-      g_rulactions.setState(rules);
+      rmRule(&g_rulactions, id);
     });
 
   g_lua.writeFunction("topRule", []() {
-      setLuaSideEffect();
-      auto rules = g_rulactions.getCopy();
-      if(rules.empty())
-       return;
-      auto subject = *rules.rbegin();
-      rules.erase(std::prev(rules.end()));
-      rules.insert(rules.begin(), subject);
-      g_rulactions.setState(rules);
+      topRule(&g_rulactions);
     });
 
   g_lua.writeFunction("mvRule", [](unsigned int from, unsigned int to) {
-      setLuaSideEffect();
-      auto rules = g_rulactions.getCopy();
-      if(from >= rules.size() || to > rules.size()) {
-       g_outputBuffer = "Error: attempt to move rules from/to invalid index\n";
-       return;
-      }
-
-      auto subject = rules[from];
-      rules.erase(rules.begin()+from);
-      if(to == rules.size())
-       rules.push_back(subject);
-      else {
-       if(from < to)
-         --to;
-       rules.insert(rules.begin()+to, subject);
-      }
-      g_rulactions.setState(rules);
+      mvRule(&g_rulactions, from, to);
     });
 
   g_lua.writeFunction("clearRules", []() {
@@ -1298,26 +1217,7 @@ void setupLuaRules()
     });
 
   g_lua.writeFunction("showRules", [](boost::optional<bool> showUUIDs) {
-      setLuaNoSideEffect();
-      int num=0;
-      if (showUUIDs.get_value_or(false)) {
-        boost::format fmt("%-3d %-38s %9d %-56s %s\n");
-        g_outputBuffer += (fmt % "#" % "UUID" % "Matches" % "Rule" % "Action").str();
-        for(const auto& lim : g_rulactions.getCopy()) {
-          string name = lim.d_rule->toString();
-          g_outputBuffer += (fmt % num % boost::uuids::to_string(lim.d_id) % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
-          ++num;
-        }
-      }
-      else {
-        boost::format fmt("%-3d %9d %-50s %s\n");
-        g_outputBuffer += (fmt % "#" % "Matches" % "Rule" % "Action").str();
-        for(const auto& lim : g_rulactions.getCopy()) {
-          string name = lim.d_rule->toString();
-          g_outputBuffer += (fmt % num % lim.d_rule->d_matches % name % lim.d_action->toString()).str();
-          ++num;
-        }
-      }
+      showRules(&g_rulactions, showUUIDs);
     });
 
   g_lua.writeFunction("RDRule", []() {
index 75b8f3190e9660baf3a8e0c32f480339db84b5c2..6af354d562eb7e4eae8412fee6510e2a285863d0 100644 (file)
@@ -219,6 +219,26 @@ static void addCustomHeaders(YaHTTP::Response& resp, const boost::optional<std::
   }
 }
 
+template<typename T>
+static json11::Json::array someResponseRulesToJson(GlobalStateHolder<vector<T>>* someResponseRules)
+{
+  using namespace json11;
+  Json::array responseRules;
+  int num=0;
+  auto localResponseRules = someResponseRules->getCopy();
+  for(const auto& a : localResponseRules) {
+    Json::object rule{
+      {"id", num++},
+      {"uuid", boost::uuids::to_string(a.d_id)},
+      {"matches", (double)a.d_rule->d_matches},
+      {"rule", a.d_rule->toString()},
+      {"action", a.d_action->toString()},
+    };
+    responseRules.push_back(rule);
+  }
+  return responseRules;
+}
+
 static void connectionThread(int sock, ComboAddress remote, string password, string apiKey, const boost::optional<std::map<std::string, std::string> >& customHeaders)
 {
   using namespace json11;
@@ -457,33 +477,8 @@ static void connectionThread(int sock, ComboAddress remote, string password, str
        rules.push_back(rule);
       }
       
-      Json::array responseRules;
-      auto localResponseRules = g_resprulactions.getCopy();
-      num=0;
-      for(const auto& a : localResponseRules) {
-        Json::object rule{
-          {"id", num++},
-          {"uuid", boost::uuids::to_string(a.d_id)},
-          {"matches", (double)a.d_rule->d_matches},
-          {"rule", a.d_rule->toString()},
-          {"action", a.d_action->toString()},
-        };
-        responseRules.push_back(rule);
-      }
-
-      Json::array cacheHitResponseRules;
-      num=0;
-      auto localCacheHitResponseRules = g_cachehitresprulactions.getCopy();
-      for(const auto& a : localCacheHitResponseRules) {
-        Json::object rule{
-          {"id", num++},
-          {"uuid", boost::uuids::to_string(a.d_id)},
-          {"matches", (double)a.d_rule->d_matches},
-          {"rule", a.d_rule->toString()},
-          {"action", a.d_action->toString()},
-        };
-        cacheHitResponseRules.push_back(rule);
-      }
+      auto responseRules = someResponseRulesToJson(&g_resprulactions);
+      auto cacheHitResponseRules = someResponseRulesToJson(&g_cachehitresprulactions);
 
       string acl;
 
index 3d1e802b3e26fd4468a87f03709983b6cfa6f824..942dfbed1d31aad75f91d7d6b8349c6d1c08dd5f 100644 (file)
@@ -78,7 +78,8 @@ class TestAPIBasics(DNSDistTest):
 
         self.assertEquals(content['daemon_type'], 'dnsdist')
 
-        for key in ['version', 'acl', 'local', 'rules', 'response-rules', 'cache-hit-response-rules', 'servers', 'frontends', 'pools']:
+        rule_groups = ['response-rules', 'cache-hit-response-rules']
+        for key in ['version', 'acl', 'local', 'rules', 'servers', 'frontends', 'pools'] + rule_groups:
             self.assertIn(key, content)
 
         for rule in content['rules']:
@@ -87,17 +88,12 @@ class TestAPIBasics(DNSDistTest):
             for key in ['id', 'matches']:
                 self.assertTrue(rule[key] >= 0)
 
-        for rule in content['response-rules']:
-            for key in ['id', 'matches', 'rule', 'action', 'uuid']:
-                self.assertIn(key, rule)
-            for key in ['id', 'matches']:
-                self.assertTrue(rule[key] >= 0)
-
-        for rule in content['cache-hit-response-rules']:
-            for key in ['id', 'matches', 'rule', 'action']:
-                self.assertIn(key, rule)
-            for key in ['id', 'matches']:
-                self.assertTrue(rule[key] >= 0)
+        for rule_group in rule_groups:
+            for rule in content[rule_group]:
+                for key in ['id', 'matches', 'rule', 'action', 'uuid']:
+                    self.assertIn(key, rule)
+                for key in ['id', 'matches']:
+                    self.assertTrue(rule[key] >= 0)
 
         for server in content['servers']:
             for key in ['id', 'latency', 'name', 'weight', 'outstanding', 'qpsLimit',