Bug 1211503: Support for Marionette protocol level 3 in the Python client
authorAndreas Tolfsen <ato@mozilla.com>
Fri, 09 Oct 2015 11:15:24 +0100
changeset 303794 4cf388b1cf3a16beeac808cae5fd8373b4e05bef
parent 303793 ff9a6aa825f5a2217fc227effdd920bbee5a98a6
child 303795 ebeeeacb56e590cd5cac9906fef007eae9f3ed6d
push id1001
push userraliiev@mozilla.com
push dateMon, 18 Jan 2016 19:06:03 +0000
treeherdermozilla-release@8b89261f3ac4 [default view] [failures only]
perfherder[talos] [build metrics] [platform microbench] (compared to previous push)
bugs1211503, 1211489
milestone44.0a1
first release with
nightly linux32
nightly linux64
nightly mac
nightly win32
nightly win64
last release without
nightly linux32
nightly linux64
nightly mac
nightly win32
nightly win64
Bug 1211503: Support for Marionette protocol level 3 in the Python client Futures-proofs the Marionette Python client to have support for level 3 of the Marionette protocol outlined in bug 1211489. This patch changes the marionette-transport API most notably by renaming the MarionetteTransport class to TcpTransport and by splitting the receive-capabilities of TcpTransport.send into a new function called request. Furthermore it introduces a message data structure for dealing with incoming responses and commands, and for marshaling messages to send in order to support all three protocol levels. r=dburns r=jgriffin
testing/marionette/client/marionette/__init__.py
testing/marionette/client/marionette/marionette_test.py
testing/marionette/client/marionette/tests/unit/test_emulator.py
testing/marionette/client/marionette/tests/unit/test_transport.py
testing/marionette/client/marionette/tests/unit/unit-tests.ini
testing/marionette/driver/marionette_driver/marionette.py
testing/marionette/transport/marionette_transport/__init__.py
testing/marionette/transport/marionette_transport/transport.py
--- a/testing/marionette/client/marionette/__init__.py
+++ b/testing/marionette/client/marionette/__init__.py
@@ -1,28 +1,36 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 
 
 __version__ = '1.0.0'
 
-from .marionette_test import MarionetteTestCase, MarionetteJSTestCase, CommonTestCase, expectedFailure, skip, SkipTest
+from .marionette_test import (
+    CommonTestCase,
+    expectedFailure,
+    MarionetteJSTestCase,
+    MarionetteTestCase,
+    skip,
+    SkipTest,
+    skip_unless_protocol,
+)
 from .runner import (
-        B2GTestCaseMixin,
-        B2GTestResultMixin,
-        BaseMarionetteArguments,
-        BaseMarionetteTestRunner,
-        BrowserMobProxyTestCaseMixin,
-        EnduranceArguments,
-        EnduranceTestCaseMixin,
-        HTMLReportingArguments,
-        HTMLReportingTestResultMixin,
-        HTMLReportingTestRunnerMixin,
-        Marionette,
-        MarionetteTest,
-        MarionetteTestResult,
-        MarionetteTextTestRunner,
-        MemoryEnduranceTestCaseMixin,
-        TestManifest,
-        TestResult,
-        TestResultCollection
+    B2GTestCaseMixin,
+    B2GTestResultMixin,
+    BaseMarionetteArguments,
+    BaseMarionetteTestRunner,
+    BrowserMobProxyTestCaseMixin,
+    EnduranceArguments,
+    EnduranceTestCaseMixin,
+    HTMLReportingArguments,
+    HTMLReportingTestResultMixin,
+    HTMLReportingTestRunnerMixin,
+    Marionette,
+    MarionetteTest,
+    MarionetteTestResult,
+    MarionetteTextTestRunner,
+    MemoryEnduranceTestCaseMixin,
+    TestManifest,
+    TestResult,
+    TestResultCollection,
 )
--- a/testing/marionette/client/marionette/marionette_test.py
+++ b/testing/marionette/client/marionette/marionette_test.py
@@ -51,19 +51,17 @@ class _ExpectedFailure(Exception):
 
 class _UnexpectedSuccess(Exception):
     """
     The test was supposed to fail, but it didn't!
     """
     pass
 
 def skip(reason):
-    """
-    Unconditionally skip a test.
-    """
+    """Unconditionally skip a test."""
     def decorator(test_item):
         if not isinstance(test_item, (type, types.ClassType)):
             @functools.wraps(test_item)
             def skip_wrapper(*args, **kwargs):
                 raise SkipTest(reason)
             test_item = skip_wrapper
 
         test_item.__unittest_skip__ = True
@@ -76,22 +74,28 @@ def expectedFailure(func):
     def wrapper(*args, **kwargs):
         try:
             func(*args, **kwargs)
         except Exception:
             raise _ExpectedFailure(sys.exc_info())
         raise _UnexpectedSuccess
     return wrapper
 
+def skip_if_desktop(target):
+    def wrapper(self, *args, **kwargs):
+        if self.marionette.session_capabilities.get('b2g') is None:
+            raise SkipTest('skipping due to desktop')
+        return target(self, *args, **kwargs)
+    return wrapper
+
 def skip_if_b2g(target):
     def wrapper(self, *args, **kwargs):
         if self.marionette.session_capabilities.get('b2g') == True:
             raise SkipTest('skipping due to b2g')
         return target(self, *args, **kwargs)
-
     return wrapper
 
 def skip_if_e10s(target):
     def wrapper(self, *args, **kwargs):
         with self.marionette.using_context('chrome'):
             multi_process_browser = self.marionette.execute_script("""
             try {
               return Services.appinfo.browserTabsRemoteAutostart;
@@ -99,16 +103,29 @@ def skip_if_e10s(target):
               return false;
             }""")
 
         if multi_process_browser:
             raise SkipTest('skipping due to e10s')
         return target(self, *args, **kwargs)
     return wrapper
 
+def skip_unless_protocol(predicate):
+    """Given a predicate passed the current protocol level, skip the
+    test if the predicate does not match."""
+    def decorator(test_item):
+        @functools.wraps(test_item)
+        def skip_wrapper(self):
+            level = self.marionette.client.protocol
+            if not predicate(level):
+                raise SkipTest('skipping because protocol level is %s' % level)
+            return self
+        return skip_wrapper
+    return decorator
+
 def parameterized(func_suffix, *args, **kwargs):
     """
     A decorator that can generate methods given a base method and some data.
 
     **func_suffix** is used as a suffix for the new created method and must be
     unique given a base method. if **func_suffix** countains characters that
     are not allowed in normal python function name, these characters will be
     replaced with "_".
--- a/testing/marionette/client/marionette/tests/unit/test_emulator.py
+++ b/testing/marionette/client/marionette/tests/unit/test_emulator.py
@@ -1,60 +1,60 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 
-from marionette import MarionetteTestCase
-from marionette_driver.errors import MarionetteException
+from unittest import skip
+
+from marionette.marionette_test import MarionetteTestCase, skip_if_desktop, skip_unless_protocol
+from marionette_driver.errors import MarionetteException, JavascriptException
 
 
 class TestEmulatorContent(MarionetteTestCase):
-
+    @skip_if_desktop
     def test_emulator_cmd(self):
         self.marionette.set_script_timeout(10000)
         expected = ["<build>",
                     "OK"]
         result = self.marionette.execute_async_script("""
         runEmulatorCmd("avd name", marionetteScriptFinished)
         """);
         self.assertEqual(result, expected)
 
+    @skip_if_desktop
     def test_emulator_shell(self):
         self.marionette.set_script_timeout(10000)
         expected = ["Hello World!"]
         result = self.marionette.execute_async_script("""
         runEmulatorShell(["echo", "Hello World!"], marionetteScriptFinished)
         """);
         self.assertEqual(result, expected)
 
+    @skip_if_desktop
     def test_emulator_order(self):
         self.marionette.set_script_timeout(10000)
         self.assertRaises(MarionetteException,
                           self.marionette.execute_async_script,
         """runEmulatorCmd("gsm status", function(result) {});
            marionetteScriptFinished(true);
         """);
 
 
 class TestEmulatorChrome(TestEmulatorContent):
-
     def setUp(self):
         super(TestEmulatorChrome, self).setUp()
         self.marionette.set_context("chrome")
 
 
 class TestEmulatorScreen(MarionetteTestCase):
-
-    def setUp(self):
-        MarionetteTestCase.setUp(self)
-
+    @skip_if_desktop
+    def test_emulator_orientation(self):
         self.screen = self.marionette.emulator.screen
         self.screen.initialize()
 
-    def test_emulator_orientation(self):
         self.assertEqual(self.screen.orientation, self.screen.SO_PORTRAIT_PRIMARY,
                          'Orientation has been correctly initialized.')
 
         self.screen.orientation = self.screen.SO_PORTRAIT_SECONDARY
         self.assertEqual(self.screen.orientation, self.screen.SO_PORTRAIT_SECONDARY,
                          'Orientation has been set to portrait-secondary')
 
         self.screen.orientation = self.screen.SO_LANDSCAPE_PRIMARY
@@ -63,8 +63,75 @@ class TestEmulatorScreen(MarionetteTestC
 
         self.screen.orientation = self.screen.SO_LANDSCAPE_SECONDARY
         self.assertEqual(self.screen.orientation, self.screen.SO_LANDSCAPE_SECONDARY,
                          'Orientation has been set to landscape-secondary')
 
         self.screen.orientation = self.screen.SO_PORTRAIT_PRIMARY
         self.assertEqual(self.screen.orientation, self.screen.SO_PORTRAIT_PRIMARY,
                          'Orientation has been set to portrait-primary')
+
+
+class TestEmulatorCallbacks(MarionetteTestCase):
+    def setUp(self):
+        MarionetteTestCase.setUp(self)
+        self.original_emulator_cmd = self.marionette._emulator_cmd
+        self.original_emulator_shell = self.marionette._emulator_shell
+        self.marionette._emulator_cmd = self.mock_emulator_cmd
+        self.marionette._emulator_shell = self.mock_emulator_shell
+
+    def tearDown(self):
+        self.marionette._emulator_cmd = self.original_emulator_cmd
+        self.marionette._emulator_shell = self.original_emulator_shell
+
+    def mock_emulator_cmd(self, *args):
+        return self.marionette._send_emulator_result("cmd response")
+
+    def mock_emulator_shell(self, *args):
+        return self.marionette._send_emulator_result("shell response")
+
+    def _execute_emulator(self, action, args):
+        script = "%s(%s, function(res) { marionetteScriptFinished(res); })" % (action, args)
+        return self.marionette.execute_async_script(script)
+
+    def emulator_cmd(self, cmd):
+        return self._execute_emulator("runEmulatorCmd", escape(cmd))
+
+    def emulator_shell(self, *args):
+        js_args = ", ".join(map(escape, args))
+        js_args = "[%s]" % js_args
+        return self._execute_emulator("runEmulatorShell", js_args)
+
+    def test_emulator_cmd_content(self):
+        with self.marionette.using_context("content"):
+            res = self.emulator_cmd("yo")
+            self.assertEqual("cmd response", res)
+
+    def test_emulator_shell_content(self):
+        with self.marionette.using_context("content"):
+            res = self.emulator_shell("first", "second")
+            self.assertEqual("shell response", res)
+
+    @skip_unless_protocol(lambda level: level >= 3)
+    def test_emulator_result_error_content(self):
+        with self.marionette.using_context("content"):
+            with self.assertRaisesRegexp(JavascriptException, "TypeError"):
+                self.marionette.execute_async_script("runEmulatorCmd()")
+
+    def test_emulator_cmd_chrome(self):
+        with self.marionette.using_context("chrome"):
+            res = self.emulator_cmd("yo")
+            self.assertEqual("cmd response", res)
+
+    def test_emulator_shell_chrome(self):
+        with self.marionette.using_context("chrome"):
+            res = self.emulator_shell("first", "second")
+            self.assertEqual("shell response", res)
+
+    @skip_unless_protocol(lambda level: level >= 3)
+    def test_emulator_result_error_chrome(self):
+        with self.marionette.using_context("chrome"):
+            with self.assertRaisesRegexp(JavascriptException, "TypeError"):
+                self.marionette.execute_async_script("runEmulatorCmd()")
+
+
+def escape(word):
+    return "'%s'" % word
new file mode 100644
--- /dev/null
+++ b/testing/marionette/client/marionette/tests/unit/test_transport.py
@@ -0,0 +1,181 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+import json
+from marionette import MarionetteTestCase, skip_unless_protocol
+from marionette_transport import (
+    Command,
+    Proto2Command,
+    Proto2Response,
+    Response
+)
+
+get_current_url = ("getCurrentUrl", None)
+execute_script = ("executeScript", {"script": "return 42"})
+
+
+class TestMessageSequencing(MarionetteTestCase):
+    @property
+    def last_id(self):
+        return self.marionette.client.last_id
+
+    @last_id.setter
+    def last_id(self, new_id):
+        self.marionette.client.last_id = new_id
+
+    def send(self, name, params):
+        self.last_id = self.last_id + 1
+        cmd = Command(self.last_id, name, params)
+        self.marionette.client.send(cmd)
+        return self.last_id
+
+    @skip_unless_protocol(lambda level: level >= 3)
+    def test_discard_older_messages(self):
+        first = self.send(*get_current_url)
+        second = self.send(*execute_script)
+        resp = self.marionette.client.receive()
+        self.assertEqual(second, resp.id)
+
+    @skip_unless_protocol(lambda level: level >= 3)
+    def test_last_id_incremented(self):
+        before = self.last_id
+        self.send(*get_current_url)
+        self.assertGreater(self.last_id, before)
+
+
+class MessageTestCase(MarionetteTestCase):
+    def assert_attr(self, obj, attr):
+        self.assertTrue(hasattr(obj, attr),
+                        "object does not have attribute %s" % attr)
+
+
+class TestCommand(MessageTestCase):
+    def create(self, msgid="msgid", name="name", params="params"):
+        return Command(msgid, name, params)
+
+    def test_initialise(self):
+        cmd = self.create()
+        self.assert_attr(cmd, "id")
+        self.assert_attr(cmd, "name")
+        self.assert_attr(cmd, "params")
+        self.assertEqual("msgid", cmd.id)
+        self.assertEqual("name", cmd.name)
+        self.assertEqual("params", cmd.params)
+
+    def test_stringify(self):
+        cmd = self.create()
+        string = str(cmd)
+        self.assertIn("Command", string)
+        self.assertIn("id=msgid", string)
+        self.assertIn("name=name", string)
+        self.assertIn("params=params", string)
+
+    def test_to_msg(self):
+        cmd = self.create()
+        msg = json.loads(cmd.to_msg())
+        self.assertEquals(msg[0], Command.TYPE)
+        self.assertEquals(msg[1], "msgid")
+        self.assertEquals(msg[2], "name")
+        self.assertEquals(msg[3], "params")
+
+    def test_from_msg(self):
+        msg = [Command.TYPE, "msgid", "name", "params"]
+        payload = json.dumps(msg)
+        cmd = Command.from_msg(payload)
+        self.assertEquals(msg[1], cmd.id)
+        self.assertEquals(msg[2], cmd.name)
+        self.assertEquals(msg[3], cmd.params)
+
+
+class TestResponse(MessageTestCase):
+    def create(self, msgid="msgid", error="error", result="result"):
+        return Response(msgid, error, result)
+
+    def test_initialise(self):
+        resp = self.create()
+        self.assert_attr(resp, "id")
+        self.assert_attr(resp, "error")
+        self.assert_attr(resp, "result")
+        self.assertEqual("msgid", resp.id)
+        self.assertEqual("error", resp.error)
+        self.assertEqual("result", resp.result)
+
+    def test_stringify(self):
+        resp = self.create()
+        string = str(resp)
+        self.assertIn("Response", string)
+        self.assertIn("id=msgid", string)
+        self.assertIn("error=error", string)
+        self.assertIn("result=result", string)
+
+    def test_to_msg(self):
+        resp = self.create()
+        msg = json.loads(resp.to_msg())
+        self.assertEquals(msg[0], Response.TYPE)
+        self.assertEquals(msg[1], "msgid")
+        self.assertEquals(msg[2], "error")
+        self.assertEquals(msg[3], "result")
+
+    def test_from_msg(self):
+        msg = [Response.TYPE, "msgid", "error", "result"]
+        payload = json.dumps(msg)
+        resp = Response.from_msg(payload)
+        self.assertEquals(msg[1], resp.id)
+        self.assertEquals(msg[2], resp.error)
+        self.assertEquals(msg[3], resp.result)
+
+
+class TestProto2Command(MessageTestCase):
+    def create(self, name="name", params="params"):
+        return Proto2Command(name, params)
+
+    def test_initialise(self):
+        cmd = self.create()
+        self.assert_attr(cmd, "id")
+        self.assert_attr(cmd, "name")
+        self.assert_attr(cmd, "params")
+        self.assertEqual(None, cmd.id)
+        self.assertEqual("name", cmd.name)
+        self.assertEqual("params", cmd.params)
+
+    def test_from_data_emulator_cmd(self):
+        data = {"emulator_cmd": "emulator_cmd"}
+        cmd = Proto2Command.from_data(data)
+        self.assertEqual("runEmulatorCmd", cmd.name)
+        self.assertEqual(data, cmd.params)
+
+    def test_from_data_emulator_shell(self):
+        data = {"emulator_shell": "emulator_shell"}
+        cmd = Proto2Command.from_data(data)
+        self.assertEqual("runEmulatorShell", cmd.name)
+        self.assertEqual(data, cmd.params)
+
+    def test_from_data_unknown(self):
+        with self.assertRaises(ValueError):
+            cmd = Proto2Command.from_data({})
+
+
+class TestProto2Response(MessageTestCase):
+    def create(self, error="error", result="result"):
+        return Proto2Response(error, result)
+
+    def test_initialise(self):
+        resp = self.create()
+        self.assert_attr(resp, "id")
+        self.assert_attr(resp, "error")
+        self.assert_attr(resp, "result")
+        self.assertEqual(None, resp.id)
+        self.assertEqual("error", resp.error)
+        self.assertEqual("result", resp.result)
+
+    def test_from_data_error(self):
+        data = {"error": "error"}
+        resp = Proto2Response.from_data(data)
+        self.assertEqual(data, resp.error)
+        self.assertEqual(None, resp.result)
+
+    def test_from_data_result(self):
+        resp = Proto2Response.from_data("result")
+        self.assertEqual(None, resp.error)
+        self.assertEqual("result", resp.result)
--- a/testing/marionette/client/marionette/tests/unit/unit-tests.ini
+++ b/testing/marionette/client/marionette/tests/unit/unit-tests.ini
@@ -46,17 +46,16 @@ b2g = false
 [test_text_chrome.py]
 disabled = "Bug 896046"
 
 [test_clearing.py]
 [test_typing.py]
 
 [test_log.py]
 [test_emulator.py]
-browser = false
 qemu = true
 
 [test_about_pages.py]
 b2g = false
 
 [test_execute_async_script.py]
 [test_execute_script.py]
 [test_simpletest_fail.js]
--- a/testing/marionette/driver/marionette_driver/marionette.py
+++ b/testing/marionette/driver/marionette_driver/marionette.py
@@ -10,26 +10,27 @@ import socket
 import StringIO
 import traceback
 import warnings
 
 from contextlib import contextmanager
 
 from decorators import do_crash_check
 from keys import Keys
-from marionette_transport import MarionetteTransport
+import marionette_transport as transport
 
 from mozrunner import B2GEmulatorRunner
 
 import geckoinstance
 import errors
 
 WEBELEMENT_KEY = "ELEMENT"
 W3C_WEBELEMENT_KEY = "element-6066-11e4-a52e-4f735466cecf"
 
+
 class HTMLElement(object):
     """
     Represents a DOM Element.
     """
 
     def __init__(self, marionette, id):
         self.marionette = marionette
         assert(id is not None)
@@ -618,26 +619,25 @@ class Marionette(object):
             self.runner = B2GEmulatorRunner(b2g_home=homedir,
                                             logdir=logdir,
                                             process_args=process_args)
             self.emulator = self.runner.device
             self.emulator.connect()
             self.port = self.emulator.setup_port_forwarding(remote_port=self.port)
             assert(self.emulator.wait_for_port(self.port)), "Timed out waiting for port!"
 
-        self.client = MarionetteTransport(
-            self.host,
-            self.port,
-            self.socket_timeout)
-
         if emulator:
             if busybox:
                 self.emulator.install_busybox(busybox=busybox)
             self.emulator.wait_for_system_message(self)
 
+        # for callbacks from a protocol level 2 or lower remote,
+        # we store the callback ID so it can be used by _send_emulator_result
+        self.emulator_callback_id = None
+
     def cleanup(self):
         if self.session:
             try:
                 self.delete_session()
             except (errors.MarionetteException, socket.error, IOError):
                 # These exceptions get thrown if the Marionette server
                 # hit an exception/died or the connection died. We can
                 # do no further server-side cleanup in this case.
@@ -662,116 +662,120 @@ class Marionette(object):
             s.bind((host, port))
             return True
         except socket.error:
             return False
         finally:
             s.close()
 
     def wait_for_port(self, timeout=60):
-        return MarionetteTransport.wait_for_port(self.host,
-                                                 self.port,
-                                                 timeout=timeout)
+        return transport.wait_for_port(self.host, self.port, timeout=timeout)
 
     @do_crash_check
-    def _send_message(self, command, body=None, key=None):
-        if not self.session_id and command != "newSession":
+    def _send_message(self, name, params=None, key=None):
+        if not self.session_id and name != "newSession":
             raise errors.MarionetteException("Please start a session")
 
-        message = {"name": command}
-        if body:
-            message["parameters"] = body
+        try:
+            if self.protocol < 3:
+                data = {"name": name}
+                if params:
+                    data["parameters"] = params
+                self.client.send(data)
+                msg = self.client.receive()
 
-        packet = json.dumps(message)
+            else:
+                msg = self.client.request(name, params)
 
-        try:
-            resp = self.client.send(packet)
         except IOError:
             if self.instance and not hasattr(self.instance, 'detached'):
                 # If we've launched the binary we've connected to, wait
                 # for it to shut down.
                 returncode = self.instance.runner.wait(timeout=self.DEFAULT_STARTUP_TIMEOUT)
                 raise IOError("process died with returncode %s" % returncode)
             raise
         except socket.timeout:
             self.session = None
             self.window = None
             self.client.close()
             raise errors.TimeoutException("Connection timed out")
 
-        # Process any emulator commands that are sent from a script
-        # while it's executing
-        if isinstance(resp, dict) and any (k in resp for k in ("emulator_cmd", "emulator_shell")):
-            while True:
-                id = resp.get("id")
-                cmd = resp.get("emulator_cmd")
-                shell = resp.get("emulator_shell")
-                if cmd:
-                    resp = self._emulator_cmd(id, cmd)
-                    continue
-                if shell:
-                    resp = self._emulator_shell(id, shell)
-                    continue
-                break
+        if isinstance(msg, transport.Command):
+            if msg.name == "runEmulatorCmd":
+                self.emulator_callback_id = msg.params.get("id")
+                msg = self._emulator_cmd(msg.params["emulator_cmd"])
+            elif msg.name == "runEmulatorShell":
+                self.emulator_callback_id = msg.params.get("id")
+                msg = self._emulator_shell(msg.params["emulator_shell"])
+            else:
+                raise IOError("Unknown command: %s" % msg)
 
-        if "error" in resp:
-            self._handle_error(resp)
+        res, err = msg.result, msg.error
+        if err:
+            self._handle_error(err)
 
         if key is not None:
-            return self._unwrap_response(resp.get(key))
+            return self._unwrap_response(res.get(key))
         else:
-            return self._unwrap_response(resp)
+            return self._unwrap_response(res)
 
     def _unwrap_response(self, value):
         if isinstance(value, dict) and \
         (WEBELEMENT_KEY in value or W3C_WEBELEMENT_KEY in value):
             if value.get(WEBELEMENT_KEY):
                 return HTMLElement(self, value.get(WEBELEMENT_KEY))
             else:
                 return HTMLElement(self, value.get(W3C_WEBELEMENT_KEY))
         elif isinstance(value, list):
             return list(self._unwrap_response(item) for item in value)
         else:
             return value
 
-    def _emulator_cmd(self, id, cmd):
+    def _emulator_cmd(self, cmd):
         if not self.emulator:
             raise errors.MarionetteException(
                 "No emulator in this test to run command against")
         payload = cmd.encode("ascii")
         result = self.emulator._run_telnet(payload)
-        return self._send_emulator_result(id, result)
+        return self._send_emulator_result(result)
 
-    def _emulator_shell(self, id, args):
+    def _emulator_shell(self, args):
         if not isinstance(args, list) or not self.emulator:
             raise errors.MarionetteException(
                 "No emulator in this test to run shell command against")
         buf = StringIO.StringIO()
         self.emulator.dm.shell(args, buf)
         result = str(buf.getvalue()[0:-1]).rstrip().splitlines()
         buf.close()
-        return self._send_emulator_result(id, result)
+        return self._send_emulator_result(result)
 
-    def _send_emulator_result(self, id, result):
-        return self.client.send(json.dumps({"name": "emulatorCmdResult",
-                                            "id": id,
-                                            "result": result}))
+    def _send_emulator_result(self, result):
+        if self.protocol < 3:
+            body = {"name": "emulatorCmdResult",
+                    "id": self.emulator_callback_id,
+                    "result": result}
+            self.client.send(body)
+            return self.client.receive()
+        else:
+            return self.client.respond(result)
 
-    def _handle_error(self, resp):
+    def _handle_error(self, obj):
         if self.protocol == 1:
-            if "error" not in resp or not isinstance(resp["error"], dict):
+            if "error" not in obj or not isinstance(obj["error"], dict):
                 raise errors.MarionetteException(
-                    "Malformed packet, expected key 'error' to be a dict: %s" % resp)
-            error = resp["error"].get("status")
-            message = resp["error"].get("message")
-            stacktrace = resp["error"].get("stacktrace")
+                    "Malformed packet, expected key 'error' to be a dict: %s" % obj)
+            error = obj["error"].get("status")
+            message = obj["error"].get("message")
+            stacktrace = obj["error"].get("stacktrace")
+
         else:
-            error = resp["error"]
-            message = resp["message"]
-            stacktrace = resp["stacktrace"]
+            error = obj["error"]
+            message = obj["message"]
+            stacktrace = obj["stacktrace"]
+
         raise errors.lookup(error)(message, stacktrace=stacktrace)
 
     def _reset_timeouts(self):
         if self.timeout is not None:
             self.timeouts(self.TIMEOUT_SEARCH, self.timeout)
             self.timeouts(self.TIMEOUT_SCRIPT, self.timeout)
             self.timeouts(self.TIMEOUT_PAGE, self.timeout)
         else:
@@ -1127,16 +1131,20 @@ class Marionette(object):
 
         :returns: A dict of the capabilities offered."""
         if self.instance:
             returncode = self.instance.runner.process_handler.proc.returncode
             if returncode is not None:
                 # We're managing a binary which has terminated, so restart it.
                 self.instance.restart()
 
+        self.client = transport.TcpTransport(
+            self.host,
+            self.port,
+            self.socket_timeout)
         self.protocol, _ = self.client.connect()
         self.wait_for_port(timeout=timeout)
 
         body = {"capabilities": desired_capabilities, "sessionId": session_id}
         resp = self._send_message("newSession", body)
 
         self.session_id = resp["sessionId"]
         self.session = resp["value"] if self.protocol == 1 else resp["capabilities"]
@@ -1303,17 +1311,16 @@ class Marionette(object):
 
         :param context: Context, may be one of the class properties
             `CONTEXT_CHROME` or `CONTEXT_CONTENT`.
 
         Usage example::
 
             marionette.set_context(marionette.CONTEXT_CHROME)
         """
-        assert(context == self.CONTEXT_CHROME or context == self.CONTEXT_CONTENT)
         if context not in [self.CONTEXT_CHROME, self.CONTEXT_CONTENT]:
             raise ValueError("Unknown context: %s" % context)
         self._send_message("setContext", {"value": context})
 
     @contextmanager
     def using_context(self, context):
         """Sets the context that Marionette commands are running in using
         a `with` statement. The state of the context on the server is
--- a/testing/marionette/transport/marionette_transport/__init__.py
+++ b/testing/marionette/transport/marionette_transport/__init__.py
@@ -1,8 +1,7 @@
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/
 
 __version__ = '0.7.1'
 
-
-from transport import MarionetteTransport
+from transport import *
--- a/testing/marionette/transport/marionette_transport/transport.py
+++ b/testing/marionette/transport/marionette_transport/transport.py
@@ -2,136 +2,309 @@
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 
 import datetime
 import errno
 import json
 import socket
 import time
+import types
+
+
+class Message(object):
+    def __init__(self, msgid):
+        self.id = msgid
+
+    def __eq__(self, other):
+        return self.id == other.id
+
+
+class Command(Message):
+    TYPE = 0
+
+    def __init__(self, msgid, name, params):
+        Message.__init__(self, msgid)
+        self.name = name
+        self.params = params
+
+    def __str__(self):
+        return "<Command id=%s, name=%s, params=%s>" % (self.id, self.name, self.params)
+
+    def to_msg(self):
+        msg = [Command.TYPE, self.id, self.name, self.params]
+        return json.dumps(msg)
+
+    @staticmethod
+    def from_msg(payload):
+        data = json.loads(payload)
+        assert data[0] == Command.TYPE
+        cmd = Command(data[1], data[2], data[3])
+        return cmd
 
 
-class MarionetteTransport(object):
-    """The Marionette socket client.  This speaks the same protocol
-    as the remote debugger inside Gecko, in which messages are always
-    preceded by the message length and a colon, e.g.:
+class Response(Message):
+    TYPE = 1
+
+    def __init__(self, msgid, error, result):
+        Message.__init__(self, msgid)
+        self.error = error
+        self.result = result
+
+    def __str__(self):
+        return "<Response id=%s, error=%s, result=%s>" % (self.id, self.error, self.result)
 
-        20:{"command": "test"}
+    def to_msg(self):
+       msg = [Response.TYPE, self.id, self.error, self.result]
+       return json.dumps(msg)
+
+    @staticmethod
+    def from_msg(payload):
+        data = json.loads(payload)
+        assert data[0] == Response.TYPE
+        return Response(data[1], data[2], data[3])
+
+
+class Proto2Command(Command):
+    """Compatibility shim that marshals messages from a protocol level
+    2 and below remote into ``Command`` objects.
     """
 
+    def __init__(self, name, params):
+        Command.__init__(self, None, name, params)
+
+    @staticmethod
+    def from_data(data):
+        if "emulator_cmd" in data:
+            name = "runEmulatorCmd"
+        elif "emulator_shell" in data:
+            name = "runEmulatorShell"
+        else:
+            raise ValueError
+        return Proto2Command(name, data)
+
+
+class Proto2Response(Response):
+    """Compatibility shim that marshals messages from a protocol level
+    2 and below remote into ``Response`` objects.
+    """
+
+    def __init__(self, error, result):
+        Response.__init__(self, None, error, result)
+
+    @staticmethod
+    def from_data(data):
+        err, res = None, None
+        if "error" in data:
+            err = data
+        else:
+            res = data
+        return Proto2Response(err, res)
+
+
+class TcpTransport(object):
+    """Socket client that communciates with Marionette via TCP.
+
+    It speaks the protocol of the remote debugger in Gecko, in which
+    messages are always preceded by the message length and a colon, e.g.:
+
+        7:MESSAGE
+
+    On top of this protocol it uses a Marionette message format, that
+    depending on the protocol level offered by the remote server, varies.
+    Supported protocol levels are 1 and above.
+    """
     max_packet_length = 4096
     connection_lost_msg = "Connection to Marionette server is lost. Check gecko.log (desktop firefox) or logcat (b2g) for errors."
 
     def __init__(self, addr, port, socket_timeout=360.0):
         self.addr = addr
         self.port = port
         self.socket_timeout = socket_timeout
-        self.sock = None
+
         self.protocol = 1
         self.application_type = None
+        self.last_id = 0
+        self.expected_responses = []
+
+        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.sock.settimeout(self.socket_timeout)
 
     def _recv_n_bytes(self, n):
-        """Convenience method for receiving exactly n bytes from self.sock
-        (assuming it's open and connected).
-        """
         data = ""
         while len(data) < n:
             chunk = self.sock.recv(n - len(data))
             if chunk == "":
                 break
             data += chunk
         return data
 
-    def receive(self):
-        """Receive the next complete response from the server, and
-        return it as a JSON structure.  Each response from the server
-        is prepended by len(message) + ":".
+    def _unmarshal(self, packet):
+        msg = None
+
+        # protocol 3 and above
+        if self.protocol >= 3:
+            typ = int(packet[1])
+            if typ == Command.TYPE:
+                msg = Command.from_msg(packet)
+            elif typ == Response.TYPE:
+                msg = Response.from_msg(packet)
+
+        # protocol 2 and below
+        else:
+            data = json.loads(packet)
+
+            # emulator callbacks
+            if isinstance(data, dict) and any(k in data for k in ("emulator_cmd", "emulator_shell")):
+                msg = Proto2Command.from_data(data)
+
+            # everything else
+            else:
+                msg = Proto2Response.from_data(data)
+
+        return msg
+
+    def receive(self, unmarshal=True):
+        """Wait for the next complete response from the remote.
+
+        :param unmarshal: Default is to deserialise the packet and
+            return a ``Message`` type.  Setting this to false will return
+            the raw packet.
         """
-        assert(self.sock)
         now = time.time()
-        response = ''
+        data = ""
         bytes_to_recv = 10
+
         while time.time() - now < self.socket_timeout:
             try:
-                data = self.sock.recv(bytes_to_recv)
-                response += data
+                chunk = self.sock.recv(bytes_to_recv)
+                data += chunk
             except socket.timeout:
                 pass
             else:
-                if not data:
+                if not chunk:
                     raise IOError(self.connection_lost_msg)
-            sep = response.find(':')
+
+            sep = data.find(":")
             if sep > -1:
-                length = response[0:sep]
-                remaining = response[sep + 1:]
+                length = data[0:sep]
+                remaining = data[sep + 1:]
+
                 if len(remaining) == int(length):
-                    return json.loads(remaining)
+                    if unmarshal:
+                        msg = self._unmarshal(remaining)
+                        self.last_id = msg.id
+
+                        if isinstance(msg, Response) and self.protocol >= 3:
+                            if msg not in self.expected_responses:
+                                raise Exception("Received unexpected response: %s" % msg)
+                            else:
+                                self.expected_responses.remove(msg)
+
+                        return msg
+                    else:
+                        return remaining
+
                 bytes_to_recv = int(length) - len(remaining)
-        raise socket.timeout('connection timed out after %d s' % self.socket_timeout)
+
+        raise socket.timeout("connection timed out after %ds" % self.socket_timeout)
 
     def connect(self):
         """Connect to the server and process the hello message we expect
         to receive in response.
 
-        Return a tuple of the protocol level and the application type.
+        Returns a tuple of the protocol level and the application type.
         """
-        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        self.sock.settimeout(self.socket_timeout)
         try:
             self.sock.connect((self.addr, self.port))
         except:
             # Unset self.sock so that the next attempt to send will cause
             # another connection attempt.
             self.sock = None
             raise
+
         self.sock.settimeout(2.0)
 
-        hello = self.receive()
+        # first packet is always a JSON Object
+        # which we can use to tell which protocol level we are at
+        raw = self.receive(unmarshal=False)
+        hello = json.loads(raw)
         self.protocol = hello.get("marionetteProtocol", 1)
         self.application_type = hello.get("applicationType")
 
         return (self.protocol, self.application_type)
 
-    def send(self, data):
-        """Send a message on the socket, prepending it with len(msg) + ":"."""
+    def send(self, obj):
+        """Send message to the remote server.  Allowed input is a
+        ``Message`` instance or a JSON serialisable object.
+        """
         if not self.sock:
             self.connect()
-        data = "%s:%s" % (len(data), data)
 
-        for packet in [data[i:i + self.max_packet_length] for i in
-                       range(0, len(data), self.max_packet_length)]:
+        if isinstance(obj, Message):
+            data = obj.to_msg()
+            self.expected_responses.append(obj)
+        else:
+            data = json.dumps(obj)
+        payload = "%s:%s" % (len(data), data)
+
+        for packet in [payload[i:i + self.max_packet_length] for i in
+                       range(0, len(payload), self.max_packet_length)]:
             try:
                 self.sock.send(packet)
             except IOError as e:
                 if e.errno == errno.EPIPE:
                     raise IOError("%s: %s" % (str(e), self.connection_lost_msg))
                 else:
                     raise e
 
+    def respond(self, obj):
+        """Send a response to a command.  This can be an arbitrary JSON
+        serialisable object or an ``Exception``.
+        """
+        res, err = None, None
+        if isinstance(obj, Exception):
+            err = obj
+        else:
+            res = obj
+        msg = Response(self.last_id, err, res)
+        self.send(msg)
+        return self.receive()
+
+    def request(self, name, params):
+        """Sends a message to the remote server and waits for a response
+        to come back.
+        """
+        self.last_id = self.last_id + 1
+        cmd = Command(self.last_id, name, params)
+        self.send(cmd)
         return self.receive()
 
     def close(self):
         """Close the socket."""
         if self.sock:
             self.sock.close()
+
+    def __del__(self):
+        self.close()
         self.sock = None
 
-    @staticmethod
-    def wait_for_port(host, port, timeout=60):
-        """ Wait for the specified Marionette host/port to be available."""
-        starttime = datetime.datetime.now()
-        poll_interval = 0.1
-        while datetime.datetime.now() - starttime < datetime.timedelta(seconds=timeout):
-            sock = None
-            try:
-                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-                sock.connect((host, port))
-                data = sock.recv(16)
+
+def wait_for_port(host, port, timeout=60):
+    """ Wait for the specified host/port to be available."""
+    starttime = datetime.datetime.now()
+    poll_interval = 0.1
+    while datetime.datetime.now() - starttime < datetime.timedelta(seconds=timeout):
+        sock = None
+        try:
+            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            sock.connect((host, port))
+            data = sock.recv(16)
+            sock.close()
+            if ':' in data:
+                return True
+        except socket.error:
+            pass
+        finally:
+            if sock:
                 sock.close()
-                if ':' in data:
-                    return True
-            except socket.error:
-                pass
-            finally:
-                if sock:
-                    sock.close()
-            time.sleep(poll_interval)
-        return False
+        time.sleep(poll_interval)
+    return False