/*
 * Decompiled with CFR 0.152.
 */
package io.crate.analyze;

import io.crate.expression.operator.AndOperator;
import io.crate.expression.operator.Operators;
import io.crate.expression.symbol.Function;
import io.crate.expression.symbol.FunctionCopyVisitor;
import io.crate.expression.symbol.RefReplacer;
import io.crate.expression.symbol.Symbol;
import io.crate.expression.symbol.SymbolType;
import io.crate.expression.symbol.SymbolVisitors;
import io.crate.expression.symbol.Symbols;
import io.crate.metadata.FunctionImplementation;
import io.crate.metadata.GeneratedReference;
import io.crate.metadata.NodeContext;
import io.crate.metadata.Reference;
import io.crate.metadata.Scalar;
import io.crate.metadata.SearchPath;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;

public final class GeneratedColumnExpander {
    private static final Map<String, String> ROUNDING_FUNCTION_MAPPING = Map.of("op_>", "op_>=", "op_<", "op_<=");
    private static final Set<String> ROUNDING_FUNCTIONS = Set.of("ceil", "floor", "round", "date_trunc");
    private static final ComparisonReplaceVisitor COMPARISON_REPLACE_VISITOR = new ComparisonReplaceVisitor();

    private GeneratedColumnExpander() {
    }

    public static Symbol maybeExpand(Symbol symbol, List<GeneratedReference> generatedCols, List<Reference> expansionCandidates, NodeContext nodeCtx) {
        return COMPARISON_REPLACE_VISITOR.addComparisons(symbol, generatedCols, expansionCandidates, nodeCtx);
    }

    private static class ComparisonReplaceVisitor
    extends FunctionCopyVisitor<Context> {
        ComparisonReplaceVisitor() {
        }

        Symbol addComparisons(Symbol symbol, List<GeneratedReference> generatedCols, List<Reference> expansionCandidates, NodeContext nodeCtx) {
            HashMap<Reference, ArrayList<GeneratedReference>> referencedSingleReferences = ComparisonReplaceVisitor.extractGeneratedReferences(generatedCols, expansionCandidates);
            if (referencedSingleReferences.isEmpty()) {
                return symbol;
            }
            Context ctx = new Context(referencedSingleReferences, nodeCtx);
            return symbol.accept(this, ctx);
        }

        @Override
        public Symbol visitFunction(Function function, Context context) {
            if (Operators.COMPARISON_OPERATORS.contains(function.name())) {
                Reference reference = null;
                Symbol otherSide = null;
                for (int i = 0; i < function.arguments().size(); ++i) {
                    Symbol arg = function.arguments().get(i);
                    if ((arg = Symbols.unwrapReferenceFromCast(arg)) instanceof Reference) {
                        reference = (Reference)arg;
                        continue;
                    }
                    otherSide = arg;
                }
                if (reference != null && otherSide != null && !SymbolVisitors.any(Symbols.IS_GENERATED_COLUMN, otherSide)) {
                    return this.addComparison(function, reference, otherSide, context);
                }
            }
            return super.visitFunction(function, context);
        }

        private Symbol addComparison(Function function, Reference reference, Symbol comparedAgainst, Context context) {
            ArrayList genColInfos = context.referencedRefsToGeneratedColumn.computeIfAbsent(reference, k -> new ArrayList());
            ArrayList<Function> comparisonsToAdd = new ArrayList<Function>(genColInfos.size());
            comparisonsToAdd.add(function);
            for (GeneratedReference genColInfo : genColInfos) {
                Function comparison = this.createAdditionalComparison(function, genColInfo, comparedAgainst, context.nodeCtx);
                if (comparison == null) continue;
                comparisonsToAdd.add(comparison);
            }
            return AndOperator.join(comparisonsToAdd);
        }

        @Nullable
        private Function createAdditionalComparison(Function function, GeneratedReference generatedReference, Symbol comparedAgainst, NodeContext nodeCtx) {
            if (generatedReference != null && generatedReference.generatedExpression().symbolType().equals((Object)SymbolType.FUNCTION)) {
                Function generatedFunction = (Function)generatedReference.generatedExpression();
                String operatorName = function.name();
                if (!operatorName.equals("op_=")) {
                    String replacedOperatorName;
                    if (!generatedFunction.hasFeature(Scalar.Feature.COMPARISON_REPLACEMENT)) {
                        return null;
                    }
                    if (ROUNDING_FUNCTIONS.contains(generatedFunction.name()) && (replacedOperatorName = ROUNDING_FUNCTION_MAPPING.get(operatorName)) != null) {
                        operatorName = replacedOperatorName;
                    }
                }
                Symbol wrapped = this.wrapInGenerationExpression(comparedAgainst, generatedReference);
                FunctionImplementation funcImpl = nodeCtx.functions().get(null, operatorName, List.of(generatedReference, wrapped), SearchPath.pathWithPGCatalogAndDoc());
                return new Function(funcImpl.signature(), List.of(generatedReference, wrapped), funcImpl.boundSignature().getReturnType().createType());
            }
            return null;
        }

        private Symbol wrapInGenerationExpression(Symbol wrapMeLikeItsHot, Reference generatedReference) {
            ReplaceIfMatch replaceIfMatch = new ReplaceIfMatch(wrapMeLikeItsHot, ((GeneratedReference)generatedReference).referencedReferences().get(0));
            return RefReplacer.replaceRefs(((GeneratedReference)generatedReference).generatedExpression(), replaceIfMatch);
        }

        private static HashMap<Reference, ArrayList<GeneratedReference>> extractGeneratedReferences(List<GeneratedReference> generatedCols, Collection<Reference> partitionCols) {
            HashMap<Reference, ArrayList<GeneratedReference>> map = new HashMap<Reference, ArrayList<GeneratedReference>>();
            for (GeneratedReference generatedColumn : generatedCols) {
                if (generatedColumn.referencedReferences().size() != 1 || !partitionCols.contains(generatedColumn)) continue;
                map.computeIfAbsent(generatedColumn.referencedReferences().get(0), v -> new ArrayList()).add(generatedColumn);
            }
            return map;
        }

        static class Context {
            private final HashMap<Reference, ArrayList<GeneratedReference>> referencedRefsToGeneratedColumn;
            private final NodeContext nodeCtx;

            public Context(HashMap<Reference, ArrayList<GeneratedReference>> referencedRefsToGeneratedColumn, NodeContext nodeCtx) {
                this.referencedRefsToGeneratedColumn = referencedRefsToGeneratedColumn;
                this.nodeCtx = nodeCtx;
            }
        }
    }

    static class ReplaceIfMatch
    implements java.util.function.Function<Reference, Symbol> {
        private final Symbol replaceWith;
        private final Reference toReplace;

        ReplaceIfMatch(Symbol replaceWith, Reference toReplace) {
            this.replaceWith = replaceWith;
            this.toReplace = toReplace;
        }

        @Override
        public Symbol apply(Reference ref) {
            if (ref.equals(this.toReplace)) {
                return this.replaceWith;
            }
            return ref;
        }
    }
}

