/*
 * Decompiled with CFR 0.152.
 */
package io.crate.planner.operators;

import io.crate.analyze.OrderBy;
import io.crate.analyze.relations.AbstractTableRelation;
import io.crate.analyze.relations.AnalyzedRelation;
import io.crate.common.annotations.VisibleForTesting;
import io.crate.common.collections.Lists2;
import io.crate.common.collections.Tuple;
import io.crate.data.Row;
import io.crate.execution.dsl.phases.HashJoinPhase;
import io.crate.execution.dsl.phases.MergePhase;
import io.crate.execution.dsl.projection.EvalProjection;
import io.crate.execution.dsl.projection.builder.InputColumns;
import io.crate.execution.dsl.projection.builder.ProjectionBuilder;
import io.crate.execution.engine.join.JoinOperations;
import io.crate.expression.symbol.SelectSymbol;
import io.crate.expression.symbol.Symbol;
import io.crate.expression.symbol.SymbolVisitors;
import io.crate.expression.symbol.Symbols;
import io.crate.metadata.RelationName;
import io.crate.planner.ExecutionPlan;
import io.crate.planner.PlannerContext;
import io.crate.planner.ResultDescription;
import io.crate.planner.distribution.DistributionInfo;
import io.crate.planner.distribution.DistributionType;
import io.crate.planner.node.dql.join.Join;
import io.crate.planner.node.dql.join.JoinType;
import io.crate.planner.operators.FetchRewrite;
import io.crate.planner.operators.HashJoinConditionSymbolsExtractor;
import io.crate.planner.operators.LogicalPlan;
import io.crate.planner.operators.LogicalPlanVisitor;
import io.crate.planner.operators.PrintContext;
import io.crate.planner.operators.SubQueryAndParamBinder;
import io.crate.planner.operators.SubQueryResults;
import io.crate.statistics.TableStats;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

public class HashJoin
implements LogicalPlan {
    private final Symbol joinCondition;
    @VisibleForTesting
    final AnalyzedRelation concreteRelation;
    private final List<Symbol> outputs;
    final LogicalPlan rhs;
    final LogicalPlan lhs;

    public HashJoin(LogicalPlan lhs, LogicalPlan rhs, Symbol joinCondition, AnalyzedRelation concreteRelation) {
        this.outputs = Lists2.concat(lhs.outputs(), rhs.outputs());
        this.lhs = lhs;
        this.rhs = rhs;
        this.concreteRelation = concreteRelation;
        this.joinCondition = joinCondition;
    }

    public JoinType joinType() {
        return JoinType.INNER;
    }

    public Symbol joinCondition() {
        return this.joinCondition;
    }

    public LogicalPlan lhs() {
        return this.lhs;
    }

    public LogicalPlan rhs() {
        return this.rhs;
    }

    @Override
    public Map<LogicalPlan, SelectSymbol> dependencies() {
        Map<LogicalPlan, SelectSymbol> leftDeps = this.lhs.dependencies();
        Map<LogicalPlan, SelectSymbol> rightDeps = this.rhs.dependencies();
        HashMap<LogicalPlan, SelectSymbol> deps = new HashMap<LogicalPlan, SelectSymbol>(leftDeps.size() + rightDeps.size());
        deps.putAll(leftDeps);
        deps.putAll(rightDeps);
        return deps;
    }

    @Override
    public ExecutionPlan build(PlannerContext plannerContext, ProjectionBuilder projectionBuilder, int limit, int offset, @Nullable OrderBy order, @Nullable Integer pageSizeHint, Row params, SubQueryResults subQueryResults) {
        boolean isDistributed;
        ExecutionPlan leftExecutionPlan = this.lhs.build(plannerContext, projectionBuilder, -1, 0, null, null, params, subQueryResults);
        ExecutionPlan rightExecutionPlan = this.rhs.build(plannerContext, projectionBuilder, -1, 0, null, null, params, subQueryResults);
        LogicalPlan leftLogicalPlan = this.lhs;
        LogicalPlan rightLogicalPlan = this.rhs;
        boolean tablesSwitched = false;
        if (this.lhs.numExpectedRows() < this.rhs.numExpectedRows()) {
            tablesSwitched = true;
            leftLogicalPlan = this.rhs;
            rightLogicalPlan = this.lhs;
            ExecutionPlan tmp = leftExecutionPlan;
            leftExecutionPlan = rightExecutionPlan;
            rightExecutionPlan = tmp;
        }
        SubQueryAndParamBinder paramBinder = new SubQueryAndParamBinder(params, subQueryResults);
        Tuple<List<Symbol>, List<Symbol>> hashSymbols = this.extractHashJoinSymbolsFromJoinSymbolsAndSplitPerSide(tablesSwitched);
        ResultDescription leftResultDesc = leftExecutionPlan.resultDescription();
        ResultDescription rightResultDesc = rightExecutionPlan.resultDescription();
        Collection<String> joinExecutionNodes = leftResultDesc.nodeIds();
        List<Symbol> leftOutputs = leftLogicalPlan.outputs();
        List<Symbol> rightOutputs = rightLogicalPlan.outputs();
        MergePhase leftMerge = null;
        MergePhase rightMerge = null;
        boolean bl = isDistributed = !leftResultDesc.hasRemainingLimitOrOffset() && !rightResultDesc.hasRemainingLimitOrOffset();
        if (joinExecutionNodes.size() == 1 && joinExecutionNodes.equals(rightResultDesc.nodeIds()) && !rightResultDesc.hasRemainingLimitOrOffset()) {
            leftExecutionPlan.setDistributionInfo(DistributionInfo.DEFAULT_SAME_NODE);
            rightExecutionPlan.setDistributionInfo(DistributionInfo.DEFAULT_SAME_NODE);
        } else {
            if (isDistributed) {
                leftOutputs = this.setModuloDistribution(Lists2.map((Collection)hashSymbols.v1(), paramBinder), leftLogicalPlan.outputs(), leftExecutionPlan);
                rightOutputs = this.setModuloDistribution(Lists2.map((Collection)hashSymbols.v2(), paramBinder), rightLogicalPlan.outputs(), rightExecutionPlan);
            } else {
                joinExecutionNodes = Collections.singletonList(plannerContext.handlerNode());
                leftExecutionPlan.setDistributionInfo(DistributionInfo.DEFAULT_BROADCAST);
                rightExecutionPlan.setDistributionInfo(DistributionInfo.DEFAULT_BROADCAST);
            }
            leftMerge = JoinOperations.buildMergePhaseForJoin(plannerContext, leftResultDesc, joinExecutionNodes);
            rightMerge = JoinOperations.buildMergePhaseForJoin(plannerContext, rightResultDesc, joinExecutionNodes);
        }
        List<List<Symbol>> joinOutputs = Lists2.concat(leftOutputs, rightOutputs);
        HashJoinPhase joinPhase = new HashJoinPhase(plannerContext.jobId(), plannerContext.nextExecutionPhaseId(), "hash-join", Collections.singletonList(JoinOperations.createJoinProjection(this.outputs, joinOutputs)), leftMerge, rightMerge, leftOutputs.size(), rightOutputs.size(), joinExecutionNodes, InputColumns.create(paramBinder.apply(this.joinCondition), joinOutputs), InputColumns.create(Lists2.map((Collection)hashSymbols.v1(), paramBinder), new InputColumns.SourceSymbols(leftOutputs)), InputColumns.create(Lists2.map((Collection)hashSymbols.v2(), paramBinder), new InputColumns.SourceSymbols(rightOutputs)), Symbols.typeView(leftOutputs), leftLogicalPlan.estimatedRowSize(), leftLogicalPlan.numExpectedRows());
        return new Join(joinPhase, leftExecutionPlan, rightExecutionPlan, -1, 0, -1, this.outputs.size(), null);
    }

    @Override
    public List<Symbol> outputs() {
        return this.outputs;
    }

    @Override
    public List<AbstractTableRelation<?>> baseTables() {
        return Lists2.concat(this.lhs.baseTables(), this.rhs.baseTables());
    }

    @Override
    public List<LogicalPlan> sources() {
        return List.of(this.lhs, this.rhs);
    }

    @Override
    public LogicalPlan replaceSources(List<LogicalPlan> sources) {
        return new HashJoin(sources.get(0), sources.get(1), this.joinCondition, this.concreteRelation);
    }

    @Override
    public LogicalPlan pruneOutputsExcept(TableStats tableStats, Collection<Symbol> outputsToKeep) {
        LinkedHashSet<Symbol> lhsToKeep = new LinkedHashSet<Symbol>();
        LinkedHashSet<Symbol> rhsToKeep = new LinkedHashSet<Symbol>();
        for (Symbol outputToKeep : outputsToKeep) {
            SymbolVisitors.intersection(outputToKeep, this.lhs.outputs(), lhsToKeep::add);
            SymbolVisitors.intersection(outputToKeep, this.rhs.outputs(), rhsToKeep::add);
        }
        SymbolVisitors.intersection(this.joinCondition, this.lhs.outputs(), lhsToKeep::add);
        SymbolVisitors.intersection(this.joinCondition, this.rhs.outputs(), rhsToKeep::add);
        LogicalPlan newLhs = this.lhs.pruneOutputsExcept(tableStats, lhsToKeep);
        LogicalPlan newRhs = this.rhs.pruneOutputsExcept(tableStats, rhsToKeep);
        if (newLhs == this.lhs && newRhs == this.rhs) {
            return this;
        }
        return new HashJoin(newLhs, newRhs, this.joinCondition, this.concreteRelation);
    }

    @Override
    @Nullable
    public FetchRewrite rewriteToFetch(TableStats tableStats, Collection<Symbol> usedColumns) {
        LinkedHashSet<Symbol> usedFromLeft = new LinkedHashSet<Symbol>();
        LinkedHashSet<Symbol> usedFromRight = new LinkedHashSet<Symbol>();
        for (Symbol usedColumn : usedColumns) {
            SymbolVisitors.intersection(usedColumn, this.lhs.outputs(), usedFromLeft::add);
            SymbolVisitors.intersection(usedColumn, this.rhs.outputs(), usedFromRight::add);
        }
        SymbolVisitors.intersection(this.joinCondition, this.lhs.outputs(), usedFromLeft::add);
        SymbolVisitors.intersection(this.joinCondition, this.rhs.outputs(), usedFromRight::add);
        FetchRewrite lhsFetchRewrite = this.lhs.rewriteToFetch(tableStats, usedFromLeft);
        if (lhsFetchRewrite == null) {
            return null;
        }
        FetchRewrite rhsFetchRewrite = this.rhs.rewriteToFetch(tableStats, usedFromRight);
        if (rhsFetchRewrite == null) {
            return null;
        }
        LinkedHashMap<Symbol, Symbol> allReplacedOutputs = new LinkedHashMap<Symbol, Symbol>(lhsFetchRewrite.replacedOutputs());
        allReplacedOutputs.putAll(rhsFetchRewrite.replacedOutputs());
        return new FetchRewrite(allReplacedOutputs, new HashJoin(lhsFetchRewrite.newPlan(), rhsFetchRewrite.newPlan(), this.joinCondition, this.concreteRelation));
    }

    private Tuple<List<Symbol>, List<Symbol>> extractHashJoinSymbolsFromJoinSymbolsAndSplitPerSide(boolean switchedTables) {
        Map<RelationName, List<Symbol>> hashJoinSymbols = HashJoinConditionSymbolsExtractor.extract(this.joinCondition);
        List<Symbol> hashJoinSymbolsForConcreteRelation = hashJoinSymbols.remove(this.concreteRelation.relationName());
        List hashJoinSymbolsForJoinTree = hashJoinSymbols.values().stream().flatMap(Collection::stream).collect(Collectors.toList());
        if (switchedTables) {
            return new Tuple<List<Symbol>, List<Symbol>>(hashJoinSymbolsForConcreteRelation, hashJoinSymbolsForJoinTree);
        }
        return new Tuple<List<Symbol>, List<Symbol>>(hashJoinSymbolsForJoinTree, hashJoinSymbolsForConcreteRelation);
    }

    @Override
    public long numExpectedRows() {
        return Math.max(this.lhs.numExpectedRows(), this.rhs.numExpectedRows());
    }

    @Override
    public long estimatedRowSize() {
        return this.lhs.estimatedRowSize() + this.rhs.estimatedRowSize();
    }

    @Override
    public <C, R> R accept(LogicalPlanVisitor<C, R> visitor, C context) {
        return visitor.visitHashJoin(this, context);
    }

    @Override
    public void print(PrintContext printContext) {
        Consumer[] consumerArray = new Consumer[2];
        consumerArray[0] = this.lhs::print;
        consumerArray[1] = this.rhs::print;
        printContext.text("HashJoin[").text(this.joinCondition.toString()).text("]").nest(consumerArray);
    }

    private List<Symbol> setModuloDistribution(List<Symbol> joinSymbols, List<Symbol> planOutputs, ExecutionPlan executionPlan) {
        List<Symbol> outputs = planOutputs;
        Symbol firstJoinSymbol = joinSymbols.get(0);
        int distributeBySymbolPos = planOutputs.indexOf(firstJoinSymbol);
        if (distributeBySymbolPos < 0) {
            outputs = this.createEvalProjectionForDistributionJoinSymbol(firstJoinSymbol, planOutputs, executionPlan);
            distributeBySymbolPos = planOutputs.size();
        }
        executionPlan.setDistributionInfo(new DistributionInfo(DistributionType.MODULO, distributeBySymbolPos));
        return outputs;
    }

    private List<Symbol> createEvalProjectionForDistributionJoinSymbol(Symbol firstJoinSymbol, List<Symbol> outputs, ExecutionPlan executionPlan) {
        ArrayList<Symbol> projectionOutputs = new ArrayList<Symbol>(outputs.size() + 1);
        projectionOutputs.addAll(outputs);
        projectionOutputs.add(firstJoinSymbol);
        EvalProjection evalProjection = new EvalProjection(InputColumns.create(projectionOutputs, new InputColumns.SourceSymbols(outputs)));
        executionPlan.addProjection(evalProjection);
        return projectionOutputs;
    }
}

