Bug 775777: Check dynamic actor type when deserializing. r=bent
authorChris Jones <jones.chris.g@gmail.com>
Thu, 20 Sep 2012 12:30:52 -0700
changeset 107768 7c9af3e022dac973a22cd0c42b7a1d59a08774d8
parent 107767 ae0ea12a8109530c029c6e77d3f8e8ead2a93676
child 107769 b68455ccdd9a5b87035b71a64a653e307fb5b372
push id82
push usershu@rfrn.org
push dateFri, 05 Oct 2012 13:20:22 +0000
reviewersbent
bugs775777
milestone18.0a1
Bug 775777: Check dynamic actor type when deserializing. r=bent
ipc/glue/AsyncChannel.h
ipc/glue/RPCChannel.h
ipc/glue/SyncChannel.h
ipc/ipdl/ipdl/lower.py
--- a/ipc/glue/AsyncChannel.h
+++ b/ipc/glue/AsyncChannel.h
@@ -66,17 +66,20 @@ public:
     {
     public:
         virtual ~AsyncListener() { }
 
         virtual void OnChannelClose() = 0;
         virtual void OnChannelError() = 0;
         virtual Result OnMessageReceived(const Message& aMessage) = 0;
         virtual void OnProcessingError(Result aError) = 0;
-        virtual void OnChannelConnected(int32_t peer_pid) {};
+        // FIXME/bug 792652: this doesn't really belong here, but a
+        // large refactoring is needed to put it where it belongs.
+        virtual int32_t GetProtocolTypeId() = 0;
+        virtual void OnChannelConnected(int32_t peer_pid) {}
     };
 
     enum Side { Parent, Child, Unknown };
 
 public:
     //
     // These methods are called on the "worker" thread
     //
--- a/ipc/glue/RPCChannel.h
+++ b/ipc/glue/RPCChannel.h
@@ -41,22 +41,23 @@ public:
     {
     public:
         virtual ~RPCListener() { }
 
         virtual void OnChannelClose() = 0;
         virtual void OnChannelError() = 0;
         virtual Result OnMessageReceived(const Message& aMessage) = 0;
         virtual void OnProcessingError(Result aError) = 0;
+        virtual int32_t GetProtocolTypeId() = 0;
         virtual bool OnReplyTimeout() = 0;
         virtual Result OnMessageReceived(const Message& aMessage,
                                          Message*& aReply) = 0;
         virtual Result OnCallReceived(const Message& aMessage,
                                       Message*& aReply) = 0;
-        virtual void OnChannelConnected(int32_t peer_pid) {};
+        virtual void OnChannelConnected(int32_t peer_pid) {}
 
         virtual void OnEnteredCxxStack()
         {
             NS_RUNTIMEABORT("default impl shouldn't be invoked");
         }
 
         virtual void OnExitedCxxStack()
         {
--- a/ipc/glue/SyncChannel.h
+++ b/ipc/glue/SyncChannel.h
@@ -27,20 +27,21 @@ public:
     {
     public:
         virtual ~SyncListener() { }
 
         virtual void OnChannelClose() = 0;
         virtual void OnChannelError() = 0;
         virtual Result OnMessageReceived(const Message& aMessage) = 0;
         virtual void OnProcessingError(Result aError) = 0;
+        virtual int32_t GetProtocolTypeId() = 0;
         virtual bool OnReplyTimeout() = 0;
         virtual Result OnMessageReceived(const Message& aMessage,
                                          Message*& aReply) = 0;
-        virtual void OnChannelConnected(int32_t peer_pid) {};
+        virtual void OnChannelConnected(int32_t peer_pid) {}
     };
 
     SyncChannel(SyncListener* aListener);
     virtual ~SyncChannel();
 
     virtual bool Send(Message* msg) MOZ_OVERRIDE {
         return AsyncChannel::Send(msg);
     }
--- a/ipc/ipdl/ipdl/lower.py
+++ b/ipc/ipdl/ipdl/lower.py
@@ -107,16 +107,19 @@ def _actorName(pname, side):
     """|pname| is the protocol name. |side| is 'Parent' or 'Child'."""
     tag = side
     if not tag[0].isupper():  tag = side.title()
     return pname + tag
 
 def _actorIdType():
     return Type.INT32
 
+def _actorTypeTagType():
+    return Type.INT32
+
 def _actorId(actor=None):
     if actor is not None:
         return ExprSelect(actor, '->', 'mId')
     return ExprVar('mId')
 
 def _actorHId(actorhandle):
     return ExprSelect(actorhandle, '.', 'mId')
 
@@ -3023,16 +3026,22 @@ class _GenerateProtocolActorCode(ipdl.as
         if ptype.isToplevel():
             onprocessingerror.addstmt(StmtReturn(
                 ExprCall(p.processingErrorVar(), args=[ codevar ])))
         else:
             onprocessingerror.addstmt(
                 _runtimeAbort("`OnProcessingError' called on non-toplevel actor"))
         self.cls.addstmts([ onprocessingerror, Whitespace.NL ])
 
+        # int32_t GetProtocolTypeId() { return PFoo; }
+        gettypetag = MethodDefn(
+            MethodDecl('GetProtocolTypeId', ret=_actorTypeTagType()))
+        gettypetag.addstmt(StmtReturn(_protocolId(ptype)))
+        self.cls.addstmts([ gettypetag, Whitespace.NL ])
+
         # OnReplyTimeout()
         if toplevel.talksSync() or toplevel.talksRpc():
             ontimeout = MethodDefn(
                 MethodDecl('OnReplyTimeout', ret=Type.BOOL))
 
             if ptype.isToplevel():
                 ontimeout.addstmt(StmtReturn(
                     ExprCall(p.shouldContinueFromTimeoutVar())))
@@ -4123,38 +4132,54 @@ class _GenerateProtocolActorCode(ipdl.as
         ifbadid.addifstmts([
                 _protocolErrorBreakpoint('bad ID for '+ self.protocol.name),
                 StmtReturn.FALSE
         ])
         read.addstmts([ ifbadid, Whitespace.NL ])
         
         # if (NULL_ID == id)
         #   *var = null
-        # else
-        #   *var = Lookup(id)
-        #   if (!*var)
-        #     return false
+        #   return true
         outactor = ExprDeref(var)
         ifnull = StmtIf(ExprBinary(_NULL_ACTOR_ID, '==', idvar))
-        ifnull.addifstmt(StmtExpr(ExprAssn(outactor, ExprLiteral.NULL)))
-
-        ifnull.addelsestmt(StmtExpr(ExprAssn(
-            outactor,
-            ExprCast(_lookupListener(idvar), cxxtype, static=1))))
-
-        ifnotfound = StmtIf(ExprNot(outactor))
+        ifnull.addifstmts([ StmtExpr(ExprAssn(outactor, ExprLiteral.NULL)),
+                            StmtReturn.TRUE ])
+        read.addstmts([ ifnull, Whitespace.NL ])
+
+        # Listener* listener = Lookup(id)
+        # if (!listener)
+        #   return false
+        listenervar = ExprVar('listener')
+        read.addstmt(StmtDecl(Decl(Type('ChannelListener', ptr=1),
+                                   listenervar.name),
+                              init=_lookupListener(idvar)))
+        ifnotfound = StmtIf(ExprNot(listenervar))
         ifnotfound.addifstmts([
-                _protocolErrorBreakpoint('could not look up '+ self.protocol.name),
+                _protocolErrorBreakpoint('could not look up '+ actortype.name()),
                 StmtReturn.FALSE
         ])
-        ifnull.addelsestmt(ifnotfound)
-
+        read.addstmts([ ifnotfound, Whitespace.NL ])
+
+        # if listener->GetProtocolTypeId() != [expected protocol type]
+        #   return false
+        ifbadtype = StmtIf(ExprBinary(
+                _protocolId(actortype), '!=',
+                ExprCall(ExprSelect(listenervar, '->', 'GetProtocolTypeId'))))
+        ifbadtype.addifstmts([
+                _protocolErrorBreakpoint('actor that should be of type '+ actortype.name() +' has different type'),
+                StmtReturn.FALSE
+        ])
+        read.addstmts([ ifbadtype, Whitespace.NL ])
+
+        # *outactor = static_cast<ExpectedType>(listener)
+        # return true
         read.addstmts([
-            ifnull,
-            StmtReturn.TRUE
+                StmtExpr(ExprAssn(outactor,
+                                  ExprCast(listenervar, cxxtype, static=1))),
+                StmtReturn.TRUE
         ])
 
         self.cls.addstmts([ write, Whitespace.NL, read, Whitespace.NL ])
 
 
     def implementSpecialArrayPickling(self, arraytype):
         var = self.var
         msgvar = self.msgvar