Stop copying data around in MOADNSParser, remove the weird -12 dance
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 30 Jul 2018 07:17:57 +0000 (09:17 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 6 Aug 2018 08:00:22 +0000 (10:00 +0200)
modules/tinydnsbackend/tinydnsbackend.cc
pdns/dnsdist-ecs.cc
pdns/dnsparser.cc
pdns/dnsparser.hh
pdns/lwres.cc
pdns/pdns_recursor.cc
pdns/protobuf.cc

index 2fbe115cfeb81af0c89c8d3a9ff8495bd44dde93..228ac34a5ea88070e244a0734c1265542978079d 100644 (file)
@@ -232,12 +232,7 @@ bool TinyDNSBackend::get(DNSResourceRecord &rr)
     }
 
 
-    vector<uint8_t> bytes;
-    const char *sval = val.c_str();
-    unsigned int len = val.size();
-    bytes.resize(len);
-    copy(sval, sval+len, bytes.begin());
-    PacketReader pr(bytes);
+    PacketReader pr(val, 0);
     rr.qtype = QType(pr.get16BitInt());
 
     if(d_isAxfr || d_qtype.getCode() == QType::ANY || rr.qtype == d_qtype) {
@@ -298,7 +293,7 @@ bool TinyDNSBackend::get(DNSResourceRecord &rr)
         DNSRecord dr;
         dr.d_class = 1;
         dr.d_type = rr.qtype.getCode();
-        dr.d_clen = val.size()-pr.d_pos;
+        dr.d_clen = val.size()-pr.getPosition();
 
         auto drc = DNSRecordContent::mastermake(dr, pr);
         rr.content = drc->getZoneRepresentation();
index 158a7352d630c05128b914ea978d594a8d0b46c6..0df7c210cc2927317ba5137bebaf8a559cfbfaed 100644 (file)
@@ -47,10 +47,9 @@ int rewriteResponseWithoutEDNS(const char * packet, const size_t len, vector<uin
 
   if (ntohs(dh->qdcount) == 0)
     return ENOENT;
-    
-  vector<uint8_t> content(len - sizeof(dnsheader));
-  copy(packet + sizeof(dnsheader), packet + len, content.begin());
-  PacketReader pr(content);
+
+  std::string packetStr(packet, len);
+  PacketReader pr(packetStr);
   
   size_t idx = 0;
   DNSName rrname;
@@ -117,7 +116,7 @@ int rewriteResponseWithoutEDNS(const char * packet, const size_t len, vector<uin
       pr.xfrBlob(blob);
       pw.xfrBlob(blob);
     } else {
-      pr.d_pos += ah.d_clen;
+      pr.skip(ah.d_clen);
     }
   }
   pw.commit();
@@ -136,9 +135,8 @@ int locateEDNSOptRR(char * packet, const size_t len, char ** optStart, size_t *
   if (ntohs(dh->arcount) == 0)
     return ENOENT;
 
-  vector<uint8_t> content(len - sizeof(dnsheader));
-  copy(packet + sizeof(dnsheader), packet + len, content.begin());
-  PacketReader pr(content);
+  std::string packetStr(packet, len);
+  PacketReader pr(packetStr);
   size_t idx = 0;
   DNSName rrname;
   uint16_t qdcount = ntohs(dh->qdcount);
@@ -162,18 +160,18 @@ int locateEDNSOptRR(char * packet, const size_t len, char ** optStart, size_t *
   for (idx = 0; idx < ancount + nscount; idx++) {
     rrname = pr.getName();
     pr.getDnsrecordheader(ah);
-    pr.d_pos += ah.d_clen;
+    pr.skip(ah.d_clen);
   }
 
   /* consume AR, looking for OPT */
   for (idx = 0; idx < arcount; idx++) {
-    uint16_t start = pr.d_pos;
+    uint16_t start = pr.getPosition();
     rrname = pr.getName();
     pr.getDnsrecordheader(ah);
 
     if (ah.d_type == QType::OPT) {
-      *optStart = packet + sizeof(dnsheader) + start;
-      *optLen = (pr.d_pos - start) + ah.d_clen;
+      *optStart = packet + start;
+      *optLen = (pr.getPosition() - start) + ah.d_clen;
 
       if ((packet + len) < (*optStart + *optLen)) {
         throw std::range_error("Opt record overflow");
@@ -187,7 +185,7 @@ int locateEDNSOptRR(char * packet, const size_t len, char ** optStart, size_t *
       }
       return 0;
     }
-    pr.d_pos += ah.d_clen;
+    pr.skip(ah.d_clen);
   }
 
   return ENOENT;
@@ -445,9 +443,8 @@ int rewriteResponseWithoutEDNSOption(const char * packet, const size_t len, cons
   if (ntohs(dh->qdcount) == 0)
     return ENOENT;
 
-  vector<uint8_t> content(len - sizeof(dnsheader));
-  copy(packet + sizeof(dnsheader), packet + len, content.begin());
-  PacketReader pr(content);
+  std::string packetStr(packet, len);
+  PacketReader pr(packetStr);
 
   size_t idx = 0;
   DNSName rrname;
index 33fac5f5e20916ba35aff38c3838b7d5ed943340..a5fd5b4f0e7a025a54effca5626e22d5526bf93a 100644 (file)
@@ -109,7 +109,7 @@ shared_ptr<DNSRecordContent> DNSRecordContent::unserialize(const DNSName& qname,
 
   struct dnsrecordheader drh;
   drh.d_type=htons(qtype);
-  drh.d_class=htons(1);
+  drh.d_class=htons(QClass::IN);
   drh.d_ttl=0;
   drh.d_clen=htons(serialized.size());
 
@@ -221,12 +221,12 @@ DNSResourceRecord DNSResourceRecord::fromWire(const DNSRecord& d) {
   return rr;
 }
 
-void MOADNSParser::init(bool query, const char *packet, unsigned int len)
+void MOADNSParser::init(bool query, const std::string& packet)
 {
-  if(len < sizeof(dnsheader))
+  if (packet.size() < sizeof(dnsheader))
     throw MOADNSException("Packet shorter than minimal header");
   
-  memcpy(&d_header, packet, sizeof(dnsheader));
+  memcpy(&d_header, packet.data(), sizeof(dnsheader));
 
   if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update)
     throw MOADNSException("Can't parse non-query packet with opcode="+ std::to_string(d_header.opcode));
@@ -238,15 +238,10 @@ void MOADNSParser::init(bool query, const char *packet, unsigned int len)
 
   if (query && (d_header.qdcount > 1))
     throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")");
-
-  uint16_t contentlen=len-sizeof(dnsheader);
-
-  d_content.resize(contentlen);
-  copy(packet+sizeof(dnsheader), packet+len, d_content.begin());
   
   unsigned int n=0;
 
-  PacketReader pr(d_content);
+  PacketReader pr(packet);
   bool validPacket=false;
   try {
     d_qtype = d_qclass = 0; // sometimes replies come in with no question, don't present garbage then
@@ -272,7 +267,7 @@ void MOADNSParser::init(bool query, const char *packet, unsigned int len)
       else 
         dr.d_place=DNSResourceRecord::ADDITIONAL;
 
-      unsigned int recordStartPos=pr.d_pos;
+      unsigned int recordStartPos=pr.getPosition();
 
       DNSName name=pr.getName();
 
@@ -295,7 +290,7 @@ void MOADNSParser::init(bool query, const char *packet, unsigned int len)
         dr.d_content=DNSRecordContent::mastermake(dr, pr, d_header.opcode);
       }
 
-      d_answers.push_back(make_pair(dr, pr.d_pos));
+      d_answers.push_back(make_pair(dr, pr.getPosition() - sizeof(dnsheader)));
 
       /* XXX: XPF records should be allowed after TSIG as soon as the actual XPF option code has been assigned:
          if (dr.d_place == DNSResourceRecord::ADDITIONAL && seenTSIG && dr.d_type != QType::XPF)
@@ -310,18 +305,18 @@ void MOADNSParser::init(bool query, const char *packet, unsigned int len)
           throw MOADNSException("Packet ("+d_qname.toLogString()+"|#"+std::to_string(d_qtype)+") has a TSIG record in an invalid position.");
         }
         seenTSIG = true;
-        d_tsigPos = recordStartPos + sizeof(struct dnsheader);
+        d_tsigPos = recordStartPos;
       }
     }
 
 #if 0
-    if(pr.d_pos!=contentlen) {
-      throw MOADNSException("Packet ("+d_qname+"|#"+std::to_string(d_qtype)+") has trailing garbage ("+ std::to_string(pr.d_pos) + " < " + 
-                            std::to_string(contentlen) + ")");
+    if(pr.getPosition()!=packet.size()) {
+      throw MOADNSException("Packet ("+d_qname+"|#"+std::to_string(d_qtype)+") has trailing garbage ("+ std::to_string(pr.getPosition()) + " < " +
+                            std::to_string(packet.size()) + ")");
     }
 #endif
   }
-  catch(std::out_of_range &re) {
+  catch(const std::out_of_range &re) {
     if(validPacket && d_header.tc) { // don't sweat it over truncated packets, but do adjust an, ns and arcount
       if(n < d_header.ancount) {
         d_header.ancount=n; d_header.nscount = d_header.arcount = 0;
@@ -334,7 +329,7 @@ void MOADNSParser::init(bool query, const char *packet, unsigned int len)
       }
     }
     else {
-      throw MOADNSException("Error parsing packet of "+std::to_string(len)+" bytes (rd="+
+      throw MOADNSException("Error parsing packet of "+std::to_string(packet.size())+" bytes (rd="+
                             std::to_string(d_header.rd)+
                             "), out of bounds: "+string(re.what()));
     }
@@ -383,45 +378,40 @@ void PacketReader::copyRecord(unsigned char* dest, uint16_t len)
 void PacketReader::xfr48BitInt(uint64_t& ret)
 {
   ret=0;
-  ret+=d_content.at(d_pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
-  ret+=d_content.at(d_pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
-  ret+=d_content.at(d_pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
-  ret+=d_content.at(d_pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
-  ret+=d_content.at(d_pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
-  ret+=d_content.at(d_pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
 }
 
 uint32_t PacketReader::get32BitInt()
 {
   uint32_t ret=0;
-  ret+=d_content.at(d_pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
-  ret+=d_content.at(d_pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
-  ret+=d_content.at(d_pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
-  ret+=d_content.at(d_pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   
   return ret;
 }
 
 
 uint16_t PacketReader::get16BitInt()
-{
-  return get16BitInt(d_content, d_pos);
-}
-
-uint16_t PacketReader::get16BitInt(const vector<unsigned char>&content, uint16_t& pos)
 {
   uint16_t ret=0;
-  ret+=content.at(pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
-  ret+=content.at(pos++);
+  ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   
   return ret;
 }
@@ -435,10 +425,8 @@ DNSName PacketReader::getName()
 {
   unsigned int consumed;
   try {
-    DNSName dn((const char*) d_content.data() - 12, d_content.size() + 12, d_pos + sizeof(dnsheader), true /* uncompress */, 0 /* qtype */, 0 /* qclass */, &consumed, sizeof(dnsheader));
+    DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, 0 /* qtype */, 0 /* qclass */, &consumed, sizeof(dnsheader));
     
-    // the -12 fakery is because we don't have the header in 'd_content', but we do need to get 
-    // the internal offsets to work
     d_pos+=consumed;
     return dn;
   }
@@ -482,7 +470,7 @@ string PacketReader::getText(bool multi, bool lenField)
     }
     uint16_t labellen;
     if(lenField)
-      labellen=d_content.at(d_pos++);
+      labellen=static_cast<uint8_t>(d_content.at(d_pos++));
     else
       labellen=d_recordlen - (d_pos - d_startrecordpos);
     
@@ -504,7 +492,7 @@ string PacketReader::getUnquotedText(bool lenField)
 {
   uint16_t stop_at;
   if(lenField)
-    stop_at = (uint8_t)d_content.at(d_pos) + d_pos + 1;
+    stop_at = static_cast<uint8_t>(d_content.at(d_pos)) + d_pos + 1;
   else
     stop_at = d_recordlen;
 
index 353efccdd07256806e6b9d04c64114078132204c..5892b5b5879ad5c9b2e7246e57617dddb8c8709a 100644 (file)
@@ -68,8 +68,8 @@ class MOADNSParser;
 class PacketReader
 {
 public:
-  PacketReader(const vector<uint8_t>& content) 
-    : d_pos(0), d_startrecordpos(0), d_content(content)
+  PacketReader(const std::string& content, uint16_t initialPos=sizeof(dnsheader))
+    : d_pos(initialPos), d_startrecordpos(initialPos), d_content(content)
   {
     if(content.size() > std::numeric_limits<uint16_t>::max())
       throw std::out_of_range("packet too large");
@@ -155,8 +155,6 @@ public:
   void xfrBlob(string& blob, int length);
   void xfrHexBlob(string& blob, bool keepReading=false);
 
-  static uint16_t get16BitInt(const vector<unsigned char>&content, uint16_t& pos);
-
   void getDnsrecordheader(struct dnsrecordheader &ah);
   void copyRecord(vector<unsigned char>& dest, uint16_t len);
   void copyRecord(unsigned char* dest, uint16_t len);
@@ -165,18 +163,28 @@ public:
   string getText(bool multi, bool lenField);
   string getUnquotedText(bool lenField);
 
-  uint16_t d_pos;
 
   bool eof() { return true; };
   const string getRemaining() const {
     return "";
   };
 
+  uint16_t getPosition() const
+  {
+    return d_pos;
+  }
+
+  void skip(uint16_t n)
+  {
+    d_pos += n;
+  }
+
 private:
+  uint16_t d_pos;
   uint16_t d_startrecordpos; // needed for getBlob later on
   uint16_t d_recordlen;      // ditto
   uint16_t not_used; // Aligns the whole class on 8-byte boundries
-  const vector<uint8_t>& d_content;
+  const std::string& d_content;
 };
 
 struct DNSRecord;
@@ -356,15 +364,15 @@ class MOADNSParser : public boost::noncopyable
 {
 public:
   //! Parse from a string
-  MOADNSParser(bool query, const string& buffer)  : d_tsigPos(0)
+  MOADNSParser(bool query, const string& buffer): d_tsigPos(0)
   {
-    init(query, buffer.c_str(), (unsigned int)buffer.size());
+    init(query, buffer);
   }
 
   //! Parse from a pointer and length
   MOADNSParser(bool query, const char *packet, unsigned int len) : d_tsigPos(0)
   {
-    init(query, packet, len);
+    init(query, std::string(packet, len));
   }
 
   DNSName d_qname;
@@ -377,21 +385,12 @@ public:
   //! All answers contained in this packet (everything *but* the question section)
   answers_t d_answers;
 
-  shared_ptr<PacketReader> getPacketReader(uint16_t offset)
-  {
-    shared_ptr<PacketReader> pr(new PacketReader(d_content));
-    pr->d_pos=offset;
-    return pr;
-  }
-
   uint16_t getTSIGPos() const
   {
     return d_tsigPos;
   }
 private:
-  void getDnsrecordheader(struct dnsrecordheader &ah);
-  void init(bool query, const char *packet, unsigned int len);
-  vector<uint8_t> d_content;
+  void init(bool query, const std::string& packet);
   uint16_t d_tsigPos;
 };
 
index b007aca107df0f133b7bfc8675dea26d724eed99..4739ae10a3abb34a8b36f6fbf7912503ade07667 100644 (file)
@@ -94,7 +94,8 @@ int asyncresolve(const ComboAddress& ip, const DNSName& domain, int type, bool d
 {
   size_t len;
   size_t bufsize=g_outgoingEDNSBufsize;
-  scoped_array<unsigned char> buf(new unsigned char[bufsize]);
+  std::string buf;
+  buf.resize(bufsize);
   vector<uint8_t> vpacket;
   //  string mapped0x20=dns0x20(domain);
   uint16_t qid = dns_random(0xffff);
@@ -169,7 +170,7 @@ int asyncresolve(const ComboAddress& ip, const DNSName& domain, int type, bool d
 
     // sleep until we see an answer to this, interface to mtasker
     
-    ret=arecvfrom(reinterpret_cast<char *>(buf.get()), bufsize, 0, ip, &len, qid,
+    ret=arecvfrom(const_cast<char *>(buf.data()), buf.size(), 0, ip, &len, qid,
                   domain, type, queryfd, now);
   }
   else {
@@ -204,12 +205,8 @@ int asyncresolve(const ComboAddress& ip, const DNSName& domain, int type, bool d
       if(!(ret > 0))
         return ret;
       
-      if(len > bufsize) {
-        bufsize=len;
-        scoped_array<unsigned char> narray(new unsigned char[bufsize]);
-        buf.swap(narray);
-      }
-      memcpy(buf.get(), packet.c_str(), len);
+      buf.resize(len);
+      memcpy(const_cast<char*>(buf.data()), packet.c_str(), len);
 
       ret=1;
     }
@@ -225,10 +222,11 @@ int asyncresolve(const ComboAddress& ip, const DNSName& domain, int type, bool d
   if(ret <= 0) // includes 'timeout'
     return ret;
 
+  buf.resize(len);
   lwr->d_records.clear();
   try {
     lwr->d_tcbit=0;
-    MOADNSParser mdp(false, (const char*)buf.get(), len);
+    MOADNSParser mdp(false, buf);
     lwr->d_aabit=mdp.d_header.aa;
     lwr->d_tcbit=mdp.d_header.tc;
     lwr->d_rcode=mdp.d_header.rcode;
index e6ac33fc06976fee8822247b4bb8ed30ecd03ae2..4901ca45b80c578d185328242ca9a63528a45e1d 100644 (file)
@@ -234,11 +234,11 @@ bool g_logRPZChanges{false};
 
 //! used to send information to a newborn mthread
 struct DNSComboWriter {
-  DNSComboWriter(const std::string& query, const struct timeval& now): d_mdp(true, query.c_str(), query.size()), d_now(now)
+  DNSComboWriter(const std::string& query, const struct timeval& now): d_mdp(true, query), d_now(now)
   {
   }
 
-  DNSComboWriter(const std::string& query, const struct timeval& now, std::vector<std::string>&& policyTags, LuaContext::LuaObject&& data): d_mdp(true, query.c_str(), query.size()), d_now(now), d_policyTags(std::move(policyTags)), d_data(std::move(data))
+  DNSComboWriter(const std::string& query, const struct timeval& now, std::vector<std::string>&& policyTags, LuaContext::LuaObject&& data): d_mdp(true, query), d_now(now), d_policyTags(std::move(policyTags)), d_data(std::move(data))
   {
   }
 
index 3c9b49a0b3af07fc35d44eb78b5b3dd6cc7084be..60dc3ada31bfa43ffe4f73e8672e08a9b01762f0 100644 (file)
@@ -147,9 +147,8 @@ void DNSProtoBufMessage::addRRsFromPacket(const char* packet, const size_t len,
   if (!response)
     return;
 
-  vector<uint8_t> content(len - sizeof(dnsheader));
-  copy(packet + sizeof(dnsheader), packet + len, content.begin());
-  PacketReader pr(content);
+  std::string packetStr(packet, len);
+  PacketReader pr(packetStr);
 
   size_t idx = 0;
   DNSName rrname;