#!/usr/bin/python

from xml.dom.minidom import parse, parseString
from xml.dom import Node
import sys

if len(sys.argv) != 3:
    print "usage: helena-pnml2lna in.pnml out.lna"
    exit(1)

pnml = sys.argv[1]
lna = sys.argv[2]

doc = parse(pnml)

out = open(lna, "w")
W = out.write
W("model {\n")

types = dict()

def childElements(n):
    return [ c for c in n.childNodes if c.nodeType == Node.ELEMENT_NODE ]
def getFirstElementByTagName(e, tagName):
    return e.getElementsByTagName(tagName)[0]
def constName(c):
    return "C_" + c.replace("-", "_")
def varName(v):
    return "V_" + v.replace("-", "_")
def placeName(p):
    return "P_" + p.replace("-", "_")
def transName(t):
    return "T_" + t.replace("-", "_")
def getExpr(e, left="", right=""):
    def binOp(e, op):
        c = childElements(e)
        result = "(" + getExpr(c[0]) + " " + op + " " + getExpr(c[1]) + ")"
        return result
    def unOp(e, op):
        c = childElements(e)
        result = "(" + op + " " + getExpr(c[0]) + ")"
        return result
    result = ""
    if e.tagName == "subterm":
        return getExpr(childElements(e)[0], left, right)
    elif e.tagName == "equality":
        result = binOp(e, "=")
    elif e.tagName == "inequality":
        result = binOp(e, "!=")
    elif e.tagName == "lessthan":
        result = binOp(e, "<")
    elif e.tagName == "lessthanorequal":
        result = binOp(e, "<=")
    elif e.tagName == "greaterthan":
        result = binOp(e, ">")
    elif e.tagName == "greaterthanorequal":
        result = binOp(e, ">=")
    elif e.tagName == "and":
        result = binOp(e, "and")
    elif e.tagName == "or":
        result = binOp(e, "or")
    elif e.tagName == "successor":
        result = unOp(e, "succ")
    elif e.tagName == "predecessor":
        result = unOp(e, "pred")
    elif e.tagName == "subterm":
        result = getExpr(childElements(e)[0])
    elif e.tagName == "variable":
        result = varName(e.getAttribute("refvariable"))
    elif e.tagName == "useroperator":
        result = constName(e.getAttribute("declaration"))
    elif e.tagName == "dotconstant":
        return "epsilon"
    elif e.tagName == "numberof":
        c = childElements(e)
        if len(c) == 1:
            pos = 0
            num = "1"
        else:
            pos = 1
            num = getExpr(childElements(e)[0])
        result = getExpr(childElements(e)[pos], "<(", ")>")
        if num != "1":
            result = num + " * " + result
    elif e.tagName == "numberconstant":
        result = e.getAttribute("value")
    elif e.tagName == "add":
        for c in childElements(e):
            if result != "":
                result += " + "
            result += getExpr(c)
    elif e.tagName == "tuple":
        for c in childElements(e):
            if result != "":
                result += ", "
            result += getExpr(c)
    elif e.tagName == "all":
        s = getFirstElementByTagName(e, "usersort")
        t = s.getAttribute("declaration")
        return "for (x in " + types[t] + ") <(x)>"
    else:
        raise Exception("unimplemented expression " + e.tagName)
    return left + result + right

#  type declarations
W("\n\
/*\n\
 *  type declarations\n\
 */\n")
for s in doc.getElementsByTagName("namedsort"):
    tid = s.getAttribute("id")
    tname = s.getAttribute("name")
    for c in s.childNodes:
        if c.nodeName == "cyclicenumeration" or \
           c.nodeName == "finiteenumeration":
            types[tid] = "T_" + tname
            W("type " + types[tid] + ": enum (\n")
            fst = True
            for const in c.getElementsByTagName("feconstant"):
                cid = const.getAttribute("id")
                if not fst:
                    W(",\n")
                W("  " + constName(cid))
                fst = False
            W("\n);\n")
        elif c.nodeName == "productsort":
            product = ""
            for ss in c.getElementsByTagName("usersort"):
                t = ss.getAttribute("declaration")
                if product != "":
                    product += " * "
                product += types[t]
            types[tid] = product

#  place declarations
W("\n\
/*\n\
 *  place declarations\n\
 */\n")
for p in doc.getElementsByTagName("place"):
    def getPlaceDom(t):
        if t == "dot":
            return "epsilon"
        elif t in types:
            return types[t]
        else:
            raise Exception("undefined type")
    def getPlaceMarking(m):
        s = getFirstElementByTagName(m, "structure")
        return getExpr(childElements(s)[0])
    pid = p.getAttribute("id")
    W("place " + placeName(pid) + " {\n")
    t = getFirstElementByTagName(p, "type")
    t = getFirstElementByTagName(t, "structure")
    t = getFirstElementByTagName(t, "usersort")
    W("  dom: " + getPlaceDom(t.getAttribute("declaration")) + ";\n")
    m = p.getElementsByTagName("hlinitialMarking")
    if len(m) > 0:
        W("  init: " + getPlaceMarking(m[0]) + ";\n")
    W("  capacity: 256;\n")
    W("}\n")

#  transition declarations
W("\n\
/*\n\
 *  transition declarations\n\
 */\n")
arcs = doc.getElementsByTagName("arc")
for t in doc.getElementsByTagName("transition"):
    tid = t.getAttribute("id")
    def getGuard(g):
        s = getFirstElementByTagName(g, "structure")
        return getExpr(childElements(s)[0])
    def getArc(a):
        i = getFirstElementByTagName(a, "hlinscription")
        s = getFirstElementByTagName(i, "structure")
        return getExpr(childElements(s)[0])
    def getArcs(td, pd):
        result = ""
        tarcs = [ a for a in arcs if a.getAttribute(td) == tid ]
        for a in tarcs:
            p = a.getAttribute(pd)
            result += "\n    " + placeName(p) + ": " + getArc(a) + ";"
        return result
    W("transition " + transName(tid) + " {\n")
    W("  in  {" + getArcs("target", "source") + "\n  }\n")
    W("  out {" + getArcs("source", "target") + "\n  }\n")
    g = t.getElementsByTagName("condition")
    if len(g) > 0:
        W("  guard: " + getGuard(g[0]) + ";\n")
    W("}\n")
    
W("}\n")
out.close()

exit(0)
