]> granicus.if.org Git - pdns/commitdiff
dnsdist: Implement ref counting for the DOHUnit object
authorRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 14 Oct 2019 14:18:46 +0000 (16:18 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 14 Oct 2019 14:18:46 +0000 (16:18 +0200)
It turns out that, at least when testing with ASAN enabled, we
sometimes trigger use-after-free detection because we get the
response from the backend, send it to the client then delete the
object before the send() call to the backend even returned.

pdns/dnsdistdist/doh.cc
pdns/doh.hh

index f93abdd82e46cfcca699b9851f4b1d192fe9fcd1..eecf787801e9031f63d373471a23b37fe25f56c4 100644 (file)
@@ -145,10 +145,13 @@ void handleDOHTimeout(DOHUnit* oldDU)
 /* we are about to erase an existing DU */
   oldDU->status_code = 502;
 
+  /* increase the ref counter before sending the pointer */
+  oldDU->get();
   if (send(oldDU->rsock, &oldDU, sizeof(oldDU), 0) != sizeof(oldDU)) {
-    delete oldDU;
-    oldDU = nullptr;
+    oldDU->release();
   }
+  oldDU->release();
+  oldDU = nullptr;
 }
 
 static void on_socketclose(void *data)
@@ -158,7 +161,6 @@ static void on_socketclose(void *data)
   ctx->release();
 }
 
-
 static const std::string& getReasonFromStatusCode(uint16_t statusCode)
 {
   /* no need to care too much about this, HTTP/2 has no 'reason' anyway */
@@ -343,7 +345,11 @@ static int processDOHQuery(DOHUnit* du)
       if (du->response.empty()) {
         du->response = std::string(reinterpret_cast<char*>(dq.dh), dq.len);
       }
-      send(du->rsock, &du, sizeof(du), 0);
+      /* increase the ref counter before sending the pointer */
+      du->get();
+      if (send(du->rsock, &du, sizeof(du), 0) != sizeof(du)) {
+        du->release();
+      }
       return 0;
     }
 
@@ -411,21 +417,30 @@ static int processDOHQuery(DOHUnit* du)
     dh->id = idOffset;
 
     int fd = pickBackendSocketForSending(ss);
-    /* you can't touch du after this line, because it might already have been freed */
-    ssize_t ret = udpClientSendRequestToBackend(ss, fd, query, dq.len);
-
-    if(ret < 0) {
-      /* we are about to handle the error, make sure that
-         this pointer is not accessed when the state is cleaned,
-         but first check that it still belongs to us */
-      if (ids->tryMarkUnused(generation)) {
-        ids->du = nullptr;
-        --ss->outstanding;
+    try {
+      /* increase the ref count since we are about to send the pointer */
+         du->get();
+      /* you can't touch du after this line, because it might already have been freed */
+      ssize_t ret = udpClientSendRequestToBackend(ss, fd, query, dq.len);
+
+      if(ret < 0) {
+        du->release();
+        /* we are about to handle the error, make sure that
+           this pointer is not accessed when the state is cleaned,
+           but first check that it still belongs to us */
+        if (ids->tryMarkUnused(generation)) {
+          ids->du = nullptr;
+          --ss->outstanding;
+        }
+        ++ss->sendErrors;
+        ++g_stats.downstreamSendErrors;
+        du->status_code = 502;
+        return -1;
       }
-      ++ss->sendErrors;
-      ++g_stats.downstreamSendErrors;
-      du->status_code = 502;
-      return -1;
+    }
+    catch (const std::exception& e) {
+      du->release();
+      throw;
     }
 
     vinfolog("Got query for %s|%s from %s (https), relayed to %s", ids->qname.toString(), QType(ids->qtype).getName(), remote.toStringWithPort(), ss->getName());
@@ -529,13 +544,13 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re
     *(ptr->self) = ptr;
     try  {
       if(send(dsc->dohquerypair[0], &ptr, sizeof(ptr), 0) != sizeof(ptr)) {
-        delete ptr;
+        ptr->release();
         ptr = nullptr;
         h2o_send_error_500(req, "Internal Server Error", "Internal Server Error", 0);
       }
     }
     catch(...) {
-      delete ptr;
+      ptr->release();
     }
   }
   catch(const std::exception& e) {
@@ -829,9 +844,13 @@ void dnsdistclient(int qsock, int rsock)
 
       if(processDOHQuery(du) < 0) {
         du->status_code = 500;
-        if(send(du->rsock, &du, sizeof(du), 0) != sizeof(du))
-          delete du;     // XXX but now what - will h2o time this out for us?
+        /* increase the ref count before sending the pointer */
+        du->get();
+        if(send(du->rsock, &du, sizeof(du), 0) != sizeof(du)) {
+          du->release();     // XXX but now what - will h2o time this out for us?
+        }
       }
+      du->release();
     }
     catch(const std::exception& e) {
       errlog("Error while processing query received over DoH: %s", e.what());
@@ -859,7 +878,7 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
 
   if(!du->req) { // it got killed in flight
 //    cout << "du "<<(void*)du<<" came back from dnsdist, but it was killed"<<endl;
-    delete du;
+    du->release();
     return;
   }
 
@@ -867,7 +886,7 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err)
 
   handleResponse(*dsc->df, du->req, du->status_code, du->response, dsc->df->d_customResponseHeaders, du->contentType, true);
 
-  delete du;
+  du->release();
 }
 
 static void on_accept(h2o_socket_t *listener, const char *err)
index 6fc3d2b2c742088a9d2a0b5f21796f36df185bd4..147c5a5483c16211a7c64efa48eef9ac47a79226 100644 (file)
@@ -156,6 +156,25 @@ struct st_h2o_req_t;
 
 struct DOHUnit
 {
+  DOHUnit()
+  {
+  }
+  DOHUnit(const DOHUnit&) = delete;
+  DOHUnit& operator=(const DOHUnit&) = delete;
+
+  void get()
+  {
+    ++d_refcnt;
+  }
+
+  void release()
+  {
+    --d_refcnt;
+    if (d_refcnt == 0) {
+      delete this;
+    }
+  }
+
   std::string query;
   std::string response;
   ComboAddress remote;
@@ -163,6 +182,7 @@ struct DOHUnit
   st_h2o_req_t* req{nullptr};
   DOHUnit** self{nullptr};
   std::string contentType;
+  std::atomic<uint64_t> d_refcnt{1};
   int rsock;
   uint16_t qtype;
   /* the status_code is set from