diff --git a/root/root.cpp b/root/root.cpp index c54f58691..28a9d481b 100644 --- a/root/root.cpp +++ b/root/root.cpp @@ -208,15 +208,13 @@ static std::string s_planet; static std::vector< SharedPtr > s_peers; static std::vector< SharedPtr > s_peersToValidate; static std::unordered_map< uint64_t,std::unordered_map< MulticastGroup,std::unordered_map< Address,int64_t,AddressHasher >,MulticastGroupHasher > > s_multicastSubscriptions; -static std::unordered_map< Identity,SharedPtr,IdentityHasher > s_peersByIdentity; -static std::unordered_map< Address,std::set< SharedPtr >,AddressHasher > s_peersByVirtAddr; +static std::unordered_map< Address,SharedPtr,AddressHasher > s_peersByVirtAddr; static std::unordered_map< RendezvousKey,RendezvousStats,RendezvousKey::Hasher > s_rendezvousTracking; static std::mutex s_planet_l; static std::mutex s_peers_l; static std::mutex s_peersToValidate_l; static std::mutex s_multicastSubscriptions_l; -static std::mutex s_peersByIdentity_l; static std::mutex s_peersByVirtAddr_l; static std::mutex s_rendezvousTracking_l; @@ -261,38 +259,34 @@ static void handlePacket(const int sock,const InetAddress *const ip,Packet &pkt) Identity id; if (id.deserialize(pkt,ZT_PROTO_VERB_HELLO_IDX_IDENTITY)) { { - std::lock_guard pbi_l(s_peersByIdentity_l); - auto pById = s_peersByIdentity.find(id); - if (pById != s_peersByIdentity.end()) { - peer = pById->second; - //printf("%s has %s (known (1))" ZT_EOL_S,ip->toString(ipstr),source().toString(astr)); + std::lock_guard p_l(s_peersByVirtAddr_l); + auto p = s_peersByVirtAddr.find(source); + if (p != s_peersByVirtAddr.end()) { + peer = p->second; } } - if (peer) { - // Peer found with this identity. - if (!pkt.dearmor(peer->key)) { - printf("%s HELLO rejected: packet authentication failed" ZT_EOL_S,ip->toString(ipstr)); - return; - } - } else { - // Check to ensure that there is no peer with the same address as this identity. If there is, - // verify both identities to pick the one with this address. - bool needsValidation = false; - bool identityValidated = false; - { - std::lock_guard pbv_l(s_peersByVirtAddr_l); - needsValidation = s_peersByVirtAddr.find(id.address()) != s_peersByVirtAddr.end(); - } - if (unlikely(needsValidation)) { - if (!id.locallyValidate()) { - printf("%s HELLO rejected: identity validate failed" ZT_EOL_S,ip->toString(ipstr)); - return; - } - identityValidated = true; - } + if (peer) { + if (unlikely(peer->id != id)) { + if (!peer->identityValidated) { + peer->identityValidated = peer->id.locallyValidate(); + if (peer->identityValidated) { + printf("%s HELLO rejected: identity address collision!" ZT_EOL_S,ip->toString(ipstr)); + // TODO: send error + return; + } else { + printf("* invalid identity found and discarded: %s" ZT_EOL_S,id.toString(false, tmpstr)); + std::lock_guard p_l(s_peersByVirtAddr_l); + s_peersByVirtAddr.erase(source); + peer.zero(); + } + } + } + } + + if (!peer) { peer.set(new RootPeer); - peer->identityValidated = identityValidated; + peer->identityValidated = false; if (!s_self.agree(id,peer->key)) { printf("%s HELLO rejected: key agreement failed" ZT_EOL_S,ip->toString(ipstr)); @@ -310,64 +304,41 @@ static void handlePacket(const int sock,const InetAddress *const ip,Packet &pkt) peer->id = id; peer->lastReceive = now; - bool added = false; { - std::lock_guard pbi_l(s_peersByIdentity_l); - auto existing = s_peersByIdentity.find(id); // make sure another thread didn't do this while we were - if (likely(existing == s_peersByIdentity.end())) { - s_peersByIdentity.emplace(id,peer); - added = true; - } else { - peer = existing->second; - } + std::lock_guard pbv_l(s_peersByVirtAddr_l); + s_peersByVirtAddr[id.address()] = peer; } - if (likely(added)) { - { - std::lock_guard pl(s_peers_l); - s_peers.emplace_back(peer); - } - if (!peer->identityValidated) { - std::lock_guard pv(s_peersToValidate_l); - s_peersToValidate.emplace_back(peer); - } - { - std::lock_guard pbv_l(s_peersByVirtAddr_l); - std::set< SharedPtr > &byVirt = s_peersByVirtAddr[id.address()]; - for(auto i=byVirt.begin();i!=byVirt.end();++i) { - if (!(*i)->identityValidated) - (*i)->identityInvalid = !(*i)->id.locallyValidate(); - } - s_peersByVirtAddr[id.address()].emplace(peer); - } + { + std::lock_guard pl(s_peers_l); + s_peers.emplace_back(peer); + } + { + std::lock_guard pv(s_peersToValidate_l); + s_peersToValidate.emplace_back(peer); } } } } - // If it wasn't a HELLO, check to see if any known identities for the sender's - // short ZT address successfully decrypt the packet. if (!peer) { - std::lock_guard pbv_l(s_peersByVirtAddr_l); - auto peers = s_peersByVirtAddr.find(source); - if (peers != s_peersByVirtAddr.end()) { - for(auto p=peers->second.begin();p!=peers->second.end();++p) { - if (!(*p)->identityInvalid) { - if (pkt.dearmor((*p)->key)) { - if (!pkt.uncompress()) { - printf("%s packet rejected: decompression failed" ZT_EOL_S,ip->toString(ipstr)); - return; - } - peer = (*p); - break; - } - } + { + std::lock_guard pbv_l(s_peersByVirtAddr_l); + auto p = s_peersByVirtAddr.find(source); + if (p != s_peersByVirtAddr.end()) { + peer = p->second; } } + if (!pkt.dearmor(peer->key)) { + printf("%s HELLO rejected: packet authentication failed" ZT_EOL_S,ip->toString(ipstr)); + return; + } + if (!pkt.uncompress()) { + printf("%s packet rejected: decompression failed" ZT_EOL_S,ip->toString(ipstr)); + return; + } } - // If we found the peer, update IP and/or time and handle certain key packet types that the - // root must concern itself with. - if (peer) { + if (likely(peer)) { const int64_t now = OSUtils::now(); if (ip->isV4()) { @@ -452,12 +423,9 @@ static void handlePacket(const int sock,const InetAddress *const ip,Packet &pkt) { std::lock_guard l(s_peersByVirtAddr_l); for(unsigned int ptr=ZT_PACKET_IDX_PAYLOAD;(ptr+ZT_ADDRESS_LENGTH)<=pkt.size();ptr+=ZT_ADDRESS_LENGTH) { - auto peers = s_peersByVirtAddr.find(Address(pkt.field(ptr,ZT_ADDRESS_LENGTH),ZT_ADDRESS_LENGTH)); - if (peers != s_peersByVirtAddr.end()) { - for(auto p=peers->second.begin();p!=peers->second.end();++p) { - if (!(*p)->identityInvalid) - results.push_back(*p); - } + auto p = s_peersByVirtAddr.find(Address(pkt.field(ptr,ZT_ADDRESS_LENGTH),ZT_ADDRESS_LENGTH)); + if (p != s_peersByVirtAddr.end()) { + results.push_back(p->second); } } } @@ -581,109 +549,91 @@ static void handlePacket(const int sock,const InetAddress *const ip,Packet &pkt) } } - std::vector< std::pair< InetAddress *,SharedPtr > > toAddrs; - toAddrs.reserve(4); + std::pair< InetAddress,SharedPtr > forwardTo; { std::lock_guard pbv_l(s_peersByVirtAddr_l); - auto peers = s_peersByVirtAddr.find(dest); - if (peers != s_peersByVirtAddr.end()) { - for(auto p=peers->second.begin();p!=peers->second.end();++p) { - if (!(*p)->identityInvalid) { - if (((*p)->v4s >= 0)&&((*p)->v6s >= 0)) { - if ((*p)->lastReceiveV4 > (*p)->lastReceiveV6) { - toAddrs.emplace_back(std::pair< InetAddress *,SharedPtr >(&((*p)->ip4),*p)); - } else { - toAddrs.emplace_back(std::pair< InetAddress *,SharedPtr >(&((*p)->ip6),*p)); - } - } else if ((*p)->v4s >= 0) { - toAddrs.emplace_back(std::pair< InetAddress *,SharedPtr >(&((*p)->ip4),*p)); - } else if ((*p)->v6s >= 0) { - toAddrs.emplace_back(std::pair< InetAddress *,SharedPtr >(&((*p)->ip6),*p)); - } + auto p = s_peersByVirtAddr.find(dest); + if (p != s_peersByVirtAddr.end()) { + if ((p->second->v4s >= 0)&&(p->second->v6s >= 0)) { + if (p->second->lastReceiveV4 > p->second->lastReceiveV6) { + forwardTo = std::pair< InetAddress,SharedPtr >(p->second->ip4,p->second); + } else { + forwardTo = std::pair< InetAddress,SharedPtr >(p->second->ip6,p->second); } + } else if (p->second->v4s >= 0) { + forwardTo = std::pair< InetAddress,SharedPtr >(p->second->ip4,p->second); + } else if (p->second->v6s >= 0) { + forwardTo = std::pair< InetAddress,SharedPtr >(p->second->ip6,p->second); } } } - if (toAddrs.empty()) { + + if (unlikely(!forwardTo.second)) { s_discardedForwardRate.log(now,pkt.size()); return; } if (introduce) { std::lock_guard l(s_peersByVirtAddr_l); - auto sources = s_peersByVirtAddr.find(source); - if (sources != s_peersByVirtAddr.end()) { - for(auto a=sources->second.begin();a!=sources->second.end();++a) { - for(auto b=toAddrs.begin();b!=toAddrs.end();++b) { - if (((*a)->v6s >= 0)&&(b->second->v6s >= 0)) { - //printf("* introducing %s(%s) to %s(%s)" ZT_EOL_S,ip->toString(ipstr),source.toString(astr),b->second->ip6.toString(ipstr2),dest.toString(astr2)); + auto sp = s_peersByVirtAddr.find(source); - // Introduce source to destination (V6) - Packet outp(source,s_self.address(),Packet::VERB_RENDEZVOUS); - outp.append((uint8_t)0); - dest.appendTo(outp); - outp.append((uint16_t)b->second->ip6.port()); - outp.append((uint8_t)16); - outp.append((const uint8_t *)(b->second->ip6.rawIpData()),16); - outp.armor((*a)->key,true); - sendto((*a)->v6s,outp.data(),outp.size(),SENDTO_FLAGS,(const struct sockaddr *)&((*a)->ip6),(socklen_t)sizeof(struct sockaddr_in6)); + if ((sp->second->v6s >= 0)&&(forwardTo.second->v6s >= 0)) { + Packet outp(source,s_self.address(),Packet::VERB_RENDEZVOUS); + outp.append((uint8_t)0); + dest.appendTo(outp); + outp.append((uint16_t)sp->second->ip6.port()); + outp.append((uint8_t)16); + outp.append((const uint8_t *)(sp->second->ip6.rawIpData()),16); + outp.armor(forwardTo.second->key,true); + sendto(forwardTo.second->v6s,outp.data(),outp.size(),SENDTO_FLAGS,(const struct sockaddr *)&forwardTo.first,(socklen_t)sizeof(struct sockaddr_in6)); - s_outputRate.log(now,outp.size()); - (*a)->lastSend = now; + s_outputRate.log(now,outp.size()); + forwardTo.second->lastSend = now; - // Introduce destination to source (V6) - outp.reset(dest,s_self.address(),Packet::VERB_RENDEZVOUS); - outp.append((uint8_t)0); - source.appendTo(outp); - outp.append((uint16_t)(*a)->ip6.port()); - outp.append((uint8_t)16); - outp.append((const uint8_t *)((*a)->ip6.rawIpData()),16); - outp.armor(b->second->key,true); - sendto(b->second->v6s,outp.data(),outp.size(),SENDTO_FLAGS,(const struct sockaddr *)&(b->second->ip6),(socklen_t)sizeof(struct sockaddr_in6)); + outp.reset(dest,s_self.address(),Packet::VERB_RENDEZVOUS); + outp.append((uint8_t)0); + source.appendTo(outp); + outp.append((uint16_t)forwardTo.first.port()); + outp.append((uint8_t)16); + outp.append((const uint8_t *)(forwardTo.first.rawIpData()),16); + outp.armor(sp->second->key,true); + sendto(sp->second->v6s,outp.data(),outp.size(),SENDTO_FLAGS,(const struct sockaddr *)&(sp->second->ip6),(socklen_t)sizeof(struct sockaddr_in6)); - s_outputRate.log(now,outp.size()); - b->second->lastSend = now; - } - if (((*a)->v4s >= 0)&&(b->second->v4s >= 0)) { - //printf("* introducing %s(%s) to %s(%s)" ZT_EOL_S,ip->toString(ipstr),source.toString(astr),b->second->ip4.toString(ipstr2),dest.toString(astr2)); + s_outputRate.log(now,outp.size()); + sp->second->lastSend = now; + } - // Introduce source to destination (V4) - Packet outp(source,s_self.address(),Packet::VERB_RENDEZVOUS); - outp.append((uint8_t)0); - dest.appendTo(outp); - outp.append((uint16_t)b->second->ip4.port()); - outp.append((uint8_t)4); - outp.append((const uint8_t *)b->second->ip4.rawIpData(),4); - outp.armor((*a)->key,true); - sendto((*a)->v4s,outp.data(),outp.size(),SENDTO_FLAGS,(const struct sockaddr *)&((*a)->ip4),(socklen_t)sizeof(struct sockaddr_in)); + if ((sp->second->v4s >= 0)&&(forwardTo.second->v4s >= 0)) { + Packet outp(source,s_self.address(),Packet::VERB_RENDEZVOUS); + outp.append((uint8_t)0); + dest.appendTo(outp); + outp.append((uint16_t)sp->second->ip4.port()); + outp.append((uint8_t)4); + outp.append((const uint8_t *)sp->second->ip4.rawIpData(),4); + outp.armor(forwardTo.second->key,true); + sendto(forwardTo.second->v4s,outp.data(),outp.size(),SENDTO_FLAGS,(const struct sockaddr *)&forwardTo.first,(socklen_t)sizeof(struct sockaddr_in)); - s_outputRate.log(now,outp.size()); - (*a)->lastSend = now; + s_outputRate.log(now,outp.size()); + forwardTo.second->lastSend = now; - // Introduce destination to source (V4) - outp.reset(dest,s_self.address(),Packet::VERB_RENDEZVOUS); - outp.append((uint8_t)0); - source.appendTo(outp); - outp.append((uint16_t)(*a)->ip4.port()); - outp.append((uint8_t)4); - outp.append((const uint8_t *)((*a)->ip4.rawIpData()),4); - outp.armor(b->second->key,true); - sendto(b->second->v6s,outp.data(),outp.size(),SENDTO_FLAGS,(const struct sockaddr *)&(b->second->ip4),(socklen_t)sizeof(struct sockaddr_in)); + outp.reset(dest,s_self.address(),Packet::VERB_RENDEZVOUS); + outp.append((uint8_t)0); + source.appendTo(outp); + outp.append((uint16_t)forwardTo.first.port()); + outp.append((uint8_t)4); + outp.append((const uint8_t *)(forwardTo.first.rawIpData()),4); + outp.armor(sp->second->key,true); + sendto(sp->second->v6s,outp.data(),outp.size(),SENDTO_FLAGS,(const struct sockaddr *)&(sp->second->ip4),(socklen_t)sizeof(struct sockaddr_in)); - s_outputRate.log(now,outp.size()); - b->second->lastSend = now; - } - } - } + s_outputRate.log(now,outp.size()); + sp->second->lastSend = now; } } - for(auto i=toAddrs.begin();i!=toAddrs.end();++i) { - if (sendto(i->first->isV4() ? i->second->v4s : i->second->v6s,pkt.data(),pkt.size(),SENDTO_FLAGS,(const struct sockaddr *)i->first,(socklen_t)(i->first->isV4() ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6))) > 0) { - s_outputRate.log(now,pkt.size()); - s_forwardRate.log(now,pkt.size()); - i->second->lastSend = now; - } + if (sendto(forwardTo.first.isV4() ? forwardTo.second->v4s : forwardTo.second->v6s,pkt.data(),pkt.size(),SENDTO_FLAGS,(const struct sockaddr *)&forwardTo.first,(socklen_t)(forwardTo.first.isV4() ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6))) > 0) { + s_outputRate.log(now,pkt.size()); + s_forwardRate.log(now,pkt.size()); + forwardTo.second->lastSend = now; } } @@ -1046,9 +996,9 @@ int main(int argc,char **argv) std::ostringstream o; o << "ZeroTier Root Server " << ZEROTIER_ONE_VERSION_MAJOR << '.' << ZEROTIER_ONE_VERSION_MINOR << '.' << ZEROTIER_ONE_VERSION_REVISION << ZT_EOL_S; o << "(c)2019 ZeroTier, Inc." ZT_EOL_S "Licensed under the ZeroTier BSL 1.1" ZT_EOL_S ZT_EOL_S; - s_peersByIdentity_l.lock(); - o << "Peers Online: " << s_peersByIdentity.size() << ZT_EOL_S; - s_peersByIdentity_l.unlock(); + s_peersByVirtAddr_l.lock(); + o << "Peers Online: " << s_peersByVirtAddr.size() << ZT_EOL_S; + s_peersByVirtAddr_l.unlock(); res.set_content(o.str(),"text/plain"); }); @@ -1060,9 +1010,9 @@ int main(int argc,char **argv) const char *root_id = s_self.address().toString(buf); o << "# HELP root_peers_online Number of active peers online" << ZT_EOL_S; o << "# TYPE root_peers_online gauge" << ZT_EOL_S; - s_peersByIdentity_l.lock(); - o << "root_peers_online{root_id=\"" << root_id << "\"} " << s_peersByIdentity.size() << ZT_EOL_S; - s_peersByIdentity_l.unlock(); + s_peersByVirtAddr_l.lock(); + o << "root_peers_online{root_id=\"" << root_id << "\"} " << s_peersByVirtAddr.size() << ZT_EOL_S; + s_peersByVirtAddr_l.unlock(); o << "# HELP root_input_rate Input rate MiB/s" << ZT_EOL_S; o << "# TYPE root_input_rate gauge" << ZT_EOL_S; o << "root_input_rate{root_id=\"" << root_id << "\"} " << std::setprecision(5) << (s_inputRate.perSecond(now)/1048576.0) << ZT_EOL_S; @@ -1279,17 +1229,11 @@ int main(int argc,char **argv) newPeers.swap(s_peers); } for(auto p=toRemove.begin();p!=toRemove.end();++p) { - { - std::lock_guard pbi_l(s_peersByIdentity_l); - s_peersByIdentity.erase((*p)->id); - } { std::lock_guard pbv_l(s_peersByVirtAddr_l); auto pbv = s_peersByVirtAddr.find((*p)->id.address()); - if (pbv != s_peersByVirtAddr.end()) { - pbv->second.erase(*p); - if (pbv->second.empty()) - s_peersByVirtAddr.erase(pbv); + if ((pbv != s_peersByVirtAddr.end())&&(pbv->second == *p)) { + s_peersByVirtAddr.erase(pbv); } } } @@ -1375,12 +1319,8 @@ int main(int argc,char **argv) FILE *sf = fopen(statsFilePath.c_str(),"wb"); if (sf) { fprintf(sf,"Uptime (seconds) : %ld" ZT_EOL_S,(long)((now - s_startTime) / 1000)); - s_peersByIdentity_l.lock(); - auto peersByIdentitySize = s_peersByIdentity.size(); - s_peersByIdentity_l.unlock(); - fprintf(sf,"Peers : %llu" ZT_EOL_S,(unsigned long long)peersByIdentitySize); s_peersByVirtAddr_l.lock(); - fprintf(sf,"Virtual Address Collisions : %lld" ZT_EOL_S,(long long)peersByIdentitySize - (long long)s_peersByVirtAddr.size()); + fprintf(sf,"Peers : %llu" ZT_EOL_S,(unsigned long long)s_peersByVirtAddr.size()); s_peersByVirtAddr_l.unlock(); s_rendezvousTracking_l.lock(); uint64_t unsuccessfulp2p = 0;