#!/usr/bin/python

import math
import os
import sys
from xml.dom.minidom import parse

def bitLength(n):
    log2 = math.log(n, 2)
    if log2 > int(log2):
        return int(log2) + 1
    else:
        return int(log2)

class XmlAttributeNotFound(Exception):
    pass

class Node:
    def __init__(self, num, nid, name):
        self.num = num
        self.nid = nid
        self.name = name
        self.arcs = []
        self.inputArcs = []
        self.outputArcs = []
    def setNum(self, num):
        self.num = num
    def getNum(self):
        return self.num
    def getSNum(self):
        return str(self.num)
    def getId(self):
        return self.nid
    def getName(self):
        return self.name
    def getArcs(self):
        return self.arcs
    def addArc(self, arc):
        if arc.getSrc() == self.nid:
            self.outputArcs.append(arc)
        else:
            self.inputArcs.append(arc)
        self.arcs.append(arc)
    def deleteArc(self, arc):
        if arc.getSrc() == self.nid:
            self.outputArcs.remove(arc)
        else:
            self.inputArcs.remove(arc)
        self.arcs.remove(arc)
    def getOutputArcs(self):
        return self.outputArcs
    def getInputArcs(self):
        return self.inputArcs
    def getPreOrPostSet(self, net, arcs):
        result = []
        for a in arcs:
            result.append((a.getValuation(),
                           net.getNode(a.getOtherEnd(self.nid))))
        return sorted(result)
    def getPreSet(self, net):
        return self.getPreOrPostSet(net, self.getInputArcs())
    def getPostSet(self, net):
        return self.getPreOrPostSet(net, self.getOutputArcs())
    def getUpdatedSet(self, net):
        if isinstance(self, Place):
            dom = [ (self.getInputArcs(), 1), (self.getOutputArcs(), -1)]
        else:
            dom = [ (self.getInputArcs(), -1), (self.getOutputArcs(), 1)]
        l = []
        for arcs, mult in dom:            
            for a in arcs:
                l.append((mult * a.getValuation(),
                          net.getNode(a.getOtherEnd(self.nid))))
        #  sort and merge the list
        l = sorted(l, key = lambda val: val[1])
        result = []
        prev = None
        for val, n in l:
            if prev is None:
                result.append((val, n))
                prev = val, n
            else:
                valPrev, nPrev = prev
                if nPrev == n and valPrev + val == 0:
                    result.pop()
                    prev = None
                else:
                    result.append((val, n))
                    prev = val, n
        return result
        
    
class Place(Node):
    def __init__(self, num, nid, name, init):
        assert(type(init) == int)
        Node.__init__(self, num, nid, name)
        self.init = init
    def getInit(self):
        return self.init

class Trans(Node):
    def __init__(self, num, nid, name):
        Node.__init__(self, num, nid, name)
        
    
class Arc:
    def __init__(self, aid, src, target):
        self.aid = aid
        self.src = src
        self.target = target
        self.valuation = 1
    def getSrc(self):
        return self.src
    def getTarget(self):
        return self.target
    def getValuation(self):
        return self.valuation
    def getOtherEnd(self, nid):
        if self.src == nid:
            return self.target
        else:
            return self.src

class Net:
    def __init__(self):
        self.N = []
        self.P = []
        self.T = []
        self.A = []
        self.capacity = 256
        self.name = ""
    def setCapacity(self, capacity):
        self.capacity = capacity
    def getCapacity(self):
        return self.capacity
    def setName(self, name):
        self.name =  name
    def addPlace(self, p):
        self.P.append(p)
        self.N.append(p)
    def addTrans(self, t):
        self.T.append(t)
        self.N.append(t)
    def addArc(self, a):
        self.A.append(a)
        for nid in [ a.getSrc(), a.getTarget() ]:
            n = self.getNode(nid)
            assert(n is not None)
            n.addArc(a)
    def getNode(self, nid):
        try:
            return next(n for n in self.N if n.getId() == nid)
        except:
            return None
    def printNet(self):
        print ">>>>> PLACES <<<<<"
        for p in sorted(self.P, key = lambda p: p.getName()):
            print p.getName() + " = " + str(p.getInit())
        print
        print ">>>>> TRANSITIONS <<<<<"
        for t in sorted(self.T, key = lambda t: t.getName()):
            line = ""
            for arcs in [ t.getInputArcs(), t.getOutputArcs() ]:
                if line != "":
                    line += " -> "
                fst = True
                for a in arcs:
                    val = a.getValuation()
                    if val > 1:
                        line += str(val) + "*"
                    if not fst:
                        line += " + "                    
                    line += self.getNode(a.getOtherEnd(t.getId())).getName()
                    fst = False
            line = t.getName() + ": " + line
            print line
    def deletePlace(self, p):
        transToDel = [ self.getNode(a.getOtherEnd(p.getId()))
                       for a in p.getArcs() ]
        arcsToDel = set(p.getArcs())
        for t in transToDel:
            for a in t.getArcs():
                arcsToDel.add(a)
        self.A = [ a for a in self.A if a not in arcsToDel ]
        for a in arcsToDel:
            for n in [ self.getNode(a.getSrc()), self.getNode(a.getTarget()) ]:
                n.deleteArc(a)
        self.N.remove(p)
        self.P.remove(p)
        for t in transToDel:
            self.N.remove(t)
            self.T.remove(t)
    def recomputeNums(self):
        for i, p in enumerate(self.P):
            p.setNum(i)
        for i, t in enumerate(self.T):
            t.setNum(i)
    def simplify(self):
        #  remove unmarked places that do not have input transitions
        for p in [ p for p in self.P if p.getInit() == 0 \
                   and len(p.getInputArcs()) == 0 ]:
            self.deletePlace(p)
        self.recomputeNums()
                
    
    ###
    #  model functions
    ###
    def compileModelFunctions(self, hw, cw):
        prot = "void init_model()"
        hw(prot + ";\n") 
        cw(prot + " {}\n")
        prot = "char * model_name()"
        hw(prot + ";\n") 
        cw(prot + " { return \"" + self.name + "\"; }\n")
        prot = "bool_t model_is_state_proposition(char * prop)"
        hw(prot + ";\n") 
        cw(prot + " { return FALSE; }\n")
        prot = "bool_t model_check_state_proposition(char * prop, mstate_t s)"
        hw(prot + ";\n") 
        cw(prot + " { return FALSE; }\n")
    
    ###
    #  enabling test
    ###
    def compileEnablingTestFunctions(self, hw, cw):
        prot = "list_t mstate_events_mem(mstate_t s, heap_t heap)"
        hw(prot + ";\n") 
        cw(prot + """ {
   list_t result = list_new(heap, sizeof(mevent_t), NULL);
   marked_place_t m;
   list_iter_t it;
   mevent_t t;
""")
        for t in self.T:
            cw("   uint32_t pre_t" + t.getSNum() + " = " +
               str(len(t.getInputArcs())) + ";\n")
        cw("""
   for(it = list_get_iter(s->marked);
       !list_iter_at_end(it);
       it = list_iter_next(it)) {
      m = * ((marked_place_t *) list_iter_item(it));
      switch(m.pid) {
""")
        for p in self.P:
            cw("      case " + p.getSNum() + ": {\n")
            for a in p.getOutputArcs():
                val = a.getValuation()
                t = self.getNode(a.getTarget())
                pret = "pre_t" + self.getNode(a.getTarget()).getSNum()
                l = "         "
                if val > 1:
                    l += "if(m.tok >= " + str(val) + ") { "
                l += "if(!(-- " + pret + ")) { " + \
                     "t = " + t.getSNum() + "; list_append(result, &t); }"
                if val > 1:
                    l += " }"
                cw(l + "\n")
            cw("         break;\n")
            cw("      }\n")
        cw("      default: assert(FALSE);\n")
        cw("      }\n")
        cw("   }\n")
        cw("   return result;\n")
        cw("}\n")
        prot = """list_t mstate_events(mstate_t s)"""
        hw(prot + ";\n") 
        cw(prot + """ {
   return mstate_events_mem(s, SYSTEM_HEAP);
}
""")
        
        prot = "mevent_t mstate_event_mem(mstate_t s, mevent_t e, heap_t h)"
        hw(prot + ";\n") 
        cw(prot + """ {
   return e;
}
""")
        prot = "mevent_t mstate_event(mstate_t s, mevent_t e)"
        hw(prot + ";\n") 
        cw(prot + """ {
   return mstate_event_mem(s, e, SYSTEM_HEAP);
}
""")
    
    ###
    #  print functions
    ###
    def compilePrintFunctions(self, hw, cw):
        prot = "void mstate_print(mstate_t s, FILE * out)"
        hw(prot + ";\n") 
        cw(prot + """ {
   marked_place_t m;
   list_iter_t it;

   for(it = list_get_iter(s->marked);
       !list_iter_at_end(it);
       it = list_iter_next(it)) {
      m = * ((marked_place_t *) list_iter_item(it));
      switch(m.pid) {
""")
        for p in self.P:
            cw("      case " + p.getSNum() + ": ")
            cw("printf(\"" + p.getName() + " = %d\\n\", m.tok); break;\n")
        cw("""       default: assert(FALSE);
      }
   }
}
""")
        
        prot = "void mevent_print(mevent_t e, FILE * out)"
        hw(prot + ";\n") 
        cw(prot + """ {
  switch(e) {
""")
        for t in self.T:
            cw("  case " + t.getSNum() + ": " +
               "printf(\"" + t.getName() + "\\n\"); break;\n")
        cw("""
  default: assert(FALSE);
  }
}
""")
    
    ###
    #  transition relation
    ###
    def compileTransitionRelationFunctions(self, hw, cw):

        #  event execution/undoing
        cw("""char remove_empty_places(void * item, void * data) {
   marked_place_t m = * ((marked_place_t *) item);
   return (0 == m.tok) ? TRUE : FALSE;
}
int marked_place_cmp(void * a, void * b) {
   marked_place_t ma = * ((marked_place_t *) a);
   marked_place_t mb = * ((marked_place_t *) b);
   if(ma.pid < mb.pid) {
      return -1;
   } else if(ma.pid > mb.pid) {
      return 1;
   } else {
      return 0;
   }
}
""")
        for f, op in [ ("exec", 1), ("undo", -1) ]:
            prot = """void mevent_""" + f + """
(mevent_t e,
 mstate_t s)"""
            hw(prot + ";\n")
            cw(prot + """ {
   marked_place_t * m;
   list_iter_t it;
   switch(e) {
""")
            for t in self.T:
                tid = t.getId()
                updated = t.getUpdatedSet(self)
                cw("   case " + t.getSNum() + ": { \n")
                for val, p in updated:
                    cw("      bool_t b" + p.getSNum() + " = FALSE;\n")
                cw("      for(it = list_get_iter(s->marked);\n")
                cw("          !list_iter_at_end(it);\n")
                cw("          it = list_iter_next(it)) {\n")
                cw("         m = (marked_place_t *) list_iter_item(it);\n")
                for val, p in updated:
                    cw("         if(" + p.getSNum() + " == m->pid) { " +
                       "b" + p.getSNum() + " = TRUE; " +
                       "m->tok += " + str(op * val) + "; " +
                       "if(m->tok > " + str(self.getCapacity()) + ") { " +
                       "error_throw(\"capacity exceeded in place " +
                       p.getName() + "\");"
                       " } }\n")
                cw("      }\n")
                for val, p in [ (val, p) for val, p in updated
                                if op * val > 0 ]:
                    cw("      if(!b" + p.getSNum() + ") {\n")
                    cw("         marked_place_t m = { " + p.getSNum() +
                       ", " + str(op * val) + " };\n")
                    cw("         list_insert_sorted(s->marked, &m, " +
                       "marked_place_cmp);\n")
                    cw("      }\n")
                cw("      break;\n")
                cw("   }\n")
            cw("   default: assert(FALSE);\n")
            cw("   }\n")
            cw("   list_filter(s->marked, remove_empty_places, NULL);\n")
            cw("}\n")

        #  state successor/predecessor
        for f, op in [ ("succ", "exec"),
                       ("pred", "undo") ]:
            prot = "mstate_t mstate_""" + f + \
                   "_mem(mstate_t s, mevent_t e, heap_t heap)"
            hw(prot + ";\n") 
            cw(prot + """ {
   mstate_t result = mstate_copy_mem(s, heap);
   mevent_""" + op + """(e, result);
   return result;
}
""")
            prot = "mstate_t mstate_" + f + "(mstate_t s, mevent_t e)"
            hw(prot + ";\n") 
            cw(prot + """ {
   return mstate_""" + f + """_mem(s, e, FALSE);
}
""")
    
    ###
    #  initial marking
    ###
    def compileInitialMarkingFunctions(self, hw, cw):
        prot = "mstate_t mstate_initial_mem(heap_t heap)"
        hw(prot + ";\n") 
        cw(prot + """ {
   marked_place_t m;
   mstate_t result = mem_alloc(heap, sizeof(struct_mstate_t));
   result->heap = heap;
   result->marked = list_new(heap, sizeof(marked_place_t), NULL);
""")
        for p in self.P:
            init = p.getInit()
            if init > 0:
                cw("   m.pid = " + p.getSNum() + ";\n")
                cw("   m.tok = " + str(init) + ";\n")
                cw("   list_append(result->marked, &m);\n")
        cw("   return result;\n")
        cw("}\n") 
        prot = "mstate_t mstate_initial()"
        hw(prot + ";\n")
        cw(prot + """ {
   return mstate_initial_mem(SYSTEM_HEAP);
}
""")

    ###
    #  serialisation functions
    ###
    def compileSerialisationFunctions(self, hw, cw):

        # state char size
        prot = "unsigned int mstate_char_size(mstate_t s)"
        hw(prot + ";\n")
        cw(prot + """ {
   unsigned int result = MODEL_PLACES_WIDTH
      + (list_size(s->marked) * (MODEL_PID_WIDTH + MODEL_TOKENS_WIDTH));
   return (result / CHAR_BIT) + ((result % CHAR_BIT) ? 1 : 0);
}
""")
        
        # state serialisation
        prot = "void mstate_serialise(mstate_t s, bit_vector_t v)"
        hw(prot + ";\n")
        cw(prot + """ {
   const int len = list_size(s->marked);
   list_iter_t it;
   marked_place_t m;
   bit_stream_t stream;

   bit_stream_init(stream, v);
   bit_stream_set(stream, len, MODEL_PLACES_WIDTH);
   for(it = list_get_iter(s->marked);
       !list_iter_at_end(it);
       it = list_iter_next(it)) {
      m = * ((marked_place_t *) list_iter_item(it));
      bit_stream_set(stream, m.pid, MODEL_PID_WIDTH);
      bit_stream_set(stream, m.tok, MODEL_TOKENS_WIDTH);
   }
}
""")
        
        # state unserialisation
        prot = "mstate_t mstate_unserialise_mem(bit_vector_t v, heap_t heap)"
        hw(prot + ";\n")
        cw(prot + """ {
   int len;
   marked_place_t m;
   bit_stream_t stream;
   mstate_t result = mem_alloc(heap, sizeof(struct_mstate_t));

   result->heap = heap;
   result->marked = list_new(heap, sizeof(marked_place_t), NULL);
   bit_stream_init(stream, v);
   bit_stream_get(stream, len, MODEL_PLACES_WIDTH);
   while(len --) {
      bit_stream_get(stream, m.pid, MODEL_PID_WIDTH);
      bit_stream_get(stream, m.tok, MODEL_TOKENS_WIDTH);
      list_append(result->marked, &m);
   }
   return result;
""")
        for p in self.P:
            init = p.getInit()
            if init > 0:
                cw("   m.pid = " + p.getSNum() + ";\n")
                cw("   m.tok = " + str(init) + ";\n")
                cw("   list_append(result->marked, &m);\n")
        cw("   return result;\n")
        cw("}\n") 
        prot = "mstate_t mstate_unserialise(bit_vector_t v)"
        hw(prot + ";\n")
        cw(prot + """ {
   return mstate_unserialise_mem(v, SYSTEM_HEAP);
}
""")

        # event char size
        prot = "unsigned int mevent_char_size(mevent_t e)"
        hw(prot + ";\n")
        cw(prot + """ {
   return sizeof(mevent_t);
}
""")

        # state to vector comparison
        prot = "bool_t mstate_cmp_vector(mstate_t s, bit_vector_t v)"
        hw(prot + ";\n")
        cw(prot + """ {
   uint32_t not_empty = 0, pid, tokens;
   marked_place_t m;
   list_iter_t it = list_get_iter(s->marked);
   bit_stream_t stream;
   bit_stream_init(stream, v);
   bit_stream_get(stream, not_empty, MODEL_PLACES_WIDTH);
   if(not_empty != list_size(s->marked)) { return FALSE; }
   while(not_empty) {
      m = * ((marked_place_t *) list_iter_item(it));
      bit_stream_get(stream, pid, MODEL_PID_WIDTH);
      bit_stream_get(stream, tokens, MODEL_TOKENS_WIDTH);
      if(pid != m.pid || tokens != m.tok) { return FALSE; }
      it = list_iter_next(it);
      not_empty --;
   }   
   return TRUE;
}
""")
        
        # event serialisation
        prot = "void mevent_serialise(mevent_t e, bit_vector_t v)"
        hw(prot + ";\n")
        cw(prot + """ {
   memcpy(v, &e, sizeof(mevent_t));
}
""")
        
        # event unserialisation
        prot = "mevent_t mevent_unserialise_mem(bit_vector_t v, heap_t heap)"
        hw(prot + ";\n")
        cw(prot + """ {
   mevent_t result;
   memcpy(&result, v, sizeof(mevent_t));
   return result;
}
""")
        prot = "mevent_t mevent_unserialise(bit_vector_t v)"
        hw(prot + ";\n")
        cw(prot + """ {
   return mevent_unserialise_mem(v, SYSTEM_HEAP);
}
""")

    ###
    #  XML functions
    ###
    def compileXMLFunctions(self, hw, cw):
        prot = """void mstate_to_xml(mstate_t s, FILE * out)"""
        hw(prot + ";\n")
        cw(prot + " {\n")
        cw("   assert(FALSE);\n")
        cw("}\n")
        prot = """void mevent_to_xml(mevent_t e, FILE * out)"""
        hw(prot + ";\n")
        cw(prot + " {\n")
        cw("   assert(FALSE);\n")
        cw("}\n")
        prot = """void model_xml_statistics(FILE * out)"""
        hw(prot + ";\n") 
        cw(prot + """ {
   fprintf(out, "<modelStatistics>\\n");
   fprintf(out, "<places>""" + str(len(self.P)) + """</places>\\n");
   fprintf(out, "<transitions>""" + str(len(self.T)) + """</transitions>\\n");
   fprintf(out, "<netArcs>""" + str(len(self.A)) + """</netArcs>\\n");
   fprintf(out, "</modelStatistics>\\n");
}\n
""")

    ###
    #  partial order reduction functions
    ###
    def compilePorFunctions(self, hw, cw):
        prot = """bool_t mevent_is_safe(mevent_t e)"""
        hw(prot + ";\n") 
        cw(prot + """ {
  return model_safe[e];
}
""")
        prot = """bool_t mevent_is_visible(mevent_t e)"""
        hw(prot + ";\n") 
        cw(prot + """ {
  return FALSE;
}
""")
        prot = """bool_t mevent_are_independent(mevent_t e, mevent_t f)"""
        hw(prot + ";\n") 
        cw(prot + """ {
  return FALSE;
}
""")
        prot = """unsigned int mevent_safe_set(mevent_t e)"""
        hw(prot + ";\n") 
        cw(prot + """ {
  return 0;
}
""")

        #  dynamic por reduction algorithm
        cw("""char mevent_is_t(void * t, void * u) {
   return (* (mevent_t *) t) == (* (mevent_t *) u) ? TRUE : FALSE;
}
char is_marked_enough(void * m, void * ref) {
  marked_place_t mm = * ((marked_place_t *) m);
  marked_place_t mref = * ((marked_place_t *) ref);
  return (mm.pid == mref.pid && mm.tok >= mref.tok) ? TRUE : FALSE;
}

void dynamic_por_reduction_add_stub
(mevent_t t, bool_t * stub, list_t en, list_t stub_en, list_t todo) {
  if(!stub[t]) {
    stub[t] = TRUE;
    list_append(todo, &t);
    if(list_find(en, mevent_is_t, &t)) {
      list_append(stub_en, &t);
    }
  }
}

bool_t dynamic_por_reduction_scapegoat_ok
(uint32_t pid, bool_t * stub, list_t en) {
  int i;

  switch(pid) {
""")
        for p in self.P:
            n = p.getSNum()
            cw("  case " + n + ": {\n")
            cw("    for(i = 0; i < model_act_nb" + n + "; i ++) { " +
               "if(!stub[model_act" + n + "[i]]) { return FALSE; } }\n")
            cw("    return TRUE;\n")
            cw("  }\n")
        cw("""  default: assert(FALSE);
  }
}

void dynamic_por_reduction_handle_scapegoat
(uint32_t pid, bool_t * stub, list_t en, list_t stub_en, list_t todo) {
  int i;

  switch(pid) {
""")
        for p in self.P:
            n = p.getSNum()
            cw("  case " + n + ": { " +
               "for(i = 0; i < model_act_nb" + n + "; i ++) { " +
               "dynamic_por_reduction_add_stub" +
               "(model_act" + n + "[i], stub, en, stub_en, todo); } " +
               "break; }\n")
        cw("""
  default: assert(FALSE);
  }
}
""")
        prot = "void dynamic_por_reduction(mstate_t s, list_t en)"
        hw(prot + ";\n") 
        cw(prot + """ {
  heap_t h;
  rseed_t rnd;
  list_iter_t it;
  list_t todo, stub_en;
  int i, sg, n;
  mevent_t t, u;
  bool_t stub[""" + str(len(self.T)) + """];
  mevent_t * conf;

  if(list_is_empty(en)) {
    return;
  }
  memset(stub, 0, sizeof(bool_t) * """ + str(len(self.T)) + """);
  h = local_heap_new();
  rnd = random_seed(0);
  todo = list_new(h, sizeof(mevent_t), NULL);
  stub_en = list_new(h, sizeof(mevent_t), NULL);
  i = random_int(&rnd) % list_size(en);
  i = 0;
  u = * ((mevent_t *) list_nth(en, i));
  dynamic_por_reduction_add_stub(u, stub, en, stub_en, todo);
  while(!list_is_empty(todo)) {
    list_pick_first(todo, &t);
    if(list_find(en, mevent_is_t, &t)) {
      switch(t) {
""")
        
        # enabled transition handling: put all transitions in conflict
        # with t in the stubborn set
        for t in self.T:
            n = t.getSNum()
            cw("      case " + n + ": { conf = model_conf" + n + "; n = " +
               " model_conf_nb" + n + "; break; }\n")
        cw("      }\n")
        cw("      for(i = 0; i < n; i ++) {\n")
        cw("        dynamic_por_reduction_add_stub" +
           "(conf[i], stub, en, stub_en, todo);\n")
        cw("      }\n")
        cw("    } else {\n")
        cw("      switch(t) {\n")
        
        #  disabled transition handling
        for t in self.T:
            n = t.getSNum()
            cw("      case " + n + ": {\n")
            cw("        for(sg = - 1, i = 0;" +
               " i < model_ip_nb" + n + "; i ++) {\n")
            cw("          if(!list_find(s->marked, is_marked_enough, " +
               "&model_ip" + n + "[i])) {\n")
            cw("            if(-1 == sg) { sg = model_ip" + n + "[i].pid; }\n")
            cw("            if(dynamic_por_reduction_scapegoat_ok" +
               "(model_ip" + n + "[i].pid, stub, en)) { " +
               "sg = model_ip" + n + "[i].pid; break; }\n")
            cw("          }\n")
            cw("        }\n")
            cw("        dynamic_por_reduction_handle_scapegoat" +
               "(sg, stub, en, stub_en, todo);\n")
            cw("        break;\n")
            cw("      }\n")
        cw("""      default: assert(FALSE);
      }
    }
  }
  list_free(todo);
  if(list_size(stub_en) != list_size(en)) {
    list_reset(en);
    for(it = list_get_iter(stub_en);
        !list_iter_at_end(it);
        it = list_iter_next(it)) {
      t = * ((mevent_t *) list_iter_item(it));
      list_append(en, &t);
    }
  }
  list_free(todo);
  list_free(stub_en);
  heap_free(h);
}
""")
    
    def compilePorConstants(self, hw, cw):
        safe = []
        for t in self.T:
            preT = t.getPreSet(self)
            
            # safe <=> t does not share its input places with other
            # transitions
            isSafe = all(len(p.getPostSet(self)) == 1 for (val, p) in preT)
            safe.append("1" if isSafe else "0")

            # input places
            n = t.getSNum()
            ip = t.getPreSet(self)
            l = str(len(ip))
            array = "{" + (", ".join(map(
                lambda (val, p): "{" + p.getSNum() + ", " + str(val) + "}",
                ip))) + "}"
            cw("marked_place_t model_ip" + n + "[] = " + array + ";\n")
            cw("uint32_t model_ip_nb" + n + " = " + l + ";\n")

            # transitions in conflict
            conf = []
            for val, p in t.getPreSet(self):
                for u in [ u for val, u in p.getPostSet(self) if u != t ]:
                    conf.append(u)
            array = ", ".join(map(lambda c: c.getSNum(), conf))
            l = str(len(conf))
            cw("mevent_t model_conf" + n + "[] = {" + array + "};\n")
            cw("uint32_t model_conf_nb" + n + " = " + l + ";\n")
        cw("bool_t model_safe[] = {" + ", ".join(safe) + "};\n")

        for p in self.P:
            # activating transitions
            act = [ t for (val, t) in p.getUpdatedSet(self) if val > 0 ]
            l = str(len(act))
            array = ", ".join(map(lambda t: t.getSNum(), act))
            cw("mevent_t model_act" + p.getSNum() + "[] = {" +
               array + "};\n")
            cw("uint32_t model_act_nb" + p.getSNum() + " = " + l + ";\n")
    
    def compileNet(self, outDir):
        f = open(outDir + os.sep + "SRC_FILES", "w")
        f.write("model\n")
        f.write("model_por\n")
        f.close()
        
        h = open(outDir + os.sep + "model.h", "w")
        hw = h.write
        hw("""#include "includes.h"
#include "common.h"
#include "errors.h"
#include "heap.h"
#include "config.h"
#include "list.h"
#include "bit_stream.h"

#ifndef LIB_MODEL
#define LIB_MODEL

#define MODEL_EVENT_UNDOABLE
#define MODEL_HAS_DYNAMIC_POR_REDUCTION

""")

        c = open(outDir + os.sep + "model.c", "w")
        cw = c.write
        cw("""#include "model.h"

""")

        ###
        #  type definitions
        ###
        if   self.capacity < 2 ** 8:
            hw("typedef uint8_t token_t;\n")
        elif self.capacity < 2 ** 16:
            hw("typedef uint16_t token_t;\n")
        else:
            hw("typedef uint32_t token_t;\n")
        hw("""
typedef struct {
   uint32_t pid;
   token_t tok;
} marked_place_t;
typedef struct {
   list_t marked;
   heap_t heap;
} struct_mstate_t;
typedef struct_mstate_t * mstate_t;
typedef uint32_t mevent_t;
typedef uint32_t mevent_id_t;

#include "model_por.h"
""")

        ###
        #  some constants
        ###
        hw("#define MODEL_PLACES_WIDTH " +
           str(bitLength(len(self.P) + 1)) + "\n")
        hw("#define MODEL_PID_WIDTH " +
           str(bitLength(len(self.P))) + "\n")
        hw("#define MODEL_TOKENS_WIDTH " +
           str(bitLength(self.capacity) + 1) + "\n")

        ###
        #  event id
        ###
        prot = """mevent_id_t mevent_id(mevent_t e)"""
        hw(prot + ";\n") 
        cw(prot + """ {
   return e;
}
""")

        ###
        #  state free
        ###
        prot = """void mstate_free(mstate_t s)"""
        hw(prot + ";\n") 
        cw(prot + """ {
   list_free(s->marked);
   mem_free(s->heap, s);
}
""")

        ###
        #  event free
        ###
        prot = """void mevent_free(mevent_t e)"""
        hw(prot + ";\n") 
        cw(prot + """ {
}
""")

        ###
        #  event comparison
        ###
        prot = """order_t mevent_cmp(mevent_t e, mevent_t f)"""
        hw(prot + ";\n") 
        cw(prot + """ {
   if(e < f) {
      return LESS;
   } else if(e > f) {
      return GREATER; 
   } else {
      return EQUAL;
   }
}
""")

        ###
        #  state copy
        ###
        prot = """mstate_t mstate_copy_mem(mstate_t s, heap_t heap)"""
        hw(prot + ";\n") 
        cw(prot + """ {
   mstate_t result = mem_alloc(heap, sizeof(struct_mstate_t));
   result->marked = list_copy(s->marked, heap, NULL);
   result->heap = heap;
   return result;
}
""")
        prot = """mstate_t mstate_copy(mstate_t s)"""
        hw(prot + ";\n")
        cw(prot + """ {
   return mstate_copy_mem(s, SYSTEM_HEAP);
}
""")

        ###
        #  event copy
        ###
        prot = """mevent_t mevent_copy_mem(mevent_t e, heap_t heap)"""
        hw(prot + ";\n") 
        cw(prot + """ {
   return e;
}
""") 
        prot = """mevent_t mevent_copy(mevent_t e)"""
        hw(prot + ";\n")
        cw(prot + """ {
   return e;
}
""")

        ###
        #  hash function
        ###
        prot = """hash_key_t mstate_hash
(mstate_t s)"""
        hw(prot + ";\n")
        cw(prot + """ {
   const int size = mstate_char_size(s);
   char buffer[size];
   buffer[size - 1] = 0;
   mstate_serialise(s, buffer);
   return bit_vector_hash(buffer, size);
}
""")

        for f in [ self.compileModelFunctions,
                   self.compileInitialMarkingFunctions,
                   self.compileEnablingTestFunctions,
                   self.compileTransitionRelationFunctions,
                   self.compileSerialisationFunctions,
                   self.compilePrintFunctions,
                   self.compileXMLFunctions ]:
            f(hw, cw)
        hw("#endif\n")
        c.close()
        h.close()

        #  por functions are put in a separate file
        h = open(outDir + os.sep + "model_por.h", "w")
        hw = h.write
        hw("""#include "includes.h"
#include "model.h"
#include "heap.h"
#include "list.h"

#ifndef LIB_MODEL_POR
#define LIB_MODEL_POR

""")

        c = open(outDir + os.sep + "model_por.c", "w")
        cw = c.write
        cw("""#include "model_por.h"

""")
        self.compilePorConstants(hw, cw)
        self.compilePorFunctions(hw, cw)
        hw("#endif\n")
        c.close()
        h.close()

def exitWithError(errMsg):
    sys.stderr.write("error: " + str(errMsg) + "\n")
    exit(1)

def exitWithUsage(code):
    print "usage: helena-generate-pnml my-model.pnml out-dir"
    exit(code)

def getAttr(e, name):
    for i in range(0, e.attributes.length):
        if e.attributes.item(i).nodeName == name:
            return e.attributes.item(i).value
    raise XmlAttributeNotFound()

def getNodeAttr(nodeXML, attr):
    return nodeXML.getElementsByTagName(attr)[0] \
                  .getElementsByTagName("text")[0] \
                  .childNodes[0].nodeValue

def parseNet(pnml):
    result = Net()
    try:
        doc = parse(pnml)
        net = doc.getElementsByTagName("net")[0]
        num = 0
        for xml in net.getElementsByTagName("place"):
            pid = getAttr(xml, "id")
            name = getNodeAttr(xml, "name")
            try:
                init = int(getNodeAttr(xml, "initialMarking"))
            except:
                init = 0
            p = Place(num, pid, name, init)
            result.addPlace(p)
            num += 1
        num = 0
        for xml in net.getElementsByTagName("transition"):
            tid = getAttr(xml, "id")
            name = getNodeAttr(xml, "name")
            t = Trans(num, tid, name)
            result.addTrans(t)
            num += 1
        for xml in net.getElementsByTagName("arc"):
            aid = getAttr(xml, "id")
            src = getAttr(xml, "source")
            target = getAttr(xml, "target")
            a = Arc(aid, src, target)
            result.addArc(a)
    except:
        exitWithError("could not parse pnml file " + pnml)
    return result

if __name__ == "__main__":
    if len(sys.argv) != 3:
        exitWithUsage(1)
    pnml = sys.argv[1]
    outDir = sys.argv[2]
    net = parseNet(pnml)
    net.simplify()
    #net.printNet()
    net.compileNet(outDir)
    exit(0)
