diff --git a/controller/DB.cpp b/controller/DB.cpp index 6a97746e7..2edcadbbe 100644 --- a/controller/DB.cpp +++ b/controller/DB.cpp @@ -59,6 +59,7 @@ void DB::initNetwork(nlohmann::json &network) void DB::initMember(nlohmann::json &member) { if (!member.count("authorized")) member["authorized"] = false; + if (!member.count("ssoExempt")) member["ssoExempt"] = false; if (!member.count("ipAssignments")) member["ipAssignments"] = nlohmann::json::array(); if (!member.count("activeBridge")) member["activeBridge"] = false; if (!member.count("tags")) member["tags"] = nlohmann::json::array(); diff --git a/controller/DB.hpp b/controller/DB.hpp index 0776c13bd..67017f855 100644 --- a/controller/DB.hpp +++ b/controller/DB.hpp @@ -31,6 +31,7 @@ #include #include #include +#include #include "../ext/json/json.hpp" diff --git a/controller/DBMirrorSet.cpp b/controller/DBMirrorSet.cpp index c89b7762d..fd508342d 100644 --- a/controller/DBMirrorSet.cpp +++ b/controller/DBMirrorSet.cpp @@ -36,7 +36,7 @@ DBMirrorSet::DBMirrorSet(DB::ChangeListener *listener) : } for(auto db=dbs.begin();db!=dbs.end();++db) { - (*db)->each([this,&dbs,&db](uint64_t networkId,const nlohmann::json &network,uint64_t memberId,const nlohmann::json &member) { + (*db)->each([&dbs,&db](uint64_t networkId,const nlohmann::json &network,uint64_t memberId,const nlohmann::json &member) { try { if (network.is_object()) { if (memberId == 0) { @@ -240,4 +240,52 @@ void DBMirrorSet::onNetworkMemberDeauthorize(const void *db,uint64_t networkId,u _listener->onNetworkMemberDeauthorize(this,networkId,memberId); } +std::vector> DBMirrorSet::membersExpiringSoon() +{ + std::vector> soon; + std::unique_lock l(_membersExpiringSoon_l); + int64_t now = OSUtils::now(); + for(auto next=_membersExpiringSoon.begin();next!=_membersExpiringSoon.end();) { + if (next->first <= now) { + // Already expired, so the node will need to re-auth. + _membersExpiringSoon.erase(next++); + } else { + const uint64_t nwid = next->second.first; + const uint64_t memberId = next->second.second; + nlohmann::json network, member; + if (this->get(nwid, network, memberId, member)) { + try { + const bool authorized = member["authorized"]; + const bool ssoExempt = member["ssoExempt"]; + const int64_t authenticationExpiryTime = member["authenticationExpiryTime"]; + if ((authenticationExpiryTime == next->first)&&(authorized)&&(!ssoExempt)) { + if ((authenticationExpiryTime - now) > 10000) { + // Stop when we get to entries more than 10s in the future. + break; + } else { + soon.push_back(std::pair(nwid, memberId)); + } + } else { + // Obsolete entry, no longer authorized, or SSO exempt. + _membersExpiringSoon.erase(next++); + } + } catch ( ... ) { + // Invalid member object, erase. + _membersExpiringSoon.erase(next++); + } + } else { + // Not found, so erase. + _membersExpiringSoon.erase(next++); + } + } + } + return soon; +} + +void DBMirrorSet::memberExpiring(int64_t expTime, uint64_t nwid, uint64_t memberId) +{ + std::unique_lock l(_membersExpiringSoon_l); + _membersExpiringSoon.insert(std::pair< int64_t, std::pair< uint64_t, uint64_t > >(expTime, std::pair< uint64_t, uint64_t >(nwid, memberId))); +} + } // namespace ZeroTier diff --git a/controller/DBMirrorSet.hpp b/controller/DBMirrorSet.hpp index d6dd0744a..0a9996eda 100644 --- a/controller/DBMirrorSet.hpp +++ b/controller/DBMirrorSet.hpp @@ -60,12 +60,17 @@ public: _dbs.push_back(db); } + std::vector> membersExpiringSoon(); + void memberExpiring(int64_t expTime, uint64_t nwid, uint64_t memberId); + private: DB::ChangeListener *const _listener; std::atomic_bool _running; std::thread _syncCheckerThread; std::vector< std::shared_ptr< DB > > _dbs; mutable std::mutex _dbs_l; + std::multimap< int64_t, std::pair > _membersExpiringSoon; + mutable std::mutex _membersExpiringSoon_l; }; } // namespace ZeroTier diff --git a/controller/EmbeddedNetworkController.cpp b/controller/EmbeddedNetworkController.cpp index e2eaf75b6..0f5a4efba 100644 --- a/controller/EmbeddedNetworkController.cpp +++ b/controller/EmbeddedNetworkController.cpp @@ -1815,17 +1815,37 @@ void EmbeddedNetworkController::_startThreads() _threads.emplace_back([this]() { for(;;) { _RQEntry *qe = (_RQEntry *)0; - if (!_queue.get(qe)) + auto timedWaitResult = _queue.get(qe, 1000); + if (timedWaitResult == BlockingQueue<_RQEntry *>::STOP) { break; - try { + } else if (timedWaitResult == BlockingQueue<_RQEntry *>::OK) { if (qe) { - _request(qe->nwid,qe->fromAddr,qe->requestPacketId,qe->identity,qe->metaData); + try { + _request(qe->nwid,qe->fromAddr,qe->requestPacketId,qe->identity,qe->metaData); + } catch (std::exception &e) { + fprintf(stderr,"ERROR: exception in controller request handling thread: %s" ZT_EOL_S,e.what()); + } catch ( ... ) { + fprintf(stderr,"ERROR: exception in controller request handling thread: unknown exception" ZT_EOL_S); + } delete qe; } - } catch (std::exception &e) { - fprintf(stderr,"ERROR: exception in controller request handling thread: %s" ZT_EOL_S,e.what()); - } catch ( ... ) { - fprintf(stderr,"ERROR: exception in controller request handling thread: unknown exception" ZT_EOL_S); + } + + auto expiringSoon = _db.membersExpiringSoon(); + for(auto soon=expiringSoon.begin();soon!=expiringSoon.end();++soon) { + Identity identity; + Dictionary lastMetaData; + { + std::unique_lock ll(_memberStatus_l); + auto ms = _memberStatus.find(_MemberStatusKey(soon->first, soon->second)); + if (ms != _memberStatus.end()) { + lastMetaData = ms->second.lastRequestMetaData; + identity = ms->second.identity; + } + } + if (identity) { + request(soon->first,InetAddress(),0,identity,lastMetaData); + } } } });