/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.fst.sequitur;

import edu.cmu.sphinx.fst.Fst;
import edu.cmu.sphinx.fst.semiring.Semiring;
import edu.cmu.sphinx.fst.semiring.TropicalSemiring;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import javax.xml.bind.JAXBContext;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Unmarshaller;
import javax.xml.bind.annotation.XmlAttribute;
import javax.xml.bind.annotation.XmlElement;
import javax.xml.bind.annotation.XmlMixed;
import javax.xml.bind.annotation.XmlRootElement;

public class SequiturImport {
    public static void main(String ... args) throws JAXBException, IOException {
        JAXBContext context = JAXBContext.newInstance((Class[])new Class[]{FSA.class});
        Unmarshaller unmarshaller = context.createUnmarshaller();
        FSA fsa = (FSA)unmarshaller.unmarshal(new File(args[0]));
        Fst fst = fsa.toFst();
        fst.saveModel(args[1]);
        System.out.println("The Sequitur G2P XML-formatted FST " + args[0] + " has been converted to Sphinx' OpenFst binary format in the file " + args[1]);
    }

    public static class Alphabet {
        @XmlElement(name="symbol")
        List<Symbol> symbols;

        public void afterUnmarshal(Unmarshaller unmarshaller, Object parent) {
            Iterator<Symbol> it = this.symbols.iterator();
            while (it.hasNext()) {
                if (!it.next().content.matches("__\\d+__")) continue;
                it.remove();
            }
            int i = 0;
            while (i < this.symbols.size()) {
                assert (this.symbols.get((int)i).index != null);
                assert (this.symbols.get((int)i).index == i);
                this.symbols.get((int)i).index = null;
                ++i;
            }
            Symbol s = new Symbol();
            s.content = "<s>";
            this.symbols.add(s);
        }

        String[] toSymbols() {
            String[] out = new String[this.symbols.size()];
            int i = 0;
            while (i < out.length) {
                out[i] = this.symbols.get((int)i).content;
                ++i;
            }
            return out;
        }
    }

    public static class Arc {
        @XmlAttribute
        int target;
        @XmlElement
        int in;
        @XmlElement
        int out;
        @XmlElement
        float weight;

        public void afterUnmarshal(Unmarshaller unmarshaller, Object parent) {
            ++this.target;
        }

        public edu.cmu.sphinx.fst.Arc toOpenFstArc(List<edu.cmu.sphinx.fst.State> openFstStates) {
            return new edu.cmu.sphinx.fst.Arc(this.in, this.out, this.weight, openFstStates.get(this.target));
        }
    }

    @XmlRootElement(name="fsa")
    public static class FSA {
        @XmlAttribute
        String semiring;
        @XmlAttribute
        int initial;
        @XmlElement(name="input-alphabet")
        Alphabet inputAlphabet;
        @XmlElement(name="output-alphabet")
        Alphabet outputAlphabet;
        @XmlElement(name="state")
        List<State> states;
        transient List<edu.cmu.sphinx.fst.State> openFstStates;
        transient Semiring ring = new TropicalSemiring();

        public void afterUnmarshal(Unmarshaller unmarshaller, Object parent) {
            assert ("tropical".equals(this.semiring));
            State initialState = new State();
            initialState.id = 0;
            Arc initialArc = new Arc();
            initialArc.in = this.inputAlphabet.symbols.size() - 1;
            initialArc.out = this.outputAlphabet.symbols.size() - 1;
            initialArc.target = this.initial + 1;
            initialArc.weight = this.ring.one();
            initialState.arcs = Collections.singletonList(initialArc);
            this.states.add(initialState);
            Collections.sort(this.states, new Comparator<State>(){

                @Override
                public int compare(State s1, State s2) {
                    return s1.id - s2.id;
                }
            });
        }

        public Fst toFst() {
            Fst openFst = new Fst(this.ring);
            openFst.setIsyms(this.inputAlphabet.toSymbols());
            openFst.setOsyms(this.outputAlphabet.toSymbols());
            this.openFstStates = new ArrayList<edu.cmu.sphinx.fst.State>(this.states.size());
            for (State state : this.states) {
                edu.cmu.sphinx.fst.State openFstState = state.toUnconnectedOpenFstState();
                openFst.addState(openFstState);
                assert (openFstState.getId() == state.id);
                this.openFstStates.add(openFstState);
            }
            openFst.setStart(this.openFstStates.get(0));
            for (State state : this.states) {
                state.connectStates(this.openFstStates);
            }
            return openFst;
        }
    }

    public static class State {
        @XmlAttribute
        int id;
        @XmlElement(name="final")
        Object finalState;
        @XmlElement
        Float weight;
        @XmlElement(name="arc")
        List<Arc> arcs;

        public void afterUnmarshal(Unmarshaller unmarshaller, Object parent) {
            ++this.id;
        }

        public edu.cmu.sphinx.fst.State toUnconnectedOpenFstState() {
            return new edu.cmu.sphinx.fst.State(this.weight != null ? this.weight.floatValue() : 0.0f);
        }

        public void connectStates(List<edu.cmu.sphinx.fst.State> openFstStates) {
            if (this.arcs != null) {
                for (Arc arc : this.arcs) {
                    edu.cmu.sphinx.fst.Arc openFstArc = arc.toOpenFstArc(openFstStates);
                    openFstStates.get(this.id).addArc(openFstArc);
                }
            }
        }
    }

    public static class Symbol {
        @XmlAttribute
        Integer index;
        @XmlMixed
        List<String> contentList;
        transient String content;

        public void afterUnmarshal(Unmarshaller unmarshaller, Object parent) {
            assert (this.contentList != null) : "Error with symbol " + this.index;
            assert (this.contentList.size() == 1) : "Error with symbol " + this.index;
            this.content = this.contentList.get(0);
            if (this.content.equals("__term__")) {
                this.content = "</s>";
            } else if (this.content.matches("__.+__")) {
                this.content = "<eps>";
            }
        }
    }
}

