/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.fst.operations;

import edu.cmu.sphinx.fst.Arc;
import edu.cmu.sphinx.fst.Fst;
import edu.cmu.sphinx.fst.ImmutableFst;
import edu.cmu.sphinx.fst.State;
import edu.cmu.sphinx.fst.semiring.Semiring;
import edu.cmu.sphinx.fst.utils.Pair;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;

public class Compose {
    private Compose() {
    }

    public static Fst compose(Fst fst1, Fst fst2, Semiring semiring, boolean sorted) {
        if (!Arrays.equals(fst1.getOsyms(), fst2.getIsyms())) {
            return null;
        }
        Fst res = new Fst(semiring);
        HashMap<Pair<State, State>, State> stateMap = new HashMap<Pair<State, State>, State>();
        LinkedList<Pair<State, State>> queue = new LinkedList<Pair<State, State>>();
        State s1 = fst1.getStart();
        State s2 = fst2.getStart();
        if (s1 == null || s2 == null) {
            System.err.println("Cannot find initial state.");
            return null;
        }
        Pair p = new Pair(s1, s2);
        State s = new State(semiring.times(s1.getFinalWeight(), s2.getFinalWeight()));
        res.addState(s);
        res.setStart(s);
        stateMap.put(p, s);
        queue.add(p);
        while (!queue.isEmpty()) {
            p = (Pair)queue.remove();
            s1 = (State)p.getLeft();
            s2 = (State)p.getRight();
            s = (State)stateMap.get(p);
            int numArcs1 = s1.getNumArcs();
            int numArcs2 = s2.getNumArcs();
            int i = 0;
            while (i < numArcs1) {
                Arc a1 = s1.getArc(i);
                int j = 0;
                while (j < numArcs2) {
                    Arc a2 = s2.getArc(j);
                    if (sorted && a1.getOlabel() < a2.getIlabel()) break;
                    if (a1.getOlabel() == a2.getIlabel()) {
                        State nextState2;
                        State nextState1 = a1.getNextState();
                        Pair<State, State> nextPair = new Pair<State, State>(nextState1, nextState2 = a2.getNextState());
                        State nextState = (State)stateMap.get(nextPair);
                        if (nextState == null) {
                            nextState = new State(semiring.times(nextState1.getFinalWeight(), nextState2.getFinalWeight()));
                            res.addState(nextState);
                            stateMap.put(nextPair, nextState);
                            queue.add(nextPair);
                        }
                        Arc a = new Arc(a1.getIlabel(), a2.getOlabel(), semiring.times(a1.getWeight(), a2.getWeight()), nextState);
                        s.addArc(a);
                    }
                    ++j;
                }
                ++i;
            }
        }
        res.setIsyms(fst1.getIsyms());
        res.setOsyms(fst2.getOsyms());
        return res;
    }

    public static Fst get(Fst fst1, Fst fst2, Semiring semiring) {
        if (fst1 == null || fst2 == null) {
            return null;
        }
        if (!Arrays.equals(fst1.getOsyms(), fst2.getIsyms())) {
            return null;
        }
        Fst filter = Compose.getFilter(fst1.getOsyms(), semiring);
        Compose.augment(1, fst1, semiring);
        Compose.augment(0, fst2, semiring);
        Fst tmp = Compose.compose(fst1, filter, semiring, false);
        Fst res = Compose.compose(tmp, fst2, semiring, false);
        return res;
    }

    public static Fst getFilter(String[] syms, Semiring semiring) {
        Fst filter = new Fst(semiring);
        int e1index = syms.length;
        int e2index = syms.length + 1;
        filter.setIsyms(syms);
        filter.setOsyms(syms);
        State s0 = new State(syms.length + 3);
        s0.setFinalWeight(semiring.one());
        State s1 = new State(syms.length);
        s1.setFinalWeight(semiring.one());
        State s2 = new State(syms.length);
        s2.setFinalWeight(semiring.one());
        filter.addState(s0);
        s0.addArc(new Arc(e2index, e1index, semiring.one(), s0));
        s0.addArc(new Arc(e1index, e1index, semiring.one(), s1));
        s0.addArc(new Arc(e2index, e2index, semiring.one(), s2));
        int i = 1;
        while (i < syms.length) {
            s0.addArc(new Arc(i, i, semiring.one(), s0));
            ++i;
        }
        filter.setStart(s0);
        filter.addState(s1);
        s1.addArc(new Arc(e1index, e1index, semiring.one(), s1));
        i = 1;
        while (i < syms.length) {
            s1.addArc(new Arc(i, i, semiring.one(), s0));
            ++i;
        }
        filter.addState(s2);
        s2.addArc(new Arc(e2index, e2index, semiring.one(), s2));
        i = 1;
        while (i < syms.length) {
            s2.addArc(new Arc(i, i, semiring.one(), s0));
            ++i;
        }
        return filter;
    }

    public static void augment(int label, Fst fst, Semiring semiring) {
        String[] isyms = fst.getIsyms();
        String[] osyms = fst.getOsyms();
        int e1inputIndex = isyms.length;
        int e2inputIndex = isyms.length + 1;
        int e1outputIndex = osyms.length;
        int e2outputIndex = osyms.length + 1;
        int numStates = fst.getNumStates();
        int i = 0;
        while (i < numStates) {
            State s = fst.getState(i);
            int numArcs = fst instanceof ImmutableFst ? s.getNumArcs() - 1 : s.getNumArcs();
            int j = 0;
            while (j < numArcs) {
                Arc a = s.getArc(j);
                if (label == 1 && a.getOlabel() == 0) {
                    a.setOlabel(e2outputIndex);
                } else if (label == 0 && a.getIlabel() == 0) {
                    a.setIlabel(e1inputIndex);
                }
                ++j;
            }
            if (label == 0) {
                if (fst instanceof ImmutableFst) {
                    s.setArc(numArcs, new Arc(e2inputIndex, 0, semiring.one(), s));
                } else {
                    s.addArc(new Arc(e2inputIndex, 0, semiring.one(), s));
                }
            } else if (label == 1) {
                if (fst instanceof ImmutableFst) {
                    s.setArc(numArcs, new Arc(0, e1outputIndex, semiring.one(), s));
                } else {
                    s.addArc(new Arc(0, e1outputIndex, semiring.one(), s));
                }
            }
            ++i;
        }
    }
}

