Semaphore *TCPNameserver::d_connectionroom_sem;
PacketHandler *TCPNameserver::s_P;
NetmaskGroup TCPNameserver::d_ng;
+size_t TCPNameserver::d_maxTransactionsPerConn;
+size_t TCPNameserver::d_maxConnectionsPerClient;
+unsigned int TCPNameserver::d_idleTimeout;
+unsigned int TCPNameserver::d_maxConnectionDuration;
+std::mutex TCPNameserver::s_clientsCountMutex;
+std::map<ComboAddress,size_t,ComboAddress::addressOnlyLessThan> TCPNameserver::s_clientsCount;
void TCPNameserver::go()
{
}
// throws PDNSException if things didn't go according to plan, returns 0 if really 0 bytes were read
-int readnWithTimeout(int fd, void* buffer, unsigned int n, bool throwOnEOF=true)
+static int readnWithTimeout(int fd, void* buffer, unsigned int n, unsigned int idleTimeout, bool throwOnEOF=true, unsigned int totalTimeout=0)
{
unsigned int bytes=n;
char *ptr = (char*)buffer;
int ret;
+ time_t start = 0;
+ unsigned int remainingTotal = totalTimeout;
+ if (totalTimeout) {
+ start = time(NULL);
+ }
while(bytes) {
ret=read(fd, ptr, bytes);
if(ret < 0) {
if(errno==EAGAIN) {
- ret=waitForData(fd, 5);
+ ret=waitForData(fd, (totalTimeout == 0 || idleTimeout <= remainingTotal) ? idleTimeout : remainingTotal);
if(ret < 0)
throw NetworkError("Waiting for data read");
if(!ret)
ptr += ret;
bytes -= ret;
+ if (totalTimeout) {
+ time_t now = time(NULL);
+ unsigned int elapsed = now - start;
+ if (elapsed >= remainingTotal) {
+ throw NetworkError("Timeout while reading data");
+ }
+ start = now;
+ remainingTotal -= elapsed;
+ }
}
return n;
}
// ditto
-void writenWithTimeout(int fd, const void *buffer, unsigned int n)
+static void writenWithTimeout(int fd, const void *buffer, unsigned int n, unsigned int idleTimeout)
{
unsigned int bytes=n;
const char *ptr = (char*)buffer;
ret=write(fd, ptr, bytes);
if(ret < 0) {
if(errno==EAGAIN) {
- ret=waitForRWData(fd, false, 5, 0);
+ ret=waitForRWData(fd, false, idleTimeout, 0);
if(ret < 0)
throw NetworkError("Waiting for data write");
if(!ret)
uint16_t len=htons(p->getString().length());
string buffer((const char*)&len, 2);
buffer.append(p->getString());
- writenWithTimeout(outsock, buffer.c_str(), buffer.length());
+ writenWithTimeout(outsock, buffer.c_str(), buffer.length(), d_idleTimeout);
}
-void TCPNameserver::getQuestion(int fd, char *mesg, int pktlen, const ComboAddress &remote)
+void TCPNameserver::getQuestion(int fd, char *mesg, int pktlen, const ComboAddress &remote, unsigned int totalTime)
try
{
- readnWithTimeout(fd, mesg, pktlen);
+ readnWithTimeout(fd, mesg, pktlen, d_idleTimeout, true, totalTime);
}
catch(NetworkError& ae) {
throw NetworkError("Error reading DNS data from TCP client "+remote.toString()+": "+ae.what());
}
-static void proxyQuestion(shared_ptr<DNSPacket> packet)
+static void proxyQuestion(shared_ptr<DNSPacket> packet, unsigned int idleTimeout)
{
int sock=socket(AF_INET, SOCK_STREAM, 0);
uint16_t len=htons(buffer.length()), slen;
- writenWithTimeout(sock, &len, 2);
- writenWithTimeout(sock, buffer.c_str(), buffer.length());
+ writenWithTimeout(sock, &len, 2, idleTimeout);
+ writenWithTimeout(sock, buffer.c_str(), buffer.length(), idleTimeout);
- readnWithTimeout(sock, &len, 2);
+ readnWithTimeout(sock, &len, 2, idleTimeout);
len=ntohs(len);
char answer[len];
- readnWithTimeout(sock, answer, len);
+ readnWithTimeout(sock, answer, len, idleTimeout);
slen=htons(len);
- writenWithTimeout(packet->getSocket(), &slen, 2);
+ writenWithTimeout(packet->getSocket(), &slen, 2, idleTimeout);
- writenWithTimeout(packet->getSocket(), answer, len);
+ writenWithTimeout(packet->getSocket(), answer, len, idleTimeout);
}
catch(NetworkError& ae) {
close(sock);
else
S.inc("tcp4-answers");
}
+
+static bool maxConnectionDurationReached(unsigned int maxConnectionDuration, time_t start, unsigned int& remainingTime)
+{
+ if (maxConnectionDuration) {
+ time_t elapsed = time(NULL) - start;
+ if (elapsed >= maxConnectionDuration) {
+ return true;
+ }
+ remainingTime = maxConnectionDuration - elapsed;
+ }
+ return false;
+}
+
+void TCPNameserver::decrementClientCount(const ComboAddress& remote)
+{
+ if (d_maxConnectionsPerClient) {
+ std::lock_guard<std::mutex> lock(s_clientsCountMutex);
+ s_clientsCount[remote]--;
+ if (s_clientsCount[remote] == 0) {
+ s_clientsCount.erase(remote);
+ }
+ }
+}
+
void *TCPNameserver::doConnection(void *data)
{
shared_ptr<DNSPacket> packet;
int fd=(int)(long)data; // gotta love C (generates a harmless warning on opteron)
ComboAddress remote;
socklen_t remotelen=sizeof(remote);
+ size_t transactions = 0;
+ time_t start = 0;
+ if (d_maxConnectionDuration) {
+ start = time(NULL);
+ }
pthread_detach(pthread_self());
if(getpeername(fd, (struct sockaddr *)&remote, &remotelen) < 0) {
DLOG(L<<"TCP Connection accepted on fd "<<fd<<endl);
bool logDNSQueries= ::arg().mustDo("log-dns-queries");
for(;;) {
+ unsigned int remainingTime = 0;
+ transactions++;
+ if (d_maxTransactionsPerConn && transactions > d_maxTransactionsPerConn) {
+ L << Logger::Notice<<"TCP Remote "<< remote <<" exceeded the number of transactions per connection, dropping.";
+ break;
+ }
+ if (maxConnectionDurationReached(d_maxConnectionDuration, start, remainingTime)) {
+ L << Logger::Notice<<"TCP Remote "<< remote <<" exceeded the maximum TCP connection duration, dropping.";
+ break;
+ }
uint16_t pktlen;
- if(!readnWithTimeout(fd, &pktlen, 2, false))
+ if(!readnWithTimeout(fd, &pktlen, 2, d_idleTimeout, false, remainingTime))
break;
else
pktlen=ntohs(pktlen);
break;
}
- getQuestion(fd, mesg.get(), pktlen, remote);
+ if (maxConnectionDurationReached(d_maxConnectionDuration, start, remainingTime)) {
+ L << Logger::Notice<<"TCP Remote "<< remote <<" exceeded the maximum TCP connection duration, dropping.";
+ break;
+ }
+
+ getQuestion(fd, mesg.get(), pktlen, remote, remainingTime);
S.inc("tcp-queries");
if(remote.sin4.sin_family == AF_INET6)
S.inc("tcp6-queries");
if(LPE) LPE->police(&(*packet), &(*reply), true);
if(shouldRecurse) {
- proxyQuestion(packet);
+ proxyQuestion(packet, d_idleTimeout);
continue;
}
}
catch(const PDNSException& e) {
L<<Logger::Error<<"Error closing TCP socket: "<<e.reason<<endl;
}
+ decrementClientCount(remote);
return 0;
}
TCPNameserver::TCPNameserver()
{
+ d_maxTransactionsPerConn = ::arg().asNum("max-tcp-transactions-per-conn");
+ d_idleTimeout = ::arg().asNum("tcp-idle-timeout");
+ d_maxConnectionDuration = ::arg().asNum("max-tcp-connection-duration");
+ d_maxConnectionsPerClient = ::arg().asNum("max-tcp-connections-per-client");
+
// sem_init(&d_connectionroom_sem,0,::arg().asNum("max-tcp-connections"));
d_connectionroom_sem = new Semaphore( ::arg().asNum( "max-tcp-connections" ));
d_tid=0;
try {
for(;;) {
int fd;
- struct sockaddr_in remote;
- Utility::socklen_t addrlen=sizeof(remote);
+ ComboAddress remote;
+ Utility::socklen_t addrlen=remote.getSocklen();
int ret=poll(&d_prfds[0], d_prfds.size(), -1); // blocks, forever if need be
if(ret <= 0)
for(const pollfd& pfd : d_prfds) {
if(pfd.revents == POLLIN) {
sock = pfd.fd;
- addrlen=sizeof(remote);
+ remote.sin4.sin_family = AF_INET6;
+ addrlen=remote.getSocklen();
if((fd=accept(sock, (sockaddr*)&remote, &addrlen))<0) {
L<<Logger::Error<<"TCP question accept error: "<<strerror(errno)<<endl;
}
}
else {
+ if (d_maxConnectionsPerClient) {
+ std::lock_guard<std::mutex> lock(s_clientsCountMutex);
+ if (s_clientsCount[remote] >= d_maxConnectionsPerClient) {
+ L<<Logger::Notice<<"Limit of simultaneous TCP connections per client reached for "<< remote<<", dropping"<<endl;
+ close(fd);
+ continue;
+ }
+ s_clientsCount[remote]++;
+ }
+
pthread_t tid;
d_connectionroom_sem->wait(); // blocks if no connections are available
L<<Logger::Error<<"Error creating thread: "<<stringerror()<<endl;
d_connectionroom_sem->post();
close(fd);
+ decrementClientCount(remote);
}
}
}