/*
 * Decompiled with CFR 0.152.
 */
package io.crate.planner.operators;

import io.crate.analyze.OrderBy;
import io.crate.common.collections.Lists2;
import io.crate.data.Row;
import io.crate.execution.dsl.phases.ExecutionPhases;
import io.crate.execution.dsl.phases.MergePhase;
import io.crate.execution.dsl.projection.GroupProjection;
import io.crate.execution.dsl.projection.Projection;
import io.crate.execution.dsl.projection.builder.ProjectionBuilder;
import io.crate.expression.symbol.AggregateMode;
import io.crate.expression.symbol.Function;
import io.crate.expression.symbol.ScopedSymbol;
import io.crate.expression.symbol.Symbol;
import io.crate.expression.symbol.SymbolVisitors;
import io.crate.metadata.Reference;
import io.crate.metadata.RowGranularity;
import io.crate.metadata.doc.DocTableInfo;
import io.crate.planner.ExecutionPlan;
import io.crate.planner.Merge;
import io.crate.planner.PlannerContext;
import io.crate.planner.distribution.DistributionInfo;
import io.crate.planner.node.dql.GroupByConsumer;
import io.crate.planner.operators.Collect;
import io.crate.planner.operators.ForwardingLogicalPlan;
import io.crate.planner.operators.LogicalPlan;
import io.crate.planner.operators.LogicalPlanVisitor;
import io.crate.planner.operators.PrintContext;
import io.crate.planner.operators.SubQueryAndParamBinder;
import io.crate.planner.operators.SubQueryResults;
import io.crate.statistics.ColumnStats;
import io.crate.statistics.Stats;
import io.crate.statistics.TableStats;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.function.Consumer;
import javax.annotation.Nullable;

public class GroupHashAggregate
extends ForwardingLogicalPlan {
    private static final String DISTRIBUTED_MERGE_PHASE_NAME = "distributed merge";
    final List<Function> aggregates;
    final List<Symbol> groupKeys;
    private final List<Symbol> outputs;
    private final long numExpectedRows;

    static long approximateDistinctValues(long numSourceRows, TableStats tableStats, List<Symbol> groupKeys) {
        long distinctValues = 1L;
        int numKeysWithStats = 0;
        for (Symbol groupKey : groupKeys) {
            Stats stats = null;
            ColumnStats columnStats = null;
            if (groupKey instanceof Reference) {
                Reference ref = (Reference)groupKey;
                stats = tableStats.getStats(ref.ident().tableIdent());
                columnStats = stats.statsByColumn().get(ref.column());
                ++numKeysWithStats;
            } else if (groupKey instanceof ScopedSymbol) {
                ScopedSymbol scopedSymbol = (ScopedSymbol)groupKey;
                stats = tableStats.getStats(scopedSymbol.relation());
                columnStats = stats.statsByColumn().get(scopedSymbol.column());
                ++numKeysWithStats;
            }
            if (columnStats == null) {
                distinctValues *= numSourceRows;
                continue;
            }
            double cardinalityRatio = columnStats.approxDistinct() / (double)stats.numDocs();
            distinctValues *= (long)((double)numSourceRows * cardinalityRatio);
        }
        if (numKeysWithStats == groupKeys.size()) {
            return Math.min(distinctValues, numSourceRows);
        }
        return numSourceRows;
    }

    public GroupHashAggregate(LogicalPlan source, List<Symbol> groupKeys, List<Function> aggregates, long numExpectedRows) {
        super(source);
        this.numExpectedRows = numExpectedRows;
        this.aggregates = List.copyOf(new LinkedHashSet<Function>(aggregates));
        this.outputs = Lists2.concat(groupKeys, this.aggregates);
        this.groupKeys = groupKeys;
    }

    @Override
    public long numExpectedRows() {
        return this.numExpectedRows;
    }

    public List<Function> aggregates() {
        return this.aggregates;
    }

    @Override
    public ExecutionPlan build(PlannerContext plannerContext, ProjectionBuilder projectionBuilder, int limit, int offset, @Nullable OrderBy order, @Nullable Integer pageSizeHint, Row params, SubQueryResults subQueryResults) {
        ExecutionPlan executionPlan = this.source.build(plannerContext, projectionBuilder, -1, 0, null, null, params, subQueryResults);
        if (executionPlan.resultDescription().hasRemainingLimitOrOffset()) {
            executionPlan = Merge.ensureOnHandler(executionPlan, plannerContext);
        }
        SubQueryAndParamBinder paramBinder = new SubQueryAndParamBinder(params, subQueryResults);
        List<Symbol> sourceOutputs = this.source.outputs();
        if (this.shardsContainAllGroupKeyValues()) {
            GroupProjection groupProjection = projectionBuilder.groupProjection(sourceOutputs, this.groupKeys, this.aggregates, paramBinder, AggregateMode.ITER_FINAL, this.source.preferShardProjections() ? RowGranularity.SHARD : RowGranularity.CLUSTER, plannerContext.transactionContext().sessionContext().searchPath());
            executionPlan.addProjection(groupProjection, -1, 0, null);
            return executionPlan;
        }
        if (ExecutionPhases.executesOnHandler(plannerContext.handlerNode(), executionPlan.resultDescription().nodeIds())) {
            if (this.source.preferShardProjections()) {
                executionPlan.addProjection(projectionBuilder.groupProjection(sourceOutputs, this.groupKeys, this.aggregates, paramBinder, AggregateMode.ITER_PARTIAL, RowGranularity.SHARD, plannerContext.transactionContext().sessionContext().searchPath()));
                executionPlan.addProjection(projectionBuilder.groupProjection(this.outputs, this.groupKeys, this.aggregates, paramBinder, AggregateMode.PARTIAL_FINAL, RowGranularity.NODE, plannerContext.transactionContext().sessionContext().searchPath()), -1, 0, null);
                return executionPlan;
            }
            executionPlan.addProjection(projectionBuilder.groupProjection(sourceOutputs, this.groupKeys, this.aggregates, paramBinder, AggregateMode.ITER_FINAL, RowGranularity.NODE, plannerContext.transactionContext().sessionContext().searchPath()), -1, 0, null);
            return executionPlan;
        }
        GroupProjection toPartial = projectionBuilder.groupProjection(sourceOutputs, this.groupKeys, this.aggregates, paramBinder, AggregateMode.ITER_PARTIAL, this.source.preferShardProjections() ? RowGranularity.SHARD : RowGranularity.NODE, plannerContext.transactionContext().sessionContext().searchPath());
        executionPlan.addProjection(toPartial);
        executionPlan.setDistributionInfo(DistributionInfo.DEFAULT_MODULO);
        GroupProjection toFinal = projectionBuilder.groupProjection(this.outputs, this.groupKeys, this.aggregates, paramBinder, AggregateMode.PARTIAL_FINAL, RowGranularity.CLUSTER, plannerContext.transactionContext().sessionContext().searchPath());
        return this.createMerge(plannerContext, executionPlan, Collections.singletonList(toFinal), executionPlan.resultDescription().nodeIds());
    }

    @Override
    public List<Symbol> outputs() {
        return this.outputs;
    }

    @Override
    public LogicalPlan pruneOutputsExcept(TableStats tableStats, Collection<Symbol> outputsToKeep) {
        HashSet<Symbol> toKeep = new HashSet<Symbol>();
        for (Symbol symbol : this.groupKeys) {
            SymbolVisitors.intersection(symbol, this.source.outputs(), toKeep::add);
        }
        ArrayList<Function> newAggregates = new ArrayList<Function>();
        for (Symbol outputToKeep : outputsToKeep) {
            SymbolVisitors.intersection(outputToKeep, this.aggregates, newAggregates::add);
        }
        for (Function newAggregate : newAggregates) {
            SymbolVisitors.intersection(newAggregate, this.source.outputs(), toKeep::add);
        }
        LogicalPlan logicalPlan = this.source.pruneOutputsExcept(tableStats, toKeep);
        if (logicalPlan == this.source && this.aggregates.size() == newAggregates.size()) {
            return this;
        }
        return new GroupHashAggregate(logicalPlan, this.groupKeys, newAggregates, this.numExpectedRows);
    }

    @Override
    public LogicalPlan replaceSources(List<LogicalPlan> sources) {
        return new GroupHashAggregate(Lists2.getOnlyElement(sources), this.groupKeys, this.aggregates, this.numExpectedRows);
    }

    private ExecutionPlan createMerge(PlannerContext plannerContext, ExecutionPlan executionPlan, List<Projection> projections, Collection<String> nodeIds) {
        return new Merge(executionPlan, new MergePhase(plannerContext.jobId(), plannerContext.nextExecutionPhaseId(), DISTRIBUTED_MERGE_PHASE_NAME, executionPlan.resultDescription().nodeIds().size(), 1, nodeIds, executionPlan.resultDescription().streamOutputs(), projections, DistributionInfo.DEFAULT_BROADCAST, null), -1, 0, this.outputs.size(), -1, null);
    }

    private boolean shardsContainAllGroupKeyValues() {
        return this.source instanceof Collect && ((Collect)this.source).tableInfo instanceof DocTableInfo && GroupByConsumer.groupedByClusteredColumnOrPrimaryKeys((DocTableInfo)((Collect)this.source).tableInfo, ((Collect)this.source).where, this.groupKeys);
    }

    @Override
    public <C, R> R accept(LogicalPlanVisitor<C, R> visitor, C context) {
        return visitor.visitGroupHashAggregate(this, context);
    }

    public String toString() {
        return "GroupBy{src=" + this.source + ", keys=" + this.groupKeys + ", agg=" + this.aggregates + "}";
    }

    @Override
    public void print(PrintContext printContext) {
        printContext.text("GroupHashAggregate[").text(Lists2.joinOn(", ", this.groupKeys, Symbol::toString));
        if (!this.aggregates.isEmpty()) {
            printContext.text(" | ").text(Lists2.joinOn(", ", this.aggregates, Symbol::toString));
        }
        Consumer[] consumerArray = new Consumer[1];
        consumerArray[0] = this.source::print;
        printContext.text("]").nest(consumerArray);
    }
}

