Merge remote-tracking branch 'origin/master' into 3671.more-test-utilities-python-3

This commit is contained in:
Itamar Turner-Trauring 2021-04-07 09:11:31 -04:00
commit 6f74bb7d88
12 changed files with 228 additions and 408 deletions

0
newsfragments/3657.minor Normal file
View File

View File

@ -11,9 +11,11 @@ if PY2:
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
import re
from foolscap.furl import decode_furl
from allmydata.crypto.util import remove_prefix
from allmydata.crypto import ed25519
from allmydata.util import base32, rrefutil, jsonbytes as json
from allmydata.util import base32, jsonbytes as json
def get_tubid_string_from_ann(ann):
@ -123,10 +125,10 @@ class AnnouncementDescriptor(object):
self.service_name = ann_d["service-name"]
self.version = ann_d.get("my-version", "")
self.nickname = ann_d.get("nickname", u"")
(service_name, key_s) = index
(_, key_s) = index
self.serverid = key_s
furl = ann_d.get("anonymous-storage-FURL")
if furl:
self.connection_hints = rrefutil.connection_hints_for_furl(furl)
_, self.connection_hints, _ = decode_furl(furl)
else:
self.connection_hints = []

View File

@ -24,11 +24,12 @@ except ImportError:
from zope.interface import implementer
from twisted.application import service
from twisted.internet import defer
from twisted.internet.address import IPv4Address
from twisted.python.failure import Failure
from foolscap.api import Referenceable
import allmydata
from allmydata import node
from allmydata.util import log, rrefutil, dictutil
from allmydata.util import log, dictutil
from allmydata.util.i2p_provider import create as create_i2p_provider
from allmydata.util.tor_provider import create as create_tor_provider
from allmydata.introducer.interfaces import \
@ -148,6 +149,15 @@ class _IntroducerNode(node.Node):
ws = IntroducerWebishServer(self, webport, nodeurl_path, staticdir)
ws.setServiceParent(self)
def stringify_remote_address(rref):
remote = rref.getPeer()
if isinstance(remote, IPv4Address):
return "%s:%d" % (remote.host, remote.port)
# loopback is a non-IPv4Address
return str(remote)
@implementer(RIIntroducerPublisherAndSubscriberService_v2)
class IntroducerService(service.MultiService, Referenceable):
name = "introducer"
@ -216,7 +226,7 @@ class IntroducerService(service.MultiService, Referenceable):
# tubid will be None. Also, subscribers do not tell us which
# pubkey they use; only publishers do that.
tubid = rref.getRemoteTubID() or "?"
remote_address = rrefutil.stringify_remote_address(rref)
remote_address = stringify_remote_address(rref)
# these three assume subscriber_info["version"]==0, but
# should tolerate other versions
nickname = subscriber_info.get("nickname", u"?")

View File

@ -0,0 +1,84 @@
"""
Tests for allmydata.util.consumer.
Ported to Python 3.
"""
from __future__ import unicode_literals
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from future.utils import PY2
if PY2:
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
from zope.interface import implementer
from twisted.trial.unittest import TestCase
from twisted.internet.interfaces import IPushProducer, IPullProducer
from allmydata.util.consumer import MemoryConsumer
@implementer(IPushProducer)
@implementer(IPullProducer)
class Producer(object):
"""Can be used as either streaming or non-streaming producer.
If used as streaming, the test should call iterate() manually.
"""
def __init__(self, consumer, data):
self.data = data
self.consumer = consumer
self.done = False
def resumeProducing(self):
"""Kick off streaming."""
self.iterate()
def iterate(self):
"""Do another iteration of writing."""
if self.done:
raise RuntimeError(
"There's a bug somewhere, shouldn't iterate after being done"
)
if self.data:
self.consumer.write(self.data.pop(0))
else:
self.done = True
self.consumer.unregisterProducer()
class MemoryConsumerTests(TestCase):
"""Tests for MemoryConsumer."""
def test_push_producer(self):
"""
A MemoryConsumer accumulates all data sent by a streaming producer.
"""
consumer = MemoryConsumer()
producer = Producer(consumer, [b"abc", b"def", b"ghi"])
consumer.registerProducer(producer, True)
self.assertEqual(consumer.chunks, [b"abc"])
producer.iterate()
producer.iterate()
self.assertEqual(consumer.chunks, [b"abc", b"def", b"ghi"])
self.assertEqual(consumer.done, False)
producer.iterate()
self.assertEqual(consumer.chunks, [b"abc", b"def", b"ghi"])
self.assertEqual(consumer.done, True)
def test_pull_producer(self):
"""
A MemoryConsumer accumulates all data sent by a non-streaming producer.
"""
consumer = MemoryConsumer()
producer = Producer(consumer, [b"abc", b"def", b"ghi"])
consumer.registerProducer(producer, False)
self.assertEqual(consumer.chunks, [b"abc", b"def", b"ghi"])
self.assertEqual(consumer.done, True)
# download_to_data() is effectively tested by some of the filenode tests, e.g.
# test_immutable.py.

View File

@ -17,15 +17,17 @@ import yaml
import json
from twisted.trial import unittest
from foolscap.api import Violation, RemoteException
from allmydata.util import idlib, mathutil
from allmydata.util import fileutil
from allmydata.util import jsonbytes
from allmydata.util import pollmixin
from allmydata.util import yamlutil
from allmydata.util import rrefutil
from allmydata.util.fileutil import EncryptedTemporaryFile
from allmydata.test.common_util import ReallyEqualMixin
from .no_network import fireNow, LocalWrapper
if six.PY3:
long = int
@ -480,7 +482,12 @@ class EqButNotIs(object):
class YAML(unittest.TestCase):
def test_convert(self):
data = yaml.safe_dump(["str", u"unicode", u"\u1234nicode"])
"""
Unicode and (ASCII) native strings get roundtripped to Unicode strings.
"""
data = yaml.safe_dump(
[six.ensure_str("str"), u"unicode", u"\u1234nicode"]
)
back = yamlutil.safe_load(data)
self.assertIsInstance(back[0], str)
self.assertIsInstance(back[1], str)
@ -521,3 +528,38 @@ class JSONBytes(unittest.TestCase):
encoded = jsonbytes.dumps_bytes(x)
self.assertIsInstance(encoded, bytes)
self.assertEqual(json.loads(encoded, encoding="utf-8"), x)
class FakeGetVersion(object):
"""Emulate an object with a get_version."""
def __init__(self, result):
self.result = result
def remote_get_version(self):
if isinstance(self.result, Exception):
raise self.result
return self.result
class RrefUtilTests(unittest.TestCase):
"""Tests for rrefutil."""
def test_version_returned(self):
"""If get_version() succeeded, it is set on the rref."""
rref = LocalWrapper(FakeGetVersion(12345), fireNow)
result = self.successResultOf(
rrefutil.add_version_to_remote_reference(rref, "default")
)
self.assertEqual(result.version, 12345)
self.assertIdentical(result, rref)
def test_exceptions(self):
"""If get_version() failed, default version is set on the rref."""
for exception in (Violation(), RemoteException(ValueError())):
rref = LocalWrapper(FakeGetVersion(exception), fireNow)
result = self.successResultOf(
rrefutil.add_version_to_remote_reference(rref, "Default")
)
self.assertEqual(result.version, "Default")
self.assertIdentical(result, rref)

View File

@ -125,6 +125,8 @@ PORTED_MODULES = [
"allmydata.util.base62",
"allmydata.util.configutil",
"allmydata.util.connection_status",
"allmydata.util.consumer",
"allmydata.util.dbutil",
"allmydata.util.deferredutil",
"allmydata.util.dictutil",
"allmydata.util.eliotutil",
@ -145,10 +147,12 @@ PORTED_MODULES = [
"allmydata.util.observer",
"allmydata.util.pipeline",
"allmydata.util.pollmixin",
"allmydata.util.rrefutil",
"allmydata.util.spans",
"allmydata.util.statistics",
"allmydata.util.time_format",
"allmydata.util.tor_provider",
"allmydata.util.yamlutil",
"allmydata.web",
"allmydata.web.check_results",
"allmydata.web.common",
@ -201,6 +205,7 @@ PORTED_TEST_MODULES = [
"allmydata.test.test_configutil",
"allmydata.test.test_connections",
"allmydata.test.test_connection_status",
"allmydata.test.test_consumer",
"allmydata.test.test_crawler",
"allmydata.test.test_crypto",

View File

@ -1,11 +1,22 @@
"""This file defines a basic download-to-memory consumer, suitable for use in
a filenode's read() method. See download_to_data() for an example of its use.
"""
This file defines a basic download-to-memory consumer, suitable for use in
a filenode's read() method. See download_to_data() for an example of its use.
Ported to Python 3.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from future.utils import PY2
if PY2:
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
from zope.interface import implementer
from twisted.internet.interfaces import IConsumer
@implementer(IConsumer)
class MemoryConsumer(object):
@ -28,6 +39,7 @@ class MemoryConsumer(object):
def unregisterProducer(self):
self.done = True
def download_to_data(n, offset=0, size=None):
"""
Return Deferred that fires with results of reading from the given filenode.

View File

@ -1,9 +1,23 @@
"""
SQLite3 utilities.
Test coverage currently provided by test_backupdb.py.
Ported to Python 3.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from future.utils import PY2
if PY2:
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
import os, sys
import sqlite3
from sqlite3 import IntegrityError
[IntegrityError]
class DBError(Exception):
@ -12,7 +26,7 @@ class DBError(Exception):
def get_db(dbfile, stderr=sys.stderr,
create_version=(None, None), updaters={}, just_create=False, dbname="db",
journal_mode=None, synchronous=None):
):
"""Open or create the given db file. The parent directory must exist.
create_version=(SCHEMA, VERNUM), and SCHEMA must have a 'version' table.
Updaters is a {newver: commands} mapping, where e.g. updaters[2] is used
@ -32,12 +46,6 @@ def get_db(dbfile, stderr=sys.stderr,
# The default is unspecified according to <http://www.sqlite.org/foreignkeys.html#fk_enable>.
c.execute("PRAGMA foreign_keys = ON;")
if journal_mode is not None:
c.execute("PRAGMA journal_mode = %s;" % (journal_mode,))
if synchronous is not None:
c.execute("PRAGMA synchronous = %s;" % (synchronous,))
if must_create:
c.executescript(schema)
c.execute("INSERT INTO version (version) VALUES (?)", (target_version,))

View File

@ -1,6 +1,16 @@
"""
Ported to Python 3.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from twisted.internet import address
from foolscap.api import Violation, RemoteException, SturdyRef
from future.utils import PY2
if PY2:
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
from foolscap.api import Violation, RemoteException
def add_version_to_remote_reference(rref, default):
@ -18,24 +28,3 @@ def add_version_to_remote_reference(rref, default):
return rref
d.addCallbacks(_got_version, _no_get_version)
return d
def connection_hints_for_furl(furl):
hints = []
for h in SturdyRef(furl).locationHints:
# Foolscap-0.2.5 and earlier used strings in .locationHints, 0.2.6
# through 0.6.4 used tuples of ("ipv4",host,port), 0.6.5 through
# 0.8.0 used tuples of ("tcp",host,port), and >=0.9.0 uses strings
# again. Tolerate them all.
if isinstance(h, tuple):
hints.append(":".join([str(s) for s in h]))
else:
hints.append(h)
return hints
def stringify_remote_address(rref):
remote = rref.getPeer()
if isinstance(remote, address.IPv4Address):
return "%s:%d" % (remote.host, remote.port)
# loopback is a non-IPv4Address
return str(remote)

View File

@ -1,24 +0,0 @@
import os
import sys
from twisted.python.util import sibpath as tsibpath
def sibpath(path, sibling):
"""
Looks for a named sibling relative to the given path. If such a file
exists, its path will be returned, otherwise a second search will be
made for the named sibling relative to the path of the executable
currently running. This is useful in the case that something built
with py2exe, for example, needs to find data files relative to its
install. Note hence that care should be taken not to search for
private package files whose names might collide with files which might
be found installed alongside the python interpreter itself. If no
file is found in either place, the sibling relative to the given path
is returned, likely leading to a file not found error.
"""
sib = tsibpath(path, sibling)
if not os.path.exists(sib):
exe_sib = tsibpath(sys.executable, sibling)
if os.path.exists(exe_sib):
return exe_sib
return sib

View File

@ -1,336 +0,0 @@
"""
"Rational" version definition and parsing for DistutilsVersionFight
discussion at PyCon 2009.
Ported to Python 3.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from future.utils import PY2
if PY2:
from builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
import re
class IrrationalVersionError(Exception):
"""This is an irrational version."""
pass
class HugeMajorVersionNumError(IrrationalVersionError):
"""An irrational version because the major version number is huge
(often because a year or date was used).
See `error_on_huge_major_num` option in `NormalizedVersion` for details.
This guard can be disabled by setting that option False.
"""
pass
# A marker used in the second and third parts of the `parts` tuple, for
# versions that don't have those segments, to sort properly. An example
# of versions in sort order ('highest' last):
# 1.0b1 ((1,0), ('b',1), ('f',))
# 1.0.dev345 ((1,0), ('f',), ('dev', 345))
# 1.0 ((1,0), ('f',), ('f',))
# 1.0.post256.dev345 ((1,0), ('f',), ('f', 'post', 256, 'dev', 345))
# 1.0.post345 ((1,0), ('f',), ('f', 'post', 345, 'f'))
# ^ ^ ^
# 'b' < 'f' ---------------------/ | |
# | |
# 'dev' < 'f' < 'post' -------------------/ |
# |
# 'dev' < 'f' ----------------------------------------------/
# Other letters would do, but 'f' for 'final' is kind of nice.
FINAL_MARKER = ('f',)
VERSION_RE = re.compile(r'''
^
(?P<version>\d+\.\d+) # minimum 'N.N'
(?P<extraversion>(?:\.\d+)*) # any number of extra '.N' segments
(?:
(?P<prerel>[abc]|rc) # 'a'=alpha, 'b'=beta, 'c'=release candidate
# 'rc'= alias for release candidate
(?P<prerelversion>\d+(?:\.\d+)*)
)?
(?P<postdev>(\.post(?P<post>\d+))?(\.dev(?P<dev>\d+))?)?
$''', re.VERBOSE)
class NormalizedVersion(object):
"""A rational version.
Good:
1.2 # equivalent to "1.2.0"
1.2.0
1.2a1
1.2.3a2
1.2.3b1
1.2.3c1
1.2.3.4
TODO: fill this out
Bad:
1 # mininum two numbers
1.2a # release level must have a release serial
1.2.3b
"""
def __init__(self, s, error_on_huge_major_num=True):
"""Create a NormalizedVersion instance from a version string.
@param s {str} The version string.
@param error_on_huge_major_num {bool} Whether to consider an
apparent use of a year or full date as the major version number
an error. Default True. One of the observed patterns on PyPI before
the introduction of `NormalizedVersion` was version numbers like this:
2009.01.03
20040603
2005.01
This guard is here to strongly encourage the package author to
use an alternate version, because a release deployed into PyPI
and, e.g. downstream Linux package managers, will forever remove
the possibility of using a version number like "1.0" (i.e.
where the major number is less than that huge major number).
"""
self._parse(s, error_on_huge_major_num)
@classmethod
def from_parts(cls, version, prerelease=FINAL_MARKER,
devpost=FINAL_MARKER):
return cls(cls.parts_to_str((version, prerelease, devpost)))
def _parse(self, s, error_on_huge_major_num=True):
"""Parses a string version into parts."""
match = VERSION_RE.search(s)
if not match:
raise IrrationalVersionError(s)
groups = match.groupdict()
parts = []
# main version
block = self._parse_numdots(groups['version'], s, False, 2)
extraversion = groups.get('extraversion')
if extraversion not in ('', None):
block += self._parse_numdots(extraversion[1:], s)
parts.append(tuple(block))
# prerelease
prerel = groups.get('prerel')
if prerel is not None:
block = [prerel]
block += self._parse_numdots(groups.get('prerelversion'), s,
pad_zeros_length=1)
parts.append(tuple(block))
else:
parts.append(FINAL_MARKER)
# postdev
if groups.get('postdev'):
post = groups.get('post')
dev = groups.get('dev')
postdev = []
if post is not None:
postdev.extend([FINAL_MARKER[0], 'post', int(post)])
if dev is None:
postdev.append(FINAL_MARKER[0])
if dev is not None:
postdev.extend(['dev', int(dev)])
parts.append(tuple(postdev))
else:
parts.append(FINAL_MARKER)
self.parts = tuple(parts)
if error_on_huge_major_num and self.parts[0][0] > 1980:
raise HugeMajorVersionNumError("huge major version number, %r, "
"which might cause future problems: %r" % (self.parts[0][0], s))
def _parse_numdots(self, s, full_ver_str, drop_trailing_zeros=True,
pad_zeros_length=0):
"""Parse 'N.N.N' sequences, return a list of ints.
@param s {str} 'N.N.N...' sequence to be parsed
@param full_ver_str {str} The full version string from which this
comes. Used for error strings.
@param drop_trailing_zeros {bool} Whether to drop trailing zeros
from the returned list. Default True.
@param pad_zeros_length {int} The length to which to pad the
returned list with zeros, if necessary. Default 0.
"""
nums = []
for n in s.split("."):
if len(n) > 1 and n[0] == '0':
raise IrrationalVersionError("cannot have leading zero in "
"version number segment: '%s' in %r" % (n, full_ver_str))
nums.append(int(n))
if drop_trailing_zeros:
while nums and nums[-1] == 0:
nums.pop()
while len(nums) < pad_zeros_length:
nums.append(0)
return nums
def __str__(self):
return self.parts_to_str(self.parts)
@classmethod
def parts_to_str(cls, parts):
"""Transforms a version expressed in tuple into its string
representation."""
# XXX This doesn't check for invalid tuples
main, prerel, postdev = parts
s = '.'.join(str(v) for v in main)
if prerel is not FINAL_MARKER:
s += prerel[0]
s += '.'.join(str(v) for v in prerel[1:])
if postdev and postdev is not FINAL_MARKER:
if postdev[0] == 'f':
postdev = postdev[1:]
i = 0
while i < len(postdev):
if i % 2 == 0:
s += '.'
s += str(postdev[i])
i += 1
return s
def __repr__(self):
return "%s('%s')" % (self.__class__.__name__, self)
def _cannot_compare(self, other):
raise TypeError("cannot compare %s and %s"
% (type(self).__name__, type(other).__name__))
def __eq__(self, other):
if not isinstance(other, NormalizedVersion):
self._cannot_compare(other)
return self.parts == other.parts
def __lt__(self, other):
if not isinstance(other, NormalizedVersion):
self._cannot_compare(other)
return self.parts < other.parts
def __ne__(self, other):
return not self.__eq__(other)
def __gt__(self, other):
return not (self.__lt__(other) or self.__eq__(other))
def __le__(self, other):
return self.__eq__(other) or self.__lt__(other)
def __ge__(self, other):
return self.__eq__(other) or self.__gt__(other)
def suggest_normalized_version(s):
"""Suggest a normalized version close to the given version string.
If you have a version string that isn't rational (i.e. NormalizedVersion
doesn't like it) then you might be able to get an equivalent (or close)
rational version from this function.
This does a number of simple normalizations to the given string, based
on observation of versions currently in use on PyPI. Given a dump of
those version during PyCon 2009, 4287 of them:
- 2312 (53.93%) match NormalizedVersion without change
- with the automatic suggestion
- 3474 (81.04%) match when using this suggestion method
@param s {str} An irrational version string.
@returns A rational version string, or None, if couldn't determine one.
"""
try:
NormalizedVersion(s)
return s # already rational
except IrrationalVersionError:
pass
rs = s.lower()
# part of this could use maketrans
for orig, repl in (('-alpha', 'a'), ('-beta', 'b'), ('alpha', 'a'),
('beta', 'b'), ('rc', 'c'), ('-final', ''),
('-pre', 'c'),
('-release', ''), ('.release', ''), ('-stable', ''),
('+', '.'), ('_', '.'), (' ', ''), ('.final', ''),
('final', '')):
rs = rs.replace(orig, repl)
# if something ends with dev or pre, we add a 0
rs = re.sub(r"pre$", r"pre0", rs)
rs = re.sub(r"dev$", r"dev0", rs)
# if we have something like "b-2" or "a.2" at the end of the
# version, that is pobably beta, alpha, etc
# let's remove the dash or dot
rs = re.sub(r"([abc]|rc)[\-\.](\d+)$", r"\1\2", rs)
# 1.0-dev-r371 -> 1.0.dev371
# 0.1-dev-r79 -> 0.1.dev79
rs = re.sub(r"[\-\.](dev)[\-\.]?r?(\d+)$", r".\1\2", rs)
# Clean: 2.0.a.3, 2.0.b1, 0.9.0~c1
rs = re.sub(r"[.~]?([abc])\.?", r"\1", rs)
# Clean: v0.3, v1.0
if rs.startswith('v'):
rs = rs[1:]
# Clean leading '0's on numbers.
#TODO: unintended side-effect on, e.g., "2003.05.09"
# PyPI stats: 77 (~2%) better
rs = re.sub(r"\b0+(\d+)(?!\d)", r"\1", rs)
# Clean a/b/c with no version. E.g. "1.0a" -> "1.0a0". Setuptools infers
# zero.
# PyPI stats: 245 (7.56%) better
rs = re.sub(r"(\d+[abc])$", r"\g<1>0", rs)
# the 'dev-rNNN' tag is a dev tag
rs = re.sub(r"\.?(dev-r|dev\.r)\.?(\d+)$", r".dev\2", rs)
# clean the - when used as a pre delimiter
rs = re.sub(r"-(a|b|c)(\d+)$", r"\1\2", rs)
# a terminal "dev" or "devel" can be changed into ".dev0"
rs = re.sub(r"[\.\-](dev|devel)$", r".dev0", rs)
# a terminal "dev" can be changed into ".dev0"
rs = re.sub(r"(?![\.\-])dev$", r".dev0", rs)
# a terminal "final" or "stable" can be removed
rs = re.sub(r"(final|stable)$", "", rs)
# The 'r' and the '-' tags are post release tags
# 0.4a1.r10 -> 0.4a1.post10
# 0.9.33-17222 -> 0.9.33.post17222
# 0.9.33-r17222 -> 0.9.33.post17222
rs = re.sub(r"\.?(r|-|-r)\.?(\d+)$", r".post\2", rs)
# Clean 'r' instead of 'dev' usage:
# 0.9.33+r17222 -> 0.9.33.dev17222
# 1.0dev123 -> 1.0.dev123
# 1.0.git123 -> 1.0.dev123
# 1.0.bzr123 -> 1.0.dev123
# 0.1a0dev.123 -> 0.1a0.dev123
# PyPI stats: ~150 (~4%) better
rs = re.sub(r"\.?(dev|git|bzr)\.?(\d+)$", r".dev\2", rs)
# Clean '.pre' (normalized from '-pre' above) instead of 'c' usage:
# 0.2.pre1 -> 0.2c1
# 0.2-c1 -> 0.2c1
# 1.0preview123 -> 1.0c123
# PyPI stats: ~21 (0.62%) better
rs = re.sub(r"\.?(pre|preview|-c)(\d+)$", r"c\g<2>", rs)
# Tcl/Tk uses "px" for their post release markers
rs = re.sub(r"p(\d+)$", r".post\1", rs)
try:
NormalizedVersion(rs)
return rs # already rational
except IrrationalVersionError:
pass
return None

View File

@ -1,11 +1,39 @@
"""
Ported to Python 3.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from future.utils import PY2
if PY2:
from future.builtins import filter, map, zip, ascii, chr, hex, input, next, oct, open, pow, round, super, bytes, dict, list, object, range, str, max, min # noqa: F401
import yaml
# Announcements contain unicode, because they come from JSON. We tell PyYAML
# to give us unicode instead of str/bytes.
def construct_unicode(loader, node):
return node.value
yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str",
construct_unicode)
if PY2:
# On Python 2 the way pyyaml deals with Unicode strings is inconsistent.
#
# >>> yaml.safe_load(yaml.safe_dump(u"hello"))
# 'hello'
# >>> yaml.safe_load(yaml.safe_dump(u"hello\u1234"))
# u'hello\u1234'
#
# In other words, Unicode strings get roundtripped to byte strings, but
# only sometimes.
#
# In order to ensure unicode stays unicode, we add a configuration saying
# that the YAML String Language-Independent Type ("a sequence of zero or
# more Unicode characters") should be the underlying Unicode string object,
# rather than converting to bytes when possible.
#
# Reference: https://yaml.org/type/str.html
def construct_unicode(loader, node):
return node.value
yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str",
construct_unicode)
def safe_load(f):
return yaml.safe_load(f)