/*
 * Decompiled with CFR 0.152.
 */
package io.crate.execution.engine.aggregation.impl;

import io.crate.breaker.RamAccounting;
import io.crate.common.annotations.VisibleForTesting;
import io.crate.data.Input;
import io.crate.execution.engine.aggregation.AggregationFunction;
import io.crate.execution.engine.aggregation.DocValueAggregator;
import io.crate.execution.engine.aggregation.impl.AggregationImplModule;
import io.crate.execution.engine.aggregation.impl.OverflowAwareMutableLong;
import io.crate.memory.MemoryManager;
import io.crate.metadata.functions.Signature;
import io.crate.types.DataType;
import io.crate.types.DataTypes;
import io.crate.types.NumericType;
import java.io.IOException;
import java.math.BigDecimal;
import java.util.List;
import javax.annotation.Nullable;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.Version;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.index.mapper.MappedFieldType;

public class NumericSumAggregation
extends AggregationFunction<BigDecimal, BigDecimal> {
    public static final String NAME = "sum";
    public static final Signature SIGNATURE = Signature.aggregate("sum", DataTypes.NUMERIC.getTypeSignature(), DataTypes.NUMERIC.getTypeSignature());
    private static final long INIT_BIG_DECIMAL_SIZE = NumericType.size(BigDecimal.ZERO);
    private final Signature signature;
    private final Signature boundSignature;
    private final DataType<BigDecimal> returnType;

    public static void register(AggregationImplModule mod) {
        mod.register(SIGNATURE, NumericSumAggregation::new);
    }

    @VisibleForTesting
    private NumericSumAggregation(Signature signature, Signature boundSignature) {
        this.signature = signature;
        this.boundSignature = boundSignature;
        DataType<?> argumentType = boundSignature.getArgumentDataTypes().get(0);
        assert (argumentType.id() == DataTypes.NUMERIC.id());
        this.returnType = argumentType;
    }

    @Override
    @Nullable
    public BigDecimal newState(RamAccounting ramAccounting, Version indexVersionCreated, Version minNodeInCluster, MemoryManager memoryManager) {
        ramAccounting.addBytes(INIT_BIG_DECIMAL_SIZE);
        return null;
    }

    @Override
    public BigDecimal iterate(RamAccounting ramAccounting, MemoryManager memoryManager, BigDecimal state, Input[] args) throws CircuitBreakingException {
        BigDecimal value = this.returnType.implicitCast(args[0].value());
        if (value != null) {
            if (state != null) {
                BigDecimal newState = state.add(value);
                ramAccounting.addBytes(NumericType.sizeDiff(newState, state));
                state = newState;
            } else {
                state = value;
            }
        }
        return state;
    }

    @Override
    public BigDecimal reduce(RamAccounting ramAccounting, BigDecimal state1, BigDecimal state2) {
        if (state1 == null) {
            return state2;
        }
        if (state2 == null) {
            return state1;
        }
        return state1.add(state2);
    }

    @Override
    public BigDecimal terminatePartial(RamAccounting ramAccounting, BigDecimal state) {
        if (state != null) {
            ramAccounting.addBytes(NumericType.size(state));
        }
        return state;
    }

    @Override
    public DataType<?> partialType() {
        return this.returnType;
    }

    @Override
    public Signature signature() {
        return this.signature;
    }

    @Override
    public Signature boundSignature() {
        return this.boundSignature;
    }

    @Override
    public boolean isRemovableCumulative() {
        return true;
    }

    @Override
    public BigDecimal removeFromAggregatedState(RamAccounting ramAccounting, BigDecimal previousAggState, Input[] stateToRemove) {
        BigDecimal value = this.returnType.implicitCast(stateToRemove[0].value());
        if (value != null && previousAggState != null) {
            return previousAggState.subtract(value);
        }
        return previousAggState;
    }

    @Override
    public DocValueAggregator<?> getDocValueAggregator(List<DataType<?>> argumentTypes, List<MappedFieldType> fieldTypes) {
        return switch (argumentTypes.get(0).id()) {
            case 2, 8, 9, 10 -> new SumLong(this.returnType, fieldTypes.get(0).name());
            case 7 -> new SumFloat(this.returnType, fieldTypes.get(0).name());
            case 6 -> new SumDouble(this.returnType, fieldTypes.get(0).name());
            default -> null;
        };
    }

    static class SumLong
    implements DocValueAggregator<OverflowAwareMutableLong> {
        private final DataType<BigDecimal> returnType;
        private final String columnName;
        private SortedNumericDocValues values;

        SumLong(DataType<BigDecimal> returnType, String columnName) {
            this.returnType = returnType;
            this.columnName = columnName;
        }

        @Override
        public OverflowAwareMutableLong initialState(RamAccounting ramAccounting) {
            ramAccounting.addBytes(INIT_BIG_DECIMAL_SIZE);
            return new OverflowAwareMutableLong(0L);
        }

        @Override
        public void loadDocValues(LeafReader reader) throws IOException {
            this.values = DocValues.getSortedNumeric((LeafReader)reader, (String)this.columnName);
        }

        @Override
        public void apply(RamAccounting ramAccounting, int doc, OverflowAwareMutableLong state) throws IOException {
            if (this.values.advanceExact(doc) && this.values.docValueCount() == 1) {
                BigDecimal prevState = state.value();
                state.add(this.values.nextValue());
                ramAccounting.addBytes(NumericType.sizeDiff(state.value(), prevState));
            }
        }

        public BigDecimal partialResult(RamAccounting ramAccounting, OverflowAwareMutableLong state) {
            if (state.hasValue()) {
                return this.returnType.implicitCast(state.value());
            }
            return null;
        }
    }

    static class SumFloat
    implements DocValueAggregator<BigDecimalValueWrapper> {
        private final DataType<BigDecimal> returnType;
        private final String columnName;
        private SortedNumericDocValues values;

        SumFloat(DataType<BigDecimal> returnType, String columnName) {
            this.returnType = returnType;
            this.columnName = columnName;
        }

        @Override
        public BigDecimalValueWrapper initialState(RamAccounting ramAccounting) {
            ramAccounting.addBytes(INIT_BIG_DECIMAL_SIZE);
            return new BigDecimalValueWrapper(BigDecimal.ZERO);
        }

        @Override
        public void loadDocValues(LeafReader reader) throws IOException {
            this.values = DocValues.getSortedNumeric((LeafReader)reader, (String)this.columnName);
        }

        @Override
        public void apply(RamAccounting ramAccounting, int doc, BigDecimalValueWrapper state) throws IOException {
            if (this.values.advanceExact(doc) && this.values.docValueCount() == 1) {
                BigDecimal prevState = state.value();
                BigDecimal fieldValue = this.returnType.implicitCast(Float.valueOf(NumericUtils.sortableIntToFloat((int)((int)this.values.nextValue()))));
                state.setValue(state.value().add(fieldValue));
                ramAccounting.addBytes(NumericType.sizeDiff(state.value(), prevState));
            }
        }

        @Override
        public Object partialResult(RamAccounting ramAccounting, BigDecimalValueWrapper state) {
            if (state.hasValue()) {
                return this.returnType.implicitCast(state.value());
            }
            return null;
        }
    }

    static class SumDouble
    implements DocValueAggregator<BigDecimalValueWrapper> {
        private final DataType<BigDecimal> returnType;
        private final String columnName;
        private SortedNumericDocValues values;

        SumDouble(DataType<BigDecimal> returnType, String columnName) {
            this.returnType = returnType;
            this.columnName = columnName;
        }

        @Override
        public BigDecimalValueWrapper initialState(RamAccounting ramAccounting) {
            ramAccounting.addBytes(INIT_BIG_DECIMAL_SIZE);
            return new BigDecimalValueWrapper(BigDecimal.ZERO);
        }

        @Override
        public void loadDocValues(LeafReader reader) throws IOException {
            this.values = DocValues.getSortedNumeric((LeafReader)reader, (String)this.columnName);
        }

        @Override
        public void apply(RamAccounting ramAccounting, int doc, BigDecimalValueWrapper state) throws IOException {
            if (this.values.advanceExact(doc) && this.values.docValueCount() == 1) {
                BigDecimal prevState = state.value();
                BigDecimal fieldValue = this.returnType.implicitCast(NumericUtils.sortableLongToDouble((long)this.values.nextValue()));
                state.setValue(state.value().add(fieldValue));
                ramAccounting.addBytes(NumericType.sizeDiff(state.value(), prevState));
            }
        }

        @Override
        public Object partialResult(RamAccounting ramAccounting, BigDecimalValueWrapper state) {
            if (state.hasValue()) {
                return this.returnType.implicitCast(state.value());
            }
            return null;
        }
    }

    public static final class BigDecimalValueWrapper {
        private BigDecimal value;
        private boolean hasValue;

        public BigDecimalValueWrapper(BigDecimal value) {
            this.value = value;
        }

        public BigDecimal value() {
            return this.value;
        }

        public boolean hasValue() {
            return this.hasValue;
        }

        public void setValue(BigDecimal value) {
            this.hasValue = true;
            this.value = value;
        }
    }
}

