Bug 1338600: Add support for COM asynchronous interfaces to mscom; r?jimm draft
authorAaron Klotz <aklotz@mozilla.com>
Wed, 15 Feb 2017 14:09:02 -0700
changeset 485455 20f128505b569b1421f9b7b75c2787570bb2ab67
parent 485454 dd4a91eab0e8efbd81335a40c35b6b7b6507d6fb
child 546018 e937fdb0c7c0118afc80ca0a4964a061a08a2741
push id45733
push useraklotz@mozilla.com
push dateThu, 16 Feb 2017 17:45:58 +0000
reviewersjimm
bugs1338600
milestone54.0a1
Bug 1338600: Add support for COM asynchronous interfaces to mscom; r?jimm MozReview-Commit-ID: EcbeH9KSZrQ
ipc/mscom/Aggregation.h
ipc/mscom/AsyncInvoker.h
ipc/mscom/moz.build
new file mode 100644
--- /dev/null
+++ b/ipc/mscom/Aggregation.h
@@ -0,0 +1,51 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_Aggregation_h
+#define mozilla_mscom_Aggregation_h
+
+#include "mozilla/Attributes.h"
+
+namespace mozilla {
+namespace mscom {
+
+/**
+ * This is used for stabilizing a COM object's reference count during
+ * construction when that object aggregates other objects. Since the aggregated
+ * object(s) may AddRef() or Release(), we need to artifically boost the
+ * refcount to prevent premature destruction. Note that we increment/decrement
+ * instead of AddRef()/Release() in this class because we want to adjust the
+ * refcount without causing any other side effects (like object destruction).
+ */
+template <typename RefCntT>
+class MOZ_RAII StabilizedRefCount
+{
+public:
+  explicit StabilizedRefCount(RefCntT& aRefCnt)
+    : mRefCnt(aRefCnt)
+  {
+    ++aRefCnt;
+  }
+
+  ~StabilizedRefCount()
+  {
+    --mRefCnt;
+  }
+
+  StabilizedRefCount(const StabilizedRefCount&) = delete;
+  StabilizedRefCount(StabilizedRefCount&&) = delete;
+  StabilizedRefCount& operator=(const StabilizedRefCount&) = delete;
+  StabilizedRefCount& operator=(StabilizedRefCount&&) = delete;
+
+private:
+  RefCntT& mRefCnt;
+};
+
+} // namespace mscom
+} // namespace mozilla
+
+#endif // mozilla_mscom_Aggregation_h
+
new file mode 100644
--- /dev/null
+++ b/ipc/mscom/AsyncInvoker.h
@@ -0,0 +1,299 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef mozilla_mscom_AsyncInvoker_h
+#define mozilla_mscom_AsyncInvoker_h
+
+#include "mozilla/Assertions.h"
+#include "mozilla/Attributes.h"
+#include "mozilla/DebugOnly.h"
+#include "mozilla/Move.h"
+#include "mozilla/mscom/Aggregation.h"
+#include "mozilla/mscom/Utils.h"
+#include "mozilla/Mutex.h"
+#include "nsISupportsImpl.h"
+
+#include <objidl.h>
+#include <windows.h>
+
+namespace mozilla {
+namespace mscom {
+namespace detail {
+
+template <typename AsyncInterface>
+class ForgettableAsyncCall : public ISynchronize
+{
+public:
+  explicit ForgettableAsyncCall(ICallFactory* aCallFactory)
+    : mRefCnt(0)
+    , mAsyncCall(nullptr)
+  {
+    StabilizedRefCount<Atomic<ULONG>> stabilizer(mRefCnt);
+
+    HRESULT hr =
+      aCallFactory->CreateCall(__uuidof(AsyncInterface), this,
+                               IID_IUnknown, getter_AddRefs(mInnerUnk));
+    if (FAILED(hr)) {
+      return;
+    }
+
+    hr = mInnerUnk->QueryInterface(__uuidof(AsyncInterface),
+                                   reinterpret_cast<void**>(&mAsyncCall));
+    if (SUCCEEDED(hr)) {
+      // Don't hang onto a ref. Because mAsyncCall is aggregated, its refcount
+      // is this->mRefCnt, so we'd create a cycle!
+      mAsyncCall->Release();
+    }
+  }
+
+  AsyncInterface* GetInterface() const
+  {
+    return mAsyncCall;
+  }
+
+  // IUnknown
+  STDMETHODIMP QueryInterface(REFIID aIid, void** aOutInterface) override
+  {
+    if (aIid == IID_IUnknown || aIid == IID_ISynchronize) {
+      RefPtr<ISynchronize> ptr(this);
+      ptr.forget(aOutInterface);
+      return S_OK;
+    }
+
+    return mInnerUnk->QueryInterface(aIid, aOutInterface);
+  }
+
+  STDMETHODIMP_(ULONG) AddRef() override
+  {
+    ULONG result = ++mRefCnt;
+    NS_LOG_ADDREF(this, result, "ForgettableAsyncCall", sizeof(*this));
+    return result;
+  }
+
+  STDMETHODIMP_(ULONG) Release() override
+  {
+    ULONG result = --mRefCnt;
+    NS_LOG_RELEASE(this, result, "ForgettableAsyncCall");
+    if (!result) {
+      delete this;
+    }
+    return result;
+  }
+
+  // ISynchronize
+  STDMETHODIMP Wait(DWORD aFlags, DWORD aTimeoutMilliseconds) override
+  {
+    return E_NOTIMPL;
+  }
+
+  STDMETHODIMP Signal() override
+  {
+    // Even though this function is a no-op, we must return S_OK as opposed to
+    // E_NOTIMPL or else COM will consider the async call to have failed.
+    return S_OK;
+  }
+
+  STDMETHODIMP Reset() override
+  {
+    // Even though this function is a no-op, we must return S_OK as opposed to
+    // E_NOTIMPL or else COM will consider the async call to have failed.
+    return S_OK;
+  }
+
+protected:
+  virtual ~ForgettableAsyncCall() {}
+
+private:
+  Atomic<ULONG>     mRefCnt;
+  RefPtr<IUnknown>  mInnerUnk;
+  AsyncInterface*   mAsyncCall; // weak reference
+};
+
+template <typename AsyncInterface>
+class WaitableAsyncCall : public ForgettableAsyncCall<AsyncInterface>
+{
+public:
+  explicit WaitableAsyncCall(ICallFactory* aCallFactory)
+    : ForgettableAsyncCall(aCallFactory)
+    , mEvent(::CreateEventW(nullptr, FALSE, FALSE, nullptr))
+  {
+  }
+
+  STDMETHODIMP Wait(DWORD aFlags, DWORD aTimeoutMilliseconds) override
+  {
+    const DWORD waitStart = aTimeoutMilliseconds == INFINITE ? 0 :
+                            ::GetTickCount();
+    DWORD flags = aFlags;
+    if (XRE_IsContentProcess() && NS_IsMainThread()) {
+      flags |= COWAIT_ALERTABLE;
+    }
+
+    HRESULT hr;
+    DWORD signaledIdx;
+
+    DWORD elapsed = 0;
+
+    while (true) {
+      if (aTimeoutMilliseconds != INFINITE) {
+        elapsed = ::GetTickCount() - waitStart;
+      }
+      if (elapsed >= aTimeoutMilliseconds) {
+        return RPC_S_CALLPENDING;
+      }
+
+      ::SetLastError(ERROR_SUCCESS);
+
+      hr = ::CoWaitForMultipleHandles(flags, aTimeoutMilliseconds - elapsed, 1,
+                                      &mEvent, &signaledIdx);
+      if (hr == RPC_S_CALLPENDING || FAILED(hr)) {
+        return hr;
+      }
+
+      if (hr == S_OK && signaledIdx == 0) {
+        return hr;
+      }
+    }
+  }
+
+  STDMETHODIMP Signal() override
+  {
+    if (!::SetEvent(mEvent)) {
+      return HRESULT_FROM_WIN32(::GetLastError());
+    }
+    return S_OK;
+  }
+
+protected:
+  ~WaitableAsyncCall()
+  {
+    if (mEvent) {
+      ::CloseHandle(mEvent);
+    }
+  }
+
+private:
+  HANDLE mEvent;
+};
+
+template <typename AsyncInterface>
+class FireAndForgetInvoker
+{
+protected:
+  typedef ForgettableAsyncCall<AsyncInterface> AsyncCallType;
+
+  RefPtr<ForgettableAsyncCall<AsyncInterface>> mAsyncCall;
+};
+
+template <typename AsyncInterface>
+class WaitableInvoker
+{
+public:
+  HRESULT Wait(DWORD aTimeout = INFINITE) const
+  {
+    if (!mAsyncCall) {
+      return E_POINTER;
+    }
+
+    return mAsyncCall->Wait(0, aTimeout);
+  }
+
+protected:
+  typedef WaitableAsyncCall<AsyncInterface> AsyncCallType;
+
+  RefPtr<WaitableAsyncCall<AsyncInterface>> mAsyncCall;
+};
+
+} // namespace detail
+
+/**
+ * This class is intended for "fire-and-forget" asynchronous invocations of COM
+ * interfaces. This requires that an interface be annotated with the
+ * |async_uuid| attribute in midl.
+ *
+ * For example, let us suppose we have some IDL as such:
+ * [object, uuid(...), async_uuid(...)]
+ * interface IFoo : IUnknown
+ * {
+ *    HRESULT Bar(long baz);
+ * }
+ *
+ * Then, given an IFoo, we may construct an AsyncInvoker<AsyncIFoo>:
+ *
+ * IFoo* foo = ...;
+ * AsyncInvoker<AsyncIFoo> myInvoker(foo);
+ * HRESULT hr = myInvoker->Begin_Bar(7);
+ *
+ * Alternatively you may use the ASYNC_INVOKER_FOR macro, which automatically
+ * derives the name of the asynchronous interface from the name of the
+ * synchronous interface:
+ *
+ * ASYNC_INVOKER_FOR(IFoo) myInvoker(foo);
+ *
+ * This class may also be used when a synchronous COM call must be made that
+ * might reenter the content process. In this case, use the WaitableAsyncInvoker
+ * variant, or the WAITABLE_ASYNC_INVOKER_FOR macro:
+ *
+ * WAITABLE_ASYNC_INVOKER_FOR(Ifoo) myInvoker(foo);
+ * myInvoker->Begin_Bar(7);
+ * myInvoker.Wait(); // <-- Wait for the COM call to complete.
+ *
+ * In general you should avoid using the waitable version, but in some corner
+ * cases it is absolutely necessary in order to preserve correctness while
+ * avoiding deadlock.
+ */
+template <typename AsyncInterface,
+          template <typename Iface> class WaitPolicy = detail::FireAndForgetInvoker>
+class MOZ_RAII AsyncInvoker final : public WaitPolicy<AsyncInterface>
+{
+public:
+  /**
+   * @param aSyncProxy The COM object on which to invoke the asynchronous event.
+   *                   This object must be a proxy to the synchronous variant of
+   *                   AsyncInterface.
+   */
+  explicit AsyncInvoker(IUnknown* aSyncProxy)
+  {
+    MOZ_ASSERT(aSyncProxy);
+    MOZ_ASSERT(IsProxy(aSyncProxy));
+
+    RefPtr<ICallFactory> callFactory;
+    if (FAILED(aSyncProxy->QueryInterface(IID_ICallFactory,
+                                          getter_AddRefs(callFactory)))) {
+      return;
+    }
+
+    mAsyncCall = new AsyncCallType(callFactory);
+  }
+
+  explicit operator bool() const
+  {
+    return mAsyncCall && mAsyncCall->GetInterface();
+  }
+
+  AsyncInterface* operator->() const
+  {
+    return mAsyncCall->GetInterface();
+  }
+
+  AsyncInvoker(const AsyncInvoker& aOther) = delete;
+  AsyncInvoker(AsyncInvoker&& aOther) = delete;
+  AsyncInvoker& operator=(const AsyncInvoker& aOther) = delete;
+  AsyncInvoker& operator=(AsyncInvoker&& aOther) = delete;
+};
+
+template <typename AsyncInterface>
+using WaitableAsyncInvoker = AsyncInvoker<AsyncInterface, detail::WaitableInvoker>;
+
+} // namespace mscom
+} // namespace mozilla
+
+#define ASYNC_INVOKER_FOR(SyncIface) \
+  mozilla::mscom::AsyncInvoker<Async##SyncIface>
+
+#define WAITABLE_ASYNC_INVOKER_FOR(SyncIface) \
+  mozilla::mscom::WaitableAsyncInvoker<Async##SyncIface>
+
+#endif // mozilla_mscom_AsyncInvoker_h
--- a/ipc/mscom/moz.build
+++ b/ipc/mscom/moz.build
@@ -1,16 +1,18 @@
 # -*- Mode: python; indent-tabs-mode: nil; tab-width: 40 -*-
 # vim: set filetype=python:
 # This Source Code Form is subject to the terms of the Mozilla Public
 # License, v. 2.0. If a copy of the MPL was not distributed with this
 # file, You can obtain one at http://mozilla.org/MPL/2.0/.
 
 EXPORTS.mozilla.mscom += [
+    'Aggregation.h',
     'AgileReference.h',
+    'AsyncInvoker.h',
     'COMApartmentRegion.h',
     'COMPtrHolder.h',
     'EnsureMTA.h',
     'MainThreadRuntime.h',
     'ProxyStream.h',
     'Ptr.h',
     'Utils.h',
 ]