Bug 1633593 [wpt PR 23281] - [wptserve] Fix RequestHeaders.get in Python 3, a=testonly
authorRobert Ma <robertma@chromium.org>
Wed, 13 May 2020 09:38:47 +0000
changeset 531048 50c1856fd3fa181f850f905bf273123ca03d2157
parent 531047 b3460cb584c67aefd04dc18948ea6f2cb4064b22
child 531049 99c05982e3897112b1def1695e0bb27aa40dc44d
push id37435
push userapavel@mozilla.com
push dateWed, 20 May 2020 15:28:23 +0000
treeherdermozilla-central@5415da14ec9a [default view] [failures only]
perfherder[talos] [build metrics] [platform microbench] (compared to previous push)
reviewerstestonly
bugs1633593, 23281
milestone78.0a1
first release with
nightly linux32
nightly linux64
nightly mac
nightly win32
nightly win64
last release without
nightly linux32
nightly linux64
nightly mac
nightly win32
nightly win64
Bug 1633593 [wpt PR 23281] - [wptserve] Fix RequestHeaders.get in Python 3, a=testonly Automatic update from web-platform-tests [wptserve] Fix RequestHeaders.get in Python 3 (#23281) When there are multiple headers with the same name, this method would crash in Python 3 when attempting to join them. Drive-by: * Clean up Request.__init__ to make the logic clearer and an unnecessary access to the headers property. * Add some FIXMEs for more potential Py3 issues. * Add some tests to make sure Request.url can be constructed correctly as it uses some native strings intentionally. -- wpt-commits: c6704bb144ce7bb067315711ec1059e856e1edac wpt-pr: 23281
testing/web-platform/tests/tools/wptserve/tests/test_request.py
testing/web-platform/tests/tools/wptserve/wptserve/request.py
new file mode 100644
--- /dev/null
+++ b/testing/web-platform/tests/tools/wptserve/tests/test_request.py
@@ -0,0 +1,78 @@
+import mock
+from six import binary_type
+
+from wptserve.request import Request, RequestHeaders
+
+
+class MockHTTPMessage(dict):
+    """A minimum (and not completely correctly) mock of HTTPMessage for testing.
+
+    Constructing HTTPMessage is annoying and different in Python 2 and 3. This
+    only implements the parts used by RequestHeaders.
+
+    Requirements for construction:
+    * Keys are header names and MUST be lower-case.
+    * Values are lists of header values (even if there's only one).
+    * Keys and values should be native strings to match stdlib's behaviours.
+    """
+    def __getitem__(self, key):
+        assert isinstance(key, str)
+        values = dict.__getitem__(self, key.lower())
+        assert isinstance(values, list)
+        return values[0]
+
+    def get(self, key, default=None):
+        try:
+            return self[key]
+        except KeyError:
+            return default
+
+    def getallmatchingheaders(self, key):
+        values = dict.__getitem__(self, key.lower())
+        return ["{}: {}\n".format(key, v) for v in values]
+
+
+def test_request_headers_get():
+    raw_headers = MockHTTPMessage({
+        'x-foo': ['foo'],
+        'x-bar': ['bar1', 'bar2'],
+    })
+    headers = RequestHeaders(raw_headers)
+    assert headers['x-foo'] == b'foo'
+    assert headers['X-Bar'] == b'bar1, bar2'
+    assert headers.get('x-bar') == b'bar1, bar2'
+
+
+def test_request_headers_encoding():
+    raw_headers = MockHTTPMessage({
+        'x-foo': ['foo'],
+        'x-bar': ['bar1', 'bar2'],
+    })
+    headers = RequestHeaders(raw_headers)
+    assert isinstance(headers['x-foo'], binary_type)
+    assert isinstance(headers['x-bar'], binary_type)
+    assert isinstance(headers.get_list('x-bar')[0], binary_type)
+
+
+def test_request_url_from_server_address():
+    request_handler = mock.Mock()
+    request_handler.server.scheme = 'http'
+    request_handler.server.server_address = ('localhost', '8000')
+    request_handler.path = '/demo'
+    request_handler.headers = MockHTTPMessage()
+
+    request = Request(request_handler)
+    assert request.url == 'http://localhost:8000/demo'
+    assert isinstance(request.url, str)
+
+
+def test_request_url_from_host_header():
+    request_handler = mock.Mock()
+    request_handler.server.scheme = 'http'
+    request_handler.server.server_address = ('localhost', '8000')
+    request_handler.path = '/demo'
+    request_handler.headers = MockHTTPMessage({'host': ['web-platform.test:8001']})
+
+    request = Request(request_handler)
+    assert request.url == 'http://web-platform.test:8001/demo'
+    assert isinstance(request.url, str)
--- a/testing/web-platform/tests/tools/wptserve/wptserve/request.py
+++ b/testing/web-platform/tests/tools/wptserve/wptserve/request.py
@@ -243,46 +243,45 @@ class Request(object):
 
     def __init__(self, request_handler):
         self.doc_root = request_handler.server.router.doc_root
         self.route_match = None  # Set by the router
 
         self.protocol_version = request_handler.protocol_version
         self.method = request_handler.command
 
+        # Keys and values in raw headers are native strings.
+        self._headers = None
+        self.raw_headers = request_handler.headers
+
         scheme = request_handler.server.scheme
-        host = request_handler.headers.get("Host")
+        host = self.raw_headers.get("Host")
         port = request_handler.server.server_address[1]
 
         if host is None:
             host = request_handler.server.server_address[0]
         else:
             if ":" in host:
                 host, port = host.split(":", 1)
 
         self.request_path = request_handler.path
         self.url_base = "/"
 
         if self.request_path.startswith(scheme + "://"):
-            self.url = request_handler.path
+            self.url = self.request_path
         else:
-            self.url = "%s://%s:%s%s" % (scheme,
-                                      host,
-                                      port,
-                                      self.request_path)
+            # TODO(#23362): Stop using native strings for URLs.
+            self.url = "%s://%s:%s%s" % (
+                scheme, host, port, self.request_path)
         self.url_parts = urlsplit(self.url)
 
-        self.raw_headers = request_handler.headers
-
         self.request_line = request_handler.raw_requestline
 
-        self._headers = None
-
         self.raw_input = InputFile(request_handler.rfile,
-                                   int(self.headers.get("Content-Length", 0)))
+                                   int(self.raw_headers.get("Content-Length", 0)))
 
         self._body = None
 
         self._GET = None
         self._POST = None
         self._cookies = None
         self._auth = None
 
@@ -298,19 +297,20 @@ class Request(object):
             self._GET = MultiDict()
             for key, value in params:
                 self._GET.add(key, value)
         return self._GET
 
     @property
     def POST(self):
         if self._POST is None:
-            #Work out the post parameters
+            # Work out the post parameters
             pos = self.raw_input.tell()
             self.raw_input.seek(0)
+            # FIXME: specify encoding in Python 3.
             fs = cgi.FieldStorage(fp=self.raw_input,
                                   environ={"REQUEST_METHOD": self.method},
                                   headers=self.raw_headers,
                                   keep_blank_values=True)
             self._POST = MultiDict.from_field_storage(fs)
             self.raw_input.seek(pos)
         return self._POST
 
@@ -395,26 +395,25 @@ class RequestHeaders(dict):
                     # getallmatchingheaders returns raw header lines, so
                     # split to get name, value
                     multiples.append(_maybe_encode(value).split(b':', 1)[1].strip())
                 headers = multiples
             else:
                 headers = [_maybe_encode(items[header])]
             dict.__setitem__(self, key, headers)
 
-
     def __getitem__(self, key):
         """Get all headers of a certain (case-insensitive) name. If there is
         more than one, the values are returned comma separated"""
         key = _maybe_encode(key)
         values = dict.__getitem__(self, key.lower())
         if len(values) == 1:
             return values[0]
         else:
-            return ", ".join(values)
+            return b", ".join(values)
 
     def __setitem__(self, name, value):
         raise Exception
 
     def get(self, key, default=None):
         """Get a string representing all headers with a particular value,
         with multiple headers separated by a comma. If no header is found
         return a default value
@@ -446,16 +445,17 @@ class RequestHeaders(dict):
     def iteritems(self):
         for item in self:
             yield item, self[item]
 
     def itervalues(self):
         for item in self:
             yield self[item]
 
+
 class CookieValue(object):
     """Representation of cookies.
 
     Note that cookies are considered read-only and the string value
     of the cookie will not change if you update the field values.
     However this is not enforced.
 
     .. attribute:: key