From 6daf396879b6502c41146cdb1013568654f52ff6 Mon Sep 17 00:00:00 2001
From: Bruce Momjian <bruce@momjian.us>
Date: Wed, 24 Mar 2004 03:45:00 +0000
Subject: [PATCH] Add thread locking to SSL and Kerberos connections.

I have removed the docs mentioning that SSL and Kerberos are not
thread-safe.

Manfred Spraul
---
 doc/src/sgml/libpq.sgml           |   5 +-
 src/backend/libpq/md5.c           |   4 +-
 src/interfaces/libpq/fe-auth.c    |  10 +-
 src/interfaces/libpq/fe-connect.c |  45 +++++++-
 src/interfaces/libpq/fe-secure.c  | 175 ++++++++++++++++++++++++------
 src/interfaces/libpq/libpq-fe.h   |  16 ++-
 src/interfaces/libpq/libpq-int.h  |  13 ++-
 7 files changed, 222 insertions(+), 46 deletions(-)

diff --git a/doc/src/sgml/libpq.sgml b/doc/src/sgml/libpq.sgml
index 4ac3743970..d4819c8b34 100644
--- a/doc/src/sgml/libpq.sgml
+++ b/doc/src/sgml/libpq.sgml
@@ -1,5 +1,5 @@
 <!--
-$PostgreSQL: pgsql/doc/src/sgml/libpq.sgml,v 1.149 2004/03/23 23:37:17 tgl Exp $
+$PostgreSQL: pgsql/doc/src/sgml/libpq.sgml,v 1.150 2004/03/24 03:44:58 momjian Exp $
 -->
 
  <chapter id="libpq">
@@ -3654,8 +3654,7 @@ call <function>fe_setauthsvc</function> at all.
 <literal>crypt()</literal> operating system function, which is often
 not thread-safe.<indexterm><primary>crypt</><secondary>thread
 safety</></> It is better to use the <literal>md5</literal> method,
-which is thread-safe on all platforms.  <application>SSL</> connections
-and <application>kerberos</> authentication are also not thread-safe.
+which is thread-safe on all platforms.
 </para>
 
 <para>
diff --git a/src/backend/libpq/md5.c b/src/backend/libpq/md5.c
index b1bb90a8da..4c194dde05 100644
--- a/src/backend/libpq/md5.c
+++ b/src/backend/libpq/md5.c
@@ -14,7 +14,7 @@
  *	Portions Copyright (c) 1994, Regents of the University of California
  *
  * IDENTIFICATION
- *	  $PostgreSQL: pgsql/src/backend/libpq/md5.c,v 1.22 2003/11/29 19:51:49 pgsql Exp $
+ *	  $PostgreSQL: pgsql/src/backend/libpq/md5.c,v 1.23 2004/03/24 03:44:58 momjian Exp $
  */
 
 
@@ -271,7 +271,7 @@ calculateDigestFromBuffer(uint8 *b, uint32 len, uint8 sum[16])
 static void
 bytesToHex(uint8 b[16], char *s)
 {
-	static char *hex = "0123456789abcdef";
+	static const char *hex = "0123456789abcdef";
 	int			q,
 				w;
 
diff --git a/src/interfaces/libpq/fe-auth.c b/src/interfaces/libpq/fe-auth.c
index 1113b3abb0..28d3e7ec6f 100644
--- a/src/interfaces/libpq/fe-auth.c
+++ b/src/interfaces/libpq/fe-auth.c
@@ -10,7 +10,7 @@
  * exceed INITIAL_EXPBUFFER_SIZE (currently 256 bytes).
  *
  * IDENTIFICATION
- *	  $PostgreSQL: pgsql/src/interfaces/libpq/fe-auth.c,v 1.89 2004/01/07 18:56:29 neilc Exp $
+ *	  $PostgreSQL: pgsql/src/interfaces/libpq/fe-auth.c,v 1.90 2004/03/24 03:44:59 momjian Exp $
  *
  *-------------------------------------------------------------------------
  */
@@ -590,6 +590,7 @@ fe_sendauth(AuthRequest areq, PGconn *conn, const char *hostname,
 
 		case AUTH_REQ_KRB4:
 #ifdef KRB4
+			pglock_thread();
 			if (pg_krb4_sendauth(PQerrormsg, conn->sock,
 							   (struct sockaddr_in *) & conn->laddr.addr,
 							   (struct sockaddr_in *) & conn->raddr.addr,
@@ -597,8 +598,10 @@ fe_sendauth(AuthRequest areq, PGconn *conn, const char *hostname,
 			{
 				snprintf(PQerrormsg, PQERRORMSG_LENGTH,
 					libpq_gettext("Kerberos 4 authentication failed\n"));
+				pgunlock_thread();
 				return STATUS_ERROR;
 			}
+			pgunlock_thread();
 			break;
 #else
 			snprintf(PQerrormsg, PQERRORMSG_LENGTH,
@@ -608,13 +611,16 @@ fe_sendauth(AuthRequest areq, PGconn *conn, const char *hostname,
 
 		case AUTH_REQ_KRB5:
 #ifdef KRB5
+			pglock_thread();
 			if (pg_krb5_sendauth(PQerrormsg, conn->sock,
 								 hostname) != STATUS_OK)
 			{
 				snprintf(PQerrormsg, PQERRORMSG_LENGTH,
 					libpq_gettext("Kerberos 5 authentication failed\n"));
+				pgunlock_thread();
 				return STATUS_ERROR;
 			}
+			pgunlock_thread();
 			break;
 #else
 			snprintf(PQerrormsg, PQERRORMSG_LENGTH,
@@ -722,6 +728,7 @@ fe_getauthname(char *PQerrormsg)
 	if (authsvc == 0)
 		return NULL;			/* leave original error message in place */
 
+	pglock_thread();
 #ifdef KRB4
 	if (authsvc == STARTUP_KRB4_MSG)
 		name = pg_krb4_authname(PQerrormsg);
@@ -759,5 +766,6 @@ fe_getauthname(char *PQerrormsg)
 
 	if (name && (authn = (char *) malloc(strlen(name) + 1)))
 		strcpy(authn, name);
+	pgunlock_thread();
 	return authn;
 }
diff --git a/src/interfaces/libpq/fe-connect.c b/src/interfaces/libpq/fe-connect.c
index 6bf07e1e20..b94504c03b 100644
--- a/src/interfaces/libpq/fe-connect.c
+++ b/src/interfaces/libpq/fe-connect.c
@@ -8,7 +8,7 @@
  *
  *
  * IDENTIFICATION
- *	  $PostgreSQL: pgsql/src/interfaces/libpq/fe-connect.c,v 1.268 2004/03/10 21:12:47 momjian Exp $
+ *	  $PostgreSQL: pgsql/src/interfaces/libpq/fe-connect.c,v 1.269 2004/03/24 03:44:59 momjian Exp $
  *
  *-------------------------------------------------------------------------
  */
@@ -2902,7 +2902,7 @@ int
 PQsetClientEncoding(PGconn *conn, const char *encoding)
 {
 	char		qbuf[128];
-	static char query[] = "set client_encoding to '%s'";
+	static const char query[] = "set client_encoding to '%s'";
 	PGresult   *res;
 	int			status;
 
@@ -3164,3 +3164,44 @@ PasswordFromFile(char *hostname, char *port, char *dbname, char *username)
 #undef LINELEN
 }
 
+/*
+ * To keep the API consistent, the locking stubs are always provided, even
+ * if they are not required.
+ */
+
+void
+PQinitSSL(int do_init)
+{
+#ifdef USE_SSL
+	pq_initssllib = do_init;
+#endif
+}
+
+static pgthreadlock_t default_threadlock;
+static void
+default_threadlock(int acquire)
+{
+#ifdef ENABLE_THREAD_SAFETY
+	static pthread_mutex_t singlethread_lock = PTHREAD_MUTEX_INITIALIZER;
+	if (acquire)
+		pthread_mutex_lock(&singlethread_lock);
+	else
+		pthread_mutex_unlock(&singlethread_lock);
+#endif
+}
+
+pgthreadlock_t *g_threadlock = default_threadlock;
+
+pgthreadlock_t *
+PQregisterThreadLock(pgthreadlock_t *newhandler)
+{
+	pgthreadlock_t *prev;
+
+	prev = g_threadlock;
+	if (newhandler)
+		g_threadlock = newhandler;
+	else
+		g_threadlock = default_threadlock;
+	return prev;
+}
+
diff --git a/src/interfaces/libpq/fe-secure.c b/src/interfaces/libpq/fe-secure.c
index 0504bdfb34..c72f5f1032 100644
--- a/src/interfaces/libpq/fe-secure.c
+++ b/src/interfaces/libpq/fe-secure.c
@@ -11,7 +11,7 @@
  *
  *
  * IDENTIFICATION
- *	  $PostgreSQL: pgsql/src/interfaces/libpq/fe-secure.c,v 1.37 2004/02/10 15:21:24 momjian Exp $
+ *	  $PostgreSQL: pgsql/src/interfaces/libpq/fe-secure.c,v 1.38 2004/03/24 03:44:59 momjian Exp $
  *
  * NOTES
  *	  The client *requires* a valid server certificate.  Since
@@ -135,11 +135,13 @@ static DH  *load_dh_file(int keylength);
 static DH  *load_dh_buffer(const char *, size_t);
 static DH  *tmp_dh_cb(SSL *s, int is_export, int keylength);
 static int	client_cert_cb(SSL *, X509 **, EVP_PKEY **);
+static int	init_ssl_system(PGconn *conn);
 static int	initialize_SSL(PGconn *);
 static void destroy_SSL(void);
 static PostgresPollingStatusType open_client_SSL(PGconn *);
 static void close_SSL(PGconn *);
-static const char *SSLerrmessage(void);
+static char *SSLerrmessage(void);
+static void SSLerrfree(char *buf);
 #endif
 
 #ifdef USE_SSL
@@ -251,9 +253,11 @@ pqsecure_open_client(PGconn *conn)
 			!SSL_set_app_data(conn->ssl, conn) ||
 			!SSL_set_fd(conn->ssl, conn->sock))
 		{
+			char *err = SSLerrmessage();
 			printfPQExpBuffer(&conn->errorMessage,
 			   libpq_gettext("could not establish SSL connection: %s\n"),
-							  SSLerrmessage());
+							  err);
+			SSLerrfree(err);
 			close_SSL(conn);
 			return PGRES_POLLING_FAILED;
 		}
@@ -327,8 +331,12 @@ rloop:
 					break;
 				}
 			case SSL_ERROR_SSL:
-				printfPQExpBuffer(&conn->errorMessage,
-					  libpq_gettext("SSL error: %s\n"), SSLerrmessage());
+				{
+					char *err = SSLerrmessage();
+					printfPQExpBuffer(&conn->errorMessage,
+						  libpq_gettext("SSL error: %s\n"), err);
+					SSLerrfree(err);
+				}
 				/* fall through */
 			case SSL_ERROR_ZERO_RETURN:
 				SOCK_ERRNO_SET(ECONNRESET);
@@ -402,8 +410,12 @@ pqsecure_write(PGconn *conn, const void *ptr, size_t len)
 					break;
 				}
 			case SSL_ERROR_SSL:
-				printfPQExpBuffer(&conn->errorMessage,
-					  libpq_gettext("SSL error: %s\n"), SSLerrmessage());
+				{
+					char *err = SSLerrmessage();
+					printfPQExpBuffer(&conn->errorMessage,
+						  libpq_gettext("SSL error: %s\n"), err);
+					SSLerrfree(err);
+				}
 				/* fall through */
 			case SSL_ERROR_ZERO_RETURN:
 				SOCK_ERRNO_SET(ECONNRESET);
@@ -750,9 +762,11 @@ client_cert_cb(SSL *ssl, X509 **x509, EVP_PKEY **pkey)
 	}
 	if (PEM_read_X509(fp, x509, NULL, NULL) == NULL)
 	{
+		char *err = SSLerrmessage();
 		printfPQExpBuffer(&conn->errorMessage,
 				  libpq_gettext("could not read certificate (%s): %s\n"),
-						  fnbuf, SSLerrmessage());
+						  fnbuf, err);
+		SSLerrfree(err);
 		fclose(fp);
 		return -1;
 	}
@@ -795,9 +809,11 @@ client_cert_cb(SSL *ssl, X509 **x509, EVP_PKEY **pkey)
 	}
 	if (PEM_read_PrivateKey(fp, pkey, cb, NULL) == NULL)
 	{
+		char *err = SSLerrmessage();
 		printfPQExpBuffer(&conn->errorMessage,
 				  libpq_gettext("could not read private key (%s): %s\n"),
-						  fnbuf, SSLerrmessage());
+						  fnbuf, err);
+		SSLerrfree(err);
 		X509_free(*x509);
 		fclose(fp);
 		return -1;
@@ -807,9 +823,11 @@ client_cert_cb(SSL *ssl, X509 **x509, EVP_PKEY **pkey)
 	/* verify that the cert and key go together */
 	if (!X509_check_private_key(*x509, *pkey))
 	{
+		char *err = SSLerrmessage();
 		printfPQExpBuffer(&conn->errorMessage,
 			libpq_gettext("certificate/private key mismatch (%s): %s\n"),
-						  fnbuf, SSLerrmessage());
+						  fnbuf, err);
+		SSLerrfree(err);
 		X509_free(*x509);
 		EVP_PKEY_free(*pkey);
 		return -1;
@@ -819,6 +837,77 @@ client_cert_cb(SSL *ssl, X509 **x509, EVP_PKEY **pkey)
 #endif
 }
 
+#ifdef ENABLE_THREAD_SAFETY
+
+static unsigned long
+pq_threadidcallback(void)
+{
+	return (unsigned long)pthread_self();
+}
+
+static pthread_mutex_t *pq_lockarray;
+static void
+pq_lockingcallback(int mode, int n, const char *file, int line)
+{
+	if (mode & CRYPTO_LOCK) {
+		pthread_mutex_lock(&pq_lockarray[n]);
+	} else {
+		pthread_mutex_unlock(&pq_lockarray[n]);
+	}
+}
+
+bool pq_initssllib = true;
+
+#endif /* ENABLE_THRAD_SAFETY */
+
+static int
+init_ssl_system(PGconn *conn)
+{
+#ifdef ENABLE_THREAD_SAFETY
+static pthread_mutex_t init_mutex = PTHREAD_MUTEX_INITIALIZER;
+
+	pthread_mutex_lock(&init_mutex);
+	
+	if (pq_initssllib && pq_lockarray == NULL) {
+		int i;
+		CRYPTO_set_id_callback(pq_threadidcallback);
+
+		pq_lockarray = malloc(sizeof(pthread_mutex_t)*CRYPTO_num_locks());
+		if (!pq_lockarray) {
+			pthread_mutex_unlock(&init_mutex);
+			return -1;
+		}
+		for (i=0;i<CRYPTO_num_locks();i++)
+			pthread_mutex_init(&pq_lockarray[i], NULL);
+
+		CRYPTO_set_locking_callback(pq_lockingcallback);
+	}
+#endif
+	if (!SSL_context)
+	{
+		if (pq_initssllib) {
+			SSL_library_init();
+			SSL_load_error_strings();
+		}
+		SSL_context = SSL_CTX_new(TLSv1_method());
+		if (!SSL_context)
+		{
+			char *err = SSLerrmessage();
+			printfPQExpBuffer(&conn->errorMessage,
+					 libpq_gettext("could not create SSL context: %s\n"),
+							  err);
+			SSLerrfree(err);
+#ifdef ENABLE_THREAD_SAFETY
+			pthread_mutex_unlock(&init_mutex);
+#endif
+			return -1;
+		}
+	}
+#ifdef ENABLE_THREAD_SAFETY
+	pthread_mutex_unlock(&init_mutex);
+#endif
+	return 0;
+}
 /*
  *	Initialize global SSL context.
  */
@@ -833,19 +922,8 @@ initialize_SSL(PGconn *conn)
 	char		fnbuf[2048];
 #endif
 
-	if (!SSL_context)
-	{
-		SSL_library_init();
-		SSL_load_error_strings();
-		SSL_context = SSL_CTX_new(TLSv1_method());
-		if (!SSL_context)
-		{
-			printfPQExpBuffer(&conn->errorMessage,
-					 libpq_gettext("could not create SSL context: %s\n"),
-							  SSLerrmessage());
-			return -1;
-		}
-	}
+	if(!init_ssl_system(conn))
+		return -1;
 
 #ifndef WIN32
 	if (pqGetpwuid(getuid(), &pwdstr, pwdbuf, sizeof(pwdbuf), &pwd) == 0)
@@ -867,9 +945,11 @@ initialize_SSL(PGconn *conn)
 		}
 		if (!SSL_CTX_load_verify_locations(SSL_context, fnbuf, 0))
 		{
+			char *err = SSLerrmessage();
 			printfPQExpBuffer(&conn->errorMessage,
 							  libpq_gettext("could not read root certificate list (%s): %s\n"),
-							  fnbuf, SSLerrmessage());
+							  fnbuf, err);
+			SSLerrfree(err);
 			return -1;
 		}
 	}
@@ -936,10 +1016,14 @@ open_client_SSL(PGconn *conn)
 					return PGRES_POLLING_FAILED;
 				}
 			case SSL_ERROR_SSL:
-				printfPQExpBuffer(&conn->errorMessage,
-					  libpq_gettext("SSL error: %s\n"), SSLerrmessage());
-				close_SSL(conn);
-				return PGRES_POLLING_FAILED;
+				{
+					char *err = SSLerrmessage();
+					printfPQExpBuffer(&conn->errorMessage,
+						  libpq_gettext("SSL error: %s\n"), err);
+					SSLerrfree(err);
+					close_SSL(conn);
+					return PGRES_POLLING_FAILED;
+				}
 
 			default:
 				printfPQExpBuffer(&conn->errorMessage,
@@ -973,9 +1057,11 @@ open_client_SSL(PGconn *conn)
 	conn->peer = SSL_get_peer_certificate(conn->ssl);
 	if (conn->peer == NULL)
 	{
+		char *err = SSLerrmessage();
 		printfPQExpBuffer(&conn->errorMessage,
 				libpq_gettext("certificate could not be obtained: %s\n"),
-						  SSLerrmessage());
+						  err);
+		SSLerrfree(err);
 		close_SSL(conn);
 		return PGRES_POLLING_FAILED;
 	}
@@ -1036,23 +1122,40 @@ close_SSL(PGconn *conn)
  * return NULL if it doesn't recognize the error code.  We don't
  * want to return NULL ever.
  */
-static const char *
+static char ssl_nomem[] = "Out of memory allocating error description";
+#define SSL_ERR_LEN	128
+
+static char *
 SSLerrmessage(void)
 {
 	unsigned long errcode;
 	const char *errreason;
-	static char errbuf[32];
+	char *errbuf;
 
+	errbuf = malloc(SSL_ERR_LEN);
+	if (!errbuf)
+		return ssl_nomem;
 	errcode = ERR_get_error();
-	if (errcode == 0)
-		return "No SSL error reported";
+	if (errcode == 0) {
+		strcpy(errbuf, "No SSL error reported");
+		return errbuf;
+	}
 	errreason = ERR_reason_error_string(errcode);
-	if (errreason != NULL)
-		return errreason;
-	snprintf(errbuf, sizeof(errbuf), "SSL error code %lu", errcode);
+	if (errreason != NULL) {
+		strncpy(errbuf, errreason, SSL_ERR_LEN-1);
+		errbuf[SSL_ERR_LEN-1] = '\0';
+		return errbuf;
+	}
+	snprintf(errbuf, SSL_ERR_LEN, "SSL error code %lu", errcode);
 	return errbuf;
 }
 
+static void
+SSLerrfree(char *buf)
+{
+	if (buf != ssl_nomem)
+		free(buf);
+}
 /*
  *	Return pointer to SSL object.
  */
diff --git a/src/interfaces/libpq/libpq-fe.h b/src/interfaces/libpq/libpq-fe.h
index 293d50e690..7a143888bb 100644
--- a/src/interfaces/libpq/libpq-fe.h
+++ b/src/interfaces/libpq/libpq-fe.h
@@ -7,7 +7,7 @@
  * Portions Copyright (c) 1996-2003, PostgreSQL Global Development Group
  * Portions Copyright (c) 1994, Regents of the University of California
  *
- * $PostgreSQL: pgsql/src/interfaces/libpq/libpq-fe.h,v 1.103 2004/03/15 10:41:26 ishii Exp $
+ * $PostgreSQL: pgsql/src/interfaces/libpq/libpq-fe.h,v 1.104 2004/03/24 03:44:59 momjian Exp $
  *
  *-------------------------------------------------------------------------
  */
@@ -274,6 +274,20 @@ extern PQnoticeProcessor PQsetNoticeProcessor(PGconn *conn,
 					 PQnoticeProcessor proc,
 					 void *arg);
 
+/*
+ *     Used to set callback that prevents concurrent access to
+ *     non-thread safe functions that libpq needs.
+ *     The default implementation uses a libpq internal mutex.
+ *     Only required for multithreaded apps that use kerberos
+ *     both within their app and for postgresql connections.
+ */
+typedef void (pgthreadlock_t)(int acquire);
+
+extern pgthreadlock_t * PQregisterThreadLock(pgthreadlock_t *newhandler);
+
+void
+PQinitSSL(int do_init);
+
 /* === in fe-exec.c === */
 
 /* Simple synchronous query */
diff --git a/src/interfaces/libpq/libpq-int.h b/src/interfaces/libpq/libpq-int.h
index f2acf2af20..589bf8b076 100644
--- a/src/interfaces/libpq/libpq-int.h
+++ b/src/interfaces/libpq/libpq-int.h
@@ -12,7 +12,7 @@
  * Portions Copyright (c) 1996-2003, PostgreSQL Global Development Group
  * Portions Copyright (c) 1994, Regents of the University of California
  *
- * $PostgreSQL: pgsql/src/interfaces/libpq/libpq-int.h,v 1.85 2004/03/05 01:53:59 tgl Exp $
+ * $PostgreSQL: pgsql/src/interfaces/libpq/libpq-int.h,v 1.86 2004/03/24 03:45:00 momjian Exp $
  *
  *-------------------------------------------------------------------------
  */
@@ -359,6 +359,16 @@ extern char *const pgresStatus[];
 extern int pqPacketSend(PGconn *conn, char pack_type,
 			 const void *buf, size_t buf_len);
 
+#ifdef ENABLE_THREAD_SAFETY
+extern pgthreadlock_t *g_threadlock;
+#define pglock_thread() g_threadlock(true);
+#define pgunlock_thread() g_threadlock(false);
+#else
+#define pglock_thread() ((void)0)
+#define pgunlock_thread() ((void)0)
+#endif
+	 
+
 /* === in fe-exec.c === */
 
 extern void pqSetResultError(PGresult *res, const char *msg);
@@ -448,6 +458,7 @@ extern ssize_t pqsecure_write(PGconn *, const void *ptr, size_t len);
 #ifdef ENABLE_THREAD_SAFETY
 extern void check_sigpipe_handler(void);
 extern pthread_key_t thread_in_send;
+extern bool pq_initssllib;
 #endif
 
 /*
-- 
2.49.0