From c8d5215870a6361a080b3fcf82f7a1707ae20bb5 Mon Sep 17 00:00:00 2001
From: Grant Limberg <grant.limberg@zerotier.com>
Date: Mon, 3 Dec 2018 15:19:15 -0800
Subject: [PATCH] add handling for PGBOUNCER_CONNSTR environment variable.

alows Central controllers to connect to PGBouncer on some threads.  LISTEN threads still require a direct connection to the DB
---
 controller/PostgreSQL.cpp | 25 ++++++++++++++++++-------
 controller/PostgreSQL.hpp |  7 +++++++
 2 files changed, 25 insertions(+), 7 deletions(-)

diff --git a/controller/PostgreSQL.cpp b/controller/PostgreSQL.cpp
index 4182e0f83..5c9b15796 100644
--- a/controller/PostgreSQL.cpp
+++ b/controller/PostgreSQL.cpp
@@ -510,7 +510,7 @@ void PostgreSQL::heartbeat()
 	const char *publicIdentity = publicId;
 	const char *hostname = hostnameTmp;
 
-	PGconn *conn = PQconnectdb(_connString.c_str());
+	PGconn *conn = getPgConn();
 	if (PQstatus(conn) == CONNECTION_BAD) {
 		fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
 		PQfinish(conn);
@@ -519,7 +519,7 @@ void PostgreSQL::heartbeat()
 	while (_run == 1) {
 		if(PQstatus(conn) != CONNECTION_OK) {
 			PQfinish(conn);
-			conn = PQconnectdb(_connString.c_str());
+			conn = getPgConn();
 		}
 		if (conn) {
 			std::string major = std::to_string(ZEROTIER_ONE_VERSION_MAJOR);
@@ -566,7 +566,7 @@ void PostgreSQL::heartbeat()
 
 void PostgreSQL::membersDbWatcher()
 {
-	PGconn *conn = PQconnectdb(_connString.c_str());
+	PGconn *conn = getPgConn(NO_OVERRIDE);
 	if (PQstatus(conn) == CONNECTION_BAD) {
 		fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
 		PQfinish(conn);
@@ -619,7 +619,7 @@ void PostgreSQL::membersDbWatcher()
 
 void PostgreSQL::networksDbWatcher()
 {
-	PGconn *conn = PQconnectdb(_connString.c_str());
+	PGconn *conn = getPgConn(NO_OVERRIDE);
 	if (PQstatus(conn) == CONNECTION_BAD) {
 		fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
 		PQfinish(conn);
@@ -670,7 +670,7 @@ void PostgreSQL::networksDbWatcher()
 
 void PostgreSQL::commitThread()
 {
-	PGconn *conn = PQconnectdb(_connString.c_str());
+	PGconn *conn = getPgConn();
 	if (PQstatus(conn) == CONNECTION_BAD) {
 		fprintf(stderr, "ERROR: Connection to database failed: %s\n", PQerrorMessage(conn));
 		PQfinish(conn);
@@ -1146,7 +1146,7 @@ void PostgreSQL::commitThread()
 
 void PostgreSQL::onlineNotificationThread()
 {
-	PGconn *conn = PQconnectdb(_connString.c_str());
+	PGconn *conn = getPgConn();
 	if (PQstatus(conn) == CONNECTION_BAD) {
 		fprintf(stderr, "Connection to database failed: %s\n", PQerrorMessage(conn));
 		PQfinish(conn);
@@ -1161,7 +1161,7 @@ void PostgreSQL::onlineNotificationThread()
 		if (PQstatus(conn) != CONNECTION_OK) {
 			fprintf(stderr, "ERROR: Online Notification thread lost connection to Postgres.");
 			PQfinish(conn);
-			conn = PQconnectdb(_connString.c_str());
+			conn = getPgConn();
 			continue;
 		}
 
@@ -1328,4 +1328,15 @@ void PostgreSQL::onlineNotificationThread()
 	PQfinish(conn);
 	exit(5);
 }
+
+PGconn *PostgreSQL::getPgConn(OverrideMode m) {
+	if (m == ALLOW_PGBOUNCER_OVERRIDE) {
+		char *connStr = getenv("PGBOUNCER_CONNSTR");
+		if (connStr != NULL) {
+			return PQconnectdb(connStr);
+		}
+	}
+
+	return PQconnectdb(_connString.c_str());
+}
 #endif //ZT_CONTROLLER_USE_LIBPQ
diff --git a/controller/PostgreSQL.hpp b/controller/PostgreSQL.hpp
index 36fe8c9f2..6f127d5ac 100644
--- a/controller/PostgreSQL.hpp
+++ b/controller/PostgreSQL.hpp
@@ -64,6 +64,13 @@ private:
     void commitThread();
     void onlineNotificationThread();
 
+    enum OverrideMode {
+        ALLOW_PGBOUNCER_OVERRIDE = 0,
+        NO_OVERRIDE = 1
+    };
+
+    PGconn * getPgConn( OverrideMode m = ALLOW_PGBOUNCER_OVERRIDE );
+
     std::string _connString;
 
     BlockingQueue<nlohmann::json *> _commitQueue;