/*
 * Decompiled with CFR 0.152.
 */
package io.crate.execution.engine.collect;

import io.crate.breaker.RamAccounting;
import io.crate.breaker.StringSizeEstimator;
import io.crate.common.collections.RefCountedItem;
import io.crate.data.BatchIterator;
import io.crate.data.CollectingBatchIterator;
import io.crate.data.Row;
import io.crate.data.RowN;
import io.crate.exceptions.Exceptions;
import io.crate.exceptions.GroupByOnArrayUnsupportedException;
import io.crate.execution.dsl.phases.RoutedCollectPhase;
import io.crate.execution.dsl.projection.GroupProjection;
import io.crate.execution.dsl.projection.Projection;
import io.crate.execution.dsl.projection.Projections;
import io.crate.execution.engine.aggregation.AggregationContext;
import io.crate.execution.engine.aggregation.AggregationFunction;
import io.crate.execution.engine.collect.CollectExpression;
import io.crate.execution.engine.collect.CollectTask;
import io.crate.execution.engine.collect.DocInputFactory;
import io.crate.execution.engine.collect.LuceneShardCollectorProvider;
import io.crate.execution.engine.fetch.ReaderContext;
import io.crate.execution.jobs.SharedShardContext;
import io.crate.expression.InputCondition;
import io.crate.expression.InputFactory;
import io.crate.expression.InputRow;
import io.crate.expression.reference.doc.lucene.CollectorContext;
import io.crate.expression.reference.doc.lucene.LuceneCollectorExpression;
import io.crate.expression.symbol.AggregateMode;
import io.crate.expression.symbol.InputColumn;
import io.crate.expression.symbol.Symbol;
import io.crate.expression.symbol.Symbols;
import io.crate.lucene.FieldTypeLookup;
import io.crate.lucene.LuceneQueryBuilder;
import io.crate.memory.MemoryManager;
import io.crate.metadata.DocReferences;
import io.crate.metadata.Reference;
import io.crate.metadata.doc.DocSysColumns;
import io.crate.metadata.doc.DocTableInfo;
import io.crate.types.DataTypes;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Terms;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Version;
import org.elasticsearch.common.lucene.BytesRefs;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.index.engine.Engine;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.ShardId;

final class GroupByOptimizedIterator {
    private static final double CARDINALITY_RATIO_THRESHOLD = 0.5;
    private static final long HASH_MAP_ENTRY_OVERHEAD = 32L;

    GroupByOptimizedIterator() {
    }

    @Nullable
    static BatchIterator<Row> tryOptimizeSingleStringKey(IndexShard indexShard, DocTableInfo table, LuceneQueryBuilder luceneQueryBuilder, FieldTypeLookup fieldTypeLookup, BigArrays bigArrays, InputFactory inputFactory, DocInputFactory docInputFactory, RoutedCollectPhase collectPhase, CollectTask collectTask) {
        Collection<? extends Projection> shardProjections = Projections.shardProjections(collectPhase.projections());
        GroupProjection groupProjection = GroupByOptimizedIterator.getSingleStringKeyGroupProjection(shardProjections);
        if (groupProjection == null) {
            return null;
        }
        assert (groupProjection.keys().size() == 1) : "Must have 1 key if getSingleStringKeyGroupProjection returned a projection";
        Reference keyRef = GroupByOptimizedIterator.getKeyRef(collectPhase.toCollect(), groupProjection.keys().get(0));
        if (keyRef == null) {
            return null;
        }
        MappedFieldType keyFieldType = fieldTypeLookup.get((keyRef = (Reference)DocReferences.inverseSourceLookup(keyRef)).column().fqn());
        if (keyFieldType == null || !keyFieldType.hasDocValues()) {
            return null;
        }
        if (Symbols.containsColumn(collectPhase.toCollect(), DocSysColumns.SCORE) || Symbols.containsColumn(collectPhase.where(), DocSysColumns.SCORE)) {
            return null;
        }
        if (GroupByOptimizedIterator.hasHighCardinalityRatio(() -> indexShard.acquireSearcher("group-by-cardinality-check"), keyFieldType.name())) {
            return null;
        }
        ShardId shardId = indexShard.shardId();
        SharedShardContext sharedShardContext = collectTask.sharedShardContexts().getOrCreateContext(shardId);
        RefCountedItem<? extends IndexSearcher> searcher = sharedShardContext.acquireSearcher("group-by-ordinals:" + LuceneShardCollectorProvider.formatSource(collectPhase));
        collectTask.addSearcher(sharedShardContext.readerId(), searcher);
        QueryShardContext queryShardContext = sharedShardContext.indexService().newQueryShardContext();
        InputFactory.Context<LuceneCollectorExpression<?>> docCtx = docInputFactory.getCtx(collectTask.txnCtx());
        docCtx.add(collectPhase.toCollect().stream()::iterator);
        InputFactory.Context<CollectExpression<Row, ?>> ctxForAggregations = inputFactory.ctxForAggregations(collectTask.txnCtx());
        ctxForAggregations.add(groupProjection.values());
        List<CollectExpression<Row, ?>> aggExpressions = ctxForAggregations.expressions();
        List<AggregationContext> aggregations = ctxForAggregations.aggregations();
        List<? extends LuceneCollectorExpression<?>> expressions = docCtx.expressions();
        RamAccounting ramAccounting = collectTask.getRamAccounting();
        CollectorContext collectorContext = new CollectorContext(sharedShardContext.readerId());
        InputRow inputRow = new InputRow(docCtx.topLevelInputs());
        LuceneQueryBuilder.Context queryContext = luceneQueryBuilder.convert(collectPhase.where(), collectTask.txnCtx(), indexShard.mapperService(), indexShard.shardId().getIndexName(), queryShardContext, table, sharedShardContext.indexService().cache());
        return GroupByOptimizedIterator.getIterator(bigArrays, searcher.item(), keyRef.column().fqn(), aggregations, expressions, aggExpressions, ramAccounting, collectTask.memoryManager(), collectTask.minNodeVersion(), inputRow, queryContext.query(), collectorContext, groupProjection.mode());
    }

    static BatchIterator<Row> getIterator(BigArrays bigArrays, IndexSearcher indexSearcher, String keyColumnName, List<AggregationContext> aggregations, List<? extends LuceneCollectorExpression<?>> expressions, List<CollectExpression<Row, ?>> aggExpressions, RamAccounting ramAccounting, MemoryManager memoryManager, Version minNodeVersion, InputRow inputRow, Query query, CollectorContext collectorContext, AggregateMode aggregateMode) {
        int expressionsSize = expressions.size();
        for (int i = 0; i < expressionsSize; ++i) {
            expressions.get(i).startCollect(collectorContext);
        }
        AtomicReference killed = new AtomicReference();
        return CollectingBatchIterator.newInstance(() -> killed.set(BatchIterator.CLOSED), killed::set, () -> {
            try {
                return CompletableFuture.completedFuture(GroupByOptimizedIterator.getRows(GroupByOptimizedIterator.applyAggregatesGroupedByKey(bigArrays, indexSearcher, keyColumnName, aggregations, expressions, aggExpressions, ramAccounting, memoryManager, minNodeVersion, inputRow, query, killed), ramAccounting, aggregations, aggregateMode));
            }
            catch (Throwable t) {
                return CompletableFuture.failedFuture(t);
            }
        }, true);
    }

    private static Iterable<Row> getRows(Map<BytesRef, Object[]> groupedStates, final RamAccounting ramAccounting, final List<AggregationContext> aggregations, final AggregateMode mode) {
        return () -> groupedStates.entrySet().stream().map(new Function<Map.Entry<BytesRef, Object[]>, Row>(){
            final Object[] cells;
            final RowN row;
            {
                this.cells = new Object[1 + aggregations.size()];
                this.row = new RowN(this.cells);
            }

            @Override
            public Row apply(Map.Entry<BytesRef, Object[]> entry) {
                this.cells[0] = BytesRefs.toString(entry.getKey());
                Object[] states = entry.getValue();
                int i = 0;
                int c = 1;
                while (i < states.length) {
                    this.cells[c] = mode.finishCollect(ramAccounting, ((AggregationContext)aggregations.get(i)).function(), states[i]);
                    ++i;
                    ++c;
                }
                return this.row;
            }
        }).iterator();
    }

    private static Map<BytesRef, Object[]> applyAggregatesGroupedByKey(BigArrays bigArrays, IndexSearcher indexSearcher, String keyColumnName, List<AggregationContext> aggregations, List<? extends LuceneCollectorExpression<?>> expressions, List<CollectExpression<Row, ?>> aggExpressions, RamAccounting ramAccounting, MemoryManager memoryManager, Version minNodeVersion, InputRow inputRow, Query query, AtomicReference<Throwable> killed) throws IOException {
        HashMap<BytesRef, Object[]> statesByKey = new HashMap<BytesRef, Object[]>();
        Weight weight = indexSearcher.createWeight(indexSearcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1.0f);
        List leaves = indexSearcher.getTopReaderContext().leaves();
        Object[] nullStates = null;
        for (LeafReaderContext leaf : leaves) {
            GroupByOptimizedIterator.raiseIfClosedOrKilled(killed);
            Scorer scorer = weight.scorer(leaf);
            if (scorer == null) continue;
            ReaderContext readerContext = new ReaderContext(leaf);
            int expressionsSize = expressions.size();
            for (int i = 0; i < expressionsSize; ++i) {
                expressions.get(i).setNextReader(readerContext);
            }
            SortedSetDocValues values = DocValues.getSortedSet((LeafReader)leaf.reader(), (String)keyColumnName);
            ObjectArray<Object[]> statesByOrd = bigArrays.newObjectArray(values.getValueCount());
            try {
                DocIdSetIterator docs = scorer.iterator();
                Bits liveDocs = leaf.reader().getLiveDocs();
                int doc = docs.nextDoc();
                while (doc != Integer.MAX_VALUE) {
                    GroupByOptimizedIterator.raiseIfClosedOrKilled(killed);
                    if (!GroupByOptimizedIterator.docDeleted(liveDocs, doc)) {
                        int i;
                        int expressionsSize2 = expressions.size();
                        for (i = 0; i < expressionsSize2; ++i) {
                            expressions.get(i).setNextDocId(doc);
                        }
                        expressionsSize2 = aggExpressions.size();
                        for (i = 0; i < expressionsSize2; ++i) {
                            aggExpressions.get(i).setNextRow(inputRow);
                        }
                        if (values.advanceExact(doc)) {
                            long ord = values.nextOrd();
                            Object[] states = (Object[])statesByOrd.get(ord);
                            if (states == null) {
                                statesByOrd.set(ord, GroupByOptimizedIterator.initStates(aggregations, ramAccounting, memoryManager, minNodeVersion));
                            } else {
                                GroupByOptimizedIterator.aggregateValues(aggregations, ramAccounting, memoryManager, states);
                            }
                            if (values.nextOrd() != -1L) {
                                throw new GroupByOnArrayUnsupportedException(keyColumnName);
                            }
                        } else if (nullStates == null) {
                            nullStates = GroupByOptimizedIterator.initStates(aggregations, ramAccounting, memoryManager, minNodeVersion);
                        } else {
                            GroupByOptimizedIterator.aggregateValues(aggregations, ramAccounting, memoryManager, nullStates);
                        }
                    }
                    doc = docs.nextDoc();
                }
                for (long ord = 0L; ord < statesByOrd.size(); ++ord) {
                    GroupByOptimizedIterator.raiseIfClosedOrKilled(killed);
                    Object[] states = (Object[])statesByOrd.get(ord);
                    if (states == null) continue;
                    BytesRef sharedKey = values.lookupOrd(ord);
                    Object[] prevStates = statesByKey.get(sharedKey);
                    if (prevStates == null) {
                        ramAccounting.addBytes(StringSizeEstimator.estimateSize(sharedKey) + 32L);
                        statesByKey.put(BytesRef.deepCopyOf((BytesRef)sharedKey), states);
                        continue;
                    }
                    for (int i = 0; i < aggregations.size(); ++i) {
                        AggregationContext aggregation = aggregations.get(i);
                        prevStates[i] = aggregation.function().reduce(ramAccounting, prevStates[i], states[i]);
                    }
                }
            }
            finally {
                if (statesByOrd == null) continue;
                statesByOrd.close();
            }
        }
        if (nullStates != null) {
            statesByKey.put(null, nullStates);
        }
        return statesByKey;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    static boolean hasHighCardinalityRatio(Supplier<Engine.Searcher> acquireSearcher, String fieldName) {
        try (Engine.Searcher searcher = acquireSearcher.get();){
            LeafReaderContext leaf;
            Terms terms;
            double cardinalityRatio;
            Iterator iterator = searcher.getIndexReader().leaves().iterator();
            do {
                if (!iterator.hasNext()) return false;
                leaf = (LeafReaderContext)iterator.next();
                terms = leaf.reader().terms(fieldName);
                if (terms != null) continue;
                boolean bl = true;
                return bl;
            } while (!((cardinalityRatio = (double)terms.size() / (double)leaf.reader().numDocs()) > 0.5));
            boolean bl = true;
            return bl;
        }
        catch (IOException e) {
            return true;
        }
    }

    private static boolean docDeleted(@Nullable Bits liveDocs, int doc) {
        return liveDocs != null && !liveDocs.get(doc);
    }

    private static void aggregateValues(List<AggregationContext> aggregations, RamAccounting ramAccounting, MemoryManager memoryManager, Object[] states) {
        for (int i = 0; i < aggregations.size(); ++i) {
            AggregationContext aggregation = aggregations.get(i);
            if (!InputCondition.matches(aggregation.filter())) continue;
            states[i] = aggregation.function().iterate(ramAccounting, memoryManager, states[i], aggregation.inputs());
        }
    }

    private static Object[] initStates(List<AggregationContext> aggregations, RamAccounting ramAccounting, MemoryManager memoryManager, Version minNodeVersion) {
        Object[] states = new Object[aggregations.size()];
        for (int i = 0; i < aggregations.size(); ++i) {
            AggregationContext aggregation = aggregations.get(i);
            AggregationFunction function = aggregation.function();
            Object newState = function.newState(ramAccounting, Version.CURRENT, minNodeVersion, memoryManager);
            states[i] = InputCondition.matches(aggregation.filter()) ? function.iterate(ramAccounting, memoryManager, newState, aggregation.inputs()) : newState;
        }
        return states;
    }

    @Nullable
    private static Reference getKeyRef(List<Symbol> toCollect, Symbol key) {
        Symbol keyRef;
        if (key instanceof InputColumn && (keyRef = toCollect.get(((InputColumn)key).index())) instanceof Reference) {
            return (Reference)keyRef;
        }
        return null;
    }

    private static GroupProjection getSingleStringKeyGroupProjection(Collection<? extends Projection> shardProjections) {
        if (shardProjections.size() != 1) {
            return null;
        }
        Projection shardProjection = shardProjections.iterator().next();
        if (!(shardProjection instanceof GroupProjection)) {
            return null;
        }
        GroupProjection groupProjection = (GroupProjection)shardProjection;
        if (groupProjection.keys().size() != 1 || groupProjection.keys().get(0).valueType() != DataTypes.STRING) {
            return null;
        }
        return groupProjection;
    }

    private static void raiseIfClosedOrKilled(AtomicReference<Throwable> killed) {
        Throwable killedException = killed.get();
        if (killedException != null) {
            Exceptions.rethrowUnchecked(killedException);
        }
    }
}

