Bug 1342974 - Add code to generate a Merkle tree summary of a release draft
authorRichard Barnes <rbarnes@mozilla.com>
Mon, 27 Feb 2017 12:56:42 -0500
changeset 490487 340873225cecd34f69c2bcbb8230d1422f95b535
parent 490032 106a96755d3bcebe64bbbc3b521d65d262ba9c02
child 547274 49462bd4412c8175c025f1828b36215c2e1e465a
push id47115
push userrlb@ipv.sx
push dateTue, 28 Feb 2017 16:38:05 +0000
bugs1342974
milestone54.0a1
Bug 1342974 - Add code to generate a Merkle tree summary of a release MozReview-Commit-ID: COav6N88SOt
testing/mozharness/mozharness/mozilla/merkle.py
testing/mozharness/scripts/release/generate-checksums.py
new file mode 100644
--- /dev/null
+++ b/testing/mozharness/mozharness/mozilla/merkle.py
@@ -0,0 +1,275 @@
+#!/usr/bin/env python
+
+import struct
+
+def _round2(n):
+    k = 1
+    while k < n:
+        k <<= 1
+    return k >> 1
+
+def _leaf_hash(hash_fn, leaf):
+    return hash_fn(b'\x00' + leaf).digest()
+
+def _pair_hash(hash_fn, left, right):
+    return hash_fn(b'\x01' + left + right).digest()
+
+
+class InclusionProof:
+    """
+    Represents a Merkle inclusion proof for purposes of serialization,
+    deserialization, and verification of the proof.  The format for inclusion
+    proofs in RFC 6962-bis is as follows:
+
+        opaque LogID<2..127>;
+        opaque NodeHash<32..2^8-1>;
+
+        struct {
+            LogID log_id;
+            uint64 tree_size;
+            uint64 leaf_index;
+            NodeHash inclusion_path<1..2^16-1>;
+        } InclusionProofDataV2;
+
+    In other words:
+      - 1 + N octets of log_id (currently zero)
+      - 8 octets of tree_size = self.n
+      - 8 octets of leaf_index = m
+      - 2 octets of path length, followed by
+      * 1 + N octets of NodeHash
+    """
+
+    # Pre-generated 'log ID'.  Not used by Firefox; it is only needed becuase
+    # there's a slot in the RFC 6962-bis format that requires a value at least
+    # two bytes long (plus a length byte).
+    LOG_ID = b'\x02\x00\x00'
+
+    def __init__(self, tree_size, leaf_index, path_elements):
+        self.tree_size = tree_size
+        self.leaf_index = leaf_index
+        self.path_elements = path_elements
+
+    @staticmethod
+    def from_rfc6962_bis(serialized):
+        start = 0
+        read = 1
+        if len(serialized) < start + read:
+            raise Exception('Inclusion proof too short for log ID header')
+        log_id_len, = struct.unpack('B', serialized[start:start+read])
+        start += read
+        start += log_id_len # Ignore the log ID itself
+
+        read = 8 + 8 + 2
+        if len(serialized) < start + read:
+            raise Exception('Inclusion proof too short for middle section')
+        tree_size, leaf_index, path_len = struct.unpack('!QQH', serialized[start:start+read])
+        start += read
+
+        path_elements = []
+        end = 1 + log_id_len + 8 + 8 + 2 + path_len
+        while start < end:
+            read = 1
+            if len(serialized) < start + read:
+                raise Exception('Inclusion proof too short for middle section')
+            elem_len, = struct.unpack('!B', serialized[start:start+read])
+            start += read
+
+            read = elem_len
+            if len(serialized) < start + read:
+                raise Exception('Inclusion proof too short for middle section')
+            if end < start + read:
+                raise Exception('Inclusion proof element exceeds declared length')
+            path_elements.append(serialized[start:start+read])
+            start += read
+
+        return InclusionProof(tree_size, leaf_index, path_elements)
+
+    def to_rfc6962_bis(self):
+        inclusion_path = b''
+        for step in self.path_elements:
+            step_len = struct.pack('B', len(step))
+            inclusion_path += step_len + step
+
+        middle = struct.pack('!QQH', self.tree_size, self.leaf_index, len(inclusion_path))
+        return self.LOG_ID + middle + inclusion_path
+
+    def _expected_head(self, hash_fn, leaf, leaf_index, tree_size):
+        node = _leaf_hash(hash_fn, leaf)
+
+        # Compute indicators of which direction the pair hashes should be done.
+        # Derived from the PATH logic in draft-ietf-trans-rfc6962-bis
+        lr = []
+        while tree_size > 1:
+            k = _round2(tree_size)
+            left = leaf_index < k
+            lr = [left] + lr
+
+            if left:
+                tree_size = k
+            else:
+                tree_size = tree_size - k
+                leaf_index = leaf_index - k
+
+        assert(len(lr) == len(self.path_elements))
+        for i in range(len(self.path_elements)):
+            if lr[i]:
+                node = _pair_hash(hash_fn, node, self.path_elements[i])
+            else:
+                node = _pair_hash(hash_fn, self.path_elements[i], node)
+
+        return node
+
+
+    def verify(self, hash_fn, leaf, leaf_index, tree_size, tree_head):
+        return self._expected_head(hash_fn, leaf, leaf_index, tree_size) == tree_head
+
+
+class MerkleTree:
+    """
+    Implements a Merkle tree on a set of data items following the
+    structure defined in RFC 6962-bis.  This allows us to create a
+    single hash value that summarizes the data (the 'head'), and an
+    'inclusion proof' for each element that connects it to the head.
+
+    https://tools.ietf.org/html/draft-ietf-trans-rfc6962-bis-24
+    """
+
+    def __init__(self, hash_fn, data):
+        self.n = len(data)
+        self.hash_fn = hash_fn
+
+        # We cache intermediate node values, as a dictionary of dictionaries,
+        # where the node representing data elements data[m:n] is represented by
+        # nodes[m][n]. This corresponds to the 'D[m:n]' notation in RFC
+        # 6962-bis.  In particular, the leaves are stored in nodes[i][i+1] and
+        # the head is nodes[0][n].
+        self.nodes = {}
+        for i in range(self.n):
+            self.nodes[i] = {}
+            self.nodes[i][i+1] = _leaf_hash(self.hash_fn, data[i])
+
+    def _node(self, start, end):
+        assert(start in self.nodes)
+        if end in self.nodes[start]:
+            return self.nodes[start][end]
+
+        k = _round2(end - start)
+        left = self._node(start, start + k)
+        right = self._node(start + k, end)
+        node = _pair_hash(self.hash_fn, left, right)
+
+        self.nodes[start][end] = node
+        return node
+
+    def head(self):
+        return self._node(0, self.n)
+
+    def _relative_proof(self, target, start, end):
+        n = end - start
+        k = _round2(n)
+
+        if n == 1:
+            return []
+        elif target - start < k:
+            return self._relative_proof(target, start, start + k) + [self._node(start + k, end)]
+        elif target - start >= k:
+            return self._relative_proof(target, start + k, end) + [self._node(start, start + k)]
+
+    def inclusion_proof(self, leaf_index):
+        path_elements = self._relative_proof(leaf_index, 0, self.n)
+        return InclusionProof(self.n, leaf_index, path_elements)
+
+if __name__ == '__main__':
+    import hashlib
+    import random
+
+    # Pre-computed tree on 7 inputs
+    #
+    #         ______F_____
+    #        /            \
+    #     __D__           _E_
+    #    /     \         /   \
+    #   A       B       C     |
+    #  / \     / \     / \    |
+    # 0   1   2   3   4   5   6
+
+    data = [
+        'fbc459361fc111024c6d1fd83d23a9ff'.decode('hex'),
+        'ae3a44925afec860451cd8658b3cadde'.decode('hex'),
+        '418903fe6ef29fc8cab93d778a7b018b'.decode('hex'),
+        '3d1c53c00b2e137af8c4c23a06388c6b'.decode('hex'),
+        'e656ebd8e2758bc72599e5896be357be'.decode('hex'),
+        '81aae91cf90be172eedd1c75c349bf9e'.decode('hex'),
+        '00c262edf8b0bc345aca769e8733e25e'.decode('hex'),
+    ]
+
+    hash_fn = hashlib.sha256
+    leaves = [_leaf_hash(hash_fn, leaf) for leaf in data]
+
+    nodeA = '06447a7baa079cb0b4b6119d0f575bec508915403fdc30923eba982b63759805'.decode('hex')
+    nodeB = '3db98027c655ead4fe897bef3a4b361839a337941a9e624b475580c9d4e882ee'.decode('hex')
+    nodeC = '17524f8b0169b2745c67846925d55449ae80a8022ef8189dcf4cbb0ec7fcc470'.decode('hex')
+    nodeD = '380d0dc6fd7d4f37859a12dbfc7171b3cce29ab0688c6cffd2b15f3e0b21af49'.decode('hex')
+    nodeE = '3a9c2886a5179a6e1948876034f99d52a8f393f47a09887adee6d1b4a5c5fbd6'.decode('hex')
+    nodeF = 'd1a0d3947db4ae8305f2ac32985957e02659b2ea3c10da52a48d2526e9af3bbc'.decode('hex')
+
+    proofs = [
+        [leaves[1], nodeB, nodeE],
+        [leaves[0], nodeB, nodeE],
+        [leaves[3], nodeA, nodeE],
+        [leaves[2], nodeA, nodeE],
+        [leaves[5], leaves[6], nodeD],
+        [leaves[4], leaves[6], nodeD],
+        [nodeC, nodeD],
+    ]
+
+    known_proof5 = '020000' + \
+                   '0000000000000007' + '0000000000000005' + \
+                   '0063' + \
+                   '20' + leaves[4].encode('hex') + \
+                   '20' + leaves[6].encode('hex') + \
+                   '20' + nodeD.encode('hex')
+    known_proof5 = known_proof5.decode('hex')
+
+    tree = MerkleTree(hash_fn, data)
+    head = tree.head()
+    assert(head == nodeF)
+
+    for i in range(len(data)):
+        proof = tree.inclusion_proof(i)
+
+        assert(proof.verify(hash_fn, data[i], i, len(data), head))
+        assert(proof.leaf_index == i)
+        assert(proof.tree_size == tree.n)
+        assert(proof.path_elements == proofs[i])
+
+    # Inclusion proof encode/decode round trip test
+    proof5 = tree.inclusion_proof(5)
+    serialized5 = proof5.to_rfc6962_bis()
+    deserialized5 = InclusionProof.from_rfc6962_bis(serialized5)
+    reserialized5 = deserialized5.to_rfc6962_bis()
+    assert(serialized5 == reserialized5)
+
+    # Inclusion proof encode known answer test
+    serialized5 = proof5.to_rfc6962_bis()
+    assert(serialized5 == known_proof5)
+
+    # Inclusion proof decode known answer test
+    known_deserialized5 = InclusionProof.from_rfc6962_bis(known_proof5)
+    assert(proof5.leaf_index == known_deserialized5.leaf_index)
+    assert(proof5.tree_size == known_deserialized5.tree_size)
+    assert(proof5.path_elements == known_deserialized5.path_elements)
+
+    # Create a Merkle tree on a larger set of random values
+    TEST_SIZE = 5000
+    ELEM_SIZE_BYTES = 16
+    data = [bytearray(random.getrandbits(8) for _ in xrange(ELEM_SIZE_BYTES)) for _ in xrange(TEST_SIZE)]
+    tree = MerkleTree(hash_fn, data)
+    head = tree.head()
+
+    for i in range(len(data)):
+        proof = tree.inclusion_proof(i)
+
+        assert(proof.verify(hash_fn, data[i], i, len(data), head))
+        assert(proof.leaf_index == i)
+        assert(proof.tree_size == tree.n)
--- a/testing/mozharness/scripts/release/generate-checksums.py
+++ b/testing/mozharness/scripts/release/generate-checksums.py
@@ -1,23 +1,25 @@
 from multiprocessing.pool import ThreadPool
 import os
 from os import path
 import re
 import sys
 import posixpath
+import hashlib
 
 sys.path.insert(1, os.path.dirname(os.path.dirname(sys.path[0])))
 
 from mozharness.base.python import VirtualenvMixin, virtualenv_config_options
 from mozharness.base.script import BaseScript
 from mozharness.base.vcs.vcsbase import VCSMixin
 from mozharness.mozilla.checksums import parse_checksums_file
 from mozharness.mozilla.signing import SigningMixin
 from mozharness.mozilla.buildbot import BuildbotMixin
+from mozharness.mozilla.merkle import MerkleTree
 
 class ChecksumsGenerator(BaseScript, VirtualenvMixin, SigningMixin, VCSMixin, BuildbotMixin):
     config_options = [
         [["--stage-product"], {
             "dest": "stage_product",
             "help": "Name of product used in file server's directory structure, eg: firefox, mobile",
         }],
         [["--version"], {
@@ -147,16 +149,25 @@ class ChecksumsGenerator(BaseScript, Vir
     def _get_file_prefix(self):
         return "pub/{}/candidates/{}-candidates/build{}/".format(
             self.config["stage_product"], self.config["version"], self.config["build_number"]
         )
 
     def _get_sums_filename(self, format_):
         return "{}SUMS".format(format_.upper())
 
+    def _get_summary_filename(self, format_):
+        return "{}SUMMARY".format(format_.upper())
+
+    def _get_hash_function(self, format_):
+        if format_ in ("sha256", "sha384", "sha512"):
+            return getattr(hashlib, format_)
+        else:
+            self.fatal("Unsupported format {}".format(format_))
+
     def _get_bucket(self):
         if not self.bucket:
             self.activate_virtualenv()
             from boto.s3.connection import S3Connection
 
             self.info("Connecting to S3")
             conn = S3Connection()
             self.debug("Successfully connected to S3")
@@ -211,16 +222,40 @@ class ChecksumsGenerator(BaseScript, Vir
                         if not set(self.config["formats"]) <= set(info["hashes"]):
                             self.fatal("Missing necessary format for file {}".format(f))
                         self.debug("Adding checksums for file: {}".format(f))
                         self.checksums[f] = info
                         break
                 else:
                     self.debug("Ignoring checksums for file: {}".format(f))
 
+    def compute_summary_info(self):
+        """
+        This step computes a Merkle tree over the checksums for each format
+        and writes a file containing the head of the tree and inclusion proofs
+        for each file.
+        """
+        for fmt in self.config["formats"]:
+            hash_fn = self._get_hash_function(fmt)
+            files = [fn for fn in sorted(self.checksums)]
+            data = [self.checksums[fn]["hashes"][fmt] for fn in files]
+
+            tree = MerkleTree(hash_fn, data)
+            head = tree.head().encode("hex")
+            proofs = [tree.inclusion_proof(i).to_rfc6962_bis().encode("hex") for i in range(len(files))]
+
+            summary = self._get_summary_filename(fmt)
+            self.info("Creating summary file: {}".format(summary))
+
+            content = "{} TREE_HEAD\n".format(head)
+            for i in range(len(files)):
+                content += "{} {}\n".format(proofs[i], files[i])
+
+            self.write_to_file(summary, content)
+
     def create_big_checksums(self):
         for fmt in self.config["formats"]:
             sums = self._get_sums_filename(fmt)
             self.info("Creating big checksums file: {}".format(sums))
             with open(sums, "w+") as output_file:
                 for fn in sorted(self.checksums):
                     output_file.write("{}  {}\n".format(self.checksums[fn]["hashes"][fmt], fn))