Bug 1262671 - IPC sentinel checking (r=froydnj)
authorBill McCloskey <billm@mozilla.com>
Wed, 27 Apr 2016 11:13:53 -0700
changeset 340378 4dc70010565c32f332c7f37eb34706571d690e18
parent 340377 63f6395614e8085c33d552e8c56e312df5c763a3
child 340379 c3b21c100d396a74f2ef871a0e05fc3120052ab9
push id1183
push userraliiev@mozilla.com
push dateMon, 05 Sep 2016 20:01:49 +0000
treeherdermozilla-release@3148731bed45 [default view] [failures only]
perfherder[talos] [build metrics] [platform microbench] (compared to previous push)
reviewersfroydnj
bugs1262671
milestone49.0a1
first release with
nightly linux32
nightly linux64
nightly mac
nightly win32
nightly win64
last release without
nightly linux32
nightly linux64
nightly mac
nightly win32
nightly win64
Bug 1262671 - IPC sentinel checking (r=froydnj)
ipc/chromium/src/base/pickle.cc
ipc/chromium/src/base/pickle.h
ipc/ipdl/ipdl/lower.py
--- a/ipc/chromium/src/base/pickle.cc
+++ b/ipc/chromium/src/base/pickle.cc
@@ -13,16 +13,20 @@
 #include <stdlib.h>
 
 #include <limits>
 #include <string>
 #include <algorithm>
 
 #include "nsDebug.h"
 
+#if !defined(RELEASE_BUILD) || defined(DEBUG)
+#define SENTINEL_CHECKING
+#endif
+
 //------------------------------------------------------------------------------
 
 static_assert(MOZ_ALIGNOF(Pickle::memberAlignmentType) >= MOZ_ALIGNOF(uint32_t),
               "Insufficient alignment");
 
 // static
 const int Pickle::kPayloadUnit = 64;
 
@@ -437,16 +441,36 @@ bool Pickle::ReadData(PickleIterator* it
   DCHECK(length);
 
   if (!ReadLength(iter, length))
     return false;
 
   return ReadBytes(iter, data, *length);
 }
 
+bool Pickle::ReadSentinel(PickleIterator* iter, uint32_t sentinel) const {
+#ifdef SENTINEL_CHECKING
+  uint32_t found;
+  if (!ReadUInt32(iter, &found)) {
+    return false;
+  }
+  return found == sentinel;
+#else
+  return true;
+#endif
+}
+
+bool Pickle::WriteSentinel(uint32_t sentinel) {
+#ifdef SENTINEL_CHECKING
+  return WriteUInt32(sentinel);
+#else
+  return true;
+#endif
+}
+
 char* Pickle::BeginWrite(uint32_t length, uint32_t alignment) {
   DCHECK(alignment % 4 == 0) << "Must be at least 32-bit aligned!";
 
   // write at an alignment-aligned offset from the beginning of the header
   uint32_t offset = AlignInt(header_->payload_size);
   uint32_t padding = (header_size_ + offset) %  alignment;
   uint32_t new_size = offset + padding + AlignInt(length);
   uint32_t needed_size = header_size_ + new_size;
--- a/ipc/chromium/src/base/pickle.h
+++ b/ipc/chromium/src/base/pickle.h
@@ -117,16 +117,18 @@ class Pickle {
   MOZ_MUST_USE bool ReadData(PickleIterator* iter, const char** data, int* length) const;
   MOZ_MUST_USE bool ReadBytes(PickleIterator* iter, const char** data, int length,
                               uint32_t alignment = sizeof(memberAlignmentType)) const;
 
   // Safer version of ReadInt() checks for the result not being negative.
   // Use it for reading the object sizes.
   MOZ_MUST_USE bool ReadLength(PickleIterator* iter, int* result) const;
 
+  MOZ_WARN_UNUSED_RESULT bool ReadSentinel(PickleIterator* iter, uint32_t sentinel) const;
+
   void EndRead(PickleIterator& iter) const {
     DCHECK(iter.iter_ == end_of_payload());
   }
 
   // Methods for adding to the payload of the Pickle.  These values are
   // appended to the end of the Pickle's payload.  When reading values from a
   // Pickle, it is important to read them in the order in which they were added
   // to the Pickle.
@@ -181,16 +183,18 @@ class Pickle {
     return WriteBytes(&value, sizeof(value));
   }
   bool WriteString(const std::string& value);
   bool WriteWString(const std::wstring& value);
   bool WriteData(const char* data, int length);
   bool WriteBytes(const void* data, int data_len,
                   uint32_t alignment = sizeof(memberAlignmentType));
 
+  bool WriteSentinel(uint32_t sentinel);
+
   // Same as WriteData, but allows the caller to write directly into the
   // Pickle. This saves a copy in cases where the data is not already
   // available in a buffer. The caller should take care to not write more
   // than the length it declares it will. Use ReadData to get the data.
   // Returns NULL on failure.
   //
   // The returned pointer will only be valid until the next write operation
   // on this Pickle.
--- a/ipc/ipdl/ipdl/lower.py
+++ b/ipc/ipdl/ipdl/lower.py
@@ -46,16 +46,21 @@ lowered form of |tu|'''
 
         return headers, cpps
 
 
 ##-----------------------------------------------------------------------------
 ## Helper code
 ##
 
+def hashfunc(value):
+    h = hash(value) % 2**32
+    if h < 0: h += 2**32
+    return h
+
 _NULL_ACTOR_ID = ExprLiteral.ZERO
 _FREED_ACTOR_ID = ExprLiteral.ONE
 
 _DISCLAIMER = Whitespace('''//
 // Automatically generated by ipdlc.
 // Edit at your own risk
 //
 
@@ -691,16 +696,17 @@ class StructDecl(ipdl.ast.StructDecl, Ha
     @staticmethod
     def upgrade(structDecl):
         assert isinstance(structDecl, ipdl.ast.StructDecl)
         structDecl.__class__ = StructDecl
         return structDecl
 
 class _StructField(_CompoundTypeComponent):
     def __init__(self, ipdltype, name, sd, side=None):
+        self.basename = name
         fname = name
         special = _hasVisibleActor(ipdltype)
         if special:
             fname += side.title()
 
         _CompoundTypeComponent.__init__(self, ipdltype, fname, side, sd)
 
     def getMethod(self, thisexpr=None, sel='.'):
@@ -4398,42 +4404,45 @@ class _GenerateProtocolActorCode(ipdl.as
         intype = _cxxConstRefType(arraytype, self.side)
         outtype = _cxxPtrToType(arraytype, self.side)
 
         write = MethodDefn(self.writeMethodDecl(intype, var))
         forwrite = StmtFor(init=ExprAssn(Decl(Type.UINT32, ivar.name),
                                          ExprLiteral.ZERO),
                            cond=ExprBinary(ivar, '<', lenvar),
                            update=ExprPrefixUnop(ivar, '++'))
-        forwrite.addstmt(StmtExpr(
-            self.write(eltipdltype, ExprIndex(var, ivar), msgvar)))
+        forwrite.addstmt(
+            self.checkedWrite(eltipdltype, ExprIndex(var, ivar), msgvar,
+                              sentinelKey=arraytype.name()))
         write.addstmts([
             StmtDecl(Decl(Type.UINT32, lenvar.name),
                      init=_callCxxArrayLength(var)),
-            StmtExpr(self.write(None, lenvar, msgvar)),
+            self.checkedWrite(None, lenvar, msgvar, sentinelKey=('length', arraytype.name())),
             Whitespace.NL,
             forwrite
         ])
 
         read = MethodDefn(self.readMethodDecl(outtype, var))
         favar = ExprVar('fa')
         forread = StmtFor(init=ExprAssn(Decl(Type.UINT32, ivar.name),
                                         ExprLiteral.ZERO),
                           cond=ExprBinary(ivar, '<', lenvar),
                           update=ExprPrefixUnop(ivar, '++'))
         forread.addstmt(
             self.checkedRead(eltipdltype, ExprAddrOf(ExprIndex(favar, ivar)),
                              msgvar, itervar, errfnRead,
-                             '\'' + eltipdltype.name() + '[i]\''))
+                             '\'' + eltipdltype.name() + '[i]\'',
+                             sentinelKey=arraytype.name()))
         read.addstmts([
             StmtDecl(Decl(_cxxArrayType(_cxxBareType(arraytype.basetype, self.side)), favar.name)),
             StmtDecl(Decl(Type.UINT32, lenvar.name)),
             self.checkedRead(None, ExprAddrOf(lenvar),
                              msgvar, itervar, errfnArrayLength,
-                             [ arraytype.name() ]),
+                             [ arraytype.name() ],
+                             sentinelKey=('length', arraytype.name())),
             Whitespace.NL,
             StmtExpr(_callCxxArraySetLength(favar, lenvar)),
             forread,
             StmtExpr(_callCxxSwapArrayElements(var, favar, '->')),
             StmtReturn.TRUE
         ])
 
         self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ])
@@ -4555,20 +4564,20 @@ class _GenerateProtocolActorCode(ipdl.as
         read = MethodDefn(self.readMethodDecl(outtype, var))
 
         def get(sel, f):
             return ExprCall(f.getMethod(thisexpr=var, sel=sel))
 
         for f in sd.fields:
             desc = '\'' + f.getMethod().name + '\' (' + f.ipdltype.name() + \
                    ') member of \'' + intype.name + '\''
-            writefield = StmtExpr(self.write(f.ipdltype, get('.', f), msgvar))
+            writefield = self.checkedWrite(f.ipdltype, get('.', f), msgvar, sentinelKey=f.basename)
             readfield = self.checkedRead(f.ipdltype,
                                          ExprAddrOf(get('->', f)),
-                                         msgvar, itervar, errfnRead, desc)
+                                         msgvar, itervar, errfnRead, desc, sentinelKey=f.basename)
             if f.special and f.side != self.side:
                 writefield = Whitespace(
                     "// skipping actor field that's meaningless on this side\n", indent=1)
                 readfield = Whitespace(
                     "// skipping actor field that's meaningless on this side\n", indent=1)
             write.addstmt(writefield)
             read.addstmt(readfield)
 
@@ -4591,23 +4600,24 @@ class _GenerateProtocolActorCode(ipdl.as
         typevar = ExprVar('type')
         writeswitch = StmtSwitch(ud.callType(var))
         readswitch = StmtSwitch(typevar)
 
         for c in ud.components:
             ct = c.ipdltype
             isactor = (ct.isIPDL() and ct.isActor())
             caselabel = CaseLabel(typename +'::'+ c.enum())
+            origenum = c.enum()
 
             writecase = StmtBlock()
             if c.special and c.side != self.side:
                 writecase.addstmt(_fatalError('wrong side!'))
             else:
                 wexpr = ExprCall(ExprSelect(var, '.', c.getTypeName()))
-                writecase.addstmt(StmtExpr(self.write(ct, wexpr, msgvar)))
+                writecase.addstmt(self.checkedWrite(ct, wexpr, msgvar, sentinelKey=c.enum()))
 
             writecase.addstmt(StmtReturn())
             writeswitch.addcase(caselabel, writecase)
 
             readcase = StmtBlock()
             if c.special and c.side == self.side:
                 # the type comes across flipped from what the actor
                 # will be on this side; i.e. child->parent messages
@@ -4617,47 +4627,50 @@ class _GenerateProtocolActorCode(ipdl.as
             else:
                 if c.special:
                     c = c.other       # see above
                 tmpvar = ExprVar('tmp')
                 ct = c.bareType()
                 readcase.addstmts([
                     StmtDecl(Decl(ct, tmpvar.name), init=c.defaultValue()),
                     StmtExpr(ExprAssn(ExprDeref(var), tmpvar)),
-                    StmtReturn(self.read(
+                    self.checkedRead(
                         c.ipdltype,
                         ExprAddrOf(ExprCall(ExprSelect(var, '->',
                                                        c.getTypeName()))),
-                        msgvar, itervar))
+                        msgvar, itervar, errfnRead, 'Union type', sentinelKey=origenum),
+                    StmtReturn(ExprLiteral.TRUE)
                 ])
 
             readswitch.addcase(caselabel, readcase)
 
         unknowntype = 'unknown union type'
         writeswitch.addcase(DefaultLabel(),
                             StmtBlock([ _fatalError(unknowntype),
                                         StmtReturn() ]))
         readswitch.addcase(DefaultLabel(), StmtBlock(errfnRead(unknowntype)))
 
         write = MethodDefn(self.writeMethodDecl(intype, var))
         write.addstmts([
             uniontdef,
-            StmtExpr(self.write(
-                None, ExprCall(Type.INT, args=[ ud.callType(var) ]), msgvar)),
+            self.checkedWrite(
+                None, ExprCall(Type.INT, args=[ ud.callType(var) ]), msgvar,
+                sentinelKey=uniontype.name()),
             Whitespace.NL,
             writeswitch
         ])
 
         read = MethodDefn(self.readMethodDecl(outtype, var))
         read.addstmts([
             uniontdef,
             StmtDecl(Decl(Type.INT, typevar.name)),
             self.checkedRead(
                 None, ExprAddrOf(typevar), msgvar, itervar, errfnUnionType,
-                [ uniontype.name() ]),
+                [ uniontype.name() ],
+                sentinelKey=uniontype.name()),
             Whitespace.NL,
             readswitch,
         ])
 
         self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ])
 
 
     def writeMethodDecl(self, intype, var, template=None):
@@ -4693,16 +4706,30 @@ class _GenerateProtocolActorCode(ipdl.as
                                            ExprCall(write, args=[ expr, to ]))
 
     def read(self, ipdltype, expr, from_, iterexpr, this=None):
         read = ExprVar('Read')
         if this:  read = ExprSelect(this, '->', read.name)
         return self.maybeAddNullabilityArg(
             ipdltype, ExprCall(read, args=[ expr, from_, iterexpr ]))
 
+    def checkedWrite(self, ipdltype, expr, msgvar, sentinelKey, this=None):
+        assert sentinelKey
+
+        write = StmtExpr(self.write(ipdltype, expr, msgvar, this))
+
+        sentinel = StmtExpr(ExprCall(ExprSelect(msgvar, '->', 'WriteSentinel'),
+                                     args=[ ExprLiteral.Int(hashfunc(sentinelKey)) ]))
+        block = Block()
+        block.addstmts([
+            write,
+            Whitespace('// Sentinel = ' + repr(sentinelKey) + '\n', indent=1),
+            sentinel ])
+        return block
+
 
     def visitMessageDecl(self, md):
         isctor = md.decl.type.isCtor()
         isdtor = md.decl.type.isDtor()
         decltype = md.decl.type
         sendmethod = None
         helpermethod = None
         recvlbl, recvcase = None, None
@@ -5070,17 +5097,17 @@ class _GenerateProtocolActorCode(ipdl.as
         routingId = self.protocol.routingId(fromActor)
         this = None
         if md.decl.type.isDtor():  this = md.actorDecl().var()
 
         stmts = ([ StmtDecl(Decl(Type('IPC::Message', ptr=1), msgvar.name),
                             init=ExprCall(ExprVar(md.pqMsgCtorFunc()),
                                           args=[ routingId ])) ]
                  + [ Whitespace.NL ]
-                 + [ StmtExpr(self.write(p.ipdltype, p.var(), msgvar, this))
+                 + [ self.checkedWrite(p.ipdltype, p.var(), msgvar, sentinelKey=p.name, this=this)
                      for p in md.params ]
                  + [ Whitespace.NL ]
                  + self.setMessageFlags(md, msgvar, reply=0))
         return msgvar, stmts
 
 
     def makeReply(self, md, errfn, routingId):
         if routingId is None:
@@ -5089,17 +5116,17 @@ class _GenerateProtocolActorCode(ipdl.as
         if not md.decl.type.hasReply():
             return [ ]
 
         replyvar = self.replyvar
         return (
             [ StmtExpr(ExprAssn(
                 replyvar, ExprCall(ExprVar(md.pqReplyCtorFunc()), args=[ routingId ]))),
               Whitespace.NL ]
-            + [ StmtExpr(self.write(r.ipdltype, r.var(), replyvar))
+            + [ self.checkedWrite(r.ipdltype, r.var(), replyvar, sentinelKey=r.name)
                 for r in md.returns ]
             + self.setMessageFlags(md, replyvar, reply=1)
             + [ self.logMessage(md, replyvar, 'Sending reply ') ])
 
 
     def setMessageFlags(self, md, var, reply):
         stmts = [ ]
 
@@ -5145,29 +5172,31 @@ class _GenerateProtocolActorCode(ipdl.as
         if isctor:
             # return the raw actor handle so that its ID can be used
             # to construct the "real" actor
             handlevar = self.handlevar
             handletype = Type('ActorHandle')
             decls = [ StmtDecl(Decl(handletype, handlevar.name)) ]
             reads = [ self.checkedRead(None, ExprAddrOf(handlevar), msgexpr,
                                        ExprAddrOf(self.itervar),
-                                       errfn, "'%s'" % handletype.name) ]
+                                       errfn, "'%s'" % handletype.name,
+                                       sentinelKey='actor') ]
             start = 1
 
         stmts.extend((
             [ StmtDecl(Decl(_iterType(ptr=0), self.itervar.name),
                      init=ExprCall(ExprVar('PickleIterator'),
                                    args=[ msgvar ])) ]
             + decls + [ StmtDecl(Decl(p.bareType(side), p.var().name))
                       for p in md.params ]
             + [ Whitespace.NL ]
             + reads + [ self.checkedRead(p.ipdltype, ExprAddrOf(p.var()),
                                          msgexpr, ExprAddrOf(itervar),
-                                         errfn, "'%s'" % p.bareType(side).name)
+                                         errfn, "'%s'" % p.bareType(side).name,
+                                         sentinelKey=p.name)
                         for p in md.params[start:] ]
             + [ self.endRead(msgvar, itervar) ]))
 
         return stmts
 
 
     def deserializeReply(self, md, replyexpr, side, errfn, actor=None):
         stmts = [ Whitespace.NL,
@@ -5180,17 +5209,18 @@ class _GenerateProtocolActorCode(ipdl.as
         stmts.extend(
             [ Whitespace.NL,
               StmtDecl(Decl(_iterType(ptr=0), itervar.name),
                        init=ExprCall(ExprVar('PickleIterator'),
                                      args=[ self.replyvar ])) ]
             + [ self.checkedRead(r.ipdltype, r.var(),
                                  ExprAddrOf(self.replyvar),
                                  ExprAddrOf(self.itervar),
-                                 errfn, "'%s'" % r.bareType(side).name)
+                                 errfn, "'%s'" % r.bareType(side).name,
+                                 sentinelKey=r.name)
                 for r in md.returns ]
             + [ self.endRead(self.replyvar, itervar) ])
 
         return stmts
 
     def sendAsync(self, md, msgexpr, actor=None):
         sendok = ExprVar('sendok__')
         return (
@@ -5331,24 +5361,38 @@ class _GenerateProtocolActorCode(ipdl.as
                 ExprVar(self.protocol.name +'::Transition'),
                 args=[ stateexpr,
                        ExprCall(ExprVar('Trigger'),
                                 args=[ action, ExprVar(msgid) ]),
                        ExprAddrOf(stateexpr) ])))
         ifbad.addifstmts(_badTransition())
         return [ ifbad ]
 
-    def checkedRead(self, ipdltype, expr, msgexpr, iterexpr, errfn, paramtype):
+    def checkedRead(self, ipdltype, expr, msgexpr, iterexpr, errfn, paramtype, sentinelKey, sentinel=True):
         ifbad = StmtIf(ExprNot(self.read(ipdltype, expr, msgexpr, iterexpr)))
         if isinstance(paramtype, list):
             errorcall = errfn(*paramtype)
         else:
             errorcall = errfn('Error deserializing ' + paramtype)
         ifbad.addifstmts(errorcall)
-        return ifbad
+
+        block = Block()
+        block.addstmt(ifbad)
+
+        if sentinel:
+            assert sentinelKey
+
+            block.addstmt(Whitespace('// Sentinel = ' + repr(sentinelKey) + '\n', indent=1))
+            read = ExprCall(ExprSelect(msgexpr, '->', 'ReadSentinel'),
+                                  args=[ iterexpr, ExprLiteral.Int(hashfunc(sentinelKey)) ])
+            ifsentinel = StmtIf(ExprNot(read))
+            ifsentinel.addifstmts(errorcall)
+            block.addstmt(ifsentinel)
+
+        return block
 
     def endRead(self, msgexpr, iterexpr):
         return StmtExpr(ExprCall(ExprSelect(msgexpr, '.', 'EndRead'),
                                  args=[ iterexpr ]))
 
 class _GenerateProtocolParentCode(_GenerateProtocolActorCode):
     def __init__(self):
         _GenerateProtocolActorCode.__init__(self, 'parent')