]> granicus.if.org Git - pdns/commitdiff
it compiles & appears to function somewhat: dnsdist
authorbert hubert <bert.hubert@netherlabs.nl>
Tue, 25 Jun 2013 10:12:03 +0000 (12:12 +0200)
committerbert hubert <bert.hubert@netherlabs.nl>
Tue, 25 Jun 2013 10:12:03 +0000 (12:12 +0200)
pdns/dnsdist.cc

index 96e7b740df9edddb952e0bc5c29d77320a9e5c25..755d61d6c965b8a0848698ceb6b5d0308346070a 100644 (file)
@@ -1,6 +1,6 @@
 /*
     PowerDNS Versatile Database Driven Nameserver
-    Copyright (C) 2002-2013  PowerDNS.COM BV
+    Copyright (C) 2013  PowerDNS.COM BV
 
     This program is free software; you can redistribute it and/or modify
     it under the terms of the GNU General Public License version 2
 #include <netinet/tcp.h>
 #include <boost/array.hpp>
 #include <boost/program_options.hpp>
+#include <boost/foreach.hpp>
+
+/* syntax: dnsdist 8.8.8.8 8.8.4.4 208.67.222.222 208.67.220.220
+   Added downstream server 8.8.8.8:53
+   Added downstream server 8.8.4.4:53
+   Added downstream server 208.67.222.222:53
+   Added downstream server 208.67.220.220:53
+   Listening on [::]:53
+
+   And you are in business!
+ */
 
 StatBag S;
 namespace po = boost::program_options;
 po::variables_map g_vm;
-bool g_verbose;
-bool g_onlyTCP;
-bool g_tcpNoDelay;
-unsigned int g_timeoutMsec;
-AtomicCounter g_networkErrors, g_otherErrors, g_OK, g_truncates, g_authAnswers, g_timeOuts;
-ComboAddress g_dest;
 
-/* On Linux, run echo 1 > /proc/sys/net/ipv4/tcp_tw_recycle 
-   to prevent running out of free TCP ports */
+bool g_verbose;
+AtomicCounter g_pos, g_timeouts;
 
-struct BenchQuery
+int Socket(int family, int type, int flags)
 {
-  BenchQuery(const std::string& qname_, uint16_t qtype_) : qname(qname_), qtype(qtype_), udpMsec(0), tcpMsec(0) {}
-  BenchQuery(){}
-  std::string qname;
-  uint16_t qtype;
-  uint16_t udpMsec, tcpMsec;
-};
+  int ret = socket(family, type, flags);
+  if(ret < 0)
+    throw runtime_error((boost::format("creating socket of type %d: %s") % family % strerror(errno)).str());
+  return ret;
+}
 
+int Connect(int sockfd, const ComboAddress& remote)
+{
+  int ret = connect(sockfd, (struct sockaddr*)&remote, remote.getSocklen());
+  if(ret < 0)
+    throw runtime_error((boost::format("connecting socket to %s: %s") % remote.toStringWithPort() % strerror(errno)).str());
+  return ret;
+}
 
-void doQuery(BenchQuery* q)
-try
+int Bind(int sockfd, const ComboAddress& local)
 {
-  vector<uint8_t> packet;
-  DNSPacketWriter pw(packet, q->qname, q->qtype);
-  int res;
-  string reply;
+  int ret = bind(sockfd, (struct sockaddr*)&local, local.getSocklen());
+  if(ret < 0)
+    throw runtime_error((boost::format("binding socket to %s: %s") % local.toStringWithPort() % strerror(errno)).str());
+  return ret;
+}
 
-  if(!g_onlyTCP) {
-    Socket udpsock((AddressFamily)g_dest.sin4.sin_family, Datagram);
-    
-    udpsock.sendTo(string((char*)&*packet.begin(), (char*)&*packet.end()), g_dest);
-    ComboAddress origin;
-    res = waitForData(udpsock.getHandle(), 0, 1000 * g_timeoutMsec);
-    if(res < 0)
-      throw NetworkError("Error waiting for response");
-    if(!res) {
-      g_timeOuts++;
-      return;
-    }
+/* the grand design. Per socket we listen on for incoming queries there is one thread.
+   Then we have a bunch of connected sockets for talking to downstream servers. 
+   We send directly to those sockets.
 
-    udpsock.recvFrom(reply, origin);
-    MOADNSParser mdp(reply);
-    if(!mdp.d_header.tc)
-      return;
-    g_truncates++;
-  }
+   For the return path, per downstream server we have a thread that listens to responses.
 
-  Socket sock((AddressFamily)g_dest.sin4.sin_family, Stream);
-  int tmp=1;
-  if(setsockopt(sock.getHandle(),SOL_SOCKET,SO_REUSEADDR,(char*)&tmp,sizeof tmp)<0) 
-    throw runtime_error("Unable to set socket reuse: "+string(strerror(errno)));
-    
-  if(g_tcpNoDelay && setsockopt(sock.getHandle(), IPPROTO_TCP, TCP_NODELAY,(char*)&tmp,sizeof tmp)<0) 
-    throw runtime_error("Unable to set socket no delay: "+string(strerror(errno)));
-
-  sock.connect(g_dest);
-  uint16_t len = htons(packet.size());
-  string tcppacket((char*)& len, 2);
-  tcppacket.append((char*)&*packet.begin(), (char*)&*packet.end());
-
-  sock.writen(tcppacket);
-
-  res = waitForData(sock.getHandle(), 0, 1000 * g_timeoutMsec);
-  if(res < 0)
-    throw NetworkError("Error waiting for response");
-  if(!res) {
-    g_timeOuts++;
-    return;
-  }
-  
-  if(sock.read((char *) &len, 2) != 2)
-    throw AhuException("tcp read failed");
-  
-  len=ntohs(len);
-  char *creply = new char[len];
-  int n=0;
-  int numread;
-  while(n<len) {
-    numread=sock.read(creply+n, len-n);
-    if(numread<0)
-      throw AhuException("tcp read failed");
-    n+=numread;
-  }
-  
-  reply=string(creply, len);
-  delete[] creply;
-  
-  MOADNSParser mdp(reply);
-  //  cout<<"Had correct TCP/IP response, "<<mdp.d_answers.size()<<" answers, aabit="<<mdp.d_header.aa<<endl;
-  if(mdp.d_header.aa)
-    g_authAnswers++;
-  g_OK++;
-}
-catch(NetworkError& ne)
+   Per socket there is an array of 2^16 states, when we send out a packet downstream, we note
+   there the original requestor and the original id. The new ID is the offset in the array.
+
+   When an answer comes in on a socket, we look up the offset by the id, and lob it to the 
+   original requestor.
+
+   IDs are assigned by atomic increments of the socket offset.
+ */
+
+struct IDState
 {
-  cerr<<"Network error: "<<ne.what()<<endl;
-  g_networkErrors++;
-}
-catch(...)
+  uint16_t origID;
+  ComboAddress origRemote;
+  int origFD;
+};
+
+struct SocketState
 {
-  g_otherErrors++;
-}
+  int fd;
+  pthread_t tid;
+  ComboAddress remote;
+  vector<IDState> idStates;
+  AtomicCounter idOffset;
+};
 
-/* read queries from stdin, put in vector
-   launch n worker threads, each picks a query using AtomicCounter
-   If a worker reaches the end of its queue, it stops */
+SocketState* g_socketstates;
+unsigned int g_numremotes;
 
-AtomicCounter g_pos;
+void* responderThread(void *p)
+{
+  SocketState* state = (SocketState*)p;
+  if(g_verbose)
+    cout << "Added downstream server "<<state->remote.toStringWithPort()<<endl;
+  char packet[65536];
 
-vector<BenchQuery> g_queries;
+  struct dnsheader* dh = (struct dnsheader*)packet;
+  int len;
+  for(;;) {
+    len = recv(state->fd, packet, sizeof(packet), 0);
+    if(len < 0)
+      continue;
+    IDState* ids = &state->idStates[dh->id];
+    if(ids->origFD < 0)
+      continue;
+    dh->id = ids->origID;
+    sendto(ids->origFD, packet, len, 0, (struct sockaddr*)&ids->origRemote, ids->origRemote.getSocklen());
+    if(g_verbose)
+      cout << "Got answer from "<<state->remote.toStringWithPort()<<", relayed to "<<ids->origRemote.toStringWithPort()<<endl;
+
+    ids->origFD = -1;
+  }
+  return 0;
+}
 
+struct ClientState
+{
+  ComboAddress local;
+  int fd;
+};
 
-static void* worker(void*)
+// listens to incoming queries, sends out to downstream servers
+void* clientThread(void* p)
 {
-  for(;;) {
-    unsigned int pos = g_pos++; 
-    if(pos >= g_queries.size())
-      break;
+  ClientState* cs = (ClientState*) p;
+  if(g_verbose)
+    cout<<"Listening on "<<cs->local.toStringWithPort()<<endl;
 
-    doQuery(&g_queries[pos]); // this is safe as long as nobody *inserts* to g_queries
+  ComboAddress remote;
+  remote.sin4.sin_family = cs->local.sin4.sin_family;
+  socklen_t socklen = cs->local.getSocklen();
+  
+  char packet[1500];
+  struct dnsheader* dh = (struct dnsheader*) packet;
+  int len;
+
+  for(;;) {
+    len = recvfrom(cs->fd, packet, sizeof(packet), 0, (struct sockaddr*) &remote, &socklen);
+    if(len < 0)
+      continue;
+    
+    SocketState& ss = g_socketstates[(g_pos++) % g_numremotes];
+    unsigned int idOffset = ss.idOffset++;
+    IDState* ids = &ss.idStates[idOffset];
+    ids->origFD = cs->fd;
+    ids->origID = dh->id;
+    ids->origRemote = remote;
+    dh->id = idOffset;
+    
+    send(ss.fd, packet, len, 0);
+    if(g_verbose)
+      cout<<"Got query from "<<remote.toStringWithPort()<<",relayed to "<<ss.remote.toStringWithPort()<<endl;
   }
   return 0;
 }
@@ -159,21 +174,16 @@ try
   po::options_description desc("Allowed options"), hidden, alloptions;
   desc.add_options()
     ("help,h", "produce help message")
-    ("verbose,v", "be verbose")
-    ("udp-first,u", "try UDP first")
-    ("tcp-no-delay", po::value<bool>()->default_value(true), "use TCP_NODELAY socket option")
-    ("timeout-msec", po::value<int>()->default_value(10), "wait for this amount of milliseconds for an answer")
-    ("workers", po::value<int>()->default_value(100), "number of parallel workers");
+    ("local", po::value<vector<string> >(), "Listen on which address")
+    ("verbose,v", "be verbose");
     
-
   hidden.add_options()
-    ("remote-host", po::value<string>(), "remote-host")
-    ("remote-port", po::value<int>()->default_value(53), "remote-port");
+    ("remotes", po::value<vector<string> >(), "remote-host");
+
   alloptions.add(desc).add(hidden); 
 
   po::positional_options_description p;
-  p.add("remote-host", 1);
-  p.add("remote-port", 1);
+  p.add("remotes", -1);
 
   po::store(po::command_line_parser(argc, argv).options(alloptions).positional(p).run(), g_vm);
   po::notify(g_vm);
@@ -182,57 +192,64 @@ try
     cout << desc<<endl;
     exit(EXIT_SUCCESS);
   }
-  g_tcpNoDelay = g_vm["tcp-no-delay"].as<bool>();
-
-  g_onlyTCP = !g_vm.count("udp-first");
-  g_verbose = g_vm.count("verbose");
-  g_timeoutMsec = g_vm["timeout-msec"].as<int>();
 
-  reportAllTypes();
-
-  if(g_vm["remote-host"].empty()) {
-    cerr<<"Syntax: tcpbench remote [port] < queries"<<endl;
-    cerr<<"Where queries is one query per line, format: qname qtype, just 1 space"<<endl;
-    cerr<<desc<<endl;
+  g_verbose=g_vm.count("verbose");
+  
+  if(!g_vm.count("remotes")) {
+    cerr<<"Need to specify at least one remote address"<<endl;
+    cout<<desc<<endl;
     exit(EXIT_FAILURE);
   }
+  vector<string> remotes = g_vm["remotes"].as<vector<string> >();
+
+  g_numremotes = remotes.size();
+  g_socketstates = new SocketState[g_numremotes];
+  int pos=0;
+  BOOST_FOREACH(const string& remote, remotes) {
+    SocketState& ss = g_socketstates[pos++];
+    ss.remote = ComboAddress(remote, 53);
+    
+    ss.fd = Socket(ss.remote.sin4.sin_family, SOCK_DGRAM, 0);
+    Connect(ss.fd, ss.remote);
 
-  g_dest = ComboAddress(g_vm["remote-host"].as<string>().c_str(), g_vm["remote-port"].as<int>());
+    ss.idStates.resize(65536);
+    BOOST_FOREACH(IDState& ids, ss.idStates) {
+      ids.origFD = -1;
+    }
 
-  unsigned int numworkers=g_vm["workers"].as<int>();
-  
-  if(g_verbose) {
-    cout<<"Sending queries to: "<<g_dest.toStringWithPort()<<endl;
-    cout<<"Attempting UDP first: " << (g_onlyTCP ? "no" : "yes") <<endl;
-    cout<<"Timeout: "<< g_timeoutMsec<<"msec"<<endl;
-    cout << "Using TCP_NODELAY: "<<g_tcpNoDelay<<endl;
+    pthread_create(&ss.tid, 0, responderThread, (void*)&ss);
   }
 
-
-  pthread_t workers[numworkers];
-
-  FILE* fp=fdopen(0, "r");
-  pair<string, string> q;
-  string line;
-  while(stringfgets(fp, line)) {
-    trim_right(line);
-    q=splitField(line, ' ');
-    g_queries.push_back(BenchQuery(q.first, DNSRecordContent::TypeToNumber(q.second)));
-  }
-  fclose(fp);
+  pthread_t tid;
+  vector<string> locals;
+  if(g_vm.count("local"))
+    locals = g_vm["local"].as<vector<string> >();
+  else
+    locals.push_back("::");
+
+  BOOST_FOREACH(const string& local, locals) {
+    ClientState* cs = new ClientState;
+    cs->local= ComboAddress(local, 53);
+    cs->fd = Socket(cs->local.sin4.sin_family, SOCK_DGRAM, 0);
+    if(cs->local.sin4.sin_family == AF_INET6) {
+      int val = 1;
+      setsockopt(cs->fd, IPPROTO_IPV6, IPV6_V6ONLY, &val, sizeof(val));
+    }
+    Bind(cs->fd, cs->local);
     
-  for(unsigned int n = 0; n < numworkers; ++n) {
-    pthread_create(&workers[n], 0, worker, 0);
+    pthread_create(&tid, 0, clientThread, (void*) cs);
   }
-  for(unsigned int n = 0; n < numworkers; ++n) {
-    void* status;
-    pthread_join(workers[n], &status);
-  }
-  cout<<"OK: "<<g_OK<<", network errors: "<<g_networkErrors<<", other errors: "<<g_otherErrors<<endl;
-  cout<<"Timeouts: "<<g_timeOuts<<endl;
-  cout<<"Truncateds: "<<g_truncates<<", auth answers: "<<g_authAnswers<<endl;
+
+  void* status;
+  pthread_join(tid, &status);
+
 }
 catch(std::exception &e)
 {
   cerr<<"Fatal: "<<e.what()<<endl;
 }
+catch(AhuException &ae)
+{
+  cerr<<"Fatal: "<<ae.reason<<endl;
+}