#!/usr/bin/python
#
#  File: helena-report
#
#  Parse an XML report file and print it to the standard output.
#

from xml.dom.minidom import parse, parseString
import os
import shutil
import subprocess
import sys

def exitWithError(errMsg):
    sys.stderr.write(str(errMsg) + "\n")
    exit(1)

def printDashes(n):
    str = ""
    for i in range(0, n): str += " "
    for i in range(0, 80 - n): str += "-"
    print str

def formatSearchResult(r):
    global terminationStates
    return terminationStates[r.strip()]

def formatNumber(n):
    pos = n.find(".")
    if pos >= 0:
        left = formatNumber(n[0 : pos])
        right = n[pos + 1 : len(n)]
        right = formatNumber(right[::-1])
        return left + "." + right[::-1]
    else:
        tmp = n[::-1]
        result = ""
        for i in range(0, len(tmp)):
            if (i % 3 == 0) and i > 0: result += ","
            result += tmp[i] 
        return result[::-1]

def noFormat(val):
    return val

def nodeValue(n):
    if len(n.childNodes) == 0:
        return "NA"
    else:
        return n.childNodes[0].nodeValue   

def getBasicNodeValue(n, fmt, space = 0, unit = None):
    result = ""
    l = n.getElementsByTagName("list")
    if len(l) == 0:
        result = fmt(nodeValue(n))
        if unit is not None:
            result += " " + unit
    else:
        l = l[0]
        node = 1
        result = "" 
        for item in l.getElementsByTagName("item"):
            if result != "":
                result += "\n" + (space * " ")
            result += "[node " + str(node) + "] "
            result += fmt(nodeValue(item))
            if unit is not None:
                result += " " + unit
            node += 1
        sumItem = l.getElementsByTagName("sum")
        if len(sumItem) > 0:
            if result != "":
                result += "\n" + (space * " ")
            result += "[sum   ] " + fmt(nodeValue(sumItem[0]))
    return result
            
            
generalInfos = {
    "model"                   : "Model analyzed  ",
    "date"                    : "Analysis date   ",
    "language"                : "Model language  ",
    "filePath"                : "File path       ",
    "host"                    : "Host machine    " }
searchInfos = {
    "property"                : "Property analysed    ",
    "errorMessage"            : "Error message        " }
terminationStates = {
    "searchTerminated"        : "Search terminated",
    "stateLimitReached"       : "State limit reached",
    "memoryExhausted"         : "Memory limit reached",
    "timeElapsed"             : "Time limit reached",
    "noCounterExample"        : "No counter-example found",
    "propertyHolds"           : "Property holds",
    "propertyViolated"        : "Property violated (see the trace report)",
    "interruption"            : "Search interrupted",
    "error"                   : "An error occurred" }
algorithms = {
    "breadthSearch"           : "Breadth-first search",
    "depthSearch"             : "Depth-first search",
    "distributedDepthSearch"  : "Distributed depth-first searches",
    "distributedBreadthSearch": "Distributed breadth-first searches",
    "randomWalk"              : "Random walk",
    "deltaDDD"                : "Delta-DDD" }
options = {
    "edgeLean"                : "Edge-lean reduction",
    "partialOrder"            : "Partial order reduction",
    "hashCompaction"          : "Hash compaction",
    "randomSuccs"             : "Random successor selection",
    "searchAlgorithm"         : "Search algorithm = ",
    "candidateSetSize"        : "Candidate set size (for Delta-DDD) = ",
    "hashTableSize"           : "Hash table size = ",
    "shmemHeapSize"           : "SHMEM heap size = ",
    "workers"                 : "Exploration threads = " }
optionsWithVal = {
    "workers"                 : formatNumber,
    "shmemHeapSize"           : formatNumber,
    "candidateSetSize"        : formatNumber,
    "hashTableSize"           : formatNumber,
    "searchAlgorithm"         : lambda a: algorithms[a] }
optionsWithUnits = {
    "shmemHeapSize"           : "bytes",
    "candidateSetSize"        : "states",
    "hashTableSize"           : "states" }
statistics = {
    "searchTime"              : "Search time               ",
    "duplicateDetectionTime"  : "  for duplicate detection ",
    "barrierTime"             : "  for local barrier wait  ",
    "sleepTime"               : "  for sleep               ",
    "compilationTime"         : "Model compilation         ",
    "places"                  : "Places                    ",
    "transitions"             : "Transitions               ",
    "netArcs"                 : "Arcs                      ",
    "inArcs"                  : "  in arcs                 ",
    "outArcs"                 : "  out arcs                ",
    "inhibArcs"               : "  inhibitor arcs          ",
    "stateVectorSize"         : "State vector size         ",
    "statesStored"            : ("States                   \n" +
                                 "        stored                  "),
    "statesProcessed"         : "  processed               ",
    "statesProcessedMax"      : "    max. over threads     ",
    "statesProcessedMin"      : "    min. over threads     ",
    "statesProcessedDev"      : "    std. dev. over threads",
    "statesReduced"           : "  reduced                 ",
    "statesAccepting"         : "  accepting               ",
    "statesTerminal"          : "  deadlock                ",
    "statesUnsafe"            : "  unsafe                  ",
    "arcs"                    : "Transitions               ",
    "bfsLevels"               : "BFS levels                ",
    "maxDFSStackSize"         : "Max DFS stack size        ",
    "eventsExecuted"          : "Events executed           ",
    "eventsExecutedDDD"       : "  for duplicate detection ",
    "eventsExecutedExpansion" : "  for state expansion     ",
    "eventExecPerSecond"      : "Event execution rate      ",
    "maxMemoryUsed"           : "Max. memory used          ",
    "bytesSent"               : "Bytes sent                ",
    "avgCPUUsage"             : "Average CPU Usage         ",
    "lvl1TotalCacheMiss"      : "Level 1 total cache miss  ",
    "lvl1TotalCacheHit"       : "Level 1 total cache hit   ",
    "lvl2TotalCacheMiss"      : "Level 2 total cache miss  ",
    "lvl2TotalCacheHit"       : "Level 2 total cache hit   ",
    "lvl3TotalCacheMiss"      : "Level 3 total cache miss  ",
    "lvl3TotalCacheHit"       : "Level 3 total cache hit   " }
statisticsUnits = {
    "searchTime"              : "s.",
    "duplicateDetectionTime"  : "s.",
    "compilationTime"         : "s.",
    "barrierTime"             : "s. (sum over all threads)",
    "sleepTime"               : "s. (sum over all threads)",
    "eventExecPerSecond"      : "exec. / s.",
    "maxMemoryUsed"           : "MB",
    "stateVectorSize"         : "bytes",
    "avgCPUUsage"             : "%" }
statisticsCategories = {
    "timeStatistics"          : "Time statistics",
    "modelStatistics"         : "Model statistics",
    "graphStatistics"         : "Exploration statistics",
    "papiStatistics"          : "PAPI statistics",
    "otherStatistics"         : "Other statistics" }
subReports = {
    "infoReport"              : "General informations",
    "searchReport"            : "Search report",
    "statisticsReport"        : "Statistics report",
    "traceReport"             : "Trace report" }

def printInfoReport(r):
    for e in r.childNodes:
        name = e.nodeName
        if e.childNodes:
            val = getBasicNodeValue(e, noFormat, 22)
        else:
            val = ""
        if name in generalInfos:
            print "    " + generalInfos[name] + ": " + val
        elif name == "modelParameters":
            P = [ c for c in e.childNodes if c.nodeName == "modelParameter" ]
            pref = "    Model parameters: "
            for p in P:
                n = p.getElementsByTagName("modelParameterName")[0]
                n = getBasicNodeValue(n, noFormat)
                v = p.getElementsByTagName("modelParameterValue")[0]
                v = getBasicNodeValue(v, noFormat)
                print pref + n + " = " + v
                pref = "                      "

def printSearchReport(r):
    for e in r.childNodes:
        name = e.nodeName
        if name == "searchResult":
            print "    Termination state    : " + \
                getBasicNodeValue(e, formatSearchResult, 27)
        elif name in searchInfos:
            print "    " + searchInfos[name] + ": " + \
                getBasicNodeValue(e, noFormat)
        elif name == "searchOptions":
            pref = "    Options              : "
            for o in [ o for o in e.childNodes if o.nodeName in options ]:
                name = o.nodeName
                line = pref + options[name]
                if o.nodeName in optionsWithVal:
                    f = optionsWithVal[o.nodeName]
                    line += f(nodeValue(o))
                    if o.nodeName in optionsWithUnits:
                        line += " " + optionsWithUnits[o.nodeName]
                print line
                pref = "                           "

def printStatisticsReport(r):
    c = [ c for c in r.childNodes if c.nodeName in statisticsCategories ]
    for e in c:
        name = e.nodeName
        print ""
        print "    " + statisticsCategories[name]
        printDashes(4)
        for s in e.childNodes:
            name = s.nodeName
            if name in statistics:
                if name in statisticsUnits:
                    unit = statisticsUnits[name]
                else:
                    unit = None
                val = getBasicNodeValue(s, formatNumber, 35, unit)
                line = "      " + statistics[name] + " : " + val
                print line

def printTraceReport(r):
    def printExprList(e, d):
        result = ""
        for ex in e.childNodes:
            name = ex.nodeName
            if name in d:
                if result != "": result += ", "
                result += d[name](ex, d)
        return result
    def printEnum(e, d):
        return nodeValue(e)
    def printNum(e, d):
        return nodeValue(e)
    def printVector(e, d):
        l = e.getElementsByTagName("exprList")[0]
        return "[" + printExprList(l, d) + "]"
    def printStruct(e, d):
        l = e.getElementsByTagName("exprList")[0]
        return "{" + printExprList(l, d) + "}"
    def printContainer(e, d):
        l = e.getElementsByTagName("exprList")[0]
        if len(l.childNodes) == 0: return "empty"
        else: return "|" + printExprList(l, d) + "|"
    def printToken(e, d):
        m = nodeValue(e.getElementsByTagName("mult")[0])
        l = e.getElementsByTagName("exprList")[0]
        if len(l.childNodes) == 0: l = "epsilon"
        else: l = "<( " + printExprList(l, d) + " )>"
        if m != "1": return m + " * " + l
        else: return l
    def printState(e, d):
        descElement = e.getElementsByTagName("stateDescription")
        if len(descElement) > 0:
            desc = nodeValue(descElement[0]).strip()
        else:
            desc = "    {" + "\n"
            for s in e.getElementsByTagName("placeState"):
                p = nodeValue(s.getElementsByTagName("place")[0])
                l = ""
                for t in s.getElementsByTagName("token"):
                    if l != "": l += " + "
                    l += printToken(t, d)
                desc += "      " + p + " = " + l + "\n"
            desc += "    }"
        print desc
    def printBinding(e, d):
        result = ""
        bs = e.getElementsByTagName("varBinding")
        for b in bs:
            var = b.getElementsByTagName("var")[0]
            if result != "": result += ", " 
            result += nodeValue(var) + " = "
            e = b.childNodes[1]
            result += d[e.nodeName](e, d)
        return result
    def printEvent(e, d):
        descElement = e.getElementsByTagName("eventDescription")
        if len(descElement) > 0:
            desc = nodeValue(descElement[0]).strip()
        else:
            trans = e.getElementsByTagName("transition")
            trans = nodeValue(trans[0])
            l = e.getElementsByTagName("binding")
            desc = "(" + trans
            if len(l) > 0:
                desc = desc + ", [" + printBinding(l[0], d) + "]"
            des = desc + ")"
        print desc + " ->"
    traceItem = {
        "enum"      : printEnum,
        "num"       : printNum,
        "vector"    : printVector,
        "struct"    : printStruct,
        "container" : printContainer,
        "state"     : printState,
        "event"     : printEvent }
    traceTypes = {
        "traceFull"  : "The following run invalidates the property.",
        "traceEvents": "The following run invalidates the property.",
        "traceState" : "The following state invalidates the property." }
    for item in r.childNodes:
        name = item.nodeName
        if name in traceTypes:
            if len(item.getElementsByTagName("traceTooLong")) > 0:
                print "\
    A run invalidating the property has been found but is too long to be\n\
    displayed.  If depth-first search was used, try running the search\n\
    again with option --random-succs to try finding another such run."
            else:
                print "    " + traceTypes[name] + "\n"
                for sub in item.childNodes:
                    name = sub.nodeName
                    if name in traceItem:
                        traceItem[name](sub, traceItem)

def printDoc(doc):
    printer = {
        "infoReport"      : printInfoReport,
        "searchReport"    : printSearchReport,
        "statisticsReport": printStatisticsReport,
        "traceReport"     : printTraceReport }
    report = doc.getElementsByTagName("helenaReport")
    printDashes(0)
    print "Helena report"
    printDashes(0)
    for e in report[0].childNodes:
        name = e.nodeName
        if name in subReports:
            print ""
            print "  " + subReports[name]
            printDashes(2)
            printer[name](e)
    printDashes(0)

def findReportFile(m):
    if not os.path.isdir(modelsDir):
        raise IOError("error: model directory does not exist")        
    for lang in os.listdir(modelsDir):
        langDir = os.path.join(modelsDir, lang)
        for mod in os.listdir(langDir):
            if mod == m:
                f = os.path.join(langDir, mod, "report.xml") 
                if os.path.exists(f):
                    return f
                else:
                    raise IOError("error: report file \"" + f + "\" not found")
    raise IOError("error: model \"" + m + "\" not found")

if not(len(sys.argv) in range(2, 4)):
    print "usage: helena-report model-name [out-file]"
    print "       helena-report report.xml [out-file]"
    exit(1)
else:
    helenaDir = os.path.join(os.getenv("HOME"), ".helena")
    modelsDir = os.path.join(helenaDir, "models")
    model = sys.argv[1]
    if len(sys.argv) > 2:
        out = sys.argv[2]
        (_, outType) = os.path.splitext(out)
    else:
        outType = "stdout"
    if os.path.exists(model):
        xml = model
    else:
        try:
            xml = findReportFile(model)
        except IOError, err:
            exitWithError(err)
    if outType == "stdout":
        try:
            doc = parse(xml)
        except:
            msg = "error: could not parse file " + xml
            exitWithError(msg)
        printDoc(doc)
    elif outType == ".xml":
        shutil.copyfile(xml, out)
    else:
        msg = "error: \"" + outType + \
            "\" is not a valid extension for output file"
        exitWithError(msg)
exit(0)
