Bug 1174906 - Add a mozilla::Variant<T1, T2, ...> template class; r=Waldo
authorNick Fitzgerald <fitzgen@gmail.com>
Tue, 30 Jun 2015 10:01:00 -0700
changeset 250721 6509d3f6a91fd6856985adee8efd13a40f57de7a
parent 250720 6b9d80a126e59bf727a0856ab58070fc40c24cb3
child 250722 335268dfcd2cee23cb07f7c8497be80f3bcb0314
push idunknown
push userunknown
push dateunknown
reviewersWaldo
bugs1174906
milestone42.0a1
Bug 1174906 - Add a mozilla::Variant<T1, T2, ...> template class; r=Waldo
mfbt/Variant.h
mfbt/moz.build
mfbt/tests/TestVariant.cpp
mfbt/tests/moz.build
new file mode 100644
--- /dev/null
+++ b/mfbt/Variant.h
@@ -0,0 +1,354 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+/* A template class for tagged unions. */
+
+#include <new>
+
+#include "mozilla/Alignment.h"
+#include "mozilla/Assertions.h"
+#include "mozilla/Move.h"
+
+#ifndef mozilla_Variant_h
+#define mozilla_Variant_h
+
+namespace mozilla {
+
+template<typename... Ts>
+class Variant;
+
+namespace detail {
+
+// MaxSizeOf computes the maximum sizeof(T) for each T in Ts.
+
+template<typename T, typename... Ts>
+struct MaxSizeOf
+{
+  static const size_t size = sizeof(T) > MaxSizeOf<Ts...>::size
+    ? sizeof(T)
+    : MaxSizeOf<Ts...>::size;
+};
+
+template<typename T>
+struct MaxSizeOf<T>
+{
+  static const size_t size = sizeof(T);
+};
+
+// The `IsVariant` helper is used in conjunction with static_assert and
+// `mozilla::EnableIf` to catch passing non-variant types to `Variant::is<T>()`
+// and friends at compile time, rather than at runtime. It ensures that the
+// given type `Needle` is one of the types in the set of types `Haystack`.
+
+template<typename Needle, typename... Haystack>
+struct IsVariant;
+
+template<typename Needle>
+struct IsVariant<Needle>
+{
+  static const bool value = false;
+};
+
+template<typename Needle, typename... Haystack>
+struct IsVariant<Needle, Needle, Haystack...>
+{
+  static const bool value = true;
+};
+
+template<typename Needle, typename T, typename... Haystack>
+struct IsVariant<Needle, T, Haystack...> : public IsVariant<Needle, Haystack...> { };
+
+// TagHelper gets the given sentinel tag value for the given type T. This has to
+// be split out from VariantImplementation because you can't nest a partial template
+// specialization within a template class.
+
+template<size_t N, typename T, typename U, typename Next, bool isMatch>
+struct TagHelper;
+
+// In the case where T != U, we continue recursion.
+template<size_t N, typename T, typename U, typename Next>
+struct TagHelper<N, T, U, Next, false>
+{
+  static size_t tag() { return Next::template tag<U>(); }
+};
+
+// In the case where T == U, return the tag number.
+template<size_t N, typename T, typename U, typename Next>
+struct TagHelper<N, T, U, Next, true>
+{
+  static size_t tag() { return N; }
+};
+
+// The VariantImplementation template provides the guts of mozilla::Variant. We create
+// an VariantImplementation for each T in Ts... which handles construction,
+// destruction, etc for when the Variant's type is T. If the Variant's type is
+// not T, it punts the request on to the next VariantImplementation.
+
+template<size_t N, typename... Ts>
+struct VariantImplementation;
+
+// The singly typed Variant / recursion base case.
+template<size_t N, typename T>
+struct VariantImplementation<N, T> {
+  template<typename U>
+  static size_t tag() {
+    static_assert(mozilla::IsSame<T, U>::value,
+                  "mozilla::Variant: tag: bad type!");
+    return N;
+  }
+
+  template<typename Variant>
+  static void copyConstruct(void* aLhs, const Variant& aRhs) {
+    new (aLhs) T(aRhs.template as<T>());
+  }
+
+  template<typename Variant>
+  static void moveConstruct(void* aLhs, Variant&& aRhs) {
+    new (aLhs) T(aRhs.template extract<T>());
+  }
+
+  template<typename Variant>
+  static void destroy(Variant& aV) {
+    aV.template as<T>().~T();
+  }
+};
+
+// VariantImplementation for some variant type T.
+template<size_t N, typename T, typename... Ts>
+struct VariantImplementation<N, T, Ts...>
+{
+  // The next recursive VariantImplementation.
+  using Next = VariantImplementation<N + 1, Ts...>;
+
+  template<typename U>
+  static size_t tag() {
+    return TagHelper<N, T, U, Next, IsSame<T, U>::value>::tag();
+  }
+
+  template<typename Variant>
+  static void copyConstruct(void* aLhs, const Variant& aRhs) {
+    if (aRhs.template is<T>()) {
+      new (aLhs) T(aRhs.template as<T>());
+    } else {
+      Next::copyConstruct(aLhs, aRhs);
+    }
+  }
+
+  template<typename Variant>
+  static void moveConstruct(void* aLhs, Variant&& aRhs) {
+    if (aRhs.template is<T>()) {
+      new (aLhs) T(aRhs.template extract<T>());
+    } else {
+      Next::moveConstruct(aLhs, aRhs);
+    }
+  }
+
+  template<typename Variant>
+  static void destroy(Variant& aV) {
+    if (aV.template is<T>()) {
+      aV.template as<T>().~T();
+    } else {
+      Next::destroy(aV);
+    }
+  }
+};
+
+} // namespace detail
+
+/**
+ * # mozilla::Variant
+ *
+ * A variant / tagged union / heterogenous disjoint union / sum-type template
+ * class. Similar in concept to (but not derived from) `boost::variant`.
+ *
+ * Sometimes, you may wish to use a C union with non-POD types. However, this is
+ * forbidden in C++ because it is not clear which type in the union should have
+ * its constructor and destructor run on creation and deletion
+ * respectively. This is the problem that `mozilla::Variant` solves.
+ *
+ * ## Usage
+ *
+ * A `mozilla::Variant` instance is constructed (via move or copy) from one of
+ * its variant types (ignoring const and references). It does *not* support
+ * construction from subclasses of variant types or types that coerce to one of
+ * the variant types.
+ *
+ *     Variant<char, uint32_t> v1('a');
+ *     Variant<UniquePtr<A>, B, C> v2(MakeUnique<A>());
+ *
+ * All access to the contained value goes through type-safe accessors.
+ *
+ *     void
+ *     Foo(Variant<A, B, C> v)
+ *     {
+ *       if (v.is<A>()) {
+ *         A& ref = v.as<A>();
+ *         ...
+ *       } else {
+ *         ...
+ *       }
+ *     }
+ *
+ * Attempting to use the contained value as type `T1` when the `Variant`
+ * instance contains a value of type `T2` causes an assertion failure.
+ *
+ *     A a;
+ *     Variant<A, B, C> v(a);
+ *     v.as<B>(); // <--- Assertion failure!
+ *
+ * Trying to use a `Variant<Ts...>` instance as some type `U` that is not a
+ * member of the set of `Ts...` is a compiler error.
+ *
+ *     A a;
+ *     Variant<A, B, C> v(a);
+ *     v.as<SomeRandomType>(); // <--- Compiler error!
+ *
+ * Additionally, you can turn a `Variant` that `is<T>` into a `T` by moving it
+ * out of the containing `Variant` instance with the `extract<T>` method:
+ *
+ *     Variant<UniquePtr<A>, B, C> v(MakeUnique<A>());
+ *     auto ptr = v.extract<UniquePtr<A>>();
+ *
+ * ## Examples
+ *
+ * A tree is either an empty leaf, or a node with a value and two children:
+ *
+ *     struct Leaf { };
+ *
+ *     template<typename T>
+ *     struct Node
+ *     {
+ *       T value;
+ *       Tree<T>* left;
+ *       Tree<T>* right;
+ *     };
+ *
+ *     template<typename T>
+ *     using Tree = Variant<Leaf, Node<T>>;
+ *
+ * A copy-on-write string is either a non-owning reference to some existing
+ * string, or an owning reference to our copy:
+ *
+ *     class CopyOnWriteString
+ *     {
+ *       Variant<const char*, UniquePtr<char[]>> string;
+ *
+ *       ...
+ *     };
+ */
+template<typename... Ts>
+class Variant
+{
+  using Impl = detail::VariantImplementation<0, Ts...>;
+  using RawData = AlignedStorage<detail::MaxSizeOf<Ts...>::size>;
+
+  // Each type is given a unique size_t sentinel. This tag lets us keep track of
+  // the contained variant value's type.
+  size_t tag;
+
+  // Raw storage for the contained variant value.
+  RawData raw;
+
+  void* ptr() {
+    return reinterpret_cast<void*>(&raw);
+  }
+
+public:
+  /** Perfect forwarding construction for some variant type T. */
+  template<typename RefT,
+           // RefT captures both const& as well as && (as intended, to support
+           // perfect forwarding), so we have to remove those qualifiers here
+           // when ensuring that T is a variant of this type, and getting T's
+           // tag, etc.
+           typename T = typename RemoveReference<typename RemoveConst<RefT>::Type>::Type,
+           typename = typename EnableIf<detail::IsVariant<T, Ts...>::value, void>::Type>
+  explicit Variant(RefT&& aT)
+    : tag(Impl::template tag<T>())
+  {
+    new (ptr()) T(Forward<T>(aT));
+  }
+
+  /** Copy construction. */
+  explicit Variant(const Variant& aRhs)
+    : tag(aRhs.tag)
+  {
+    Impl::copyConstruct(ptr(), aRhs);
+  }
+
+  /** Move construction. */
+  explicit Variant(Variant&& aRhs)
+    : tag(aRhs.tag)
+  {
+    Impl::moveConstruct(ptr(), Move(aRhs));
+  }
+
+  /** Copy assignment. */
+  Variant& operator=(const Variant& aRhs) {
+    MOZ_ASSERT(&aRhs != this, "self-assign disallowed");
+    this->~Variant();
+    new (this) Variant(aRhs);
+    return *this;
+  }
+
+  /** Move assignment. */
+  Variant& operator=(Variant&& aRhs) {
+    MOZ_ASSERT(&aRhs != this, "self-assign disallowed");
+    this->~Variant();
+    new (this) Variant(Move(aRhs));
+    return *this;
+  }
+
+  ~Variant()
+  {
+    Impl::destroy(*this);
+  }
+
+  /** Check which variant type is currently contained. */
+  template<typename T>
+  bool is() const {
+    static_assert(detail::IsVariant<T, Ts...>::value,
+                  "provided a type not found in this Variant's type list");
+    return Impl::template tag<T>() == tag;
+  }
+
+  // Accessors for working with the contained variant value.
+
+  /** Mutable reference. */
+  template<typename T>
+  T& as() {
+    static_assert(detail::IsVariant<T, Ts...>::value,
+                  "provided a type not found in this Variant's type list");
+    MOZ_ASSERT(is<T>());
+    return *reinterpret_cast<T*>(&raw);
+  }
+
+  /** Immutable const reference. */
+  template<typename T>
+  const T& as() const {
+    static_assert(detail::IsVariant<T, Ts...>::value,
+                  "provided a type not found in this Variant's type list");
+    MOZ_ASSERT(is<T>());
+    return *reinterpret_cast<const T*>(&raw);
+  }
+
+  /**
+   * Extract the contained variant value from this container into a temporary
+   * value.  On completion, the value in the variant will be in a
+   * safely-destructible state, as determined by the behavior of T's move
+   * constructor when provided the variant's internal value.
+   */
+  template<typename T>
+  T extract() {
+    static_assert(detail::IsVariant<T, Ts...>::value,
+                  "provided a type not found in this Variant's type list");
+    MOZ_ASSERT(is<T>());
+    return T(Move(as<T>()));
+  }
+};
+
+} // namespace mozilla
+
+#endif /* mozilla_Variant_h */
--- a/mfbt/moz.build
+++ b/mfbt/moz.build
@@ -81,16 +81,17 @@ EXPORTS.mozilla = [
     'TemplateLib.h',
     'ThreadLocal.h',
     'ToString.h',
     'Tuple.h',
     'TypedEnumBits.h',
     'Types.h',
     'TypeTraits.h',
     'UniquePtr.h',
+    'Variant.h',
     'Vector.h',
     'WeakPtr.h',
     'unused.h',
 ]
 
 if CONFIG['OS_ARCH'] == 'WINNT':
     EXPORTS.mozilla += [
         'WindowsVersion.h',
new file mode 100644
--- /dev/null
+++ b/mfbt/tests/TestVariant.cpp
@@ -0,0 +1,106 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=8 sts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "mozilla/UniquePtr.h"
+#include "mozilla/Variant.h"
+
+using mozilla::MakeUnique;
+using mozilla::UniquePtr;
+using mozilla::Variant;
+
+struct Destroyer {
+  static int destroyedCount;
+  ~Destroyer() {
+    destroyedCount++;
+  }
+};
+
+int Destroyer::destroyedCount = 0;
+
+static void
+testSimple()
+{
+  printf("testSimple\n");
+  Variant<uint32_t, uint64_t> v(uint64_t(1));
+  MOZ_RELEASE_ASSERT(v.is<uint64_t>());
+  MOZ_RELEASE_ASSERT(!v.is<uint32_t>());
+  MOZ_RELEASE_ASSERT(v.as<uint64_t>() == 1);
+}
+
+static void
+testCopy()
+{
+  printf("testCopy\n");
+  Variant<uint32_t, uint64_t> v1(uint64_t(1));
+  Variant<uint32_t, uint64_t> v2(v1);
+  MOZ_RELEASE_ASSERT(v2.is<uint64_t>());
+  MOZ_RELEASE_ASSERT(!v2.is<uint32_t>());
+  MOZ_RELEASE_ASSERT(v2.as<uint64_t>() == 1);
+
+  Variant<uint32_t, uint64_t> v3(uint32_t(10));
+  v3 = v2;
+  MOZ_RELEASE_ASSERT(v3.is<uint64_t>());
+  MOZ_RELEASE_ASSERT(v3.as<uint64_t>() == 1);
+}
+
+static void
+testMove()
+{
+  printf("testMove\n");
+  Variant<UniquePtr<int>, char> v1(MakeUnique<int>(5));
+  Variant<UniquePtr<int>, char> v2(Move(v1));
+
+  MOZ_RELEASE_ASSERT(v2.is<UniquePtr<int>>());
+  MOZ_RELEASE_ASSERT(*v2.as<UniquePtr<int>>() == 5);
+
+  MOZ_RELEASE_ASSERT(v1.is<UniquePtr<int>>());
+  MOZ_RELEASE_ASSERT(v1.as<UniquePtr<int>>() == nullptr);
+
+  Destroyer::destroyedCount = 0;
+  {
+    Variant<char, UniquePtr<Destroyer>> v3(MakeUnique<Destroyer>());
+    Variant<char, UniquePtr<Destroyer>> v4(Move(v3));
+
+    Variant<char, UniquePtr<Destroyer>> v5('a');
+    v5 = Move(v4);
+
+    auto ptr = v5.extract<UniquePtr<Destroyer>>();
+    MOZ_RELEASE_ASSERT(Destroyer::destroyedCount == 0);
+  }
+  MOZ_RELEASE_ASSERT(Destroyer::destroyedCount == 1);
+}
+
+static void
+testDestructor()
+{
+  printf("testDestructor\n");
+  Destroyer::destroyedCount = 0;
+
+  {
+    Destroyer d;
+
+    {
+      Variant<char, UniquePtr<char[]>, Destroyer> v(d);
+      MOZ_RELEASE_ASSERT(Destroyer::destroyedCount == 0); // None detroyed yet.
+    }
+
+    MOZ_RELEASE_ASSERT(Destroyer::destroyedCount == 1); // v's copy of d is destroyed.
+  }
+
+  MOZ_RELEASE_ASSERT(Destroyer::destroyedCount == 2); // d is destroyed.
+}
+
+int
+main()
+{
+  testSimple();
+  testCopy();
+  testMove();
+  testDestructor();
+
+  printf("TestVariant OK!\n");
+  return 0;
+}
--- a/mfbt/tests/moz.build
+++ b/mfbt/tests/moz.build
@@ -28,16 +28,17 @@ CppUnitTests([
     'TestSegmentedVector',
     'TestSHA1',
     'TestSplayTree',
     'TestTemplateLib',
     'TestTuple',
     'TestTypedEnum',
     'TestTypeTraits',
     'TestUniquePtr',
+    'TestVariant',
     'TestVector',
     'TestWeakPtr',
 ])
 
 if not CONFIG['MOZ_ASAN']:
     CppUnitTests([
         'TestPoisonArea',
     ])