Bug 555275: Implement a DeallocShmem() interface. r=bent
authorChris Jones <jones.chris.g@gmail.com>
Mon, 26 Apr 2010 20:11:40 -0500
changeset 41362 4a02a072129ba5e86badd9669f5306bb629c24ec
parent 41361 fa53318b87aa5a1fdc9a278a74ad110471d11ddb
child 41363 9e2c5334e23bd8550a516595528fae36d6ffb66f
push idunknown
push userunknown
push dateunknown
reviewersbent
bugs555275
milestone1.9.3a5pre
Bug 555275: Implement a DeallocShmem() interface. r=bent
ipc/glue/ProtocolUtils.h
ipc/glue/Shmem.cpp
ipc/glue/Shmem.h
ipc/ipdl/ipdl/cxx/ast.py
ipc/ipdl/ipdl/lower.py
ipc/ipdl/test/cxx/TestShmem.cpp
ipc/ipdl/test/cxx/TestSysVShmem.cpp
--- a/ipc/glue/ProtocolUtils.h
+++ b/ipc/glue/ProtocolUtils.h
@@ -47,16 +47,17 @@
 #include "prenv.h"
 
 #include "mozilla/ipc/Shmem.h"
 
 // WARNING: this takes into account the private, special-message-type
 // enum in ipc_channel.h.  They need to be kept in sync.
 namespace {
 enum {
+    SHMEM_DESTROYED_MESSAGE_TYPE = kuint16max - 5,
     UNBLOCK_CHILD_MESSAGE_TYPE = kuint16max - 4,
     BLOCK_CHILD_MESSAGE_TYPE   = kuint16max - 3,
     SHMEM_CREATED_MESSAGE_TYPE = kuint16max - 2,
     GOODBYE_MESSAGE_TYPE       = kuint16max - 1,
 };
 }
 
 namespace mozilla {
@@ -88,16 +89,17 @@ public:
     virtual int32 RegisterID(ListenerT*, int32) = 0;
     virtual ListenerT* Lookup(int32) = 0;
     virtual void Unregister(int32) = 0;
     virtual void RemoveManagee(int32, ListenerT*) = 0;
 
     virtual Shmem::SharedMemory* CreateSharedMemory(
         size_t, SharedMemory::SharedMemoryType, int32*) = 0;
     virtual Shmem::SharedMemory* LookupSharedMemory(int32) = 0;
+    virtual bool DestroySharedMemory(Shmem&) = 0;
 
     // XXX odd duck, acknowledged
     virtual ProcessHandle OtherProcess() const = 0;
 };
 
 
 inline bool
 LoggingEnabled()
--- a/ipc/glue/Shmem.cpp
+++ b/ipc/glue/Shmem.cpp
@@ -124,16 +124,31 @@ public:
 
   void Log(const std::string& aPrefix,
            FILE* aOutf) const
   {
     fputs("(special ShmemCreated msg)", aOutf);
   }
 };
 
+class ShmemDestroyed : public IPC::Message
+{
+private:
+  typedef Shmem::id_t id_t;
+
+public:
+  ShmemDestroyed(int32 routingId,
+                 const id_t& aIPDLId) :
+    IPC::Message(routingId, SHMEM_DESTROYED_MESSAGE_TYPE, PRIORITY_NORMAL)
+  {
+    IPC::WriteParam(this, aIPDLId);
+  }
+};
+
+
 #ifdef MOZ_HAVE_SHAREDMEMORYSYSV
 static Shmem::SharedMemory*
 CreateSegment(size_t aNBytes, SharedMemorySysV::Handle aHandle)
 {
   nsAutoPtr<SharedMemory> segment;
 
   if (SharedMemorySysV::IsHandleValid(aHandle)) {
     segment = new SharedMemorySysV(aHandle);
@@ -593,10 +608,19 @@ Shmem::ShareTo(IHadBetterBeIPDLCodeCalli
 #endif
   else {
     NS_RUNTIMEABORT("unknown shmem type (here?!)");
   }
 
   return 0;
 }
 
+IPC::Message*
+Shmem::UnshareFrom(IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead,
+                   base::ProcessHandle aProcess,
+                   int32 routingId)
+{
+  AssertInvariants();
+  return new ShmemDestroyed(routingId, mId);
+}
+
 } // namespace ipc
 } // namespace mozilla
--- a/ipc/glue/Shmem.h
+++ b/ipc/glue/Shmem.h
@@ -219,16 +219,25 @@ public:
   // that contains enough information for the other process to map
   // this segment in OpenExisting() below.  Return a new message if
   // successful (owned by the caller), NULL if not.
   IPC::Message*
   ShareTo(IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead,
           base::ProcessHandle aProcess,
           int32 routingId);
 
+  // Stop sharing this with |aProcess|.  Return an IPC message that
+  // contains enough information for the other process to unmap this
+  // segment.  Return a new message if successful (owned by the
+  // caller), NULL if not.
+  IPC::Message*
+  UnshareFrom(IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead,
+              base::ProcessHandle aProcess,
+              int32 routingId);
+
   // Return a SharedMemory instance in this process using the
   // descriptor shared to us by the process that created the
   // underlying OS shmem resource.  The contents of the descriptor
   // depend on the type of SharedMemory that was passed to us.
   static SharedMemory*
   OpenExisting(IHadBetterBeIPDLCodeCallingThis_OtherwiseIAmADoodyhead,
                const IPC::Message& aDescriptor,
                id_t* aId,
--- a/ipc/ipdl/ipdl/cxx/ast.py
+++ b/ipc/ipdl/ipdl/cxx/ast.py
@@ -652,16 +652,17 @@ class StmtBlock(Block):
     def __init__(self, stmts=[ ]):
         Block.__init__(self)
         self.addstmts(stmts)
 
 class StmtDecl(Node):
     def __init__(self, decl, init=None, initargs=None):
         assert not (init and initargs)
         assert not isinstance(init, str) # easy to confuse with Decl
+        assert not isinstance(init, list)
         assert not isinstance(decl, tuple)
         
         Node.__init__(self)
         self.decl = decl
         self.init = init
         self.initargs = initargs
 
 class Label(Node):
--- a/ipc/ipdl/ipdl/lower.py
+++ b/ipc/ipdl/ipdl/lower.py
@@ -197,18 +197,18 @@ def _lookupActor(idexpr, outactor, actor
 
 def _lookupActorHandle(handle, outactor, actortype, cxxactortype, errfn):
     return _lookupActor(_actorHId(handle), outactor, actortype, cxxactortype,
                         errfn)
 
 def _lookupListener(idexpr):
     return ExprCall(ExprVar('Lookup'), args=[ idexpr ])
 
-def _shmemType(ptr=0):
-    return Type('Shmem', ptr=ptr)
+def _shmemType(ptr=0, ref=0):
+    return Type('Shmem', ptr=ptr, ref=ref)
 
 def _rawShmemType(ptr=0):
     return Type('Shmem::SharedMemory', ptr=ptr)
 
 def _shmemIdType(ptr=0):
     return Type('Shmem::id_t', ptr=ptr)
 
 def _shmemTypeType():
@@ -226,28 +226,37 @@ def _shmemId(shmemexpr):
     return ExprCall(ExprSelect(shmemexpr, '.', 'Id'),
                     args=[ _shmemBackstagePass() ])
 
 def _shmemAlloc(size, type):
     # starts out UNprotected
     return ExprCall(ExprVar('Shmem::Alloc'),
                     args=[ _shmemBackstagePass(), size, type ])
 
+def _shmemDealloc(rawmemvar):
+    return ExprCall(ExprVar('Shmem::Dealloc'),
+                    args=[ _shmemBackstagePass(), rawmemvar ])
+
 def _shmemShareTo(shmemvar, processvar, route):
     return ExprCall(ExprSelect(shmemvar, '.', 'ShareTo'),
                     args=[ _shmemBackstagePass(),
                            processvar, route ])
 
 def _shmemOpenExisting(descriptor, outid):
     # starts out protected
     return ExprCall(ExprVar('Shmem::OpenExisting'),
                     args=[ _shmemBackstagePass(),
                            # true => protect
                            descriptor, outid, ExprLiteral.TRUE ])
 
+def _shmemUnshareFrom(shmemvar, processvar, route):
+    return ExprCall(ExprSelect(shmemvar, '.', 'UnshareFrom'),
+                    args=[ _shmemBackstagePass(),
+                           processvar, route ])
+
 def _shmemForget(shmemexpr):
     return ExprCall(ExprSelect(shmemexpr, '.', 'forget'),
                     args=[ _shmemBackstagePass() ])
 
 def _shmemRevokeRights(shmemexpr):
     return ExprCall(ExprSelect(shmemexpr, '.', 'RevokeRights'),
                     args=[ _shmemBackstagePass() ])
 
@@ -1355,16 +1364,19 @@ class Protocol(ipdl.ast.Protocol):
         return ExprVar('RemoveManagee')
 
     def createSharedMemory(self):
         return ExprVar('CreateSharedMemory')
  
     def lookupSharedMemory(self):
         return ExprVar('LookupSharedMemory')
 
+    def destroySharedMemory(self):
+        return ExprVar('DestroySharedMemory')
+
     def otherProcessMethod(self):
         return ExprVar('OtherProcess')
 
     def shouldContinueFromTimeoutVar(self):
         assert self.decl.type.isToplevel()
         return ExprVar('ShouldContinueFromReplyTimeout')
 
     def enteredCxxStackVar(self):
@@ -1476,16 +1488,20 @@ class Protocol(ipdl.ast.Protocol):
         else: assert 0
 
     def nextShmemIdExpr(self, side):
         assert self.decl.type.isToplevel()
         if side is 'parent':   op = '++'
         elif side is 'child':  op = '--'
         return ExprPrefixUnop(self.lastShmemIdVar(), op)
 
+    def removeShmemId(self, idexpr):
+        return ExprCall(ExprSelect(self.shmemMapVar(), '.', 'Remove'),
+                        args=[ idexpr ])
+
     def usesShmem(self):
         for md in self.messageDecls:
             for param in md.inParams:
                 if ipdl.type.hasshmem(param.type):
                     return True
             for ret in md.outParams:
                 if ipdl.type.hasshmem(ret.type):
                     return True
@@ -3137,16 +3153,17 @@ class _GenerateProtocolActorCode(ipdl.as
                 StmtDecl(Decl(
                     p.managedVarType(managed, self.side),
                     p.managedVar(managed, self.side).name)) ])
 
     def implementManagerIface(self):
         p = self.protocol
         routedvar = ExprVar('aRouted')
         idvar = ExprVar('aId')
+        shmemvar = ExprVar('aShmem')
         sizevar = ExprVar('aSize')
         typevar = ExprVar('type')
         listenertype = Type('ChannelListener', ptr=1)
 
         register = MethodDefn(MethodDecl(
             p.registerMethod().name,
             params=[ Decl(listenertype, routedvar.name) ],
             ret=_actorIdType(), virtual=1))
@@ -3172,17 +3189,22 @@ class _GenerateProtocolActorCode(ipdl.as
                      Decl(_shmemTypeType(), typevar.name),
                      Decl(_shmemIdType(ptr=1), idvar.name) ],
             virtual=1))
         lookupshmem = MethodDefn(MethodDecl(
             p.lookupSharedMemory().name,
             ret=_rawShmemType(ptr=1),
             params=[ Decl(_shmemIdType(), idvar.name) ],
             virtual=1))
-        
+        destroyshmem = MethodDefn(MethodDecl(
+            p.destroySharedMemory().name,
+            ret=Type.BOOL,
+            params=[ Decl(_shmemType(ref=1), shmemvar.name) ],
+            virtual=1))
+
         otherprocess = MethodDefn(MethodDecl(
             p.otherProcessMethod().name,
             ret=Type('ProcessHandle'),
             const=1,
             virtual=1))
 
         if p.decl.type.isToplevel():
             tmpvar = ExprVar('tmp')
@@ -3209,32 +3231,30 @@ class _GenerateProtocolActorCode(ipdl.as
                          [ idvar ])))
 
             # SharedMemory* CreateSharedMemory(size, type, id_t*):
             #   nsAutoPtr<SharedMemory> seg(Shmem::Alloc(size, type));
             #   if (!shmem)
             #     return false
             #   Shmem s(seg, [nextshmemid]);
             #   Message descriptor;
-            #   if (!s->ShareTo(subprocess, mId, descriptor))
-            #     return false;
-            #   if (!Send(descriptor))
+            #   if (!s->ShareTo(subprocess, mId, descriptor) ||
+            #       !Send(descriptor))
             #     return false;
             #   mShmemMap.Add(seg, id);
             #   return shmem.forget();
             rawvar = ExprVar('segment')
 
             createshmem.addstmt(StmtDecl(
                 Decl(_autoptr(_rawShmemType()), rawvar.name),
                 initargs=[ _shmemAlloc(sizevar, typevar) ]))
             failif = StmtIf(ExprNot(rawvar))
             failif.addifstmt(StmtReturn(ExprLiteral.FALSE))
             createshmem.addstmt(failif)
 
-            shmemvar = ExprVar('shmem')
             descriptorvar = ExprVar('descriptor')
             createshmem.addstmts([
                 StmtDecl(
                     Decl(_shmemType(), shmemvar.name),
                     initargs=[ _shmemBackstagePass(),
                                _autoptrGet(rawvar),
                                p.nextShmemIdExpr(self.side) ]),
                 StmtDecl(Decl(Type('Message', ptr=1), descriptorvar.name),
@@ -3259,22 +3279,72 @@ class _GenerateProtocolActorCode(ipdl.as
                 StmtReturn(_autoptrForget(rawvar))
             ])
 
             # SharedMemory* Lookup(id)
             lookupshmem.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.shmemMapVar(), '.', 'Lookup'),
                 args=[ idvar ])))
 
+            # bool DestroySharedMemory(shmem):
+            #   id = shmem.Id()
+            #   SharedMemory* rawmem = Lookup(id)
+            #   if (!rawmem)
+            #     return false;
+            #   Message descriptor = UnShare(subprocess, mId, descriptor)
+            #   mShmemMap.Remove(id)
+            #   Shmem::Dealloc(rawmem)
+            #   return descriptor && Send(descriptor)
+            destroyshmem.addstmts([
+                StmtDecl(Decl(_shmemIdType(), idvar.name),
+                         init=_shmemId(shmemvar)),
+                StmtDecl(Decl(_rawShmemType(ptr=1), rawvar.name),
+                         init=_lookupShmem(idvar))
+            ])
+
+            failif = StmtIf(ExprNot(rawvar))
+            failif.addifstmt(StmtReturn(ExprLiteral.FALSE))
+            destroyshmem.addstmts([
+                failif,
+                StmtDecl(Decl(Type('Message', ptr=1), descriptorvar.name),
+                         init=_shmemUnshareFrom(
+                             shmemvar,
+                             ExprCall(p.otherProcessMethod()),
+                             p.routingId())),
+                Whitespace.NL,
+                StmtExpr(p.removeShmemId(idvar)),
+                StmtExpr(_shmemDealloc(rawvar)),
+                Whitespace.NL,
+                StmtReturn(ExprBinary(
+                    descriptorvar, '&&',
+                    ExprCall(
+                        ExprSelect(p.channelVar(), p.channelSel(), 'Send'),
+                        args=[ descriptorvar ])))
+            ])
+
+
             # "private" message that passes shmem mappings from one process
             # to the other
             if p.usesShmem():
                 self.asyncSwitch.addcase(
                     CaseLabel('SHMEM_CREATED_MESSAGE_TYPE'),
                     self.genShmemCreatedHandler())
+                self.asyncSwitch.addcase(
+                    CaseLabel('SHMEM_DESTROYED_MESSAGE_TYPE'),
+                    self.genShmemDestroyedHandler())
+            else:
+                abort = StmtBlock()
+                abort.addstmts([
+                    _runtimeAbort('this protocol tree does not use shmem'),
+                    StmtReturn(_Result.NotKnown)
+                ])
+                self.asyncSwitch.addcase(
+                    CaseLabel('SHMEM_CREATED_MESSAGE_TYPE'), abort)
+                self.asyncSwitch.addcase(
+                    CaseLabel('SHMEM_DESTROYED_MESSAGE_TYPE'), abort)
             
             otherprocess.addstmt(StmtReturn(p.otherProcessVar()))
         else:
             # delegate registration to manager
             register.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.managerVar(), '->', p.registerMethod().name),
                 [ routedvar ])))
             registerid.addstmt(StmtReturn(ExprCall(
@@ -3287,16 +3357,19 @@ class _GenerateProtocolActorCode(ipdl.as
                 ExprSelect(p.managerVar(), '->', p.unregisterMethod().name),
                 [ idvar ])))
             createshmem.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.managerVar(), '->', p.createSharedMemory().name),
                 [ sizevar, typevar, idvar ])))
             lookupshmem.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.managerVar(), '->', p.lookupSharedMemory().name),
                 [ idvar ])))
+            destroyshmem.addstmt(StmtReturn(ExprCall(
+                ExprSelect(p.managerVar(), '->', p.destroySharedMemory().name),
+                [ shmemvar ])))
             otherprocess.addstmt(StmtReturn(ExprCall(
                 ExprSelect(p.managerVar(), '->',
                            p.otherProcessMethod().name))))
 
         # all protocols share the "same" RemoveManagee() implementation
         pvar = ExprVar('aProtocolId')
         listenervar = ExprVar('aListener')
         removemanagee = MethodDefn(MethodDecl(
@@ -3337,16 +3410,17 @@ class _GenerateProtocolActorCode(ipdl.as
 
         return [ register,
                  registerid,
                  lookup,
                  unregister,
                  removemanagee,
                  createshmem,
                  lookupshmem,
+                 destroyshmem,
                  otherprocess,
                  Whitespace.NL ]
 
     def makeShmemIface(self):
         p = self.protocol
         idvar = ExprVar('aId')
         sizevar = ExprVar('aSize')
         typevar = ExprVar('aType')
@@ -3378,19 +3452,39 @@ class _GenerateProtocolActorCode(ipdl.as
                                                 typevar,
                                                 ExprAddrOf(idvar) ]) ]),
             ifallocfails,
             Whitespace.NL,
             StmtExpr(ExprAssn(
                 ExprDeref(memvar), _shmemCtor(_autoptrForget(rawvar), idvar))),
             StmtReturn(ExprLiteral.TRUE)
         ])
-                
+
+        # bool DeallocShmem(Shmem& mem):
+        #   bool ok = DestroySharedMemory(mem);
+        #   mem.forget();
+        #   return ok;
+        deallocShmem = MethodDefn(MethodDecl(
+            'DeallocShmem',
+            params=[ Decl(_shmemType(ref=1), memvar.name) ],
+            ret=Type.BOOL))
+        okvar = ExprVar('ok')
+
+        deallocShmem.addstmts([
+            StmtDecl(Decl(Type.BOOL, okvar.name),
+                     init=ExprCall(p.destroySharedMemory(),
+                                   args=[ memvar ])),
+            StmtExpr(_shmemForget(memvar)),
+            StmtReturn(okvar)
+        ])
+
         return [ Whitespace('// Methods for managing shmem\n', indent=1),
                  allocShmem,
+                 Whitespace.NL,
+                 deallocShmem,
                  Whitespace.NL ]
 
     def genShmemCreatedHandler(self):
         p = self.protocol
         assert p.decl.type.isToplevel()
         
         case = StmtBlock()                                          
 
@@ -3411,16 +3505,57 @@ class _GenerateProtocolActorCode(ipdl.as
                 ExprSelect(p.shmemMapVar(), '.', 'AddWithID'),
                 args=[ _autoptrForget(rawvar), idvar ])),
             Whitespace.NL,
             StmtReturn(_Result.Processed)
         ])
 
         return case
 
+    def genShmemDestroyedHandler(self):
+        p = self.protocol
+        assert p.decl.type.isToplevel()
+        
+        case = StmtBlock()                                          
+
+        rawvar = ExprVar('rawmem')
+        idvar = ExprVar('id')
+        itervar = ExprVar('iter')
+        case.addstmts([
+            StmtDecl(Decl(_shmemIdType(), idvar.name)),
+            StmtDecl(Decl(Type.VOIDPTR, itervar.name), init=ExprLiteral.NULL)
+        ])
+
+        failif = StmtIf(ExprNot(
+            ExprCall(ExprVar('IPC::ReadParam'),
+                     args=[ ExprAddrOf(self.msgvar), ExprAddrOf(itervar),
+                            ExprAddrOf(idvar) ])))
+        failif.addifstmt(StmtReturn(_Result.PayloadError))
+
+        case.addstmts([
+            failif,
+            StmtExpr(ExprCall(ExprSelect(self.msgvar, '.', 'EndRead'),
+                              args=[ itervar ])),
+            Whitespace.NL,
+            StmtDecl(Decl(_rawShmemType(ptr=1), rawvar.name),
+                     init=ExprCall(p.lookupSharedMemory(), args=[ idvar ]))
+        ])
+
+        failif = StmtIf(ExprNot(rawvar))
+        failif.addifstmt(StmtReturn(_Result.ValuError))
+
+        case.addstmts([
+            failif,
+            StmtExpr(p.removeShmemId(idvar)),
+            StmtExpr(_shmemDealloc(rawvar)),
+            StmtReturn(_Result.Processed)
+        ])
+
+        return case
+
 
     ##-------------------------------------------------------------------------
     ## The next few functions are the crux of the IPDL code generator.
     ## They generate code for all the nasty work of message
     ## serialization/deserialization and dispatching handlers for
     ## received messages.
     ##
     def visitMessageDecl(self, md):
--- a/ipc/ipdl/test/cxx/TestShmem.cpp
+++ b/ipc/ipdl/test/cxx/TestShmem.cpp
@@ -39,16 +39,19 @@ TestShmemParent::RecvTake(Shmem& mem, co
 {
     if (mem.Size<char>() != expectedSize)
         fail("expected shmem size %lu, but it has size %lu",
              expectedSize, mem.Size<char>());
 
     if (strcmp(mem.get<char>(), "And yourself!"))
         fail("expected message was not written");
 
+    if (!DeallocShmem(mem))
+        fail("DeallocShmem");
+
     Close();
 
     return true;
 }
 
 //-----------------------------------------------------------------------------
 // Child
 
--- a/ipc/ipdl/test/cxx/TestSysVShmem.cpp
+++ b/ipc/ipdl/test/cxx/TestSysVShmem.cpp
@@ -12,16 +12,19 @@ namespace _ipdltest {
 void
 TestSysVShmemParent::Main()
 {
     Shmem mem;
     size_t size = 12345;
     if (!AllocShmem(size, SharedMemory::TYPE_SYSV, &mem))
         fail("can't alloc shmem");
 
+    if (0 > mem.GetSysVID())
+        fail("invalid shmem ID");
+
     if (mem.Size<char>() != size)
         fail("shmem is wrong size: expected %lu, got %lu",
              size, mem.Size<char>());
 
     char* ptr = mem.get<char>();
     memcpy(ptr, "Hello!", sizeof("Hello!"));
     if (!SendGive(mem, size))
         fail("can't send Give()");
@@ -39,16 +42,19 @@ TestSysVShmemParent::RecvTake(Shmem& mem
 {
     if (mem.Size<char>() != expectedSize)
         fail("expected shmem size %lu, but it has size %lu",
              expectedSize, mem.Size<char>());
 
     if (strcmp(mem.get<char>(), "And yourself!"))
         fail("expected message was not written");
 
+    if (!DeallocShmem(mem))
+        fail("DeallocShmem");
+
     Close();
 
     return true;
 }
 
 //-----------------------------------------------------------------------------
 // Child