Bug 562741: Allow |Shmem|s to be shared across different protocol trees. r=bent
authorChris Jones <jones.chris.g@gmail.com>
Sat, 22 May 2010 14:35:33 -0500
changeset 42562 2390945942d68fb47b7357a7012ba778bb977b2f
parent 42561 4254363b9c635df43e2fff78392cdff960a35785
child 42563 c0e8fc4ceb57795b78a8eaa7e6df98374d7a6480
push idunknown
push userunknown
push dateunknown
reviewersbent
bugs562741
milestone1.9.3a5pre
Bug 562741: Allow |Shmem|s to be shared across different protocol trees. r=bent
ipc/chromium/src/base/id_map.h
ipc/glue/ProtocolUtils.h
ipc/glue/Shmem.h
ipc/ipdl/ipdl/lower.py
--- a/ipc/chromium/src/base/id_map.h
+++ b/ipc/chromium/src/base/id_map.h
@@ -71,16 +71,24 @@ class IDMap {
   bool IsEmpty() const {
     return data_.empty();
   }
 
 #if defined(CHROMIUM_MOZILLA_BUILD)
   void Clear() {
     data_.clear();
   }
+
+  bool HasData(const T* data) const {
+    // XXX would like to use <algorithm> here ...
+    for (const_iterator it = begin(); it != end(); ++it)
+      if (data == it->second)
+        return true;
+    return false;
+  }
 #endif
 
   T* Lookup(int32 id) const {
     const_iterator i = data_.find(id);
     if (i == data_.end())
       return NULL;
     return i->second;
   }
--- a/ipc/glue/ProtocolUtils.h
+++ b/ipc/glue/ProtocolUtils.h
@@ -88,17 +88,19 @@ public:
     virtual int32 Register(ListenerT*) = 0;
     virtual int32 RegisterID(ListenerT*, int32) = 0;
     virtual ListenerT* Lookup(int32) = 0;
     virtual void Unregister(int32) = 0;
     virtual void RemoveManagee(int32, ListenerT*) = 0;
 
     virtual Shmem::SharedMemory* CreateSharedMemory(
         size_t, SharedMemory::SharedMemoryType, int32*) = 0;
+    virtual bool AdoptSharedMemory(Shmem::SharedMemory*, int32*) = 0;
     virtual Shmem::SharedMemory* LookupSharedMemory(int32) = 0;
+    virtual bool IsTrackingSharedMemory(Shmem::SharedMemory*) = 0;
     virtual bool DestroySharedMemory(Shmem&) = 0;
 
     // XXX odd duck, acknowledged
     virtual ProcessHandle OtherProcess() const = 0;
 };
 
 
 inline bool
--- a/ipc/glue/Shmem.h
+++ b/ipc/glue/Shmem.h
@@ -184,20 +184,24 @@ public:
     AssertAligned<T>();
 
     return mSize / sizeof(T);
   }
 
   int GetSysVID() const;
 
   // These shouldn't be used directly, use the IPDL interface instead.
-  id_t Id(IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead) {
+  id_t Id(IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead) const {
     return mId;
   }
 
+  SharedMemory* Segment(IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead) const {
+    return mSegment;
+  }
+
 #ifndef DEBUG
   void RevokeRights(IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead)
   {
   }
 #else
   void RevokeRights(IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead);
 #endif
 
--- a/ipc/ipdl/ipdl/lower.py
+++ b/ipc/ipdl/ipdl/lower.py
@@ -197,17 +197,17 @@ def _lookupActor(idexpr, outactor, actor
 
 def _lookupActorHandle(handle, outactor, actortype, cxxactortype, errfn):
     return _lookupActor(_actorHId(handle), outactor, actortype, cxxactortype,
                         errfn)
 
 def _lookupListener(idexpr):
     return ExprCall(ExprVar('Lookup'), args=[ idexpr ])
 
-def _shmemType(ptr=0, ref=0):
+def _shmemType(ptr=0, const=1, ref=0):
     return Type('Shmem', ptr=ptr, ref=ref)
 
 def _rawShmemType(ptr=0):
     return Type('Shmem::SharedMemory', ptr=ptr)
 
 def _shmemIdType(ptr=0):
     return Type('Shmem::id_t', ptr=ptr)
 
@@ -221,16 +221,20 @@ def _shmemBackstagePass():
 def _shmemCtor(rawmem, idexpr):
     return ExprCall(ExprVar('Shmem'),
                     args=[ _shmemBackstagePass(), rawmem, idexpr ])
 
 def _shmemId(shmemexpr):
     return ExprCall(ExprSelect(shmemexpr, '.', 'Id'),
                     args=[ _shmemBackstagePass() ])
 
+def _shmemSegment(shmemexpr):
+    return ExprCall(ExprSelect(shmemexpr, '.', 'Segment'),
+                    args=[ _shmemBackstagePass() ])
+
 def _shmemAlloc(size, type):
     # starts out UNprotected
     return ExprCall(ExprVar('Shmem::Alloc'),
                     args=[ _shmemBackstagePass(), size, type ])
 
 def _shmemDealloc(rawmemvar):
     return ExprCall(ExprVar('Shmem::Dealloc'),
                     args=[ _shmemBackstagePass(), rawmemvar ])
@@ -1382,20 +1386,26 @@ class Protocol(ipdl.ast.Protocol):
             return ExprSelect(actorThis, '->', 'Unregister')
         return ExprVar('Unregister')
 
     def removeManageeMethod(self):
         return ExprVar('RemoveManagee')
 
     def createSharedMemory(self):
         return ExprVar('CreateSharedMemory')
+
+    def adoptSharedMemory(self):
+        return ExprVar('AdoptSharedMemory')
  
     def lookupSharedMemory(self):
         return ExprVar('LookupSharedMemory')
 
+    def isTrackingSharedMemory(self):
+        return ExprVar('IsTrackingSharedMemory')
+
     def destroySharedMemory(self):
         return ExprVar('DestroySharedMemory')
 
     def otherProcessMethod(self):
         return ExprVar('OtherProcess')
 
     def shouldContinueFromTimeoutVar(self):
         assert self.decl.type.isToplevel()
@@ -3270,16 +3280,17 @@ class _GenerateProtocolActorCode(ipdl.as
                     p.managedVarType(managed, self.side),
                     p.managedVar(managed, self.side).name)) ])
 
     def implementManagerIface(self):
         p = self.protocol
         routedvar = ExprVar('aRouted')
         idvar = ExprVar('aId')
         shmemvar = ExprVar('aShmem')
+        rawvar = ExprVar('segment')
         sizevar = ExprVar('aSize')
         typevar = ExprVar('type')
         listenertype = Type('ChannelListener', ptr=1)
 
         register = MethodDefn(MethodDecl(
             p.registerMethod().name,
             params=[ Decl(listenertype, routedvar.name) ],
             ret=_actorIdType(), virtual=1))
@@ -3300,26 +3311,37 @@ class _GenerateProtocolActorCode(ipdl.as
 
         createshmem = MethodDefn(MethodDecl(
             p.createSharedMemory().name,
             ret=_rawShmemType(ptr=1),
             params=[ Decl(Type.SIZE, sizevar.name),
                      Decl(_shmemTypeType(), typevar.name),
                      Decl(_shmemIdType(ptr=1), idvar.name) ],
             virtual=1))
+        adoptshmem = MethodDefn(MethodDecl(
+            p.adoptSharedMemory().name,
+            ret=Type.BOOL,
+            params=[ Decl(_rawShmemType(ptr=1), rawvar.name),
+                     Decl(_shmemIdType(ptr=1), idvar.name) ],
+            virtual=1))
         lookupshmem = MethodDefn(MethodDecl(
             p.lookupSharedMemory().name,
             ret=_rawShmemType(ptr=1),
             params=[ Decl(_shmemIdType(), idvar.name) ],
             virtual=1))
         destroyshmem = MethodDefn(MethodDecl(
             p.destroySharedMemory().name,
             ret=Type.BOOL,
             params=[ Decl(_shmemType(ref=1), shmemvar.name) ],
             virtual=1))
+        istracking = MethodDefn(MethodDecl(
+            p.isTrackingSharedMemory().name,
+            ret=Type.BOOL,
+            params=[ Decl(_rawShmemType(ptr=1), rawvar.name) ],
+            virtual=1))
 
         otherprocess = MethodDefn(MethodDecl(
             p.otherProcessMethod().name,
             ret=Type('ProcessHandle'),
             const=1,
             virtual=1))
 
         if p.decl.type.isToplevel():
@@ -3349,26 +3371,24 @@ class _GenerateProtocolActorCode(ipdl.as
             # SharedMemory* CreateSharedMemory(size, type, id_t*):
             #   nsAutoPtr<SharedMemory> seg(Shmem::Alloc(size, type));
             #   if (!shmem)
             #     return false
             #   Shmem s(seg, [nextshmemid]);
             #   Message descriptor;
             #   if (!s->ShareTo(subprocess, mId, descriptor) ||
             #       !Send(descriptor))
-            #     return false;
+            #     return null;
             #   mShmemMap.Add(seg, id);
             #   return shmem.forget();
-            rawvar = ExprVar('segment')
-
             createshmem.addstmt(StmtDecl(
                 Decl(_autoptr(_rawShmemType()), rawvar.name),
                 initargs=[ _shmemAlloc(sizevar, typevar) ]))
             failif = StmtIf(ExprNot(rawvar))
-            failif.addifstmt(StmtReturn(ExprLiteral.FALSE))
+            failif.addifstmt(StmtReturn(ExprLiteral.NULL))
             createshmem.addstmt(failif)
 
             descriptorvar = ExprVar('descriptor')
             createshmem.addstmts([
                 StmtDecl(
                     Decl(_shmemType(), shmemvar.name),
                     initargs=[ _shmemBackstagePass(),
                                _autoptrGet(rawvar),
@@ -3390,21 +3410,68 @@ class _GenerateProtocolActorCode(ipdl.as
             createshmem.addstmts([
                 StmtExpr(ExprAssn(ExprDeref(idvar), _shmemId(shmemvar))),
                 StmtExpr(ExprCall(
                     ExprSelect(p.shmemMapVar(), '.', 'AddWithID'),
                     args=[ rawvar, ExprDeref(idvar) ])),
                 StmtReturn(_autoptrForget(rawvar))
             ])
 
+            # SharedMemory* AdoptSharedMemory(SharedMemory*, id_t*):
+            #   Shmem s(seg, [nextshmemid]);
+            #   Message descriptor;
+            #   if (!s->ShareTo(subprocess, mId, descriptor) ||
+            #       !Send(descriptor))
+            #     return false;
+            #   mShmemMap.Add(seg, id);
+            #   seg->AddRef();
+            #   return true;
+
+            # XXX this is close to the same code as above, could be
+            # refactored
+            descriptorvar = ExprVar('descriptor')
+            adoptshmem.addstmts([
+                StmtDecl(
+                    Decl(_shmemType(), shmemvar.name),
+                    initargs=[ _shmemBackstagePass(),
+                               rawvar,
+                               p.nextShmemIdExpr(self.side) ]),
+                StmtDecl(Decl(Type('Message', ptr=1), descriptorvar.name),
+                         init=_shmemShareTo(shmemvar,
+                                            ExprCall(p.otherProcessMethod()),
+                                            p.routingId()))
+            ])
+            failif = StmtIf(ExprNot(descriptorvar))
+            failif.addifstmt(StmtReturn(ExprLiteral.FALSE))
+            adoptshmem.addstmt(failif)
+
+            failif = StmtIf(ExprNot(ExprCall(
+                ExprSelect(p.channelVar(), p.channelSel(), 'Send'),
+                args=[ descriptorvar ])))
+            adoptshmem.addstmt(failif)
+
+            adoptshmem.addstmts([
+                StmtExpr(ExprAssn(ExprDeref(idvar), _shmemId(shmemvar))),
+                StmtExpr(ExprCall(
+                    ExprSelect(p.shmemMapVar(), '.', 'AddWithID'),
+                    args=[ rawvar, ExprDeref(idvar) ])),
+                StmtExpr(ExprCall(ExprSelect(rawvar, '->', 'AddRef'))),
+                StmtReturn(ExprLiteral.TRUE)
+            ])
+
             # SharedMemory* Lookup(id)
             lookupshmem.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.shmemMapVar(), '.', 'Lookup'),
                 args=[ idvar ])))
 
+            # bool IsTrackingSharedMemory(mem)
+            istracking.addstmt(StmtReturn(ExprCall(
+                ExprSelect(p.shmemMapVar(), '.', 'HasData'),
+                args=[ rawvar ])))
+
             # bool DestroySharedMemory(shmem):
             #   id = shmem.Id()
             #   SharedMemory* rawmem = Lookup(id)
             #   if (!rawmem)
             #     return false;
             #   Message descriptor = UnShare(subprocess, mId, descriptor)
             #   mShmemMap.Remove(id)
             #   Shmem::Dealloc(rawmem)
@@ -3470,19 +3537,26 @@ class _GenerateProtocolActorCode(ipdl.as
                 ExprSelect(p.managerVar(), '->', p.lookupIDMethod().name),
                 [ idvar ])))
             unregister.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.managerVar(), '->', p.unregisterMethod().name),
                 [ idvar ])))
             createshmem.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.managerVar(), '->', p.createSharedMemory().name),
                 [ sizevar, typevar, idvar ])))
+            adoptshmem.addstmt(StmtReturn(ExprCall(
+                ExprSelect(p.managerVar(), '->', p.adoptSharedMemory().name),
+                [ rawvar, idvar ])))
             lookupshmem.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.managerVar(), '->', p.lookupSharedMemory().name),
                 [ idvar ])))
+            istracking.addstmt(StmtReturn(ExprCall(
+                ExprSelect(p.managerVar(), '->',
+                           p.isTrackingSharedMemory().name),
+                [ rawvar ])))
             destroyshmem.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.managerVar(), '->', p.destroySharedMemory().name),
                 [ shmemvar ])))
             otherprocess.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.managerVar(), '->',
                            p.otherProcessMethod().name))))
 
         # all protocols share the "same" RemoveManagee() implementation
@@ -3525,35 +3599,38 @@ class _GenerateProtocolActorCode(ipdl.as
         removemanagee.addstmt(switchontype)
 
         return [ register,
                  registerid,
                  lookup,
                  unregister,
                  removemanagee,
                  createshmem,
+                 adoptshmem,
                  lookupshmem,
+                 istracking,
                  destroyshmem,
                  otherprocess,
                  Whitespace.NL ]
 
     def makeShmemIface(self):
         p = self.protocol
         idvar = ExprVar('aId')
         sizevar = ExprVar('aSize')
         typevar = ExprVar('aType')
         memvar = ExprVar('aMem')
+        outmemvar = ExprVar('aOutMem')
         rawvar = ExprVar('rawmem')
 
         # bool AllocShmem(size_t size, Shmem* outmem):
         #   id_t id;
         #   nsAutoPtr<SharedMemory> mem(CreateSharedMemory(&id));
         #   if (!mem)
         #     return false;
-        #   *outmem = Shmem(shmem, id)
+        #   *outmem = Shmem(mem, id)
         #   return true;
         allocShmem = MethodDefn(MethodDecl(
             'AllocShmem',
             params=[ Decl(Type.SIZE, sizevar.name),
                      Decl(_shmemTypeType(), typevar.name),
                      Decl(_shmemType(ptr=1), memvar.name) ],
             ret=Type.BOOL))
 
@@ -3569,16 +3646,53 @@ class _GenerateProtocolActorCode(ipdl.as
                                                 ExprAddrOf(idvar) ]) ]),
             ifallocfails,
             Whitespace.NL,
             StmtExpr(ExprAssn(
                 ExprDeref(memvar), _shmemCtor(_autoptrForget(rawvar), idvar))),
             StmtReturn(ExprLiteral.TRUE)
         ])
 
+        # bool AdoptShmem(const Shmem& mem, Shmem* outmem):
+        #   SharedMemory* raw = mem.mSegment;
+        #   if (!raw || IsTrackingSharedMemory(raw))
+        #     RUNTIMEABORT()
+        #   id_t id
+        #   if (!AdoptSharedMemory(raw, &id))
+        #     return false
+        #   *outmem = Shmem(raw, id);
+        #   return true;
+        adoptShmem = MethodDefn(MethodDecl(
+            'AdoptShmem',
+            params=[ Decl(_shmemType(const=1, ref=1), memvar.name),
+                     Decl(_shmemType(ptr=1), outmemvar.name) ],
+            ret=Type.BOOL))
+
+        adoptShmem.addstmt(StmtDecl(Decl(_rawShmemType(ptr=1), rawvar.name),
+                                    init=_shmemSegment(memvar)))
+        ifbad = StmtIf(ExprBinary(
+            ExprNot(rawvar), '||',
+            ExprCall(ExprVar('IsTrackingSharedMemory'), args=[ rawvar ])))
+        ifbad.addifstmt(_runtimeAbort('bad Shmem'))
+        adoptShmem.addstmt(ifbad)
+
+        ifadoptfails = StmtIf(ExprNot(ExprCall(
+            p.adoptSharedMemory(), args=[ rawvar, ExprAddrOf(idvar) ])))
+        ifadoptfails.addifstmt(StmtReturn(ExprLiteral.FALSE))
+
+        adoptShmem.addstmts([
+            Whitespace.NL,
+            StmtDecl(Decl(_shmemIdType(), idvar.name)),
+            ifadoptfails,
+            Whitespace.NL,
+            StmtExpr(ExprAssn(ExprDeref(outmemvar),
+                              _shmemCtor(rawvar, idvar))),
+            StmtReturn(ExprLiteral.TRUE)
+        ])
+
         # bool DeallocShmem(Shmem& mem):
         #   bool ok = DestroySharedMemory(mem);
         #   mem.forget();
         #   return ok;
         deallocShmem = MethodDefn(MethodDecl(
             'DeallocShmem',
             params=[ Decl(_shmemType(ref=1), memvar.name) ],
             ret=Type.BOOL))
@@ -3590,16 +3704,18 @@ class _GenerateProtocolActorCode(ipdl.as
                                    args=[ memvar ])),
             StmtExpr(_shmemForget(memvar)),
             StmtReturn(okvar)
         ])
 
         return [ Whitespace('// Methods for managing shmem\n', indent=1),
                  allocShmem,
                  Whitespace.NL,
+                 adoptShmem,
+                 Whitespace.NL,
                  deallocShmem,
                  Whitespace.NL ]
 
     def genShmemCreatedHandler(self):
         p = self.protocol
         assert p.decl.type.isToplevel()
         
         case = StmtBlock()