factor out yamlutil.py

The yaml.SafeLoader.add_constructor() should probably only be done once,
and moving this all into a module gives us an opportunity to test it
directly.
This commit is contained in:
Brian Warner
2016-07-19 17:22:12 -07:00
parent fa28ed0730
commit 2c5f7ed425
5 changed files with 41 additions and 42 deletions

View File

@ -1,4 +1,4 @@
import os, stat, time, weakref, yaml import os, stat, time, weakref
from allmydata import node from allmydata import node
from base64 import urlsafe_b64encode from base64 import urlsafe_b64encode
@ -17,9 +17,10 @@ from allmydata.immutable.upload import Uploader
from allmydata.immutable.offloaded import Helper from allmydata.immutable.offloaded import Helper
from allmydata.control import ControlServer from allmydata.control import ControlServer
from allmydata.introducer.client import IntroducerClient from allmydata.introducer.client import IntroducerClient
from allmydata.util import hashutil, base32, pollmixin, log, keyutil, idlib from allmydata.util import (hashutil, base32, pollmixin, log, keyutil, idlib,
from allmydata.util.encodingutil import get_filesystem_encoding, \ yamlutil)
from_utf8_or_none from allmydata.util.encodingutil import (get_filesystem_encoding,
from_utf8_or_none)
from allmydata.util.fileutil import abspath_expanduser_unicode from allmydata.util.fileutil import abspath_expanduser_unicode
from allmydata.util.abbreviate import parse_abbreviated_size from allmydata.util.abbreviate import parse_abbreviated_size
from allmydata.util.time_format import parse_duration, parse_date from allmydata.util.time_format import parse_duration, parse_date
@ -189,17 +190,15 @@ class Client(node.Node, pollmixin.PollMixin):
Load the connections.yaml file if it exists, otherwise Load the connections.yaml file if it exists, otherwise
create a default configuration. create a default configuration.
""" """
connections_filepath = FilePath(os.path.join(self.basedir, "private", "connections.yaml")) fn = os.path.join(self.basedir, "private", "connections.yaml")
def construct_unicode(loader, node): connections_filepath = FilePath(fn)
return node.value
yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str",
construct_unicode)
try: try:
with connections_filepath.open() as f: with connections_filepath.open() as f:
self.connections_config = yaml.safe_load(f) self.connections_config = yamlutil.safe_load(f)
except EnvironmentError: except EnvironmentError:
self.connections_config = { 'servers' : {} } self.connections_config = { 'servers' : {} }
connections_filepath.setContent(yaml.safe_dump(self.connections_config)) content = yamlutil.safe_dump(self.connections_config)
connections_filepath.setContent(content)
def init_stats_provider(self): def init_stats_provider(self):
gatherer_furl = self.get_config("client", "stats_gatherer.furl", None) gatherer_furl = self.get_config("client", "stats_gatherer.furl", None)

View File

@ -1,5 +1,5 @@
import time, yaml import time
from zope.interface import implements from zope.interface import implements
from twisted.application import service from twisted.application import service
from foolscap.api import Referenceable, eventually from foolscap.api import Referenceable, eventually
@ -8,7 +8,7 @@ from allmydata.introducer.interfaces import IIntroducerClient, \
RIIntroducerSubscriberClient_v2 RIIntroducerSubscriberClient_v2
from allmydata.introducer.common import sign_to_foolscap, unsign_from_foolscap,\ from allmydata.introducer.common import sign_to_foolscap, unsign_from_foolscap,\
get_tubid_string_from_ann get_tubid_string_from_ann
from allmydata.util import log from allmydata.util import log, yamlutil
from allmydata.util.rrefutil import add_version_to_remote_reference from allmydata.util.rrefutil import add_version_to_remote_reference
from allmydata.util.keyutil import BadSignatureError from allmydata.util.keyutil import BadSignatureError
@ -89,15 +89,9 @@ class IntroducerClient(service.Service, Referenceable):
d.addErrback(connect_failed) d.addErrback(connect_failed)
def _load_announcements(self): def _load_announcements(self):
# Announcements contain unicode, because they come from JSON. We tell
# PyYAML to give us unicode instead of str/bytes.
def construct_unicode(loader, node):
return node.value
yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str",
construct_unicode)
try: try:
with self._cache_filepath.open() as f: with self._cache_filepath.open() as f:
servers = yaml.safe_load(f) servers = yamlutil.safe_load(f)
except EnvironmentError: except EnvironmentError:
return # no cache file return # no cache file
if not isinstance(servers, list): if not isinstance(servers, list):
@ -121,7 +115,7 @@ class IntroducerClient(service.Service, Referenceable):
"key_s" : key_s, "key_s" : key_s,
} }
announcements.append(server_params) announcements.append(server_params)
announcement_cache_yaml = yaml.safe_dump(announcements) announcement_cache_yaml = yamlutil.safe_dump(announcements)
self._cache_filepath.setContent(announcement_cache_yaml) self._cache_filepath.setContent(announcement_cache_yaml)
def _got_introducer(self, publisher): def _got_introducer(self, publisher):

View File

@ -1,5 +1,5 @@
import os, re, itertools, yaml import os, re, itertools
from base64 import b32decode from base64 import b32decode
import simplejson import simplejson
@ -20,7 +20,7 @@ from allmydata.introducer.common import get_tubid_string_from_ann, \
from allmydata.introducer import IntroducerNode from allmydata.introducer import IntroducerNode
from allmydata.web import introweb from allmydata.web import introweb
from allmydata.client import Client as TahoeClient from allmydata.client import Client as TahoeClient
from allmydata.util import pollmixin, keyutil, idlib, fileutil, iputil from allmydata.util import pollmixin, keyutil, idlib, fileutil, iputil, yamlutil
import allmydata.test.common_util as testutil import allmydata.test.common_util as testutil
class LoggingMultiService(service.MultiService): class LoggingMultiService(service.MultiService):
@ -719,12 +719,8 @@ class Announcements(unittest.TestCase):
self.failUnlessEqual(a[0].announcement["anonymous-storage-FURL"], furl1) self.failUnlessEqual(a[0].announcement["anonymous-storage-FURL"], furl1)
def _load_cache(self, cache_filepath): def _load_cache(self, cache_filepath):
def construct_unicode(loader, node):
return node.value
yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str",
construct_unicode)
with cache_filepath.open() as f: with cache_filepath.open() as f:
return yaml.safe_load(f) return yamlutil.safe_load(f)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_client_cache(self): def test_client_cache(self):
@ -808,18 +804,6 @@ class Announcements(unittest.TestCase):
self.failUnlessEqual(announcements[pub2]["anonymous-storage-FURL"], self.failUnlessEqual(announcements[pub2]["anonymous-storage-FURL"],
furl3) furl3)
class YAMLUnicode(unittest.TestCase):
def test_convert(self):
data = yaml.safe_dump(["str", u"unicode", u"\u1234nicode"])
def construct_unicode(loader, node):
return node.value
yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str",
construct_unicode)
back = yaml.safe_load(data)
self.failUnlessEqual(type(back[0]), unicode)
self.failUnlessEqual(type(back[1]), unicode)
self.failUnlessEqual(type(back[2]), unicode)
class ClientSeqnums(unittest.TestCase): class ClientSeqnums(unittest.TestCase):
def test_client(self): def test_client(self):
basedir = "introducer/ClientSeqnums/test_client" basedir = "introducer/ClientSeqnums/test_client"

View File

@ -1,7 +1,7 @@
def foo(): pass # keep the line number constant def foo(): pass # keep the line number constant
import os, time, sys import os, time, sys, yaml
from StringIO import StringIO from StringIO import StringIO
from datetime import timedelta from datetime import timedelta
from twisted.trial import unittest from twisted.trial import unittest
@ -13,7 +13,7 @@ from pycryptopp.hash.sha256 import SHA256 as _hash
from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil from allmydata.util import base32, idlib, humanreadable, mathutil, hashutil
from allmydata.util import assertutil, fileutil, deferredutil, abbreviate from allmydata.util import assertutil, fileutil, deferredutil, abbreviate
from allmydata.util import limiter, time_format, pollmixin, cachedir from allmydata.util import limiter, time_format, pollmixin, cachedir
from allmydata.util import statistics, dictutil, pipeline from allmydata.util import statistics, dictutil, pipeline, yamlutil
from allmydata.util import log as tahoe_log from allmydata.util import log as tahoe_log
from allmydata.util.spans import Spans, overlap, DataSpans from allmydata.util.spans import Spans, overlap, DataSpans
from allmydata.test.common_util import ReallyEqualMixin, TimezoneMixin from allmydata.test.common_util import ReallyEqualMixin, TimezoneMixin
@ -2412,3 +2412,11 @@ class StringSpans(unittest.TestCase):
length = max(1, int(what[5:6], 16)) length = max(1, int(what[5:6], 16))
d1 = s1.get(start, length); d2 = s2.get(start, length) d1 = s1.get(start, length); d2 = s2.get(start, length)
self.failUnlessEqual(d1, d2, "%d+%d" % (start, length)) self.failUnlessEqual(d1, d2, "%d+%d" % (start, length))
class YAML(unittest.TestCase):
def test_convert(self):
data = yaml.safe_dump(["str", u"unicode", u"\u1234nicode"])
back = yamlutil.safe_load(data)
self.failUnlessEqual(type(back[0]), unicode)
self.failUnlessEqual(type(back[1]), unicode)
self.failUnlessEqual(type(back[2]), unicode)

View File

@ -0,0 +1,14 @@
import yaml
# Announcements contain unicode, because they come from JSON. We tell PyYAML
# to give us unicode instead of str/bytes.
def construct_unicode(loader, node):
return node.value
yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str",
construct_unicode)
def safe_load(f):
return yaml.safe_load(f)
def safe_dump(obj):
return yaml.safe_dump(obj)