From 1e26e48bfc80e5a66722c6916957395f9d1ea4a9 Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Wed, 3 Apr 2019 18:10:55 +0200 Subject: [PATCH] dnsdist: Try reading from the TCP backend right away Instead of waiting for the socket to be readable, as it might already be, so we save a multiplexer trip, and prevent an issue if we ever add a TLS layer between dnsdist and the backends. --- pdns/dnsdist-tcp.cc | 81 +++++++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 47 deletions(-) diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 0e81173a7..fa341c208 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -531,8 +531,9 @@ public: static void handleIOCallback(int fd, FDMultiplexer::funcparam_t& param); static void handleNewIOState(std::shared_ptr& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional ttd=boost::none); static void handleIO(std::shared_ptr& state, struct timeval& now); +static void handleDownstreamIO(std::shared_ptr& state, struct timeval& now); -static void handleResponseSent(std::shared_ptr& state) +static void handleResponseSent(std::shared_ptr& state, struct timeval& now) { handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback); @@ -540,10 +541,7 @@ static void handleResponseSent(std::shared_ptr& stat /* we need to resume reading from the backend! */ state->d_state = IncomingTCPConnectionState::State::readingResponseSizeFromBackend; state->d_currentPos = 0; - //cerr<<__func__<<": add read client FD "<d_ci.fd<d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendReadTTD()); + handleDownstreamIO(state, now); return; } @@ -552,8 +550,6 @@ static void handleResponseSent(std::shared_ptr& stat return; } - struct timeval now; - gettimeofday(&now, 0); if (state->maxConnectionDurationReached(g_maxTCPConnectionDuration, now)) { vinfolog("Terminating TCP connection from %s because it reached the maximum TCP connection duration", state->d_ci.remote.toStringWithPort()); return; @@ -564,7 +560,7 @@ static void handleResponseSent(std::shared_ptr& stat handleIO(state, now); } -static void sendResponse(std::shared_ptr& state) +static void sendResponse(std::shared_ptr& state, struct timeval& now) { state->d_state = IncomingTCPConnectionState::State::sendingResponse; const uint8_t sizeBytes[] = { static_cast(state->d_responseSize / 256), static_cast(state->d_responseSize % 256) }; @@ -575,26 +571,10 @@ static void sendResponse(std::shared_ptr& state) state->d_currentPos = 0; - try { - auto iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size()); - if (iostate == IOState::Done) { - - handleResponseSent(state); - return; - } - else { - //cerr<<__func__<<": adding client write FD "<d_ci.fd<d_ci.fd, handleIOCallback, state->getClientWriteTTD()); - } - } - catch (const std::exception& e) { - vinfolog("Got an exception while writing TCP response to %s: %s", state->d_ci.remote.toStringWithPort(), e.what()); - ++state->d_ci.cs->tcpDiedSendingResponse; - handleNewIOState(state, IOState::Done, state->d_ci.fd, handleIOCallback); - } + handleIO(state, now); } -static void handleResponse(std::shared_ptr& state) +static void handleResponse(std::shared_ptr& state, struct timeval& now) { if (state->d_responseSize < sizeof(dnsheader)) { return; @@ -643,7 +623,7 @@ static void handleResponse(std::shared_ptr& state) state->d_xfrStarted = true; } - sendResponse(state); + sendResponse(state, now); ++g_stats.responses; struct timespec answertime; @@ -652,7 +632,7 @@ static void handleResponse(std::shared_ptr& state) g_rings.insertResponse(answertime, state->d_ci.remote, *dr.qname, dr.qtype, static_cast(udiff), static_cast(state->d_responseBuffer.size()), cleartextDH, state->d_ds->remote); } -static void sendQueryToBackend(std::shared_ptr& state) +static void sendQueryToBackend(std::shared_ptr& state, struct timeval& now) { auto ds = state->d_ds; state->d_state = IncomingTCPConnectionState::State::sendingQueryToBackend; @@ -677,8 +657,7 @@ static void sendQueryToBackend(std::shared_ptr& stat return; } - //cerr<<__func__<<": add write backend FD "<d_downstreamSocket->getHandle()<d_downstreamSocket->getHandle(), handleDownstreamIOCallback, state->getBackendWriteTTD()); + handleDownstreamIO(state, now); return; } @@ -687,7 +666,7 @@ static void sendQueryToBackend(std::shared_ptr& stat vinfolog("Downstream connection to %s failed %u times in a row, giving up.", ds->getName(), state->d_downstreamFailures); } -static void handleQuery(std::shared_ptr& state) +static void handleQuery(std::shared_ptr& state, struct timeval& now) { if (state->d_querySize < sizeof(dnsheader)) { ++g_stats.nonCompliantQueries; @@ -702,9 +681,7 @@ static void handleQuery(std::shared_ptr& state) /* we need an accurate ("real") value for the response and to store into the IDS, but not for insertion into the rings for example */ - struct timespec now; struct timespec queryRealTime; - gettime(&now); gettime(&queryRealTime, true); auto query = reinterpret_cast(&state->d_buffer.at(0)); @@ -713,7 +690,7 @@ static void handleQuery(std::shared_ptr& state) if (dnsCryptResponse) { state->d_responseBuffer = std::move(*dnsCryptResponse); state->d_responseSize = state->d_responseBuffer.size(); - sendResponse(state); + sendResponse(state, now); return; } @@ -744,7 +721,7 @@ static void handleQuery(std::shared_ptr& state) state->d_buffer.resize(dq.len); state->d_responseBuffer = std::move(state->d_buffer); state->d_responseSize = state->d_responseBuffer.size(); - sendResponse(state); + sendResponse(state, now); return; } @@ -760,7 +737,7 @@ static void handleQuery(std::shared_ptr& state) that could occur if we had to deal with the size during the processing, especially alignment issues */ state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2); - sendQueryToBackend(state); + sendQueryToBackend(state, now); } static void handleNewIOState(std::shared_ptr& state, IOState iostate, const int fd, FDMultiplexer::callbackfunc_t callback, boost::optional ttd) @@ -805,20 +782,15 @@ static void handleNewIOState(std::shared_ptr& state, } } -static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param) +static void handleDownstreamIO(std::shared_ptr& state, struct timeval& now) { - auto state = boost::any_cast>(param); if (state->d_downstreamSocket == nullptr) { throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!"); } - if (fd != state->d_downstreamSocket->getHandle()) { - throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamSocket->getHandle())); - } + int fd = state->d_downstreamSocket->getHandle(); IOState iostate = IOState::Done; bool connectionDied = false; - struct timeval now; - gettimeofday(&now, 0); try { if (state->d_state == IncomingTCPConnectionState::State::sendingQueryToBackend) { @@ -877,7 +849,7 @@ static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param } fd = -1; - handleResponse(state); + handleResponse(state, now); return; } } @@ -922,8 +894,23 @@ static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param } if (connectionDied) { - sendQueryToBackend(state); + sendQueryToBackend(state, now); + } +} + +static void handleDownstreamIOCallback(int fd, FDMultiplexer::funcparam_t& param) +{ + auto state = boost::any_cast>(param); + if (state->d_downstreamSocket == nullptr) { + throw std::runtime_error("No downstream socket in " + std::string(__func__) + "!"); + } + if (fd != state->d_downstreamSocket->getHandle()) { + throw std::runtime_error("Unexpected socket descriptor " + std::to_string(fd) + " received in " + std::string(__func__) + ", expected " + std::to_string(state->d_downstreamSocket->getHandle())); } + + struct timeval now; + gettimeofday(&now, 0); + handleDownstreamIO(state, now); } static void handleIO(std::shared_ptr& state, struct timeval& now) @@ -967,7 +954,7 @@ static void handleIO(std::shared_ptr& state, struct iostate = state->d_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize); if (iostate == IOState::Done) { handleNewIOState(state, IOState::Done, fd, handleIOCallback); - handleQuery(state); + handleQuery(state, now); return; } } @@ -975,7 +962,7 @@ static void handleIO(std::shared_ptr& state, struct if (state->d_state == IncomingTCPConnectionState::State::sendingResponse) { iostate = state->d_handler.tryWrite(state->d_responseBuffer, state->d_currentPos, state->d_responseBuffer.size()); if (iostate == IOState::Done) { - handleResponseSent(state); + handleResponseSent(state, now); return; } } -- 2.40.0