]> granicus.if.org Git - pdns/commitdiff
auth: register lua functions only once when shared context
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Fri, 30 Aug 2019 09:14:53 +0000 (11:14 +0200)
committerCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Tue, 17 Sep 2019 15:55:40 +0000 (17:55 +0200)
pdns/lua-record.cc

index 623b4bee07de1cbe6cb4cee8b1a7de1041b19086..31b756271a5a20d958ed57b8c24d2b7172d74553 100644 (file)
@@ -468,51 +468,34 @@ static vector<pair<int, ComboAddress> > convWIplist(std::unordered_map<int, wipl
 static thread_local unique_ptr<AuthLua4> s_LUA;
 bool g_LuaRecordSharedState;
 
-std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, const DNSName& query, const DNSName& zone, int zoneid, const DNSPacket& dnsp, uint16_t qtype)
+typedef struct AuthLuaRecordContext
 {
-  if(!s_LUA ||                  // we don't have a Lua state yet
-     !g_LuaRecordSharedState) { // or we want a new one even if we had one
-    s_LUA = make_unique<AuthLua4>();
-  }
+  ComboAddress          bestwho;
+  DNSName               qname;
+  DNSName               zone;
+  int                   zoneid;
+} lua_record_ctx_t;
 
-  std::vector<shared_ptr<DNSRecordContent>> ret;
+static thread_local unique_ptr<lua_record_ctx_t> s_lua_record_ctx;
 
+void setupLuaRecords()
+{
   LuaContext& lua = *s_LUA->getLua();
-  lua.writeVariable("qname", query);
-  lua.writeVariable("who", dnsp.getRemote());
-  lua.writeVariable("dh", (dnsheader*)&dnsp.d);
-  lua.writeVariable("dnssecOK", dnsp.d_dnssecOk);
-  lua.writeVariable("tcp", dnsp.d_tcp);
-  lua.writeVariable("ednsPKTSize", dnsp.d_ednsRawPacketSizeLimit);
-  ComboAddress bestwho;
-  if(dnsp.hasEDNSSubnet()) {
-    lua.writeVariable("ecswho", dnsp.getRealRemote());
-    bestwho=dnsp.getRealRemote().getNetwork();
-  }
-  else {
-    lua.writeVariable("ecswho", nullptr);
-    bestwho=dnsp.getRemote();
-  }
 
-  lua.writeVariable("bestwho", bestwho);
-
-  lua.writeFunction("latlon", [&bestwho]() {
+  lua.writeFunction("latlon", []() {
       double lat, lon;
-      getLatLon(bestwho.toString(), lat, lon);
+      getLatLon(s_lua_record_ctx->bestwho.toString(), lat, lon);
       return std::to_string(lat)+" "+std::to_string(lon);
     });
-
-  lua.writeFunction("latlonloc", [&bestwho]() {
+  lua.writeFunction("latlonloc", []() {
       string loc;
-      getLatLon(bestwho.toString(), loc);
+      getLatLon(s_lua_record_ctx->bestwho.toString(), loc);
       return loc;
   });
-
-
-  lua.writeFunction("closestMagic", [&bestwho,&query]() {
+  lua.writeFunction("closestMagic", []() {
       vector<ComboAddress> candidates;
       // Getting something like 192-0-2-1.192-0-2-2.198-51-100-1.example.org
-      for(auto l : query.getRawLabels()) {
+      for(auto l : s_lua_record_ctx->qname.getRawLabels()) {
         boost::replace_all(l, "-", ".");
         try {
           candidates.emplace_back(l);
@@ -521,11 +504,10 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
           break ;
         }
       }
-      return pickclosest(bestwho, candidates).toString();
+      return pickclosest(s_lua_record_ctx->bestwho, candidates).toString();
     });
-
-  lua.writeFunction("latlonMagic", [&query](){
-      auto labels= query.getRawLabels();
+  lua.writeFunction("latlonMagic", [](){
+      auto labels= s_lua_record_ctx->qname.getRawLabels();
       if(labels.size()<4)
         return std::string("unknown");
       double lat, lon;
@@ -534,49 +516,46 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
     });
 
 
-  lua.writeFunction("createReverse", [&query](string suffix, boost::optional<std::unordered_map<string,string>> e){
+  lua.writeFunction("createReverse", [](string suffix, boost::optional<std::unordered_map<string,string>> e){
       try {
-      auto labels= query.getRawLabels();
-      if(labels.size()<4)
-        return std::string("unknown");
-
-      vector<ComboAddress> candidates;
-
-      // exceptions are relative to zone
-      // so, query comes in for 4.3.2.1.in-addr.arpa, zone is called 2.1.in-addr.arpa
-      // e["1.2.3.4"]="bert.powerdns.com" - should match, easy enough to do
-      // the issue is with classless delegation..
-      if(e) {
-        ComboAddress req(labels[3]+"."+labels[2]+"."+labels[1]+"."+labels[0], 0);
-        const auto& uom = *e;
-        for(const auto& c : uom)
-          if(ComboAddress(c.first, 0) == req)
-            return c.second;
-      }
-
-
-      boost::format fmt(suffix);
-      fmt.exceptions( boost::io::all_error_bits ^ ( boost::io::too_many_args_bit | boost::io::too_few_args_bit )  );
-      fmt % labels[3] % labels[2] % labels[1] % labels[0];
-
-      fmt % (labels[3]+"-"+labels[2]+"-"+labels[1]+"-"+labels[0]);
+        auto labels = s_lua_record_ctx->qname.getRawLabels();
+        if(labels.size()<4)
+          return std::string("unknown");
+        
+        vector<ComboAddress> candidates;
+        
+        // exceptions are relative to zone
+        // so, query comes in for 4.3.2.1.in-addr.arpa, zone is called 2.1.in-addr.arpa
+        // e["1.2.3.4"]="bert.powerdns.com" - should match, easy enough to do
+        // the issue is with classless delegation..
+        if(e) {
+          ComboAddress req(labels[3]+"."+labels[2]+"."+labels[1]+"."+labels[0], 0);
+          const auto& uom = *e;
+          for(const auto& c : uom)
+            if(ComboAddress(c.first, 0) == req)
+              return c.second;
+        }
+        boost::format fmt(suffix);
+        fmt.exceptions( boost::io::all_error_bits ^ ( boost::io::too_many_args_bit | boost::io::too_few_args_bit )  );
+        fmt % labels[3] % labels[2] % labels[1] % labels[0];
+        
+        fmt % (labels[3]+"-"+labels[2]+"-"+labels[1]+"-"+labels[0]);
 
-      boost::format fmt2("%02x%02x%02x%02x");
-      for(int i=3; i>=0; --i)
-        fmt2 % atoi(labels[i].c_str());
+        boost::format fmt2("%02x%02x%02x%02x");
+        for(int i=3; i>=0; --i)
+          fmt2 % atoi(labels[i].c_str());
 
-      fmt % (fmt2.str());
+        fmt % (fmt2.str());
 
-      return fmt.str();
+        return fmt.str();
       }
       catch(std::exception& e) {
         g_log<<Logger::Error<<"error: "<<e.what()<<endl;
       }
       return std::string("error");
     });
-
-  lua.writeFunction("createForward", [&zone, &query]() {
-      DNSName rel=query.makeRelative(zone);
+  lua.writeFunction("createForward", []() {
+      DNSName rel=s_lua_record_ctx->qname.makeRelative(s_lua_record_ctx->zone);
       auto parts = rel.getRawLabels();
       if(parts.size()==4)
         return parts[0]+"."+parts[1]+"."+parts[2]+"."+parts[3];
@@ -593,8 +572,8 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
       return std::string("0.0.0.0");
     });
 
-  lua.writeFunction("createForward6", [&query,&zone]() {
-      DNSName rel=query.makeRelative(zone);
+  lua.writeFunction("createForward6", []() {
+      DNSName rel=s_lua_record_ctx->qname.makeRelative(s_lua_record_ctx->zone);
       auto parts = rel.getRawLabels();
       if(parts.size()==8) {
         string tot;
@@ -614,13 +593,11 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
 
       return std::string("::");
     });
-
-
-  lua.writeFunction("createReverse6", [&query](string suffix, boost::optional<std::unordered_map<string,string>> e){
+  lua.writeFunction("createReverse6", [](string suffix, boost::optional<std::unordered_map<string,string>> e){
       vector<ComboAddress> candidates;
 
       try {
-        auto labels= query.getRawLabels();
+        auto labels= s_lua_record_ctx->qname.getRawLabels();
         if(labels.size()<32)
           return std::string("unknown");
         boost::format fmt(suffix);
@@ -671,7 +648,6 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
       return std::string("unknown");
     });
 
-
   /*
    * Simplistic test to see if an IP address listens on a certain port
    * Will return a single IP address from the set of available IP addresses. If
@@ -680,7 +656,7 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
    *
    * @example ifportup(443, { '1.2.3.4', '5.4.3.2' })"
    */
-  lua.writeFunction("ifportup", [&bestwho](int port, const vector<pair<int, string> >& ips, const boost::optional<std::unordered_map<string,string>> options) {
+  lua.writeFunction("ifportup", [](int port, const vector<pair<int, string> >& ips, const boost::optional<std::unordered_map<string,string>> options) {
       vector<ComboAddress> candidates, unavailables;
       opts_t opts;
       vector<ComboAddress > conv;
@@ -706,14 +682,13 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
         selector = getOptionValue(options, "backupSelector", "random");
       }
 
-      vector<ComboAddress> res = useSelector(selector, bestwho, candidates);
+      vector<ComboAddress> res = useSelector(selector, s_lua_record_ctx->bestwho, candidates);
       return convIpListToString(res);
     });
 
-  lua.writeFunction("ifurlup", [&bestwho](const std::string& url,
+  lua.writeFunction("ifurlup", [](const std::string& url,
                                           const boost::variant<iplist_t, ipunitlist_t>& ips,
                                           boost::optional<opts_t> options) {
-
       vector<vector<ComboAddress> > candidates;
       opts_t opts;
       if(options)
@@ -737,7 +712,7 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
           }
         }
         if(!available.empty()) {
-          vector<ComboAddress> res = useSelector(getOptionValue(options, "selector", "random"), bestwho, available);
+          vector<ComboAddress> res = useSelector(getOptionValue(options, "selector", "random"), s_lua_record_ctx->bestwho, available);
           return convIpListToString(res);
         }
       }
@@ -748,15 +723,9 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
         ret.insert(ret.end(), unit.begin(), unit.end());
       }
 
-      vector<ComboAddress> res = useSelector(getOptionValue(options, "backupSelector", "random"), bestwho, ret);
+      vector<ComboAddress> res = useSelector(getOptionValue(options, "backupSelector", "random"), s_lua_record_ctx->bestwho, ret);
       return convIpListToString(res);
     });
-
-
-  /* idea: we have policies on vectors of ComboAddresses, like
-     random, pickwrandom, pickwhashed, pickclosest. In C++ this is ComboAddress in,
-     ComboAddress out. In Lua, vector string in, string out */
-
   /*
    * Returns a random IP address from the supplied list
    * @example pickrandom({ '1.2.3.4', '5.4.3.2' })"
@@ -784,68 +753,63 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
    * supplied, as weighted by the various `weight` parameters
    * @example pickwhashed({ {15, '1.2.3.4'}, {50, '5.4.3.2'} })
    */
-  lua.writeFunction("pickwhashed", [&bestwho](std::unordered_map<int, wiplist_t > ips) {
+  lua.writeFunction("pickwhashed", [](std::unordered_map<int, wiplist_t > ips) {
       vector<pair<int,ComboAddress> > conv;
 
       for(auto& i : ips)
         conv.emplace_back(atoi(i.second[1].c_str()), ComboAddress(i.second[2]));
 
-      return pickwhashed(bestwho, conv).toString();
+      return pickwhashed(s_lua_record_ctx->bestwho, conv).toString();
     });
 
 
-  lua.writeFunction("pickclosest", [&bestwho](const iplist_t& ips) {
+  lua.writeFunction("pickclosest", [](const iplist_t& ips) {
       vector<ComboAddress > conv = convIplist(ips);
 
-      return pickclosest(bestwho, conv).toString();
+      return pickclosest(s_lua_record_ctx->bestwho, conv).toString();
 
     });
 
+  if (g_luaRecordExecLimit > 0) {
+      lua.executeCode(boost::str(boost::format("debug.sethook(report, '', %d)") % g_luaRecordExecLimit));
+  }
 
   lua.writeFunction("report", [](string event, boost::optional<string> line){
       throw std::runtime_error("Script took too long");
     });
-  if (g_luaRecordExecLimit > 0) {
-      lua.executeCode(boost::str(boost::format("debug.sethook(report, '', %d)") % g_luaRecordExecLimit));
-  }
 
-  // TODO: make this better. Accept netmask/CA objects; provide names for the attr constants
   lua.writeFunction("geoiplookup", [](const string &ip, const GeoIPInterface::GeoIPQueryAttribute attr) {
     return getGeo(ip, attr);
   });
 
   typedef const boost::variant<string,vector<pair<int,string> > > combovar_t;
-  lua.writeFunction("continent", [&bestwho](const combovar_t& continent) {
-      string res=getGeo(bestwho.toString(), GeoIPInterface::Continent);
+  lua.writeFunction("continent", [](const combovar_t& continent) {
+     string res=getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::Continent);
       return doCompare(continent, res, [](const std::string& a, const std::string& b) {
           return !strcasecmp(a.c_str(), b.c_str());
         });
     });
-
-  lua.writeFunction("asnum", [&bestwho](const combovar_t& asns) {
-      string res=getGeo(bestwho.toString(), GeoIPInterface::ASn);
+  lua.writeFunction("asnum", [](const combovar_t& asns) {
+      string res=getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::ASn);
       return doCompare(asns, res, [](const std::string& a, const std::string& b) {
           return !strcasecmp(a.c_str(), b.c_str());
         });
     });
-
-  lua.writeFunction("country", [&bestwho](const combovar_t& var) {
-      string res = getGeo(bestwho.toString(), GeoIPInterface::Country2);
+  lua.writeFunction("country", [](const combovar_t& var) {
+      string res = getGeo(s_lua_record_ctx->bestwho.toString(), GeoIPInterface::Country2);
       return doCompare(var, res, [](const std::string& a, const std::string& b) {
           return !strcasecmp(a.c_str(), b.c_str());
         });
 
     });
-
-  lua.writeFunction("netmask", [bestwho](const iplist_t& ips) {
+  lua.writeFunction("netmask", [](const iplist_t& ips) {
       for(const auto& i :ips) {
         Netmask nm(i.second);
-        if(nm.match(bestwho))
+        if(nm.match(s_lua_record_ctx->bestwho))
           return true;
       }
       return false;
     });
-
   /* {
        {
         {'192.168.0.0/16', '10.0.0.0/8'},
@@ -856,13 +820,13 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
        }
      }
   */
-  lua.writeFunction("view", [bestwho](const vector<pair<int, vector<pair<int, iplist_t> > > >& in) {
+  lua.writeFunction("view", [](const vector<pair<int, vector<pair<int, iplist_t> > > >& in) {
       for(const auto& rule : in) {
         const auto& netmasks=rule.second[0].second;
         const auto& destinations=rule.second[1].second;
         for(const auto& nmpair : netmasks) {
           Netmask nm(nmpair.second);
-          if(nm.match(bestwho)) {
+          if(nm.match(s_lua_record_ctx->bestwho)) {
             return destinations[dns_random(destinations.size())].second;
           }
         }
@@ -872,18 +836,54 @@ std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, cons
     );
 
 
-  lua.writeFunction("include", [&lua,zone,zoneid](string record) {
+  lua.writeFunction("include", [&lua](string record) {
       try {
-        vector<DNSZoneRecord> drs = lookup(DNSName(record) +zone, QType::LUA, zoneid);
+        vector<DNSZoneRecord> drs = lookup(DNSName(record) + s_lua_record_ctx->zone, QType::LUA, s_lua_record_ctx->zoneid);
         for(const auto& dr : drs) {
           auto lr = getRR<LUARecordContent>(dr.dr);
           lua.executeCode(lr->getCode());
         }
       }
       catch(std::exception& e) {
-        g_log<<Logger::Error<<"Failed to load include record for LUArecord "<<(DNSName(record)+zone)<<": "<<e.what()<<endl;
+        g_log<<Logger::Error<<"Failed to load include record for LUArecord "<<(DNSName(record)+s_lua_record_ctx->zone)<<": "<<e.what()<<endl;
       }
     });
+}
+
+std::vector<shared_ptr<DNSRecordContent>> luaSynth(const std::string& code, const DNSName& query, const DNSName& zone, int zoneid, const DNSPacket& dnsp, uint16_t qtype)
+{
+  if(!s_LUA ||                  // we don't have a Lua state yet
+     !g_LuaRecordSharedState) { // or we want a new one even if we had one
+    s_LUA = make_unique<AuthLua4>();
+    setupLuaRecords();
+  }
+
+  std::vector<shared_ptr<DNSRecordContent>> ret;
+
+  LuaContext& lua = *s_LUA->getLua();
+
+  s_lua_record_ctx = std::unique_ptr<lua_record_ctx_t>(new lua_record_ctx_t());
+  s_lua_record_ctx->qname = query;
+  s_lua_record_ctx->zone = zone;
+  s_lua_record_ctx->zoneid = zoneid;
+  
+  lua.writeVariable("qname", query);
+  lua.writeVariable("zone", zone);
+  lua.writeVariable("zoneid", zoneid);
+  lua.writeVariable("who", dnsp.getRemote());
+  lua.writeVariable("dh", (dnsheader*)&dnsp.d);
+  lua.writeVariable("dnssecOK", dnsp.d_dnssecOk);
+  lua.writeVariable("tcp", dnsp.d_tcp);
+  lua.writeVariable("ednsPKTSize", dnsp.d_ednsRawPacketSizeLimit);
+  if(dnsp.hasEDNSSubnet()) {
+    lua.writeVariable("ecswho", dnsp.getRealRemote());
+    s_lua_record_ctx->bestwho = dnsp.getRealRemote().getNetwork();
+  }
+  else {
+    lua.writeVariable("ecswho", nullptr);
+    s_lua_record_ctx->bestwho = dnsp.getRemote();
+  }
+  lua.writeVariable("bestwho", s_lua_record_ctx->bestwho);
 
   try {
     string actual;