Bug 1494800 [wpt PR 13246] - [wptserve] Rework header encoding/decoding in Request, a=testonly
authorMs2ger <Ms2ger@gmail.com>
Thu, 11 Oct 2018 10:03:48 +0000
changeset 489287 17db37b424452e5444426bf5c3b178efe9762182
parent 489286 0d8a90abecca5e3fbf461b84613638ef44f046eb
child 489288 38d3c0a4f59f791d956c6ce19e78299f904be66d
push id247
push userfmarier@mozilla.com
push dateSat, 27 Oct 2018 01:06:44 +0000
reviewerstestonly
bugs1494800, 13246, 11769, 13204
milestone64.0a1
Bug 1494800 [wpt PR 13246] - [wptserve] Rework header encoding/decoding in Request, a=testonly Automatic update from web-platform-testsReland "Various test fixes for python3 support. (#11769)" -- Add a Unicode smoke test to wptserve The test sends a request to wptserve with non-ASCII characters in a header and sets up a simple handler to return the value of that header. The server shouldn't crash in either Python 2 or 3, and the response should not be garbled. The server crashes in Python 2 (#13204). -- Rework encode/decode of headers in Request This change changes the encoding/decoding of headers in Request and username/password in Authentication: now all keys and values have binary type (because of an implementation detail in the Python 3 standard library, we actually need to re-encode the headers back to bytes in Python 3). Documentation and comments are also improved to clarify the encoding situation. Also add test cases for non-ASCII characters in auth headers. -- wpt-commits: f0686bdb720f8b9d28c7e774e6f7ec1a8ccde0c7, d79ae192cdfb2133cf52d9da4f92fa9d14259745, 8b09eceee87d1001d261536e20901290c7c1b6d1 wpt-pr: 13246
testing/web-platform/tests/tools/wptserve/tests/functional/base.py
testing/web-platform/tests/tools/wptserve/tests/functional/test_pipes.py
testing/web-platform/tests/tools/wptserve/tests/functional/test_request.py
testing/web-platform/tests/tools/wptserve/wptserve/pipes.py
testing/web-platform/tests/tools/wptserve/wptserve/ranges.py
testing/web-platform/tests/tools/wptserve/wptserve/request.py
--- a/testing/web-platform/tests/tools/wptserve/tests/functional/base.py
+++ b/testing/web-platform/tests/tools/wptserve/tests/functional/base.py
@@ -70,17 +70,17 @@ class TestUsingServer(unittest.TestCase)
 
         for name, value in iteritems(headers):
             req.add_header(name, value)
 
         if body is not None:
             req.add_data(body)
 
         if auth is not None:
-            req.add_header("Authorization", "Basic %s" % base64.b64encode('%s:%s' % auth))
+            req.add_header("Authorization", b"Basic %s" % base64.b64encode((b"%s:%s" % auth)))
 
         return urlopen(req)
 
 
 @pytest.mark.skipif(not wptserve.utils.http2_compatible(), reason="h2 server only works in python 2.7.15")
 class TestUsingH2Server:
     def setup_method(self, test_method):
         self.server = wptserve.server.WebTestHttpd(host="localhost",
--- a/testing/web-platform/tests/tools/wptserve/tests/functional/test_pipes.py
+++ b/testing/web-platform/tests/tools/wptserve/tests/functional/test_pipes.py
@@ -52,84 +52,75 @@ class TestSlice(TestUsingServer):
         self.assertEqual(resp.read(), expected[1:])
 
     def test_no_lower(self):
         resp = self.request("/document.txt", query="pipe=slice(null,10)")
         expected = open(os.path.join(doc_root, "document.txt"), 'rb').read()
         self.assertEqual(resp.read(), expected[:10])
 
 class TestSub(TestUsingServer):
-    @pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
     def test_sub_config(self):
         resp = self.request("/sub.txt", query="pipe=sub")
-        expected = "localhost localhost %i" % self.server.port
+        expected = b"localhost localhost %i" % self.server.port
         self.assertEqual(resp.read().rstrip(), expected)
 
     @pytest.mark.xfail(sys.platform == "win32",
                        reason="https://github.com/web-platform-tests/wpt/issues/12949")
-    @pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
     def test_sub_file_hash(self):
         resp = self.request("/sub_file_hash.sub.txt")
-        expected = """
+        expected = b"""
 md5: JmI1W8fMHfSfCarYOSxJcw==
 sha1: nqpWqEw4IW8NjD6R375gtrQvtTo=
 sha224: RqQ6fMmta6n9TuA/vgTZK2EqmidqnrwBAmQLRQ==
 sha256: G6Ljg1uPejQxqFmvFOcV/loqnjPTW5GSOePOfM/u0jw=
 sha384: lkXHChh1BXHN5nT5BYhi1x67E1CyYbPKRKoF2LTm5GivuEFpVVYtvEBHtPr74N9E
-sha512: r8eLGRTc7ZznZkFjeVLyo6/FyQdra9qmlYCwKKxm3kfQAswRS9+3HsYk3thLUhcFmmWhK4dXaICz
-JwGFonfXwg=="""
+sha512: r8eLGRTc7ZznZkFjeVLyo6/FyQdra9qmlYCwKKxm3kfQAswRS9+3HsYk3thLUhcFmmWhK4dXaICzJwGFonfXwg=="""
         self.assertEqual(resp.read().rstrip(), expected.strip())
 
     def test_sub_file_hash_unrecognized(self):
         with self.assertRaises(urllib.error.HTTPError):
             self.request("/sub_file_hash_unrecognized.sub.txt")
 
-    @pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
     def test_sub_headers(self):
         resp = self.request("/sub_headers.txt", query="pipe=sub", headers={"X-Test": "PASS"})
-        expected = "PASS"
+        expected = b"PASS"
         self.assertEqual(resp.read().rstrip(), expected)
 
     @pytest.mark.xfail(sys.platform == "win32",
                        reason="https://github.com/web-platform-tests/wpt/issues/12949")
-    @pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
     def test_sub_location(self):
         resp = self.request("/sub_location.sub.txt?query_string")
         expected = """
 host: localhost:{0}
 hostname: localhost
 path: /sub_location.sub.txt
 pathname: /sub_location.sub.txt
 port: {0}
 query: ?query_string
 scheme: http
-server: http://localhost:{0}""".format(self.server.port)
+server: http://localhost:{0}""".format(self.server.port).encode("ascii")
         self.assertEqual(resp.read().rstrip(), expected.strip())
 
-    @pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
     def test_sub_params(self):
         resp = self.request("/sub_params.txt", query="test=PASS&pipe=sub")
-        expected = "PASS"
+        expected = b"PASS"
         self.assertEqual(resp.read().rstrip(), expected)
 
-    @pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
     def test_sub_url_base(self):
         resp = self.request("/sub_url_base.sub.txt")
-        self.assertEqual(resp.read().rstrip(), "Before / After")
+        self.assertEqual(resp.read().rstrip(), b"Before / After")
 
-    @pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
     def test_sub_uuid(self):
         resp = self.request("/sub_uuid.sub.txt")
-        self.assertRegexpMatches(resp.read().rstrip(), r"Before [a-f0-9-]+ After")
+        self.assertRegexpMatches(resp.read().rstrip(), b"Before [a-f0-9-]+ After")
 
-    @pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
     def test_sub_var(self):
         resp = self.request("/sub_var.sub.txt")
         port = self.server.port
-        expected = "localhost %s A %s B localhost C" % (port, port)
+        expected = b"localhost %d A %d B localhost C" % (port, port)
         self.assertEqual(resp.read().rstrip(), expected)
 
 class TestTrickle(TestUsingServer):
     def test_trickle(self):
         #Actually testing that the response trickles in is not that easy
         t0 = time.time()
         resp = self.request("/document.txt", query="pipe=trickle(1:d2:5:d1:r2)")
         t1 = time.time()
--- a/testing/web-platform/tests/tools/wptserve/tests/functional/test_request.py
+++ b/testing/web-platform/tests/tools/wptserve/tests/functional/test_request.py
@@ -1,10 +1,9 @@
-import sys
-
+# -*- coding: utf-8 -*-
 import pytest
 
 wptserve = pytest.importorskip("wptserve")
 from .base import TestUsingServer
 from wptserve.request import InputFile
 
 
 class TestInputFile(TestUsingServer):
@@ -110,21 +109,45 @@ class TestRequest(TestUsingServer):
         def handler(request, response):
             return request.route_match["match"] + " " + request.route_match["*"]
 
         route = ("GET", "/test/{match}_*", handler)
         self.server.router.register(*route)
         resp = self.request("/test/some_route")
         self.assertEqual(b"some route", resp.read())
 
+    def test_non_ascii_in_headers(self):
+        @wptserve.handlers.handler
+        def handler(request, response):
+            return request.headers["foo"]
+
+        route = ("GET", "/test/test_unicode_in_headers", handler)
+        self.server.router.register(*route)
+
+        # Try some non-ASCII characters and the server shouldn't crash.
+        encoded_text = u"你好".encode("utf-8")
+        resp = self.request(route[1], headers={"foo": encoded_text})
+        self.assertEqual(encoded_text, resp.read())
+
+        # Try a different encoding from utf-8 to make sure the binary value is
+        # returned in verbatim.
+        encoded_text = u"どうも".encode("shift-jis")
+        resp = self.request(route[1], headers={"foo": encoded_text})
+        self.assertEqual(encoded_text, resp.read())
+
 
 class TestAuth(TestUsingServer):
-    @pytest.mark.xfail(sys.version_info >= (3,), reason="wptserve only works on Py2")
     def test_auth(self):
         @wptserve.handlers.handler
         def handler(request, response):
-            return " ".join((request.auth.username, request.auth.password))
+            return b" ".join((request.auth.username, request.auth.password))
 
         route = ("GET", "/test/test_auth", handler)
         self.server.router.register(*route)
-        resp = self.request(route[1], auth=("test", "PASS"))
+
+        resp = self.request(route[1], auth=(b"test", b"PASS"))
         self.assertEqual(200, resp.getcode())
-        self.assertEqual(["test", "PASS"], resp.read().split(" "))
+        self.assertEqual([b"test", b"PASS"], resp.read().split(b" "))
+
+        encoded_text = u"どうも".encode("shift-jis")
+        resp = self.request(route[1], auth=(encoded_text, encoded_text))
+        self.assertEqual(200, resp.getcode())
+        self.assertEqual([encoded_text, encoded_text], resp.read().split(b" "))
--- a/testing/web-platform/tests/tools/wptserve/wptserve/pipes.py
+++ b/testing/web-platform/tests/tools/wptserve/wptserve/pipes.py
@@ -1,10 +1,11 @@
 from cgi import escape
 from collections import deque
+import base64
 import gzip as gzip_module
 import hashlib
 import os
 import re
 import time
 import uuid
 from six.moves import StringIO
 
@@ -388,35 +389,35 @@ class SubFunctions(object):
     # are available on all platforms [1]. This ensures that test authors do not
     # unknowingly introduce platform-specific tests.
     #
     # [1] https://docs.python.org/2/library/hashlib.html
     supported_algorithms = ("md5", "sha1", "sha224", "sha256", "sha384", "sha512")
 
     @staticmethod
     def file_hash(request, algorithm, path):
-        algorithm = algorithm.decode("ascii")
+        assert isinstance(algorithm, text_type)
         if algorithm not in SubFunctions.supported_algorithms:
             raise ValueError("Unsupported encryption algorithm: '%s'" % algorithm)
 
         hash_obj = getattr(hashlib, algorithm)()
         absolute_path = os.path.join(request.doc_root, path)
 
         try:
-            with open(absolute_path) as f:
+            with open(absolute_path, "rb") as f:
                 hash_obj.update(f.read())
         except IOError:
             # In this context, an unhandled IOError will be interpreted by the
             # server as an indication that the template file is non-existent.
             # Although the generic "Exception" is less precise, it avoids
             # triggering a potentially-confusing HTTP 404 error in cases where
             # the path to the file to be hashed is invalid.
             raise Exception('Cannot open file for hash computation: "%s"' % absolute_path)
 
-        return hash_obj.digest().encode('base64').strip()
+        return base64.b64encode(hash_obj.digest()).strip()
 
 def template(request, content, escape_type="html"):
     #TODO: There basically isn't any error handling here
     tokenizer = ReplacementTokenizer()
 
     variables = {}
 
     def config_replacement(match):
@@ -485,19 +486,24 @@ def template(request, content, escape_ty
         assert isinstance(value, (int, (binary_type, text_type))), tokens
 
         if variable is not None:
             variables[variable] = value
 
         escape_func = {"html": lambda x:escape(x, quote=True),
                        "none": lambda x:x}[escape_type]
 
-        #Should possibly support escaping for other contexts e.g. script
-        #TODO: read the encoding of the response
-        return escape_func(text_type(value)).encode("utf-8")
+        # Should possibly support escaping for other contexts e.g. script
+        # TODO: read the encoding of the response
+        # cgi.escape() only takes text strings in Python 3.
+        if isinstance(value, binary_type):
+            value = value.decode("utf-8")
+        elif isinstance(value, int):
+            value = text_type(value)
+        return escape_func(value).encode("utf-8")
 
     template_regexp = re.compile(br"{{([^}]*)}}")
     new_content = template_regexp.sub(config_replacement, content)
 
     return new_content
 
 @pipe()
 def gzip(request, response):
--- a/testing/web-platform/tests/tools/wptserve/wptserve/ranges.py
+++ b/testing/web-platform/tests/tools/wptserve/wptserve/ranges.py
@@ -1,13 +1,17 @@
 from .utils import HTTPException
 
 
 class RangeParser(object):
     def __call__(self, header, file_size):
+        try:
+            header = header.decode("ascii")
+        except UnicodeDecodeError:
+            raise HTTPException(400, "Non-ASCII range header value")
         prefix = "bytes="
         if not header.startswith(prefix):
             raise HTTPException(416, message="Unrecognised range type %s" % (header,))
 
         parts = header[len(prefix):].split(",")
         ranges = []
         for item in parts:
             components = item.split("-")
--- a/testing/web-platform/tests/tools/wptserve/wptserve/request.py
+++ b/testing/web-platform/tests/tools/wptserve/wptserve/request.py
@@ -1,12 +1,12 @@
 import base64
 import cgi
 from six.moves.http_cookies import BaseCookie
-from six import BytesIO
+from six import BytesIO, binary_type, text_type
 import tempfile
 
 from six.moves.urllib.parse import parse_qsl, urlsplit
 
 from . import stash
 from .utils import HTTPException
 
 missing = object()
@@ -303,27 +303,27 @@ class Request(object):
     @property
     def POST(self):
         if self._POST is None:
             #Work out the post parameters
             pos = self.raw_input.tell()
             self.raw_input.seek(0)
             fs = cgi.FieldStorage(fp=self.raw_input,
                                   environ={"REQUEST_METHOD": self.method},
-                                  headers=self.headers,
+                                  headers=self.raw_headers,
                                   keep_blank_values=True)
             self._POST = MultiDict.from_field_storage(fs)
             self.raw_input.seek(pos)
         return self._POST
 
     @property
     def cookies(self):
         if self._cookies is None:
             parser = BaseCookie()
-            cookie_headers = self.headers.get("cookie", "")
+            cookie_headers = self.headers.get("cookie", b"")
             parser.load(cookie_headers)
             cookies = Cookies()
             for key, value in parser.iteritems():
                 cookies[key] = CookieValue(value)
             self._cookies = cookies
         return self._cookies
 
     @property
@@ -350,39 +350,64 @@ class Request(object):
 
 class H2Request(Request):
     def __init__(self, request_handler):
         self.h2_stream_id = request_handler.h2_stream_id
         self.frames = []
         super(H2Request, self).__init__(request_handler)
 
 
+def _maybe_encode(s):
+    """Encodes a text-type string into binary data using iso-8859-1.
+
+    Returns `str` in Python 2 and `bytes` in Python 3. The function is a no-op
+    if the argument already has a binary type.
+    """
+    if isinstance(s, binary_type):
+        return s
+
+    # Python 3 assumes iso-8859-1 when parsing headers, which will garble text
+    # with non ASCII characters. We try to encode the text back to binary.
+    # https://github.com/python/cpython/blob/273fc220b25933e443c82af6888eb1871d032fb8/Lib/http/client.py#L213
+    if isinstance(s, text_type):
+        return s.encode("iso-8859-1")
+
+    raise TypeError("Unexpected value in RequestHeaders: %r" % s)
+
+
 class RequestHeaders(dict):
-    """Dictionary-like API for accessing request headers."""
+    """Read-only dictionary-like API for accessing request headers.
+
+    Unlike BaseHTTPRequestHandler.headers, this class always returns all
+    headers with the same name (separated by commas). And it ensures all keys
+    (i.e. names of headers) and values have binary type.
+    """
     def __init__(self, items):
         for header in items.keys():
-            key = header.lower()
+            key = _maybe_encode(header).lower()
             # get all headers with the same name
             values = items.getallmatchingheaders(header)
             if len(values) > 1:
                 # collect the multiple variations of the current header
                 multiples = []
                 # loop through the values from getallmatchingheaders
                 for value in values:
                     # getallmatchingheaders returns raw header lines, so
                     # split to get name, value
-                    multiples.append(value.split(':', 1)[1].strip())
-                dict.__setitem__(self, key, multiples)
+                    multiples.append(_maybe_encode(value).split(b':', 1)[1].strip())
+                headers = multiples
             else:
-                dict.__setitem__(self, key, [items[header]])
+                headers = [_maybe_encode(items[header])]
+            dict.__setitem__(self, key, headers)
 
 
     def __getitem__(self, key):
         """Get all headers of a certain (case-insensitive) name. If there is
         more than one, the values are returned comma separated"""
+        key = _maybe_encode(key)
         values = dict.__getitem__(self, key.lower())
         if len(values) == 1:
             return values[0]
         else:
             return ", ".join(values)
 
     def __setitem__(self, name, value):
         raise Exception
@@ -398,25 +423,27 @@ class RequestHeaders(dict):
         try:
             return self[key]
         except KeyError:
             return default
 
     def get_list(self, key, default=missing):
         """Get all the header values for a particular field name as
         a list"""
+        key = _maybe_encode(key)
         try:
             return dict.__getitem__(self, key.lower())
         except KeyError:
             if default is not missing:
                 return default
             else:
                 raise
 
     def __contains__(self, key):
+        key = _maybe_encode(key)
         return dict.__contains__(self, key.lower())
 
     def iteritems(self):
         for item in self:
             yield item, self[item]
 
     def itervalues(self):
         for item in self:
@@ -585,26 +612,33 @@ class Authentication(object):
 
     The username supplied in the HTTP Authorization
     header, or None
 
     .. attribute:: password
 
     The password supplied in the HTTP Authorization
     header, or None
+
+    Both attributes are binary strings (`str` in Py2, `bytes` in Py3), since
+    RFC7617 Section 2.1 does not specify the encoding for username & passsword
+    (as long it's compatible with ASCII). UTF-8 should be a relatively safe
+    choice if callers need to decode them as most browsers use it.
     """
     def __init__(self, headers):
         self.username = None
         self.password = None
 
-        auth_schemes = {"Basic": self.decode_basic}
+        auth_schemes = {b"Basic": self.decode_basic}
 
         if "authorization" in headers:
             header = headers.get("authorization")
-            auth_type, data = header.split(" ", 1)
+            assert isinstance(header, binary_type)
+            auth_type, data = header.split(b" ", 1)
             if auth_type in auth_schemes:
                 self.username, self.password = auth_schemes[auth_type](data)
             else:
                 raise HTTPException(400, "Unsupported authentication scheme %s" % auth_type)
 
     def decode_basic(self, data):
-        decoded_data = base64.decodestring(data)
-        return decoded_data.split(":", 1)
+        assert isinstance(data, binary_type)
+        decoded_data = base64.b64decode(data)
+        return decoded_data.split(b":", 1)