From 4c3d0ea6cc87edc28f39e73766865245eb59d1d2 Mon Sep 17 00:00:00 2001 From: meejah Date: Sat, 21 Dec 2019 00:03:38 -0700 Subject: [PATCH] use 'with open' for more file-opens --- src/allmydata/blacklist.py | 15 ++-- src/allmydata/control.py | 13 +-- src/allmydata/frontends/auth.py | 29 +++--- src/allmydata/scripts/common.py | 26 +++--- src/allmydata/stats.py | 10 +-- src/allmydata/storage/crawler.py | 10 +-- src/allmydata/storage/expirer.py | 16 ++-- src/allmydata/storage/mutable.py | 148 +++++++++++++++---------------- src/allmydata/storage/server.py | 22 +++-- src/allmydata/storage/shares.py | 5 +- src/allmydata/util/configutil.py | 10 +-- src/allmydata/version_checks.py | 24 ++--- 12 files changed, 155 insertions(+), 173 deletions(-) diff --git a/src/allmydata/blacklist.py b/src/allmydata/blacklist.py index 874ff95ca..af1d185d0 100644 --- a/src/allmydata/blacklist.py +++ b/src/allmydata/blacklist.py @@ -34,13 +34,14 @@ class Blacklist(object): try: if self.last_mtime is None or current_mtime > self.last_mtime: self.entries.clear() - for line in open(self.blacklist_fn, "r").readlines(): - line = line.strip() - if not line or line.startswith("#"): - continue - si_s, reason = line.split(None, 1) - si = base32.a2b(si_s) # must be valid base32 - self.entries[si] = reason + with open(self.blacklist_fn, "r") as f: + for line in f.readlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + si_s, reason = line.split(None, 1) + si = base32.a2b(si_s) # must be valid base32 + self.entries[si] = reason self.last_mtime = current_mtime except Exception as e: twisted_log.err(e, "unparseable blacklist file") diff --git a/src/allmydata/control.py b/src/allmydata/control.py index 07802efba..55f6db4f9 100644 --- a/src/allmydata/control.py +++ b/src/allmydata/control.py @@ -19,12 +19,13 @@ def get_memory_usage(): "VmData") stats = {} try: - for line in open("/proc/self/status", "r").readlines(): - name, right = line.split(":",2) - if name in stat_names: - assert right.endswith(" kB\n") - right = right[:-4] - stats[name] = int(right) * 1024 + with open("/proc/self/status", "r") as f: + for line in f.readlines(): + name, right = line.split(":",2) + if name in stat_names: + assert right.endswith(" kB\n") + right = right[:-4] + stats[name] = int(right) * 1024 except: # Probably not on (a compatible version of) Linux stats['VmSize'] = 0 diff --git a/src/allmydata/frontends/auth.py b/src/allmydata/frontends/auth.py index 49647bc60..ab56bf94d 100644 --- a/src/allmydata/frontends/auth.py +++ b/src/allmydata/frontends/auth.py @@ -31,20 +31,21 @@ class AccountFileChecker(object): self.passwords = {} self.pubkeys = {} self.rootcaps = {} - for line in open(abspath_expanduser_unicode(accountfile), "r"): - line = line.strip() - if line.startswith("#") or not line: - continue - name, passwd, rest = line.split(None, 2) - if passwd.startswith("ssh-"): - bits = rest.split() - keystring = " ".join([passwd] + bits[:-1]) - rootcap = bits[-1] - self.pubkeys[name] = keystring - else: - self.passwords[name] = passwd - rootcap = rest - self.rootcaps[name] = rootcap + with open(abspath_expanduser_unicode(accountfile), "r") as f: + for line in f.readlines(): + line = line.strip() + if line.startswith("#") or not line: + continue + name, passwd, rest = line.split(None, 2) + if passwd.startswith("ssh-"): + bits = rest.split() + keystring = " ".join([passwd] + bits[:-1]) + rootcap = bits[-1] + self.pubkeys[name] = keystring + else: + self.passwords[name] = passwd + rootcap = rest + self.rootcaps[name] = rootcap def _avatarId(self, username): return FTPAvatarID(username, self.rootcaps[username]) diff --git a/src/allmydata/scripts/common.py b/src/allmydata/scripts/common.py index d3dde9a72..89c1f94d1 100644 --- a/src/allmydata/scripts/common.py +++ b/src/allmydata/scripts/common.py @@ -131,22 +131,22 @@ def get_aliases(nodedir): aliasfile = os.path.join(nodedir, "private", "aliases") rootfile = os.path.join(nodedir, "private", "root_dir.cap") try: - f = open(rootfile, "r") - rootcap = f.read().strip() - if rootcap: - aliases[DEFAULT_ALIAS] = rootcap + with open(rootfile, "r") as f: + rootcap = f.read().strip() + if rootcap: + aliases[DEFAULT_ALIAS] = rootcap except EnvironmentError: pass try: - f = codecs.open(aliasfile, "r", "utf-8") - for line in f.readlines(): - line = line.strip() - if line.startswith("#") or not line: - continue - name, cap = line.split(u":", 1) - # normalize it: remove http: prefix, urldecode - cap = cap.strip().encode('utf-8') - aliases[name] = cap + with codecs.open(aliasfile, "r", "utf-8") as f: + for line in f.readlines(): + line = line.strip() + if line.startswith("#") or not line: + continue + name, cap = line.split(u":", 1) + # normalize it: remove http: prefix, urldecode + cap = cap.strip().encode('utf-8') + aliases[name] = cap except EnvironmentError: pass return aliases diff --git a/src/allmydata/stats.py b/src/allmydata/stats.py index 73e311fe3..34907b1fa 100644 --- a/src/allmydata/stats.py +++ b/src/allmydata/stats.py @@ -250,16 +250,15 @@ class JSONStatsGatherer(StdOutStatsGatherer): self.jsonfile = os.path.join(basedir, "stats.json") if os.path.exists(self.jsonfile): - f = open(self.jsonfile, 'rb') try: - self.gathered_stats = json.load(f) + with open(self.jsonfile, 'rb') as f: + self.gathered_stats = json.load(f) except Exception: print("Error while attempting to load stats file %s.\n" "You may need to restore this file from a backup," " or delete it if no backup is available.\n" % quote_local_unicode_path(self.jsonfile)) raise - f.close() else: self.gathered_stats = {} @@ -272,9 +271,8 @@ class JSONStatsGatherer(StdOutStatsGatherer): def dump_json(self): tmp = "%s.tmp" % (self.jsonfile,) - f = open(tmp, 'wb') - json.dump(self.gathered_stats, f) - f.close() + with open(tmp, 'wb') as f: + json.dump(self.gathered_stats, f) if os.path.exists(self.jsonfile): os.unlink(self.jsonfile) os.rename(tmp, self.jsonfile) diff --git a/src/allmydata/storage/crawler.py b/src/allmydata/storage/crawler.py index 438dd5e31..008619d64 100644 --- a/src/allmydata/storage/crawler.py +++ b/src/allmydata/storage/crawler.py @@ -191,9 +191,8 @@ class ShareCrawler(service.MultiService): # of the last bucket to be processed, or # None if we are sleeping between cycles try: - f = open(self.statefile, "rb") - state = pickle.load(f) - f.close() + with open(self.statefile, "rb") as f: + state = pickle.load(f) except Exception: state = {"version": 1, "last-cycle-finished": None, @@ -230,9 +229,8 @@ class ShareCrawler(service.MultiService): last_complete_prefix = self.prefixes[lcpi] self.state["last-complete-prefix"] = last_complete_prefix tmpfile = self.statefile + ".tmp" - f = open(tmpfile, "wb") - pickle.dump(self.state, f) - f.close() + with open(tmpfile, "wb") as f: + pickle.dump(self.state, f) fileutil.move_into_place(tmpfile, self.statefile) def startService(self): diff --git a/src/allmydata/storage/expirer.py b/src/allmydata/storage/expirer.py index a284d4f74..a13c188bd 100644 --- a/src/allmydata/storage/expirer.py +++ b/src/allmydata/storage/expirer.py @@ -84,9 +84,8 @@ class LeaseCheckingCrawler(ShareCrawler): # initialize history if not os.path.exists(self.historyfile): history = {} # cyclenum -> dict - f = open(self.historyfile, "wb") - pickle.dump(history, f) - f.close() + with open(self.historyfile, "wb") as f: + pickle.dump(history, f) def create_empty_cycle_dict(self): recovered = self.create_empty_recovered_dict() @@ -303,14 +302,14 @@ class LeaseCheckingCrawler(ShareCrawler): # copy() needs to become a deepcopy h["space-recovered"] = s["space-recovered"].copy() - history = pickle.load(open(self.historyfile, "rb")) + with open(self.historyfile, "rb") as f: + history = pickle.load(f) history[cycle] = h while len(history) > 10: oldcycles = sorted(history.keys()) del history[oldcycles[0]] - f = open(self.historyfile, "wb") - pickle.dump(history, f) - f.close() + with open(self.historyfile, "wb") as f: + pickle.dump(history, f) def get_state(self): """In addition to the crawler state described in @@ -379,7 +378,8 @@ class LeaseCheckingCrawler(ShareCrawler): progress = self.get_progress() state = ShareCrawler.get_state(self) # does a shallow copy - history = pickle.load(open(self.historyfile, "rb")) + with open(self.historyfile, "rb") as f: + history = pickle.load(f) state["history"] = history if not progress["cycle-in-progress"]: diff --git a/src/allmydata/storage/mutable.py b/src/allmydata/storage/mutable.py index 37f773c0d..287ed8fb9 100644 --- a/src/allmydata/storage/mutable.py +++ b/src/allmydata/storage/mutable.py @@ -57,8 +57,8 @@ class MutableShareFile(object): self.home = filename if os.path.exists(self.home): # we don't cache anything, just check the magic - f = open(self.home, 'rb') - data = f.read(self.HEADER_SIZE) + with open(self.home, 'rb') as f: + data = f.read(self.HEADER_SIZE) (magic, write_enabler_nodeid, write_enabler, data_length, extra_least_offset) = \ @@ -80,17 +80,17 @@ class MutableShareFile(object): + data_length) assert extra_lease_offset == self.DATA_OFFSET # true at creation num_extra_leases = 0 - f = open(self.home, 'wb') - header = struct.pack(">32s20s32sQQ", - self.MAGIC, my_nodeid, write_enabler, - data_length, extra_lease_offset, - ) - leases = ("\x00"*self.LEASE_SIZE) * 4 - f.write(header + leases) - # data goes here, empty after creation - f.write(struct.pack(">L", num_extra_leases)) - # extra leases go here, none at creation - f.close() + with open(self.home, 'wb') as f: + header = struct.pack( + ">32s20s32sQQ", + self.MAGIC, my_nodeid, write_enabler, + data_length, extra_lease_offset, + ) + leases = ("\x00" * self.LEASE_SIZE) * 4 + f.write(header + leases) + # data goes here, empty after creation + f.write(struct.pack(">L", num_extra_leases)) + # extra leases go here, none at creation def unlink(self): os.unlink(self.home) @@ -261,10 +261,9 @@ class MutableShareFile(object): def get_leases(self): """Yields a LeaseInfo instance for all leases.""" - f = open(self.home, 'rb') - for i, lease in self._enumerate_leases(f): - yield lease - f.close() + with open(self.home, 'rb') as f: + for i, lease in self._enumerate_leases(f): + yield lease def _enumerate_leases(self, f): for i in range(self._get_num_lease_slots(f)): @@ -277,29 +276,26 @@ class MutableShareFile(object): def add_lease(self, lease_info): precondition(lease_info.owner_num != 0) # 0 means "no lease here" - f = open(self.home, 'rb+') - num_lease_slots = self._get_num_lease_slots(f) - empty_slot = self._get_first_empty_lease_slot(f) - if empty_slot is not None: - self._write_lease_record(f, empty_slot, lease_info) - else: - self._write_lease_record(f, num_lease_slots, lease_info) - f.close() + with open(self.home, 'rb+') as f: + num_lease_slots = self._get_num_lease_slots(f) + empty_slot = self._get_first_empty_lease_slot(f) + if empty_slot is not None: + self._write_lease_record(f, empty_slot, lease_info) + else: + self._write_lease_record(f, num_lease_slots, lease_info) def renew_lease(self, renew_secret, new_expire_time): accepting_nodeids = set() - f = open(self.home, 'rb+') - for (leasenum,lease) in self._enumerate_leases(f): - if timing_safe_compare(lease.renew_secret, renew_secret): - # yup. See if we need to update the owner time. - if new_expire_time > lease.expiration_time: - # yes - lease.expiration_time = new_expire_time - self._write_lease_record(f, leasenum, lease) - f.close() - return - accepting_nodeids.add(lease.nodeid) - f.close() + with open(self.home, 'rb+') as f: + for (leasenum,lease) in self._enumerate_leases(f): + if timing_safe_compare(lease.renew_secret, renew_secret): + # yup. See if we need to update the owner time. + if new_expire_time > lease.expiration_time: + # yes + lease.expiration_time = new_expire_time + self._write_lease_record(f, leasenum, lease) + return + accepting_nodeids.add(lease.nodeid) # Return the accepting_nodeids set, to give the client a chance to # update the leases on a share which has been migrated from its # original server to a new one. @@ -333,21 +329,21 @@ class MutableShareFile(object): cancel_secret="\x00"*32, expiration_time=0, nodeid="\x00"*20) - f = open(self.home, 'rb+') - for (leasenum,lease) in self._enumerate_leases(f): - accepting_nodeids.add(lease.nodeid) - if timing_safe_compare(lease.cancel_secret, cancel_secret): - self._write_lease_record(f, leasenum, blank_lease) - modified += 1 - else: - remaining += 1 - if modified: - freed_space = self._pack_leases(f) - f.close() - if not remaining: - freed_space += os.stat(self.home)[stat.ST_SIZE] - self.unlink() - return freed_space + with open(self.home, 'rb+') as f: + for (leasenum,lease) in self._enumerate_leases(f): + accepting_nodeids.add(lease.nodeid) + if timing_safe_compare(lease.cancel_secret, cancel_secret): + self._write_lease_record(f, leasenum, blank_lease) + modified += 1 + else: + remaining += 1 + if modified: + freed_space = self._pack_leases(f) + f.close() + if not remaining: + freed_space += os.stat(self.home)[stat.ST_SIZE] + self.unlink() + return freed_space msg = ("Unable to cancel non-existent lease. I have leases " "accepted by nodeids: ") @@ -372,10 +368,9 @@ class MutableShareFile(object): def readv(self, readv): datav = [] - f = open(self.home, 'rb') - for (offset, length) in readv: - datav.append(self._read_share_data(f, offset, length)) - f.close() + with open(self.home, 'rb') as f: + for (offset, length) in readv: + datav.append(self._read_share_data(f, offset, length)) return datav # def remote_get_length(self): @@ -385,10 +380,9 @@ class MutableShareFile(object): # return data_length def check_write_enabler(self, write_enabler, si_s): - f = open(self.home, 'rb+') - (real_write_enabler, write_enabler_nodeid) = \ - self._read_write_enabler_and_nodeid(f) - f.close() + with open(self.home, 'rb+') as f: + (real_write_enabler, write_enabler_nodeid) = \ + self._read_write_enabler_and_nodeid(f) # avoid a timing attack #if write_enabler != real_write_enabler: if not timing_safe_compare(write_enabler, real_write_enabler): @@ -405,27 +399,25 @@ class MutableShareFile(object): def check_testv(self, testv): test_good = True - f = open(self.home, 'rb+') - for (offset, length, operator, specimen) in testv: - data = self._read_share_data(f, offset, length) - if not testv_compare(data, operator, specimen): - test_good = False - break - f.close() + with open(self.home, 'rb+') as f: + for (offset, length, operator, specimen) in testv: + data = self._read_share_data(f, offset, length) + if not testv_compare(data, operator, specimen): + test_good = False + break return test_good def writev(self, datav, new_length): - f = open(self.home, 'rb+') - for (offset, data) in datav: - self._write_share_data(f, offset, data) - if new_length is not None: - cur_length = self._read_data_length(f) - if new_length < cur_length: - self._write_data_length(f, new_length) - # TODO: if we're going to shrink the share file when the - # share data has shrunk, then call - # self._change_container_size() here. - f.close() + with open(self.home, 'rb+') as f: + for (offset, data) in datav: + self._write_share_data(f, offset, data) + if new_length is not None: + cur_length = self._read_data_length(f) + if new_length < cur_length: + self._write_data_length(f, new_length) + # TODO: if we're going to shrink the share file when the + # share data has shrunk, then call + # self._change_container_size() here. def testv_compare(a, op, b): assert op in ("lt", "le", "eq", "ne", "ge", "gt") diff --git a/src/allmydata/storage/server.py b/src/allmydata/storage/server.py index 7741e0c18..5cc4be95c 100644 --- a/src/allmydata/storage/server.py +++ b/src/allmydata/storage/server.py @@ -317,9 +317,8 @@ class StorageServer(service.MultiService, Referenceable): def _iter_share_files(self, storage_index): for shnum, filename in self._get_bucket_shares(storage_index): - f = open(filename, 'rb') - header = f.read(32) - f.close() + with open(filename, 'rb') as f: + header = f.read(32) if header[:32] == MutableShareFile.MAGIC: sf = MutableShareFile(filename, self) # note: if the share has been migrated, the renew_lease() @@ -682,15 +681,14 @@ class StorageServer(service.MultiService, Referenceable): # windows can't handle colons in the filename fn = os.path.join(self.corruption_advisory_dir, "%s--%s-%d" % (now, si_s, shnum)).replace(":","") - f = open(fn, "w") - f.write("report: Share Corruption\n") - f.write("type: %s\n" % share_type) - f.write("storage_index: %s\n" % si_s) - f.write("share_number: %d\n" % shnum) - f.write("\n") - f.write(reason) - f.write("\n") - f.close() + with open(fn, "w") as f: + f.write("report: Share Corruption\n") + f.write("type: %s\n" % share_type) + f.write("storage_index: %s\n" % si_s) + f.write("share_number: %d\n" % shnum) + f.write("\n") + f.write(reason) + f.write("\n") log.msg(format=("client claims corruption in (%(share_type)s) " + "%(si)s-%(shnum)d: %(reason)s"), share_type=share_type, si=si_s, shnum=shnum, reason=reason, diff --git a/src/allmydata/storage/shares.py b/src/allmydata/storage/shares.py index 558bddc19..bd94c0f3f 100644 --- a/src/allmydata/storage/shares.py +++ b/src/allmydata/storage/shares.py @@ -4,9 +4,8 @@ from allmydata.storage.mutable import MutableShareFile from allmydata.storage.immutable import ShareFile def get_share_file(filename): - f = open(filename, "rb") - prefix = f.read(32) - f.close() + with open(filename, "rb") as f: + prefix = f.read(32) if prefix == MutableShareFile.MAGIC: return MutableShareFile(filename) # otherwise assume it's immutable diff --git a/src/allmydata/util/configutil.py b/src/allmydata/util/configutil.py index d58bc4217..3699db35d 100644 --- a/src/allmydata/util/configutil.py +++ b/src/allmydata/util/configutil.py @@ -13,15 +13,12 @@ class UnknownConfigError(Exception): def get_config(tahoe_cfg): config = SafeConfigParser() - f = open(tahoe_cfg, "rb") - try: + with open(tahoe_cfg, "rb") as f: # Skip any initial Byte Order Mark. Since this is an ordinary file, we # don't need to handle incomplete reads, and can assume seekability. if f.read(3) != '\xEF\xBB\xBF': f.seek(0) config.readfp(f) - finally: - f.close() return config def set_config(config, section, option, value): @@ -31,11 +28,8 @@ def set_config(config, section, option, value): assert config.get(section, option) == value def write_config(tahoe_cfg, config): - f = open(tahoe_cfg, "wb") - try: + with open(tahoe_cfg, "wb") as f: config.write(f) - finally: - f.close() def validate_config(fname, cfg, valid_config): """ diff --git a/src/allmydata/version_checks.py b/src/allmydata/version_checks.py index 9f084ff4b..aa42c2eee 100644 --- a/src/allmydata/version_checks.py +++ b/src/allmydata/version_checks.py @@ -235,18 +235,18 @@ def _get_linux_distro(): return (_distname, _version) try: - etclsbrel = open("/etc/lsb-release", "rU") - for line in etclsbrel: - m = _distributor_id_file_re.search(line) - if m: - _distname = m.group(1).strip() - if _distname and _version: - return (_distname, _version) - m = _release_file_re.search(line) - if m: - _version = m.group(1).strip() - if _distname and _version: - return (_distname, _version) + with open("/etc/lsb-release", "rU") as etclsbrel: + for line in etclsbrel: + m = _distributor_id_file_re.search(line) + if m: + _distname = m.group(1).strip() + if _distname and _version: + return (_distname, _version) + m = _release_file_re.search(line) + if m: + _version = m.group(1).strip() + if _distname and _version: + return (_distname, _version) except EnvironmentError: pass