/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.result;

import edu.cmu.sphinx.decoder.search.AlternateHypothesisManager;
import edu.cmu.sphinx.decoder.search.Token;
import edu.cmu.sphinx.linguist.dictionary.Pronunciation;
import edu.cmu.sphinx.linguist.dictionary.Word;
import edu.cmu.sphinx.result.Edge;
import edu.cmu.sphinx.result.Node;
import edu.cmu.sphinx.result.Result;
import edu.cmu.sphinx.result.WordResult;
import edu.cmu.sphinx.util.LogMath;
import edu.cmu.sphinx.util.TimeFrame;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.LineNumberReader;
import java.io.PrintWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.StringTokenizer;

public class Lattice {
    protected Node initialNode;
    protected Node terminalNode;
    protected Set<Edge> edges = new HashSet<Edge>();
    protected Map<String, Node> nodes = new HashMap<String, Node>();
    protected double logBase;
    protected LogMath logMath = LogMath.getLogMath();
    private boolean wordTokenFirst;
    private Set<Token> visitedWordTokens;
    private AlternateHypothesisManager loserManager;

    public Lattice() {
    }

    public Lattice(Result result) {
        this();
        this.visitedWordTokens = new HashSet<Token>();
        this.wordTokenFirst = result.getWordTokenFirst();
        this.loserManager = result.getAlternateHypothesisManager();
        if (this.loserManager != null) {
            this.loserManager.purge();
        }
        Token token = result.getBestFinalToken();
        assert (token != null && token.getWord().isSentenceEndWord());
        if (this.terminalNode == null) {
            this.initialNode = this.terminalNode = new Node(this.getNodeID(result.getBestToken()), token.getWord(), -1L, -1L);
            this.addNode(this.terminalNode);
        }
        this.collapseWordToken(token);
    }

    private TimeFrame getTimeFrameWordTokenFirst(Token token) {
        return new TimeFrame(0L, 0L);
    }

    private TimeFrame getTimeFrameWordTokenLast(Token token) {
        TimeFrame capTimeFrame = new TimeFrame(0L, 0L);
        Word word = null;
        long lastStartTime = -1L;
        long lastEndTime = -1L;
        for (Token dataToken = token; dataToken != null; dataToken = dataToken.getPredecessor()) {
            if (dataToken.isWord()) {
                if (word != null && lastStartTime >= 0L) {
                    return new TimeFrame(lastStartTime, lastEndTime);
                }
                word = dataToken.getWord();
                lastEndTime = dataToken.getCollectTime();
            }
            lastStartTime = dataToken.getCollectTime();
        }
        if (lastEndTime >= 0L && lastStartTime >= 0L) {
            return new TimeFrame(lastStartTime, lastEndTime);
        }
        return capTimeFrame;
    }

    private TimeFrame getTimeFrame(Token token) {
        if (this.wordTokenFirst) {
            return this.getTimeFrameWordTokenFirst(token);
        }
        return this.getTimeFrameWordTokenLast(token);
    }

    private Node getNode(Token token) {
        if (token.getWord().isSentenceEndWord()) {
            return this.terminalNode;
        }
        Node node = this.nodes.get(this.getNodeID(token));
        if (node == null) {
            TimeFrame timeFrame = this.getTimeFrame(token);
            node = new Node(this.getNodeID(token), token.getWord(), timeFrame.getStart(), timeFrame.getEnd());
            this.addNode(node);
        }
        return node;
    }

    private void collapseWordToken(Token token) {
        assert (token != null);
        if (this.visitedWordTokens.contains(token)) {
            return;
        }
        this.visitedWordTokens.add(token);
        this.collapseWordPath(this.getNode(token), token.getPredecessor(), token.getAcousticScore() + token.getInsertionScore(), token.getLanguageScore());
        if (this.loserManager != null && this.loserManager.hasAlternatePredecessors(token)) {
            for (Token loser : this.loserManager.getAlternatePredecessors(token)) {
                this.collapseWordPath(this.getNode(token), loser, token.getAcousticScore(), token.getLanguageScore());
            }
        }
    }

    private void collapseWordPath(Node parentWordNode, Token token, float acousticScore, float languageScore) {
        if (token == null) {
            return;
        }
        if (token.isWord()) {
            Node fromNode = this.getNode(token);
            this.addEdge(fromNode, parentWordNode, acousticScore, languageScore);
            if (token.getPredecessor() != null) {
                this.collapseWordToken(token);
            } else {
                assert (token.getWord().isSentenceStartWord());
                this.initialNode = fromNode;
            }
            return;
        }
        while (true) {
            acousticScore += token.getAcousticScore() + token.getInsertionScore();
            languageScore += token.getLanguageScore();
            Token preToken = token.getPredecessor();
            if (preToken == null) {
                return;
            }
            if (preToken.isWord() || this.loserManager != null && this.loserManager.hasAlternatePredecessors(token)) break;
            token = preToken;
        }
        this.collapseWordPath(parentWordNode, token.getPredecessor(), acousticScore, languageScore);
        if (this.loserManager != null && this.loserManager.hasAlternatePredecessors(token)) {
            for (Token loser : this.loserManager.getAlternatePredecessors(token)) {
                this.collapseWordPath(parentWordNode, loser, acousticScore, languageScore);
            }
        }
    }

    private String getNodeID(Token token) {
        return Integer.toString(token.hashCode());
    }

    public Lattice(String fileName) {
        this();
        try {
            String line;
            System.err.println("Loading from " + fileName);
            LineNumberReader in = new LineNumberReader(new FileReader(fileName));
            while ((line = in.readLine()) != null) {
                StringTokenizer tokens = new StringTokenizer(line);
                if (!tokens.hasMoreTokens()) continue;
                String type = tokens.nextToken();
                if (type.equals("edge:")) {
                    Edge.load(this, tokens);
                    continue;
                }
                if (type.equals("node:")) {
                    Node.load(this, tokens);
                    continue;
                }
                if (type.equals("initialNode:")) {
                    this.setInitialNode(this.getNode(tokens.nextToken()));
                    continue;
                }
                if (type.equals("terminalNode:")) {
                    this.setTerminalNode(this.getNode(tokens.nextToken()));
                    continue;
                }
                if (type.equals("logBase:")) {
                    this.logBase = Double.parseDouble(tokens.nextToken());
                    continue;
                }
                in.close();
                throw new Error("SYNTAX ERROR: " + fileName + '[' + in.getLineNumber() + "] " + line);
            }
            in.close();
        }
        catch (Exception e) {
            throw new Error(e.toString());
        }
    }

    public static Lattice readSlf(InputStream stream) throws NumberFormatException, IOException {
        String line;
        Lattice lattice = new Lattice();
        LineNumberReader in = new LineNumberReader(new InputStreamReader(stream));
        boolean readingNodes = false;
        boolean readingEdges = false;
        int startIdx = 0;
        int endIdx = 1;
        double lmscale = 9.5;
        while ((line = in.readLine()) != null) {
            String[] parts;
            if (line.contains("Node definitions")) {
                readingEdges = false;
                readingNodes = true;
                continue;
            }
            if (line.contains("Link definitions")) {
                readingEdges = true;
                readingNodes = false;
                continue;
            }
            if (line.startsWith("#")) continue;
            if (readingNodes) {
                parts = line.split("\\s+");
                if (!(parts.length == 3 && parts[0].startsWith("I=") && parts[1].startsWith("t=") && parts[2].startsWith("W="))) {
                    in.close();
                    throw new IOException("Unknown node definition: " + line);
                }
                int idx = Integer.parseInt(parts[0].substring(2));
                long beginTime = (long)(Double.parseDouble(parts[1].substring(2)) * 1000.0);
                String wordStr = parts[2].substring(2);
                boolean isFiller = false;
                if (idx == startIdx || wordStr.equals("!ENTER")) {
                    wordStr = "<s>";
                    isFiller = true;
                }
                if (idx == endIdx || wordStr.equals("!EXIT")) {
                    wordStr = "</s>";
                    isFiller = true;
                }
                if (wordStr.equals("!NULL")) {
                    wordStr = "<sil>";
                    isFiller = true;
                }
                if (wordStr.startsWith("[")) {
                    isFiller = true;
                }
                Word word = new Word(wordStr, new Pronunciation[0], isFiller);
                Node node = lattice.addNode(Integer.toString(idx), word, beginTime, -1L);
                if (wordStr.equals("<s>")) {
                    lattice.setInitialNode(node);
                }
                if (!wordStr.equals("</s>")) continue;
                lattice.setTerminalNode(node);
                continue;
            }
            if (readingEdges) {
                parts = line.split("\\s+");
                if (!(parts.length == 5 && parts[1].startsWith("S=") && parts[2].startsWith("E=") && parts[3].startsWith("a=") && parts[4].startsWith("l="))) {
                    in.close();
                    throw new IOException("Unknown edge definition: " + line);
                }
                String fromId = parts[1].substring(2);
                String toId = parts[2].substring(2);
                double ascore = Double.parseDouble(parts[3].substring(2));
                double lscore = Double.parseDouble(parts[4].substring(2)) * lmscale;
                lattice.addEdge(lattice.nodes.get(fromId), lattice.nodes.get(toId), ascore, lscore);
                continue;
            }
            if (line.startsWith("start=")) {
                startIdx = Integer.parseInt(line.replace("start=", ""));
            }
            if (line.startsWith("end=")) {
                endIdx = Integer.parseInt(line.replace("end=", ""));
            }
            if (!line.startsWith("lmscale=")) continue;
            lmscale = Double.parseDouble(line.replace("lmscale=", ""));
        }
        for (Node node : lattice.nodes.values()) {
            for (Edge edge : node.getLeavingEdges()) {
                if (node.getEndTime() >= 0L && node.getEndTime() <= edge.getToNode().getBeginTime()) continue;
                node.setEndTime(Math.max(edge.getToNode().getBeginTime(), node.getBeginTime()));
            }
        }
        return lattice;
    }

    public static Lattice readSlf(String fileName) throws IOException {
        FileInputStream stream = new FileInputStream(fileName);
        Lattice result = Lattice.readSlf(stream);
        stream.close();
        return result;
    }

    public Edge addEdge(Node fromNode, Node toNode, double acousticScore, double lmScore) {
        Edge e = new Edge(fromNode, toNode, acousticScore, lmScore);
        fromNode.addLeavingEdge(e);
        toNode.addEnteringEdge(e);
        this.edges.add(e);
        return e;
    }

    protected Node addNode(String id, Word word, long beginTime, long endTime) {
        Node n = new Node(id, word, beginTime, endTime);
        this.addNode(n);
        return n;
    }

    public Node addNode(String id, String word, long beginTime, long endTime) {
        Word w = new Word(word, new Pronunciation[0], false);
        return this.addNode(id, w, beginTime, endTime);
    }

    boolean hasEdge(Edge edge) {
        return this.edges.contains(edge);
    }

    boolean hasNode(Node node) {
        return this.hasNode(node.getId());
    }

    protected boolean hasNode(String ID) {
        return this.nodes.containsKey(ID);
    }

    protected void addNode(Node n) {
        assert (!this.hasNode(n.getId()));
        this.nodes.put(n.getId(), n);
    }

    protected void removeNode(Node n) {
        assert (this.hasNode(n.getId()));
        this.nodes.remove(n.getId());
    }

    protected Node getNode(String id) {
        return this.nodes.get(id);
    }

    protected Collection<Node> getCopyOfNodes() {
        return new ArrayList<Node>(this.nodes.values());
    }

    public Collection<Node> getNodes() {
        return this.nodes.values();
    }

    protected void removeEdge(Edge e) {
        this.edges.remove(e);
    }

    public Collection<Edge> getEdges() {
        return this.edges;
    }

    public void dumpAISee(String fileName, String title) {
        try {
            System.err.println("Dumping " + title + " to " + fileName);
            FileWriter f = new FileWriter(fileName);
            f.write("graph: {\n");
            f.write("title: \"" + title + "\"\n");
            f.write("display_edge_labels: yes\n");
            for (Node node : this.nodes.values()) {
                node.dumpAISee(f);
            }
            for (Edge edge : this.edges) {
                edge.dumpAISee(f);
            }
            f.write("}\n");
            f.close();
        }
        catch (IOException e) {
            throw new Error(e.toString());
        }
    }

    public void dumpDot(String fileName, String title) {
        try {
            System.err.println("Dumping " + title + " to " + fileName);
            FileWriter f = new FileWriter(fileName);
            f.write("digraph \"" + title + "\" {\n");
            f.write("rankdir = LR\n");
            for (Node node : this.nodes.values()) {
                node.dumpDot(f);
            }
            for (Edge edge : this.edges) {
                edge.dumpDot(f);
            }
            f.write("}\n");
            f.close();
        }
        catch (IOException e) {
            throw new Error(e.toString());
        }
    }

    public void dumpSlf(Writer w) throws IOException {
        w.write("VERSION=1.1\n");
        w.write("UTTERANCE=test\n");
        w.write("base=1.0001\n");
        w.write("lmscale=9.5\n");
        w.write("start=0\n");
        w.write("end=1\n");
        w.write("#\n# Size line.\n#\n");
        w.write("NODES=" + this.nodes.size() + "    LINKS=" + this.edges.size() + "\n");
        HashMap<String, Integer> nodeIdMap = new HashMap<String, Integer>();
        nodeIdMap.put(this.initialNode.getId(), 0);
        nodeIdMap.put(this.terminalNode.getId(), 1);
        int count = 2;
        w.write("#\n# Node definitions.\n#\n");
        for (Node node : this.nodes.values()) {
            if (nodeIdMap.containsKey(node.getId())) {
                w.write("I=" + nodeIdMap.get(node.getId()));
            } else {
                nodeIdMap.put(node.getId(), count);
                w.write("I=" + count);
                ++count;
            }
            w.write("    t=" + (double)node.getBeginTime() * 1.0 / 1000.0);
            String spelling = node.getWord().getSpelling();
            if (spelling.startsWith("<")) {
                spelling = "!NULL";
            }
            w.write("    W=" + spelling);
            w.write("\n");
        }
        w.write("#\n# Link definitions.\n#\n");
        count = 0;
        for (Edge edge : this.edges) {
            w.write("J=" + count);
            w.write("    S=" + nodeIdMap.get(edge.getFromNode().getId()));
            w.write("    E=" + nodeIdMap.get(edge.getToNode().getId()));
            w.write("    a=" + edge.getAcousticScore());
            w.write("    l=" + edge.getLMScore() / 9.5);
            w.write("\n");
            ++count;
        }
        w.flush();
    }

    protected void dump(PrintWriter out) throws IOException {
        for (Node node : this.nodes.values()) {
            node.dump(out);
        }
        for (Edge edge : this.edges) {
            edge.dump(out);
        }
        out.println("initialNode: " + this.initialNode.getId());
        out.println("terminalNode: " + this.terminalNode.getId());
        out.println("logBase: " + this.logMath.getLogBase());
        out.flush();
    }

    public void dump(String file) {
        try {
            this.dump(new PrintWriter(new FileWriter(file)));
        }
        catch (IOException e) {
            throw new Error(e.toString());
        }
    }

    protected void removeNodeAndEdges(Node n) {
        for (Edge e : n.getLeavingEdges()) {
            e.getToNode().removeEnteringEdge(e);
            this.edges.remove(e);
        }
        for (Edge e : n.getEnteringEdges()) {
            e.getFromNode().removeLeavingEdge(e);
            this.edges.remove(e);
        }
        this.nodes.remove(n.getId());
        assert (this.checkConsistency());
    }

    protected void removeNodeAndCrossConnectEdges(Node n) {
        System.err.println("Removing node " + n + " and cross connecting edges");
        for (Edge ei : n.getEnteringEdges()) {
            for (Edge ej : n.getLeavingEdges()) {
                this.addEdge(ei.getFromNode(), ej.getToNode(), ei.getAcousticScore(), ei.getLMScore());
            }
        }
        this.removeNodeAndEdges(n);
        assert (this.checkConsistency());
    }

    public Node getInitialNode() {
        return this.initialNode;
    }

    public void setInitialNode(Node initialNode) {
        this.initialNode = initialNode;
    }

    public Node getTerminalNode() {
        return this.terminalNode;
    }

    public void setTerminalNode(Node terminalNode) {
        this.terminalNode = terminalNode;
    }

    public void dumpAllPaths() {
        for (String path : this.allPaths()) {
            System.out.println(path);
        }
    }

    public List<String> allPaths() {
        return this.allPathsFrom("", this.initialNode);
    }

    protected List<String> allPathsFrom(String path, Node n) {
        String p = path + ' ' + n.getWord();
        LinkedList<String> l = new LinkedList<String>();
        if (n == this.terminalNode) {
            l.add(p);
        } else {
            for (Edge e : n.getLeavingEdges()) {
                l.addAll(this.allPathsFrom(p, e.getToNode()));
            }
        }
        return l;
    }

    boolean checkConsistency() {
        for (Node n : this.nodes.values()) {
            for (Edge e : n.getEnteringEdges()) {
                if (this.hasEdge(e)) continue;
                throw new Error("Lattice has NODE with missing FROM edge: " + n + ',' + e);
            }
            for (Edge e : n.getLeavingEdges()) {
                if (this.hasEdge(e)) continue;
                throw new Error("Lattice has NODE with missing TO edge: " + n + ',' + e);
            }
        }
        for (Edge e : this.edges) {
            if (!this.hasNode(e.getFromNode())) {
                throw new Error("Lattice has EDGE with missing FROM node: " + e);
            }
            if (!this.hasNode(e.getToNode())) {
                throw new Error("Lattice has EDGE with missing TO node: " + e);
            }
            if (!e.getToNode().hasEdgeFromNode(e.getFromNode())) {
                throw new Error("Lattice has EDGE with TO node with no corresponding FROM edge: " + e);
            }
            if (e.getFromNode().hasEdgeToNode(e.getToNode())) continue;
            throw new Error("Lattice has EDGE with FROM node with no corresponding TO edge: " + e);
        }
        return true;
    }

    protected void sortHelper(Node n, List<Node> sorted, Set<Node> visited) {
        if (visited.contains(n)) {
            return;
        }
        visited.add(n);
        if (n == null) {
            throw new Error("Node is null");
        }
        for (Edge e : n.getLeavingEdges()) {
            this.sortHelper(e.getToNode(), sorted, visited);
        }
        sorted.add(n);
    }

    public List<Node> sortNodes() {
        ArrayList<Node> sorted = new ArrayList<Node>(this.nodes.size());
        this.sortHelper(this.initialNode, sorted, new HashSet<Node>());
        Collections.reverse(sorted);
        return sorted;
    }

    public void computeNodePosteriors(float languageModelWeightAdjustment) {
        this.computeNodePosteriors(languageModelWeightAdjustment, false);
    }

    public void computeNodePosteriors(float languageModelWeightAdjustment, boolean useAcousticScoresOnly) {
        if (this.initialNode == null) {
            return;
        }
        this.initialNode.setForwardScore(0.0);
        this.initialNode.setViterbiScore(0.0);
        List<Node> sortedNodes = this.sortNodes();
        assert (sortedNodes.get(0) == this.initialNode);
        for (Node currentNode : sortedNodes) {
            for (Edge edge : currentNode.getLeavingEdges()) {
                double forwardProb = edge.getFromNode().getForwardScore();
                double edgeScore = this.computeEdgeScore(edge, languageModelWeightAdjustment, useAcousticScoresOnly);
                edge.getToNode().setForwardScore(this.logMath.addAsLinear((float)(forwardProb += edgeScore), (float)edge.getToNode().getForwardScore()));
                double vs = edge.getFromNode().getViterbiScore() + edgeScore;
                if (edge.getToNode().getBestPredecessor() != null && !(vs > edge.getToNode().getViterbiScore())) continue;
                edge.getToNode().setBestPredecessor(currentNode);
                edge.getToNode().setViterbiScore(vs);
            }
        }
        this.terminalNode.setBackwardScore(0.0);
        assert (sortedNodes.get(sortedNodes.size() - 1) == this.terminalNode);
        ListIterator<Node> n = sortedNodes.listIterator(sortedNodes.size() - 1);
        while (n.hasPrevious()) {
            Node currentNode;
            currentNode = n.previous();
            Collection<Edge> currentEdges = currentNode.getLeavingEdges();
            for (Edge edge : currentEdges) {
                double backwardProb = edge.getToNode().getBackwardScore();
                edge.getFromNode().setBackwardScore(this.logMath.addAsLinear((float)(backwardProb += this.computeEdgeScore(edge, languageModelWeightAdjustment, useAcousticScoresOnly)), (float)edge.getFromNode().getBackwardScore()));
            }
        }
        double normalizationFactor = this.terminalNode.getForwardScore();
        for (Node node : this.nodes.values()) {
            node.setPosterior(node.getForwardScore() + node.getBackwardScore() - normalizationFactor);
        }
    }

    public List<Node> getViterbiPath() {
        LinkedList<Node> path = new LinkedList<Node>();
        for (Node n = this.terminalNode; n != this.initialNode; n = n.getBestPredecessor()) {
            path.addFirst(n);
        }
        path.addFirst(this.initialNode);
        return path;
    }

    public List<WordResult> getWordResultPath() {
        List<Node> path = this.getViterbiPath();
        LinkedList<WordResult> wordResults = new LinkedList<WordResult>();
        for (Node node : path) {
            if (node.getWord().isSentenceStartWord() || node.getWord().isSentenceEndWord()) continue;
            wordResults.add(new WordResult(node));
        }
        return wordResults;
    }

    private double computeEdgeScore(Edge edge, float languageModelWeightAdjustment, boolean useAcousticScoresOnly) {
        if (useAcousticScoresOnly) {
            return edge.getAcousticScore();
        }
        return edge.getAcousticScore() + edge.getLMScore() * (double)languageModelWeightAdjustment;
    }

    public boolean isEquivalent(Lattice other) {
        return this.checkNodesEquivalent(this.initialNode, other.getInitialNode());
    }

    private boolean checkNodesEquivalent(Node n1, Node n2) {
        assert (n1 != null && n2 != null);
        boolean equivalent = n1.isEquivalent(n2);
        if (equivalent) {
            Collection<Edge> leavingEdges = n1.getCopyOfLeavingEdges();
            Collection<Edge> leavingEdges2 = n2.getCopyOfLeavingEdges();
            System.out.println("# edges: " + leavingEdges.size() + ' ' + leavingEdges2.size());
            for (Edge edge : leavingEdges) {
                Edge e2 = n2.findEquivalentLeavingEdge(edge);
                if (e2 == null) {
                    System.out.println("Equivalent edge not found, lattices not equivalent.");
                    return false;
                }
                if (!leavingEdges2.remove(e2)) {
                    System.out.println("Equivalent edge already matched, lattices not equivalent.");
                    return false;
                }
                if (equivalent &= this.checkNodesEquivalent(edge.getToNode(), e2.getToNode())) continue;
                return false;
            }
            if (!leavingEdges2.isEmpty()) {
                System.out.println("One lattice has too many edges.");
                return false;
            }
        }
        return equivalent;
    }

    boolean isFillerNode(Node node) {
        Word word = node.getWord();
        if (word.isSentenceStartWord() || word.isSentenceEndWord()) {
            return false;
        }
        return word.isFiller();
    }
}

