add IPDL front-end support for transitioning to one of a set of states
authorChris Jones <jones.chris.g@gmail.com>
Wed, 19 Aug 2009 21:21:46 -0500
changeset 35868 c3b494310a9ff401ca9c2bbbf14666a5bc7a18a1
parent 35867 0f9546f174946f6238b0300eca3a44f787c03a54
child 35869 5be0cf05da79c683b6c25edeb3656add9890a5dc
push idunknown
push userunknown
push dateunknown
milestone1.9.2a1pre
add IPDL front-end support for transitioning to one of a set of states
ipc/ipdl/ipdl/ast.py
ipc/ipdl/ipdl/parser.py
ipc/ipdl/ipdl/type.py
--- a/ipc/ipdl/ipdl/ast.py
+++ b/ipc/ipdl/ipdl/ast.py
@@ -275,21 +275,21 @@ class MessageDecl(Node):
 
 class TransitionStmt(Node):
     def __init__(self, loc, state, transitions):
         Node.__init__(self, loc)
         self.state = state
         self.transitions = transitions
 
 class Transition(Node):
-    def __init__(self, loc, trigger, msg, toState):
+    def __init__(self, loc, trigger, msg, toStates):
         Node.__init__(self, loc)
         self.trigger = trigger
         self.msg = msg
-        self.toState = toState
+        self.toStates = toStates
 
     @staticmethod
     def nameToTrigger(name):
         return { 'send': SEND, 'recv': RECV, 'call': CALL, 'answer': ANSWER }[name]
 
 class SEND:
     pretty = 'send'
     @classmethod
@@ -316,24 +316,24 @@ class ANSWER:
     def direction(cls): return IN
 
 class State(Node):
     def __init__(self, loc, name, start=False):
         Node.__init__(self, loc)
         self.name = name
         self.start = start
     def __eq__(self, o):
-        return (isinstance(o, State)
-                and o.name == self.name
-                and o.start == self.start)
+         return (isinstance(o, State)
+                 and o.name == self.name
+                 and o.start == self.start)
     def __hash__(self):
         return hash(repr(self))
     def __ne__(self, o):
         return not (self == o)
-    def __repr__(self): return '<State %r start=%s>'% (self.name, self.start)
+    def __repr__(self): return '<State %r start=%r>'% (self.name, self.start)
     def __str__(self): return '<State %s start=%s>'% (self.name, self.start)
 
 class Param(Node):
     def __init__(self, loc, typespec, name):
         Node.__init__(self, loc)
         self.name = name
         self.typespec = typespec
 
--- a/ipc/ipdl/ipdl/parser.py
+++ b/ipc/ipdl/ipdl/parser.py
@@ -151,16 +151,17 @@ reserved = set((
         'both',
         'call',
         'child',
         'goto',
         'include',
         'manager',
         'manages',
         'namespace',
+        'or',
         'parent',
         'protocol',
         'recv',
         'returns',
         'rpc',
         'send',
         'share',
         'start',
@@ -416,27 +417,36 @@ def p_Transitions(p):
                    | Transition"""
     if 3 == len(p):
         p[1].append(p[2])
         p[0] = p[1]
     else:
         p[0] = [ p[1] ]
 
 def p_Transition(p):
-    """Transition : Trigger MessageId GOTO State ';'"""
+    """Transition : Trigger MessageId GOTO StateList ';'"""
     loc, trigger = p[1]
     p[0] = Transition(loc, trigger, p[2], p[4])
 
 def p_Trigger(p):
     """Trigger : SEND
                | RECV
                | CALL
                | ANSWER"""
     p[0] = [ locFromTok(p, 1), Transition.nameToTrigger(p[1]) ]
 
+def p_StateList(p):
+    """StateList : StateList OR State
+                 | State"""
+    if 2 == len(p):
+        p[0] = [ p[1] ]
+    else:
+        p[1].append(p[3])
+        p[0] = p[1]
+
 def p_State(p):
     """State : ID"""
     p[0] = State(locFromTok(p, 1), p[1])
 
 ##--------------------
 ## Minor stuff
 def p_OptionalSendSemanticsQual(p):
     """OptionalSendSemanticsQual : SendSemanticsQual
--- a/ipc/ipdl/ipdl/type.py
+++ b/ipc/ipdl/ipdl/type.py
@@ -129,18 +129,20 @@ class IPDLType(Type):
 
     def hasReply(self):  return self.isSync() or self.isRpc()
 
     def needsMoreJuiceThan(self, o):
         return (o.isAsync() and not self.isAsync()
                 or o.isSync() and self.isRpc())
 
 class StateType(IPDLType):
-    def __init__(self): pass
-    def isState(self): return True
+    def __init__(self, start=False):
+        self.start = start
+    def isState(self):
+        return True
 
 class MessageType(IPDLType):
     def __init__(self, sendSemantics, direction,
                  ctor=False, dtor=False, cdtype=None):
         assert not (ctor and dtor)
         assert not (ctor or dtor) or type is not None
 
         self.sendSemantics = sendSemantics
@@ -465,31 +467,33 @@ class GatherDecls(TcheckVisitor):
             if not(ctordecl and dtordecl
                    and ctordecl.type.isCtor() and dtordecl.type.isDtor()):
                 self.error(
                     managed.loc,
                     "constructor and destructor declarations are required for managed protocol `%s' (managed by protocol `%s')",
                     mgdname, p.name)
 
         p.states = { }
+        
         if len(p.transitionStmts):
             p.startStates = [ ts for ts in p.transitionStmts
                               if ts.state.start ]
             if 0 == len(p.startStates):
                 p.startStates = [ p.transitionStmts[0] ]
                 
         # declare each state before decorating their mention
         for trans in p.transitionStmts:
             p.states[trans.state] = trans
             trans.state.decl = self.declare(
                 loc=trans.state.loc,
-                type=StateType(),
+                type=StateType(trans.state.start),
                 progname=trans.state.name)
 
         for trans in p.transitionStmts:
+            self.seentriggers = set()
             trans.accept(self)
 
         # visit the message decls once more and resolve the state names
         # attached to actor params and returns
         def resolvestate(param):
             loc = param.loc
             statename = param.type.state.name
             statedecl = self.symtab.lookup(statename)
@@ -656,42 +660,62 @@ class GatherDecls(TcheckVisitor):
 
         md.decl = self.declare(
             loc=loc,
             type=msgtype,
             progname=msgname)
         md.protocolDecl = self.currentProtocolDecl
 
 
+    def visitTransitionStmt(self, ts):
+        self.seentriggers = set()
+        TcheckVisitor.visitTransitionStmt(self, ts)
+
     def visitTransition(self, t):
         loc = t.loc
 
-        sname = t.toState.name
-        sdecl = self.symtab.lookup(sname)
-        if sdecl is None:
-            self.error(loc, "state `%s' has not been declared", sname)
-        elif not sdecl.type.isState():
-            self.error(
-                loc, "`%s' should have state type, but instead has type `%s'",
-                sname, sdecl.type.typename())
-        else:
-            t.toState.decl = sdecl
-
+        # check the trigger message
         mname = t.msg
+        if mname in self.seentriggers:
+            self.error(loc, "trigger `%s' appears multiple times", mname)
+        self.seentriggers.add(mname)
+        
         mdecl = self.symtab.lookup(mname)
         if mdecl is None:
             self.error(loc, "message `%s' has not been declared", mname)
         elif not mdecl.type.isMessage():
             self.error(
                 loc,
                 "`%s' should have message type, but instead has type `%s'",
                 mname, mdecl.type.typename())
         else:
             t.msg = mdecl
 
+        # check the to-states
+        seenstates = set()
+        for toState in t.toStates:
+            sname = toState.name
+            sdecl = self.symtab.lookup(sname)
+
+            if sname in seenstates:
+                self.error(loc, "to-state `%s' appears multiple times", sname)
+            seenstates.add(sname)
+
+            if sdecl is None:
+                self.error(loc, "state `%s' has not been declared", sname)
+            elif not sdecl.type.isState():
+                self.error(
+                    loc, "`%s' should have state type, but instead has type `%s'",
+                    sname, sdecl.type.typename())
+            else:
+                toState.decl = sdecl
+                toState.start = sdecl.type.start
+
+        t.toStates = set(t.toStates)
+
 ##-----------------------------------------------------------------------------
 
 class CheckTypes(TcheckVisitor):
     def __init__(self, errors):
         # don't need the symbol table, we just want the error reporting
         TcheckVisitor.__init__(self, None, errors)
         self.visited = set()
 
@@ -803,16 +827,28 @@ class CheckTypes(TcheckVisitor):
             self.error(
                 loc, "%s %s message `%s' is not `%s'd",
                 mtype.sendSemantics.pretty, mtype.direction.pretty,
                 t.msg.progname,
                 t.trigger.pretty)
 
 ##-----------------------------------------------------------------------------
 
+def unique_pairs(s):
+    n = len(s)
+    for i, e1 in enumerate(s):
+        for j in xrange(i+1, n):
+            yield (e1, s[j])
+
+def cross_product(s1, s2):
+    for e1 in s1:
+        for e2 in s2:
+            yield (e1, e2)
+
+
 class CheckStateMachine(TcheckVisitor):
     def __init__(self, errors):
         # don't need the symbol table, we just want the error reporting
         TcheckVisitor.__init__(self, None, errors)
         self.p = None
 
     def visitProtocol(self, p):
         self.p = p
@@ -854,81 +890,99 @@ class CheckStateMachine(TcheckVisitor):
         #
         #   *Rule 2*: the "Diamond Rule".
         #   from a state S,
         #     for any pair of triggers t1 and t2,
         #         where t1 and t2 have opposite direction,
         #         and t1 transitions to state T1 and t2 to T2,
         #       then the following must be true:
         #         T2 allows the trigger t1, transitioning to state U
-        #         T1 allows the trigger t2, transitioning to state U"""
+        #         T1 allows the trigger t2, transitioning to state U
         #
         # This is a more formal way of expressing "it doesn't matter
         # in which order the triggers t1 and t2 occur / are processed."
+        #
+        # The presence of triggers with multiple out states complicates
+        # this check slightly, but doesn't fundamentally change it.
+        #
+        #   from a state S,
+        #     for any pair of triggers t1 and t2,
+        #         where t1 and t2 have opposite direction,
+        #       for each pair of states (T1, T2) \in t1_out x t2_out,
+        #           where t1_out is the set of outstates from t1
+        #                 t2_out is the set of outstates from t2
+        #                 t1_out x t2_out is their Cartesian product
+        #                 and t1 transitions to state T1 and t2 to T2,
+        #         then the following must be true:
+        #           T2 allows the trigger t1, with out-state set { U }
+        #           T1 allows the trigger t2, with out-state set { U }
+        #
         syncdirection = None
         syncok = True
         for trans in ts.transitions:
             if not trans.msg.type.isSync(): continue
             if syncdirection is None:
                 syncdirection = trans.trigger.direction()
             elif syncdirection is not trans.trigger.direction():
                 self.error(
                     trans.loc,
                     "sync trigger at state `%s' in protocol `%s' has different direction from earlier sync trigger at same state",
                     ts.state.name, self.p.name)
                 syncok = False
         # don't check the Diamond Rule if Rule 1 doesn't hold
         if not syncok:
             return
 
-        def triggerTarget(S, t):
-            '''Return the state transitioned to from state |S|
-upon trigger |t|, or None if |t| is not a trigger in |S|.'''
+        def triggerTargets(S, t):
+            '''Return the set of states transitioned to from state |S|
+upon trigger |t|, or { } if |t| is not a trigger in |S|.'''
             for trans in self.p.states[S].transitions:
                 if t.trigger is trans.trigger and t.msg is trans.msg:
-                    return trans.toState
-            return None
+                    return trans.toStates
+            return set()
+
+
+        for (t1, t2) in unique_pairs(ts.transitions):
+            # if the triggers have the same direction, they can't race,
+            # since only one endpoint can initiate either (and delivery
+            # is in-order)
+            if t1.trigger.direction() == t2.trigger.direction():
+                continue
 
-        ntrans = len(ts.transitions)
-        for i, t1 in enumerate(ts.transitions):
-            for j in xrange(i+1, ntrans):
-                t2 = ts.transitions[j]
-                # if the triggers have the same direction, they can't race,
-                # since only one endpoint can initiate either (and delivery
-                # is in-order)
-                if t1.trigger.direction() == t2.trigger.direction():
-                    continue
+            t1_out = t1.toStates
+            t2_out = t2.toStates
 
-                T1 = t1.toState
-                T2 = t2.toState
+            for (T1, T2) in cross_product(t1_out, t2_out):
+                U1 = triggerTargets(T1, t2)
+                U2 = triggerTargets(T2, t1)
 
-                U1 = triggerTarget(T1, t2)
-                U2 = triggerTarget(T2, t1)
-
-                if U1 is None or U1 != U2:
+                if (0 == len(U1)
+                    or 1 < len(U1) or 1 < len(U2)
+                    or U1 != U2):
                     self.error(
                         t2.loc,
-                        "trigger `%s' potentially races (does not commute) with `%s' at state `%s' in protocol `%s'",
-                        t1.msg.progname, t2.msg.progname,
-                        ts.state.name, self.p.name)
+                        "in protocol `%s' state `%s', trigger `%s' potentially races (does not commute) with `%s'",
+                        self.p.name, ts.state.name,
+                        t1.msg.progname, t2.msg.progname)
                     # don't report more than one Diamond Rule
-                    # violation per state. there may be O(n^2) total,
-                    # way too many for a human to parse
+                    # violation per state. there may be O(n^4)
+                    # total, way too many for a human to parse
                     #
-                    # XXX/cjones: could set a limit on #printed and stop after
-                    # that limit ...
+                    # XXX/cjones: could set a limit on #printed
+                    # and stop after that limit ...
                     return
 
     def checkReachability(self, p):
         visited = set()         # set(State)
         def explore(ts):
             if ts.state in visited:
                 return
             visited.add(ts.state)
             for outedge in ts.transitions:
-                explore(p.states[outedge.toState])
+                for toState in outedge.toStates:
+                    explore(p.states[toState])
 
         for root in p.startStates:
             explore(root)
         for ts in p.transitionStmts:
             if ts.state not in visited:
                 self.error(ts.loc, "unreachable state `%s' in protocol `%s'",
                            ts.state.name, p.name)