/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.sphinx.fst.operations;

import edu.cmu.sphinx.fst.Arc;
import edu.cmu.sphinx.fst.Fst;
import edu.cmu.sphinx.fst.State;
import edu.cmu.sphinx.fst.operations.Determinize;
import edu.cmu.sphinx.fst.operations.ExtendFinal;
import edu.cmu.sphinx.fst.operations.Reverse;
import edu.cmu.sphinx.fst.semiring.Semiring;
import edu.cmu.sphinx.fst.utils.Pair;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.PriorityQueue;

public class NShortestPaths {
    private NShortestPaths() {
    }

    public static float[] shortestDistance(Fst fst) {
        Fst reversed = Reverse.get(fst);
        float[] d = new float[reversed.getNumStates()];
        float[] r = new float[reversed.getNumStates()];
        Semiring semiring = reversed.getSemiring();
        Arrays.fill(d, semiring.zero());
        Arrays.fill(r, semiring.zero());
        LinkedHashSet<State> queue = new LinkedHashSet<State>();
        queue.add(reversed.getStart());
        d[reversed.getStart().getId()] = semiring.one();
        r[reversed.getStart().getId()] = semiring.one();
        while (!queue.isEmpty()) {
            State q = (State)queue.iterator().next();
            queue.remove(q);
            float rnew = r[q.getId()];
            r[q.getId()] = semiring.zero();
            int i = 0;
            while (i < q.getNumArcs()) {
                float dnextnew;
                Arc a = q.getArc(i);
                State nextState = a.getNextState();
                float dnext = d[a.getNextState().getId()];
                if (dnext != (dnextnew = semiring.plus(dnext, semiring.times(rnew, a.getWeight())))) {
                    d[a.getNextState().getId()] = dnextnew;
                    r[a.getNextState().getId()] = semiring.plus(r[a.getNextState().getId()], semiring.times(rnew, a.getWeight()));
                    if (!queue.contains(nextState)) {
                        queue.add(nextState);
                    }
                }
                ++i;
            }
        }
        return d;
    }

    public static Fst get(Fst fst, int n, boolean determinize) {
        if (fst == null) {
            return null;
        }
        if (fst.getSemiring() == null) {
            return null;
        }
        Fst fstdet = fst;
        if (determinize) {
            fstdet = Determinize.get(fst);
        }
        final Semiring semiring = fstdet.getSemiring();
        Fst res = new Fst(semiring);
        res.setIsyms(fstdet.getIsyms());
        res.setOsyms(fstdet.getOsyms());
        final float[] d = NShortestPaths.shortestDistance(fstdet);
        ExtendFinal.apply(fstdet);
        int[] r = new int[fstdet.getNumStates()];
        PriorityQueue<Pair<State, Float>> queue = new PriorityQueue<Pair<State, Float>>(10, new Comparator<Pair<State, Float>>(){

            @Override
            public int compare(Pair<State, Float> o1, Pair<State, Float> o2) {
                float a2;
                float d2;
                float previous = o1.getRight().floatValue();
                float d1 = d[o1.getLeft().getId()];
                float next = o2.getRight().floatValue();
                float a1 = semiring.times(next, d2 = d[o2.getLeft().getId()]);
                if (semiring.naturalLess(a1, a2 = semiring.times(previous, d1))) {
                    return 1;
                }
                if (a1 == a2) {
                    return 0;
                }
                return -1;
            }
        });
        HashMap<Pair<State, Float>, Pair> previous = new HashMap<Pair<State, Float>, Pair>(fst.getNumStates());
        HashMap<Pair, State> stateMap = new HashMap<Pair, State>(fst.getNumStates());
        State start = fstdet.getStart();
        Pair<State, Float> item = new Pair<State, Float>(start, Float.valueOf(semiring.one()));
        queue.add(item);
        previous.put(item, null);
        while (!queue.isEmpty()) {
            Pair pair = (Pair)queue.remove();
            State p = (State)pair.getLeft();
            Float c = (Float)pair.getRight();
            State s = new State(p.getFinalWeight());
            res.addState(s);
            stateMap.put(pair, s);
            if (previous.get(pair) == null) {
                res.setStart(s);
            } else {
                State previouState = (State)stateMap.get(previous.get(pair));
                State previousOldState = (State)((Pair)previous.get(pair)).getLeft();
                int j = 0;
                while (j < previousOldState.getNumArcs()) {
                    Arc a = previousOldState.getArc(j);
                    if (a.getNextState().equals(p)) {
                        previouState.addArc(new Arc(a.getIlabel(), a.getOlabel(), a.getWeight(), s));
                    }
                    ++j;
                }
            }
            Integer stateIndex = p.getId();
            int n2 = stateIndex;
            r[n2] = r[n2] + 1;
            if (r[stateIndex] == n && p.getFinalWeight() != semiring.zero()) break;
            if (r[stateIndex] > n) continue;
            int j = 0;
            while (j < p.getNumArcs()) {
                Arc a = p.getArc(j);
                float cnew = semiring.times(c.floatValue(), a.getWeight());
                Pair<State, Float> next = new Pair<State, Float>(a.getNextState(), Float.valueOf(cnew));
                previous.put(next, pair);
                queue.add(next);
                ++j;
            }
        }
        return res;
    }
}

