From c470c6255e1b77857e4d83f616e0537e1bd35048 Mon Sep 17 00:00:00 2001 From: Adam Ierymenko Date: Fri, 28 May 2021 17:08:24 -0400 Subject: [PATCH] Postgres code for SSO (almost certainly needs work) --- controller/DB.cpp | 9 ++- controller/DB.hpp | 3 +- controller/EmbeddedNetworkController.cpp | 8 --- controller/EmbeddedNetworkController.hpp | 1 - controller/PostgreSQL.cpp | 92 ++++++++++++++++++++++++ controller/PostgreSQL.hpp | 2 + 6 files changed, 102 insertions(+), 13 deletions(-) diff --git a/controller/DB.cpp b/controller/DB.cpp index 80cdb7fa1..edfbb16b8 100644 --- a/controller/DB.cpp +++ b/controller/DB.cpp @@ -68,7 +68,6 @@ void DB::initMember(nlohmann::json &member) if (!member.count("lastAuthorizedCredentialType")) member["lastAuthorizedCredentialType"] = nlohmann::json(); if (!member.count("lastAuthorizedCredential")) member["lastAuthorizedCredential"] = nlohmann::json(); if (!member.count("authenticationExpiryTime")) member["authenticationExpiryTime"] = -1LL; - if (!member.count("authenticationURL")) member["authenticationURL"] = nlohmann::json(); if (!member.count("vMajor")) member["vMajor"] = -1; if (!member.count("vMinor")) member["vMinor"] = -1; if (!member.count("vRev")) member["vRev"] = -1; @@ -94,6 +93,8 @@ void DB::cleanMember(nlohmann::json &member) member.erase("recentLog"); member.erase("lastModified"); member.erase("lastRequestMetaData"); + member.erase("authenticationURL"); // computed + member.erase("authenticationClientID"); // computed } DB::DB() {} @@ -135,6 +136,7 @@ bool DB::get(const uint64_t networkId,nlohmann::json &network,const uint64_t mem if (m == nw->members.end()) return false; member = m->second; + updateMemberOnLoad(networkId, memberId, member); } return true; } @@ -158,6 +160,7 @@ bool DB::get(const uint64_t networkId,nlohmann::json &network,const uint64_t mem if (m == nw->members.end()) return false; member = m->second; + updateMemberOnLoad(networkId, memberId, member); } return true; } @@ -176,8 +179,10 @@ bool DB::get(const uint64_t networkId,nlohmann::json &network,std::vector l2(nw->lock); network = nw->config; - for(auto m=nw->members.begin();m!=nw->members.end();++m) + for(auto m=nw->members.begin();m!=nw->members.end();++m) { members.push_back(m->second); + updateMemberOnLoad(networkId, m->first, members.back()); + } } return true; } diff --git a/controller/DB.hpp b/controller/DB.hpp index 6a6906eff..8d9e2202c 100644 --- a/controller/DB.hpp +++ b/controller/DB.hpp @@ -101,11 +101,10 @@ public: } virtual bool save(nlohmann::json &record,bool notifyListeners) = 0; - virtual void eraseNetwork(const uint64_t networkId) = 0; virtual void eraseMember(const uint64_t networkId,const uint64_t memberId) = 0; - virtual void nodeIsOnline(const uint64_t networkId,const uint64_t memberId,const InetAddress &physicalAddress) = 0; + virtual void updateMemberOnLoad(const uint64_t networkId, const uint64_t memberId, nlohmann::json &member) {} inline void addListener(DB::ChangeListener *const listener) { diff --git a/controller/EmbeddedNetworkController.cpp b/controller/EmbeddedNetworkController.cpp index a3fc4a6e5..b4eb28b87 100644 --- a/controller/EmbeddedNetworkController.cpp +++ b/controller/EmbeddedNetworkController.cpp @@ -466,14 +466,6 @@ EmbeddedNetworkController::EmbeddedNetworkController(Node *node,const char *ztPa _db(this), _rc(rc) { - memset(_ssoPsk, 0, sizeof(_ssoPsk)); - char *const ssoPskHex = getenv("ZT_SSO_PSK"); - if (ssoPskHex) { - // SECURITY: note that ssoPskHex will always be null-terminated if libc acatually - // returns something non-NULL. If the hex encodes something shorter than 48 bytes, - // it will be padded at the end with zeroes. If longer, it'll be truncated. - Utils::unhex(ssoPskHex, _ssoPsk, sizeof(_ssoPsk)); - } } EmbeddedNetworkController::~EmbeddedNetworkController() diff --git a/controller/EmbeddedNetworkController.hpp b/controller/EmbeddedNetworkController.hpp index 326bdce87..e499dd647 100644 --- a/controller/EmbeddedNetworkController.hpp +++ b/controller/EmbeddedNetworkController.hpp @@ -140,7 +140,6 @@ private: Identity _signingId; std::string _signingIdAddressString; NetworkController::Sender *_sender; - uint8_t _ssoPsk[48]; DBMirrorSet _db; BlockingQueue< _RQEntry * > _queue; diff --git a/controller/PostgreSQL.cpp b/controller/PostgreSQL.cpp index a031c1ff0..eca923a19 100644 --- a/controller/PostgreSQL.cpp +++ b/controller/PostgreSQL.cpp @@ -16,6 +16,7 @@ #ifdef ZT_CONTROLLER_USE_LIBPQ #include "../node/Constants.hpp" +#include "../node/SHA512.hpp" #include "EmbeddedNetworkController.hpp" #include "../version.h" #include "Redis.hpp" @@ -90,6 +91,15 @@ PostgreSQL::PostgreSQL(const Identity &myId, const char *path, int listenPort, R _myAddressStr = myId.address().toString(myAddress); _connString = std::string(path) + " application_name=controller_" + _myAddressStr; + memset(_ssoPsk, 0, sizeof(_ssoPsk)); + char *const ssoPskHex = getenv("ZT_SSO_PSK"); + if (ssoPskHex) { + // SECURITY: note that ssoPskHex will always be null-terminated if libc acatually + // returns something non-NULL. If the hex encodes something shorter than 48 bytes, + // it will be padded at the end with zeroes. If longer, it'll be truncated. + Utils::unhex(ssoPskHex, _ssoPsk, sizeof(_ssoPsk)); + } + // Database Schema Version Check PGconn *conn = getPgConn(); if (PQstatus(conn) != CONNECTION_OK) { @@ -263,6 +273,88 @@ void PostgreSQL::nodeIsOnline(const uint64_t networkId, const uint64_t memberId, } } +void PostgreSQL::updateMemberOnLoad(const uint64_t networkId, const uint64_t memberId, nlohmann::json &member) +{ + const uint64_t nwid = OSUtils::jsonIntHex(member["nwid"],0ULL); + const uint64_t id = OSUtils::jsonIntHex(member["id"],0ULL); + char nwids[24],ids[24]; + OSUtils::ztsnprintf(nwids, sizeof(nwids), "%.16llx", nwid); + OSUtils::ztsnprintf(ids, sizeof(ids), "%.10llx", id); + + bool have_auth = false; + try { + PGconn *conn = getPgConn(); + if (PQstatus(conn) != CONNECTION_OK) { + fprintf(stderr, "Bad Database Connection: %s", PQerrorMessage(conn)); + exit(1); + } + + const char *params[1] = { nwids }; + PGresult *res = PQexecParams(conn, "SELECT org.client_id, org.authorization_endpoint FROM ztc_network AS nw, ztc_org AS org WHERE nw.id = $1 AND nw.sso_enabled = true AND org.owner_id = nw.owner_id", + 1, + NULL, + params, + NULL, + NULL, + 0); + if (PQresultStatus(res) != PGRES_TUPLES_OK) { + fprintf(stderr, "Org client_id and authorization_endpoint lookup failed: %s", PQerrorMessage(conn)); + PQclear(res); + exit(1); + } + + if (PQntuples(res) >= 1) { + std::string client_id = PQgetvalue(res, 0, 0); + std::string authorization_endpoint = PQgetvalue(res, 0, 1); + PQclear(res); + if ((!client_id.empty())&&(!authorization_endpoint.empty())) { + const char *params2[2] = { nwids, ids }; + res = PQexecParams(conn, "SELECT e.nonce, e.authentication_expiry_time FROM ztc_sso_expiry AS e WHERE e.network_id = $1 AND e.member_id = $2 ORDER BY n.authentication_expiry_time DESC LIMIT 1", + 1, + NULL, + params2, + NULL, + NULL, + 0); + if (PQntuples(res) >= 1) { + std::string nonce = PQgetvalue(res, 0, 0); + int64_t authentication_expiry_time = std::stoll(PQgetvalue(res, 0, 1)); + if ((authentication_expiry_time >= 0)&&(!nonce.empty())) { + have_auth = true; + + uint8_t state[48]; + HMACSHA384(_ssoPsk, nonce.data(), (unsigned int)nonce.length(), state); + char state_hex[256]; + Utils::hex(state, 48, state_hex); + char authenticationURL[4096]; + const char *redirect_url = "redirect_uri=http%3A%2F%2Fmy.zerotier.com%2Fapi%2Fnetwork%2Fsso-auth"; // TODO: this should be configurable + Utils::ztsnprintf(authenticationURL, sizeof(authenticationURL), + "%s?response_type=id_token&response_mode=form_post&scope=openid+email+profile&redriect_uri=%s&nonce=%s&state=%s&client_id=%s", + authorization_endpoint.c_str(), + redirect_url, + nonce.c_str(), + state_hex, // NOTE: should these be URL escaped? Don't think there's a risk as they are not user definable. + client_id.c_str()); + + member["authenticationExpiryTime"] = authentication_expiry_time; + member["authenticationURL"] = authenticationURL; + } + } + PQclear(res); + } + } else { + PQclear(res); + } + + } catch (sw::redis::Error &e) { + fprintf(stderr, "ERROR: Error updating member on load, in Redis: %s\n", e.what()); + exit(-1); + } catch (std::exception &e) { + fprintf(stderr, "ERROR: Error updating member on load: %s\n", e.what()); + exit(-1); + } +} + void PostgreSQL::initializeNetworks(PGconn *conn) { try { diff --git a/controller/PostgreSQL.hpp b/controller/PostgreSQL.hpp index c1d9dfd1a..1870b5afb 100644 --- a/controller/PostgreSQL.hpp +++ b/controller/PostgreSQL.hpp @@ -49,6 +49,7 @@ public: virtual void eraseNetwork(const uint64_t networkId); virtual void eraseMember(const uint64_t networkId, const uint64_t memberId); virtual void nodeIsOnline(const uint64_t networkId, const uint64_t memberId, const InetAddress &physicalAddress); + virtual void updateMemberOnLoad(const uint64_t networkId, const uint64_t memberId, nlohmann::json &member); protected: struct _PairHasher @@ -103,6 +104,7 @@ private: mutable volatile bool _waitNoticePrinted; int _listenPort; + uint8_t _ssoPsk[48]; RedisConfig *_rc; std::shared_ptr _redis;