]> granicus.if.org Git - pdns/commitdiff
auth: Reconnect to the server if the My/Pg connection has been closed
authorRemi Gacogne <remi.gacogne@powerdns.com>
Sat, 15 Apr 2017 18:09:32 +0000 (20:09 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 13 Jun 2017 13:46:58 +0000 (15:46 +0200)
modules/gmysqlbackend/gmysqlbackend.cc
modules/gmysqlbackend/gmysqlbackend.hh
modules/gmysqlbackend/smysql.cc
modules/gmysqlbackend/smysql.hh
modules/gpgsqlbackend/spgsql.cc
modules/gpgsqlbackend/spgsql.hh
pdns/backends/gsql/gsqlbackend.cc
pdns/backends/gsql/gsqlbackend.hh
pdns/backends/gsql/ssql.hh

index be7dd3e64b07e43df22fd68f2cf1fd989d0a9231..cb84b21a6667e23ca9cfcca03d03217809a51272 100644 (file)
 gMySQLBackend::gMySQLBackend(const string &mode, const string &suffix)  : GSQLBackend(mode,suffix)
 {
   try {
-    setDB(new SMySQL(getArg("dbname"),
-                     getArg("host"),
-                     getArgAsNum("port"),
-                     getArg("socket"),
-                     getArg("user"),
-                     getArg("password"),
-                     getArg("group"),
-                     mustDo("innodb-read-committed"),
-                     getArgAsNum("timeout")));
+    reconnect();
   }
 
   catch(SSqlException &e) {
@@ -57,6 +49,26 @@ gMySQLBackend::gMySQLBackend(const string &mode, const string &suffix)  : GSQLBa
   L<<Logger::Info<<mode<<" Connection successful. Connected to database '"<<getArg("dbname")<<"' on '"<<(getArg("host").empty() ? getArg("socket") : getArg("host"))<<"'."<<endl;
 }
 
+void gMySQLBackend::reconnect()
+{
+  setDB(new SMySQL(getArg("dbname"),
+                   getArg("host"),
+                   getArgAsNum("port"),
+                   getArg("socket"),
+                   getArg("user"),
+                   getArg("password"),
+                   getArg("group"),
+                   mustDo("innodb-read-committed"),
+                   getArgAsNum("timeout")));
+}
+
+void gMySQLBackend::reconnectIfNeeded()
+{
+  if (!isConnectionUsable()) {
+    reconnect();
+  }
+}
+
 class gMySQLFactory : public BackendFactory
 {
 public:
index 579ee1ad563dd03027ff46be77a2425a70e53867..be6acfea236655654677c057dfbe143846f4388e 100644 (file)
@@ -34,6 +34,9 @@ class gMySQLBackend : public GSQLBackend
 {
 public:
   gMySQLBackend(const string &mode, const string &suffix); //!< Makes our connection to the database. Throws an exception if it fails.
+protected:
+  void reconnectIfNeeded() override;
+  void reconnect();
 };
 
 #endif /* PDNS_GMYSQLBACKEND_HH */
index 646feb2aca7fc2ce9944c9b000a28629363967c9..0309aca95df88b43892de3a7a61c69693f8a8858 100644 (file)
@@ -393,8 +393,7 @@ private:
   int d_residx;
 };
 
-SMySQL::SMySQL(const string &database, const string &host, uint16_t port, const string &msocket, const string &user,
-               const string &password, const string &group, bool setIsolation, unsigned int timeout)
+void SMySQL::connect()
 {
   int retry=1;
 
@@ -410,9 +409,9 @@ SMySQL::SMySQL(const string &database, const string &host, uint16_t port, const
 #endif
 
 #if MYSQL_VERSION_ID >= 50100
-    if(timeout) {
-      mysql_options(&d_db, MYSQL_OPT_READ_TIMEOUT, &timeout);
-      mysql_options(&d_db, MYSQL_OPT_WRITE_TIMEOUT, &timeout);
+    if(d_timeout) {
+      mysql_options(&d_db, MYSQL_OPT_READ_TIMEOUT, &d_timeout);
+      mysql_options(&d_db, MYSQL_OPT_WRITE_TIMEOUT, &d_timeout);
     }
 #endif
 
@@ -420,18 +419,18 @@ SMySQL::SMySQL(const string &database, const string &host, uint16_t port, const
     mysql_options(&d_db, MYSQL_SET_CHARSET_NAME, MYSQL_AUTODETECT_CHARSET_NAME);
 #endif
 
-    if (setIsolation && (retry == 1))
+    if (d_setIsolation && (retry == 1))
       mysql_options(&d_db, MYSQL_INIT_COMMAND,"SET SESSION tx_isolation='READ-COMMITTED'");
 
-    mysql_options(&d_db, MYSQL_READ_DEFAULT_GROUP, group.c_str());
+    mysql_options(&d_db, MYSQL_READ_DEFAULT_GROUP, d_group.c_str());
 
-    if (!mysql_real_connect(&d_db, host.empty() ? NULL : host.c_str(),
-                          user.empty() ? NULL : user.c_str(),
-                          password.empty() ? NULL : password.c_str(),
-                          database.empty() ? NULL : database.c_str(),
-                          port,
-                          msocket.empty() ? NULL : msocket.c_str(),
-                          CLIENT_MULTI_RESULTS)) {
+    if (!mysql_real_connect(&d_db, d_host.empty() ? NULL : d_host.c_str(),
+                            d_user.empty() ? NULL : d_user.c_str(),
+                            d_password.empty() ? NULL : d_password.c_str(),
+                            d_database.empty() ? NULL : d_database.c_str(),
+                            d_port,
+                            d_msocket.empty() ? NULL : d_msocket.c_str(),
+                            CLIENT_MULTI_RESULTS)) {
 
       if (retry == 0)
         throw sPerrorException("Unable to connect to database");
@@ -446,6 +445,13 @@ SMySQL::SMySQL(const string &database, const string &host, uint16_t port, const
   } while (retry >= 0);
 }
 
+SMySQL::SMySQL(const string &database, const string &host, uint16_t port, const string &msocket, const string &user,
+               const string &password, const string &group, bool setIsolation, unsigned int timeout):
+  d_database(database), d_host(host), d_msocket(msocket), d_user(user), d_password(password), d_group(group), d_timeout(timeout), d_port(port), d_setIsolation(setIsolation)
+{
+  connect();
+}
+
 void SMySQL::setLog(bool state)
 {
   s_dolog=state;
@@ -487,3 +493,26 @@ void SMySQL::commit() {
 void SMySQL::rollback() {
   execute("rollback");
 }
+
+bool SMySQL::isConnectionUsable()
+{
+  bool usable = false;
+  int sd = d_db.net.fd;
+  bool wasNonBlocking = isNonBlocking(sd);
+
+  if (!wasNonBlocking) {
+    if (!setNonBlocking(sd)) {
+      return usable;
+    }
+  }
+
+  usable = isTCPSocketUsable(sd);
+
+  if (!wasNonBlocking) {
+    if (!setBlocking(sd)) {
+      usable = false;
+    }
+  }
+
+  return usable;
+}
index 904c19ca7170aa6542ce2e33fd6d11f3ae884903..f79af781383b47d67ceceb9cbd13f4e8526ebce5 100644 (file)
@@ -36,19 +36,31 @@ public:
 
   ~SMySQL();
 
-  SSqlException sPerrorException(const string &reason);
-  void setLog(bool state);
-  SSqlStatement* prepare(const string& query, int nparams);
-  void execute(const string& query);
-
-  void startTransaction();
-  void commit();
-  void rollback();
+  SSqlException sPerrorException(const string &reason) override;
+  void setLog(bool state) override;
+  SSqlStatement* prepare(const string& query, int nparams) override;
+  void execute(const string& query) override;
 
+  void startTransaction() override;
+  void commit() override;
+  void rollback() override;
+  bool isConnectionUsable() override;
 private:
-  MYSQL d_db;
+  void connect();
+
   static bool s_dolog;
   static pthread_mutex_t s_myinitlock;
+
+  MYSQL d_db;
+  std::string d_database;
+  std::string d_host;
+  std::string d_msocket;
+  std::string d_user;
+  std::string d_password;
+  std::string d_group;
+  unsigned int d_timeout;
+  uint16_t d_port;
+  bool d_setIsolation;
 };
 
 #endif /* SSMYSQL_HH */
index 52a3da8018392dc30f2411066671b889c3209e0f..78e6539a36f8515454d191c57ef9efbe0c26396e 100644 (file)
@@ -346,3 +346,13 @@ void SPgSQL::rollback() {
   execute("rollback");
   d_in_trx = false;
 }
+
+bool SPgSQL::isConnectionUsable()
+{
+  return PQstatus(d_db) == CONNECTION_OK;
+}
+
+void SPgSQL::reconnect()
+{
+  PQreset(d_db);
+}
index 579ceb3ff0d71360cd0ada8197cfdc56d36ff153..bf39b3ab5fdd77d7156a9947fdebdbe0eac42fa2 100644 (file)
@@ -34,14 +34,17 @@ public:
 
   ~SPgSQL();
   
-  SSqlException sPerrorException(const string &reason);
-  void setLog(bool state);
-  SSqlStatement* prepare(const string& query, int nparams);
-  void execute(const string& query);
+  SSqlException sPerrorException(const string &reason) override;
+  void setLog(bool state) override;
+  SSqlStatement* prepare(const string& query, int nparams) override;
+  void execute(const string& query) override;
 
-  void startTransaction();
-  void rollback();
-  void commit();
+  void startTransaction() override;
+  void rollback() override;
+  void commit() override;
+
+  bool isConnectionUsable() override;
+  void reconnect() override;
 
   PGconn* db() { return d_db; }
   bool in_trx() { return d_in_trx; }
index 89e89a4983f1068adfe128ccb6556a73be58bcc3..53b2faf6a01c5b1bd4db2aa437953c4380d50eb2 100644 (file)
@@ -187,6 +187,8 @@ GSQLBackend::GSQLBackend(const string &mode, const string &suffix)
 void GSQLBackend::setNotified(uint32_t domain_id, uint32_t serial)
 {
   try {
+    reconnectIfNeeded();
+
     d_UpdateSerialOfZoneQuery_stmt->
       bind("serial", serial)->
       bind("domain_id", domain_id)->
@@ -201,6 +203,8 @@ void GSQLBackend::setNotified(uint32_t domain_id, uint32_t serial)
 void GSQLBackend::setFresh(uint32_t domain_id)
 {
   try {
+    reconnectIfNeeded();
+
     d_UpdateLastCheckofZoneQuery_stmt->
       bind("last_check", time(0))->
       bind("domain_id", domain_id)->
@@ -215,6 +219,8 @@ void GSQLBackend::setFresh(uint32_t domain_id)
 bool GSQLBackend::isMaster(const DNSName &domain, const string &ip)
 {
   try {
+    reconnectIfNeeded();
+
     d_MasterOfDomainsZoneQuery_stmt->
       bind("domain", domain)->
       execute()->
@@ -246,6 +252,8 @@ bool GSQLBackend::isMaster(const DNSName &domain, const string &ip)
 bool GSQLBackend::setMaster(const DNSName &domain, const string &ip)
 {
   try {
+    reconnectIfNeeded();
+
     d_UpdateMasterOfZoneQuery_stmt->
       bind("master", ip)->
       bind("domain", domain)->
@@ -261,6 +269,8 @@ bool GSQLBackend::setMaster(const DNSName &domain, const string &ip)
 bool GSQLBackend::setKind(const DNSName &domain, const DomainInfo::DomainKind kind)
 {
   try {
+    reconnectIfNeeded();
+
     d_UpdateKindOfZoneQuery_stmt->
       bind("kind", toUpper(DomainInfo::getKindString(kind)))->
       bind("domain", domain)->
@@ -276,6 +286,8 @@ bool GSQLBackend::setKind(const DNSName &domain, const DomainInfo::DomainKind ki
 bool GSQLBackend::setAccount(const DNSName &domain, const string &account)
 {
   try {
+    reconnectIfNeeded();
+
     d_UpdateAccountOfZoneQuery_stmt->
             bind("account", account)->
             bind("domain", domain)->
@@ -293,6 +305,8 @@ bool GSQLBackend::getDomainInfo(const DNSName &domain, DomainInfo &di)
   /* fill DomainInfo from database info:
      id,name,master IP(s),last_check,notified_serial,type,account */
   try {
+    reconnectIfNeeded();
+
     d_InfoOfDomainsZoneQuery_stmt->
       bind("domain", domain)->
       execute()->
@@ -344,6 +358,8 @@ void GSQLBackend::getUnfreshSlaveInfos(vector<DomainInfo> *unfreshDomains)
   /* list all domains that need refreshing for which we are slave, and insert into SlaveDomain:
      id,name,master IP,serial */
   try {
+    reconnectIfNeeded();
+
     d_InfoOfAllSlaveDomainsQuery_stmt->
       execute()->
       getResult(d_result)->
@@ -388,6 +404,8 @@ void GSQLBackend::getUpdatedMasters(vector<DomainInfo> *updatedDomains)
   /* list all domains that need notifications for which we are master, and insert into updatedDomains
      id,name,master IP,serial */
   try {
+    reconnectIfNeeded();
+
     d_InfoOfAllMasterDomainsQuery_stmt->
       execute()->
       getResult(d_result)->
@@ -435,6 +453,8 @@ bool GSQLBackend::updateDNSSECOrderNameAndAuth(uint32_t domain_id, const DNSName
   if (!ordername.empty()) {
     if (qtype == QType::ANY) {
       try {
+        reconnectIfNeeded();
+
         d_updateOrderNameAndAuthQuery_stmt->
           bind("ordername", ordername.labelReverse().toString(" ", false))->
           bind("auth", auth)->
@@ -448,6 +468,8 @@ bool GSQLBackend::updateDNSSECOrderNameAndAuth(uint32_t domain_id, const DNSName
       }
     } else {
       try {
+        reconnectIfNeeded();
+
         d_updateOrderNameAndAuthTypeQuery_stmt->
           bind("ordername", ordername.labelReverse().toString(" ", false))->
           bind("auth", auth)->
@@ -463,6 +485,8 @@ bool GSQLBackend::updateDNSSECOrderNameAndAuth(uint32_t domain_id, const DNSName
     }
   } else {
     if (qtype == QType::ANY) {
+      reconnectIfNeeded();
+
       try {
         d_nullifyOrderNameAndUpdateAuthQuery_stmt->
           bind("auth", auth)->
@@ -476,6 +500,8 @@ bool GSQLBackend::updateDNSSECOrderNameAndAuth(uint32_t domain_id, const DNSName
       }
     } else {
       try {
+        reconnectIfNeeded();
+
         d_nullifyOrderNameAndUpdateAuthTypeQuery_stmt->
           bind("auth", auth)->
           bind("domain_id", domain_id)->
@@ -496,6 +522,8 @@ bool GSQLBackend::updateEmptyNonTerminals(uint32_t domain_id, set<DNSName>& inse
 {
   if(remove) {
     try {
+      reconnectIfNeeded();
+
       d_RemoveEmptyNonTerminalsFromZoneQuery_stmt->
         bind("domain_id", domain_id)->
         execute()->
@@ -510,6 +538,8 @@ bool GSQLBackend::updateEmptyNonTerminals(uint32_t domain_id, set<DNSName>& inse
   {
     for(const auto& qname: erase) {
       try {
+        reconnectIfNeeded();
+
         d_DeleteEmptyNonTerminalQuery_stmt->
           bind("domain_id", domain_id)->
           bind("qname", qname)->
@@ -525,6 +555,8 @@ bool GSQLBackend::updateEmptyNonTerminals(uint32_t domain_id, set<DNSName>& inse
 
   for(const auto& qname: insert) {
     try {
+      reconnectIfNeeded();
+
       d_InsertEmptyNonTerminalOrderQuery_stmt->
         bind("domain_id", domain_id)->
         bind("qname", qname)->
@@ -555,6 +587,8 @@ bool GSQLBackend::getBeforeAndAfterNamesAbsolute(uint32_t id, const DNSName& qna
 
   SSqlStatement::row_t row;
   try {
+    reconnectIfNeeded();
+
     d_afterOrderQuery_stmt->
       bind("ordername", qname.labelReverse().toString(" ", false))->
       bind("domain_id", id)->
@@ -574,6 +608,8 @@ bool GSQLBackend::getBeforeAndAfterNamesAbsolute(uint32_t id, const DNSName& qna
 
   if(after.empty()) {
     try {
+      reconnectIfNeeded();
+
       d_firstOrderQuery_stmt->
         bind("domain_id", id)->
         execute();
@@ -593,6 +629,8 @@ bool GSQLBackend::getBeforeAndAfterNamesAbsolute(uint32_t id, const DNSName& qna
     unhashed.clear();
 
     try {
+      reconnectIfNeeded();
+
       d_beforeOrderQuery_stmt->
         bind("ordername", qname.labelReverse().toString(" ", false))->
         bind("domain_id", id)->
@@ -620,6 +658,8 @@ bool GSQLBackend::getBeforeAndAfterNamesAbsolute(uint32_t id, const DNSName& qna
     }
 
     try {
+      reconnectIfNeeded();
+
       d_lastOrderQuery_stmt->
         bind("domain_id", id)->
         execute();
@@ -651,6 +691,8 @@ bool GSQLBackend::addDomainKey(const DNSName& name, const KeyData& key, int64_t&
     return false;
 
   try {
+    reconnectIfNeeded();
+
     d_AddDomainKeyQuery_stmt->
       bind("flags", key.flags)->
       bind("active", key.active)->
@@ -664,6 +706,8 @@ bool GSQLBackend::addDomainKey(const DNSName& name, const KeyData& key, int64_t&
   }
 
   try {
+    reconnectIfNeeded();
+
     d_GetLastInsertedKeyIdQuery_stmt->execute();
     if (!d_GetLastInsertedKeyIdQuery_stmt->hasNextRow()) {
       id = -2;
@@ -690,6 +734,8 @@ bool GSQLBackend::activateDomainKey(const DNSName& name, unsigned int id)
     return false;
 
   try {
+    reconnectIfNeeded();
+
     d_ActivateDomainKeyQuery_stmt->
       bind("domain", name)->
       bind("key_id", id)->
@@ -708,6 +754,8 @@ bool GSQLBackend::deactivateDomainKey(const DNSName& name, unsigned int id)
     return false;
 
   try {
+    reconnectIfNeeded();
+
     d_DeactivateDomainKeyQuery_stmt->
       bind("domain", name)->
       bind("key_id", id)->
@@ -726,6 +774,8 @@ bool GSQLBackend::removeDomainKey(const DNSName& name, unsigned int id)
     return false;
 
   try {
+    reconnectIfNeeded();
+
     d_RemoveDomainKeyQuery_stmt->
       bind("domain", name)->
       bind("key_id", id)->
@@ -741,6 +791,8 @@ bool GSQLBackend::removeDomainKey(const DNSName& name, unsigned int id)
 bool GSQLBackend::getTSIGKey(const DNSName& name, DNSName* algorithm, string* content)
 {
   try {
+    reconnectIfNeeded();
+
     d_getTSIGKeyQuery_stmt->
       bind("key_name", name)->
       execute();
@@ -771,6 +823,8 @@ bool GSQLBackend::getTSIGKey(const DNSName& name, DNSName* algorithm, string* co
 bool GSQLBackend::setTSIGKey(const DNSName& name, const DNSName& algorithm, const string& content)
 {
   try {
+    reconnectIfNeeded();
+
     d_setTSIGKeyQuery_stmt->
       bind("key_name", name)->
       bind("algorithm", algorithm)->
@@ -787,6 +841,8 @@ bool GSQLBackend::setTSIGKey(const DNSName& name, const DNSName& algorithm, cons
 bool GSQLBackend::deleteTSIGKey(const DNSName& name)
 {
   try {
+    reconnectIfNeeded();
+
     d_deleteTSIGKeyQuery_stmt->
       bind("key_name", name)->
       execute()->
@@ -801,6 +857,8 @@ bool GSQLBackend::deleteTSIGKey(const DNSName& name)
 bool GSQLBackend::getTSIGKeys(std::vector< struct TSIGKey > &keys)
 {
   try {
+    reconnectIfNeeded();
+
     d_getTSIGKeysQuery_stmt->
       execute();
 
@@ -835,6 +893,8 @@ bool GSQLBackend::getDomainKeys(const DNSName& name, std::vector<KeyData>& keys)
     return false;
 
   try {
+    reconnectIfNeeded();
+
     d_ListDomainKeysQuery_stmt->
       bind("domain", name)->
       execute();
@@ -875,6 +935,8 @@ void GSQLBackend::alsoNotifies(const DNSName &domain, set<string> *ips)
 bool GSQLBackend::getAllDomainMetadata(const DNSName& name, std::map<std::string, std::vector<std::string> >& meta)
 {
   try {
+    reconnectIfNeeded();
+
     d_GetAllDomainMetadataQuery_stmt->
       bind("domain", name)->
       execute();
@@ -905,6 +967,8 @@ bool GSQLBackend::getDomainMetadata(const DNSName& name, const std::string& kind
     return false;
 
   try {
+    reconnectIfNeeded();
+
     d_GetDomainMetadataQuery_stmt->
       bind("domain", name)->
       bind("kind", kind)->
@@ -933,6 +997,8 @@ bool GSQLBackend::setDomainMetadata(const DNSName& name, const std::string& kind
     return false;
 
   try {
+    reconnectIfNeeded();
+
     d_ClearDomainMetadataQuery_stmt->
       bind("domain", name)->
       bind("kind", kind)->
@@ -959,6 +1025,8 @@ bool GSQLBackend::setDomainMetadata(const DNSName& name, const std::string& kind
 void GSQLBackend::lookup(const QType &qtype,const DNSName &qname, DNSPacket *pkt_p, int domain_id)
 {
   try {
+    reconnectIfNeeded();
+
     if(qtype.getCode()!=QType::ANY) {
       if(domain_id < 0) {
         d_query_name = "basic-query";
@@ -1005,6 +1073,8 @@ bool GSQLBackend::list(const DNSName &target, int domain_id, bool include_disabl
   DLOG(L<<"GSQLBackend constructing handle for list of domain id '"<<domain_id<<"'"<<endl);
 
   try {
+    reconnectIfNeeded();
+
     d_query_name = "list-query";
     d_query_stmt = d_listQuery_stmt;
     d_query_stmt->
@@ -1025,6 +1095,8 @@ bool GSQLBackend::listSubZone(const DNSName &zone, int domain_id) {
   string wildzone = "%." + toLower(zone.toStringNoDot());
 
   try {
+    reconnectIfNeeded();
+
     d_query_name = "list-subzone-query";
     d_query_stmt = d_listSubZoneQuery_stmt;
     d_query_stmt->
@@ -1075,6 +1147,8 @@ bool GSQLBackend::superMasterBackend(const string &ip, const DNSName &domain, co
   // check if we know the ip/ns couple in the database
   for(vector<DNSResourceRecord>::const_iterator i=nsset.begin();i!=nsset.end();++i) {
     try {
+      reconnectIfNeeded();
+
       d_SuperMasterInfoQuery_stmt->
         bind("ip", ip)->
         bind("nameserver", i->content)->
@@ -1099,6 +1173,8 @@ bool GSQLBackend::superMasterBackend(const string &ip, const DNSName &domain, co
 bool GSQLBackend::createDomain(const DNSName &domain, const string &type, const string &masters, const string &account)
 {
   try {
+    reconnectIfNeeded();
+
     d_InsertZoneQuery_stmt->
       bind("type", type)->
       bind("domain", domain)->
@@ -1120,6 +1196,8 @@ bool GSQLBackend::createSlaveDomain(const string &ip, const DNSName &domain, con
   try {
     if (!nameserver.empty()) {
       // figure out all IP addresses for the master
+      reconnectIfNeeded();
+
       d_GetSuperMasterIPs_stmt->
         bind("nameserver", nameserver)->
         bind("account", account)->
@@ -1153,6 +1231,8 @@ bool GSQLBackend::deleteDomain(const DNSName &domain)
   }
 
   try {
+    reconnectIfNeeded();
+
     d_DeleteZoneQuery_stmt->
       bind("domain_id", di.id)->
       execute()->
@@ -1185,6 +1265,8 @@ void GSQLBackend::getAllDomains(vector<DomainInfo> *domains, bool include_disabl
   DLOG(L<<"GSQLBackend retrieving all domains."<<endl);
 
   try {
+    reconnectIfNeeded();
+
     d_getAllDomainsQuery_stmt->
       bind("include_disabled", (int)include_disabled)->
       execute();
@@ -1233,6 +1315,8 @@ void GSQLBackend::getAllDomains(vector<DomainInfo> *domains, bool include_disabl
 bool GSQLBackend::replaceRRSet(uint32_t domain_id, const DNSName& qname, const QType& qt, const vector<DNSResourceRecord>& rrset)
 {
   try {
+    reconnectIfNeeded();
+
     if (qt != QType::ANY) {
       d_DeleteRRSetQuery_stmt->
         bind("domain_id", domain_id)->
@@ -1254,6 +1338,8 @@ bool GSQLBackend::replaceRRSet(uint32_t domain_id, const DNSName& qname, const Q
 
   if (rrset.empty()) {
     try {
+      reconnectIfNeeded();
+
       d_DeleteCommentRRsetQuery_stmt->
         bind("domain_id", domain_id)->
         bind("qname", qname)->
@@ -1286,6 +1372,8 @@ bool GSQLBackend::feedRecord(const DNSResourceRecord &r, const DNSName &ordernam
   }
 
   try {
+    reconnectIfNeeded();
+
     d_InsertRecordQuery_stmt->
       bind("content",content)->
       bind("ttl",r.ttl)->
@@ -1319,6 +1407,8 @@ bool GSQLBackend::feedEnts(int domain_id, map<DNSName,bool>& nonterm)
 {
   for(const auto& nt: nonterm) {
     try {
+      reconnectIfNeeded();
+
       d_InsertEmptyNonTerminalOrderQuery_stmt->
         bind("domain_id",domain_id)->
         bind("qname", nt.first)->
@@ -1343,6 +1433,8 @@ bool GSQLBackend::feedEnts3(int domain_id, const DNSName &domain, map<DNSName,bo
 
   for(const auto& nt: nonterm) {
     try {
+      reconnectIfNeeded();
+
       d_InsertEmptyNonTerminalOrderQuery_stmt->
         bind("domain_id",domain_id)->
         bind("qname", nt.first);
@@ -1369,6 +1461,8 @@ bool GSQLBackend::feedEnts3(int domain_id, const DNSName &domain, map<DNSName,bo
 bool GSQLBackend::startTransaction(const DNSName &domain, int domain_id)
 {
   try {
+    reconnectIfNeeded();
+
     d_db->startTransaction();
     if(domain_id >= 0) {
       d_DeleteZoneQuery_stmt->
@@ -1414,6 +1508,8 @@ bool GSQLBackend::calculateSOASerial(const DNSName& domain, const SOAData& sd, t
   }
   
   try {
+    reconnectIfNeeded();
+
     d_ZoneLastChangeQuery_stmt->
       bind("domain_id", sd.domain_id)->
       execute()->
@@ -1437,6 +1533,8 @@ bool GSQLBackend::calculateSOASerial(const DNSName& domain, const SOAData& sd, t
 bool GSQLBackend::listComments(const uint32_t domain_id)
 {
   try {
+    reconnectIfNeeded();
+
     d_query_name = "list-comments-query";
     d_query_stmt = d_ListCommentsQuery_stmt;
     d_query_stmt->
@@ -1483,6 +1581,8 @@ bool GSQLBackend::getComment(Comment& comment)
 void GSQLBackend::feedComment(const Comment& comment)
 {
   try {
+    reconnectIfNeeded();
+
     d_InsertCommentQuery_stmt->
       bind("domain_id",comment.domain_id)->
       bind("qname",comment.qname)->
@@ -1501,6 +1601,8 @@ void GSQLBackend::feedComment(const Comment& comment)
 bool GSQLBackend::replaceComments(const uint32_t domain_id, const DNSName& qname, const QType& qt, const vector<Comment>& comments)
 {
   try {
+    reconnectIfNeeded();
+
     d_DeleteCommentRRsetQuery_stmt->
       bind("domain_id",domain_id)->
       bind("qname", qname)->
@@ -1526,6 +1628,8 @@ string GSQLBackend::directBackendCmd(const string &query)
 
    unique_ptr<SSqlStatement> stmt(d_db->prepare(query,0));
 
+   reconnectIfNeeded();
+
    stmt->execute();
 
    SSqlStatement::row_t row;
@@ -1560,6 +1664,8 @@ bool GSQLBackend::searchRecords(const string &pattern, int maxResults, vector<DN
   try {
     string escaped_pattern = pattern2SQLPattern(pattern);
 
+    reconnectIfNeeded();
+
     d_SearchRecordsQuery_stmt->
       bind("value", escaped_pattern)->
       bind("value2", escaped_pattern)->
@@ -1597,6 +1703,8 @@ bool GSQLBackend::searchComments(const string &pattern, int maxResults, vector<C
   try {
     string escaped_pattern = pattern2SQLPattern(pattern);
 
+    reconnectIfNeeded();
+
     d_SearchCommentsQuery_stmt->
       bind("value", escaped_pattern)->
       bind("value2", escaped_pattern)->
index 559c8043c370ebeb849af13e9ac5824ef632a472..31224b650564fd8dcaa6a7a9820f5ae7d8a9bd68 100644 (file)
@@ -244,6 +244,20 @@ protected:
   string pattern2SQLPattern(const string& pattern);
   void extractRecord(const SSqlStatement::row_t& row, DNSResourceRecord& rr);
   void extractComment(const SSqlStatement::row_t& row, Comment& c);
+  bool isConnectionUsable() {
+    if (d_db) {
+      return d_db->isConnectionUsable();
+    }
+    return false;
+  }
+  virtual void reconnectIfNeeded()
+  {
+    if (d_db) {
+      if(!d_db->isConnectionUsable()) {
+        d_db->reconnect();
+      }
+    }
+  }
 
 private:
   string d_query_name;
index 8e36b9d6392c17cf4ca0106a3c248e8b42a547a2..642b4a662855b73572b85202aa8c2a6dea659d19 100644 (file)
@@ -82,6 +82,11 @@ public:
   virtual void rollback()=0;
   virtual void commit()=0;
   virtual void setLog(bool state){}
+  virtual bool isConnectionUsable()
+  {
+    return true;
+  }
+  virtual void reconnect() {};
   virtual ~SSql(){};
 };