diff --git a/node/BandwidthAccount.hpp b/node/BandwidthAccount.hpp index 12c303e41..927037f8f 100644 --- a/node/BandwidthAccount.hpp +++ b/node/BandwidthAccount.hpp @@ -28,16 +28,14 @@ #ifndef _ZT_BWACCOUNT_HPP #define _ZT_BWACCOUNT_HPP +#include #include +#include + #include "Constants.hpp" #include "Utils.hpp" -#ifdef __WINDOWS__ -#define fmin(a,b) (((a) <= (b)) ? (a) : (b)) -#define fmax(a,b) (((a) >= (b)) ? (a) : (b)) -#endif - namespace ZeroTier { /** @@ -56,27 +54,6 @@ namespace ZeroTier { class BandwidthAccount { public: - /** - * Rate of balance accrual and min/max - */ - struct Accrual - { - /** - * Rate of balance accrual in bytes per second - */ - double bytesPerSecond; - - /** - * Maximum balance that can ever be accrued (should be > 0.0) - */ - double maxBalance; - - /** - * Minimum balance, or maximum allowable "debt" (should be <= 0.0) - */ - double minBalance; - }; - /** * Create an uninitialized account * @@ -88,43 +65,55 @@ public: * Create and initialize * * @param preload Initial balance to place in account + * @param minb Minimum allowed balance (or maximum debt) (<= 0) + * @param maxb Maximum allowed balance (> 0) + * @param acc Rate of accrual in bytes per second */ - BandwidthAccount(double preload) + BandwidthAccount(int32_t preload,int32_t minb,int32_t maxb,int32_t acc) throw() { - init(preload); + init(preload,minb,maxb,acc); } /** * Initialize or re-initialize account * * @param preload Initial balance to place in account + * @param minb Minimum allowed balance (or maximum debt) (<= 0) + * @param maxb Maximum allowed balance (> 0) + * @param acc Rate of accrual in bytes per second */ - inline void init(double preload) + inline void init(int32_t preload,int32_t minb,int32_t maxb,int32_t acc) throw() { _lastTime = Utils::nowf(); _balance = preload; + _minBalance = minb; + _maxBalance = maxb; + _accrual = acc; } /** * Update balance by accruing and then deducting * - * @param ar Current rate of accrual * @param deduct Amount to deduct, or 0.0 to just update * @return New balance with deduction applied */ - inline double update(const Accrual &ar,double deduct) + inline int32_t update(int32_t deduct) throw() { double lt = _lastTime; - double now = _lastTime = Utils::nowf(); - return (_balance = fmax(ar.minBalance,fmin(ar.maxBalance,(_balance + (ar.bytesPerSecond * (now - lt))) - deduct))); + double now = Utils::nowf(); + _lastTime = now; + return (_balance = std::max(_minBalance,std::min(_maxBalance,(int32_t)round(((double)_balance) + (((double)_accrual) * (now - lt))) - deduct))); } private: double _lastTime; - double _balance; + int32_t _balance; + int32_t _minBalance; + int32_t _maxBalance; + int32_t _accrual; }; } // namespace ZeroTier diff --git a/node/Network.cpp b/node/Network.cpp index bc651661e..60f87f927 100644 --- a/node/Network.cpp +++ b/node/Network.cpp @@ -76,27 +76,13 @@ bool Network::Certificate::qualifyMembership(const Network::Certificate &mc) con if (myField->second != theirField->second) return false; } else { - // Otherwise compare range with max delta. Presence of a dot in delta - // indicates a floating point comparison. Otherwise an integer - // comparison occurs. - if (deltaField->second.find('.') != std::string::npos) { - double my = Utils::strToDouble(myField->second.c_str()); - double their = Utils::strToDouble(theirField->second.c_str()); - double delta = Utils::strToDouble(deltaField->second.c_str()); - if (fabs(my - their) > delta) - return false; - } else { - uint64_t my = Utils::hexStrToU64(myField->second.c_str()); - uint64_t their = Utils::hexStrToU64(theirField->second.c_str()); - uint64_t delta = Utils::hexStrToU64(deltaField->second.c_str()); - if (my > their) { - if ((my - their) > delta) - return false; - } else { - if ((their - my) > delta) - return false; - } - } + // Otherwise compare the absolute value of the difference between + // the two values against the max delta. + int64_t my = Utils::hexStrTo64(myField->second.c_str()); + int64_t their = Utils::hexStrTo64(theirField->second.c_str()); + int64_t delta = Utils::hexStrTo64(deltaField->second.c_str()); + if (llabs((long long)(my - their)) > delta) + return false; } } } diff --git a/node/Network.hpp b/node/Network.hpp index 6e79705d5..880355866 100644 --- a/node/Network.hpp +++ b/node/Network.hpp @@ -142,7 +142,8 @@ public: * Key is multicast group in lower case hex format: MAC (without :s) / * ADI (hex). Value is a comma-delimited list of: preload, min, max, * rate of accrual for bandwidth accounts. A key called '*' indicates - * the default for unlisted groups. + * the default for unlisted groups. Values are in hexadecimal and may + * be prefixed with '-' to indicate a negative value. */ class MulticastRates : private Dictionary { @@ -153,16 +154,17 @@ public: struct Rate { Rate() {} - Rate(double pl,double minr,double maxr,double bps) + Rate(int32_t pl,int32_t minb,int32_t maxb,int32_t acc) { preload = pl; - accrual.bytesPerSecond = bps; - accrual.maxBalance = maxr; - accrual.minBalance = minr; + minBalance = minb; + maxBalance = maxb; + accrual = acc; } - - double preload; - BandwidthAccount::Accrual accrual; + int32_t preload; + int32_t minBalance; + int32_t maxBalance; + int32_t accrual; }; MulticastRates() {} @@ -178,7 +180,7 @@ public: /** * @return Default rate, or GLOBAL_DEFAULT_RATE if not specified */ - Rate defaultRate() const + inline Rate defaultRate() const { Rate r; const_iterator dfl(find("*")); @@ -193,7 +195,7 @@ public: * @param mg Multicast group * @return Rate or default() rate if not specified */ - Rate get(const MulticastGroup &mg) const + inline Rate get(const MulticastGroup &mg) const { const_iterator r(find(mg.toString())); if (r == end()) @@ -206,26 +208,22 @@ public: { char tmp[16384]; Utils::scopy(tmp,sizeof(tmp),s.c_str()); - Rate r; - r.preload = 0.0; - r.accrual.bytesPerSecond = 0.0; - r.accrual.maxBalance = 0.0; - r.accrual.minBalance = 0.0; + Rate r(0,0,0,0); char *saveptr = (char *)0; unsigned int fn = 0; for(char *f=Utils::stok(tmp,",",&saveptr);(f);f=Utils::stok((char *)0,",",&saveptr)) { switch(fn++) { case 0: - r.preload = Utils::strToDouble(f); + r.preload = (int32_t)Utils::hexStrToLong(f); break; case 1: - r.accrual.minBalance = Utils::strToDouble(f); + r.minBalance = (int32_t)Utils::hexStrToLong(f); break; case 2: - r.accrual.maxBalance = Utils::strToDouble(f); + r.maxBalance = (int32_t)Utils::hexStrToLong(f); break; case 3: - r.accrual.bytesPerSecond = Utils::strToDouble(f); + r.accrual = (int32_t)Utils::hexStrToLong(f); break; } } @@ -538,10 +536,24 @@ public: else return ((_etWhitelist[etherType / 8] & (unsigned char)(1 << (etherType % 8))) != 0); } + /** + * Update multicast balance for an address and multicast group, return whether packet is allowed + * + * @param a Address that wants to send/relay packet + * @param mg Multicast group + * @param bytes Size of packet + * @return True if packet is within budget + */ inline bool updateAndCheckMulticastBalance(const Address &a,const MulticastGroup &mg,unsigned int bytes) { Mutex::Lock _l(_lock); - std::map< std::pair,BandwidthAccount >::iterator bal(_multicastRateAccounts.find(std::pair(a,mg))); + std::pair k(a,mg); + std::map< std::pair,BandwidthAccount >::iterator bal(_multicastRateAccounts.find(k)); + if (bal == _multicastRateAccounts.end()) { + MulticastRates::Rate r(_mcRates.get(mg)); + bal = _multicastRateAccounts.insert(std::make_pair(k,BandwidthAccount(r.preload,r.minBalance,r.maxBalance,r.accrual))).first; + } + return (bal->second.update((int32_t)bytes) < (int32_t)bytes); } private: @@ -563,6 +575,7 @@ private: // Configuration from network master node Config _configuration; Certificate _myCertificate; + MulticastRates _mcRates; // Ethertype whitelist bit field, set from config, for really fast lookup unsigned char _etWhitelist[65536 / 8]; diff --git a/node/Utils.hpp b/node/Utils.hpp index 15121e282..329d4c4c7 100644 --- a/node/Utils.hpp +++ b/node/Utils.hpp @@ -461,24 +461,44 @@ public: #endif } - // String to number converters + // String to number converters -- defined here to permit portability + // ifdefs for platforms that lack some of the strtoXX functions. static inline unsigned int strToUInt(const char *s) throw() { return (unsigned int)strtoul(s,(char **)0,10); } + static inline int strToInt(const char *s) + throw() + { + return (int)strtol(s,(char **)0,10); + } static inline unsigned long strToULong(const char *s) throw() { return strtoul(s,(char **)0,10); } + static inline long strToLong(const char *s) + throw() + { + return strtol(s,(char **)0,10); + } static inline unsigned long long strToU64(const char *s) throw() { #ifdef __WINDOWS__ - return _strtoui64(s,(char **)0,10); + return (unsigned long long)_strtoui64(s,(char **)0,10); #else return strtoull(s,(char **)0,10); +#endif + } + static inline long long strTo64(const char *s) + throw() + { +#ifdef __WINDOWS__ + return (long long)_strtoi64(s,(char **)0,10); +#else + return strtoll(s,(char **)0,10); #endif } static inline unsigned int hexStrToUInt(const char *s) @@ -486,18 +506,37 @@ public: { return (unsigned int)strtoul(s,(char **)0,16); } + static inline int hexStrToInt(const char *s) + throw() + { + return (int)strtol(s,(char **)0,16); + } static inline unsigned long hexStrToULong(const char *s) throw() { return strtoul(s,(char **)0,16); } + static inline long hexStrToLong(const char *s) + throw() + { + return strtol(s,(char **)0,16); + } static inline unsigned long long hexStrToU64(const char *s) throw() { #ifdef __WINDOWS__ - return _strtoui64(s,(char **)0,16); + return (unsigned long long)_strtoui64(s,(char **)0,16); #else return strtoull(s,(char **)0,16); +#endif + } + static inline long long hexStrTo64(const char *s) + throw() + { +#ifdef __WINDOWS__ + return (long long)_strtoi64(s,(char **)0,16); +#else + return strtoll(s,(char **)0,16); #endif } static inline double strToDouble(const char *s)