/*
 * Decompiled with CFR 0.152.
 */
package io.crate.execution.engine.indexing;

import io.crate.action.FutureActionListener;
import io.crate.action.LimitedExponentialBackoff;
import io.crate.data.BatchIterator;
import io.crate.data.BatchIterators;
import io.crate.data.CollectionBucket;
import io.crate.data.Row;
import io.crate.data.Row1;
import io.crate.exceptions.Exceptions;
import io.crate.execution.dml.ShardRequest;
import io.crate.execution.dml.ShardResponse;
import io.crate.execution.engine.collect.CollectExpression;
import io.crate.execution.engine.indexing.BatchIteratorBackpressureExecutor;
import io.crate.execution.jobs.NodeJobsCounter;
import io.crate.execution.support.RetryListener;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collector;
import javax.annotation.Nullable;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.bulk.BackoffPolicy;
import org.elasticsearch.cluster.service.ClusterService;

public class ShardDMLExecutor<TReq extends ShardRequest<TReq, TItem>, TItem extends ShardRequest.Item, TAcc, TResult extends Iterable<? extends Row>>
implements Function<BatchIterator<Row>, CompletableFuture<? extends Iterable<? extends Row>>> {
    private static final Logger LOGGER = LogManager.getLogger(ShardDMLExecutor.class);
    private static final BackoffPolicy BACKOFF_POLICY = LimitedExponentialBackoff.limitedExponential(1000);
    public static final int DEFAULT_BULK_SIZE = 10000;
    private final UUID jobId;
    private final int bulkSize;
    private final ScheduledExecutorService scheduler;
    private final Executor executor;
    private final CollectExpression<Row, ?> uidExpression;
    private final NodeJobsCounter nodeJobsCounter;
    private final Supplier<TReq> requestFactory;
    private final Function<String, TItem> itemFactory;
    private final BiConsumer<TReq, ActionListener<ShardResponse>> operation;
    private final String localNodeId;
    private final Collector<ShardResponse, TAcc, TResult> collector;
    private int numItems = -1;
    public static final Collector<ShardResponse, long[], Iterable<Row>> ROW_COUNT_COLLECTOR = Collector.of(() -> new long[]{0L}, (acc, response) -> {
        acc[0] = acc[0] + ShardDMLExecutor.toRowCount(response);
    }, (acc, response) -> {
        acc[0] = acc[0] + response[0];
        return acc;
    }, acc -> List.of(new Row1(acc[0])), new Collector.Characteristics[0]);
    public static final Collector<ShardResponse, List<Object[]>, Iterable<Row>> RESULT_ROW_COLLECTOR = Collector.of(ArrayList::new, (acc, response) -> acc.addAll(ShardDMLExecutor.toResultRows(response)), (acc, response) -> {
        acc.addAll(response);
        return acc;
    }, CollectionBucket::new, new Collector.Characteristics[0]);

    public ShardDMLExecutor(UUID jobId, int bulkSize, ScheduledExecutorService scheduler, Executor executor, CollectExpression<Row, ?> uidExpression, ClusterService clusterService, NodeJobsCounter nodeJobsCounter, Supplier<TReq> requestFactory, Function<String, TItem> itemFactory, BiConsumer<TReq, ActionListener<ShardResponse>> transportAction, Collector<ShardResponse, TAcc, TResult> collector) {
        this.jobId = jobId;
        this.bulkSize = bulkSize;
        this.scheduler = scheduler;
        this.executor = executor;
        this.uidExpression = uidExpression;
        this.nodeJobsCounter = nodeJobsCounter;
        this.requestFactory = requestFactory;
        this.itemFactory = itemFactory;
        this.operation = transportAction;
        this.localNodeId = ShardDMLExecutor.getLocalNodeId(clusterService);
        this.collector = collector;
    }

    private void addRowToRequest(TReq req, Row row) {
        ++this.numItems;
        this.uidExpression.setNextRow(row);
        ((ShardRequest)req).add(this.numItems, (ShardRequest.Item)((ShardRequest.Item)this.itemFactory.apply((String)this.uidExpression.value())));
    }

    private CompletableFuture<TAcc> executeBatch(TReq request) {
        FutureActionListener<ShardResponse, Object> listener = new FutureActionListener<ShardResponse, Object>(a -> {
            TAcc acc = this.collector.supplier().get();
            this.collector.accumulator().accept(acc, (ShardResponse)a);
            return acc;
        });
        this.nodeJobsCounter.increment(this.localNodeId);
        CompletionStage result = listener.whenComplete((r, f) -> this.nodeJobsCounter.decrement(this.localNodeId));
        this.operation.accept(request, this.withRetry(request, listener));
        return result;
    }

    private RetryListener<ShardResponse> withRetry(TReq request, FutureActionListener<ShardResponse, TAcc> listener) {
        return new RetryListener<ShardResponse>(this.scheduler, l -> this.operation.accept(request, (ActionListener<ShardResponse>)l), listener, BACKOFF_POLICY);
    }

    @Override
    public CompletableFuture<TResult> apply(BatchIterator<Row> batchIterator) {
        BatchIterator<ShardRequest> reqBatchIterator = BatchIterators.partition(batchIterator, this.bulkSize, this.requestFactory, this::addRowToRequest, r -> false);
        Predicate<ShardRequest> shouldPause = ignored -> true;
        if (batchIterator.hasLazyResultSet()) {
            shouldPause = ignored -> this.nodeJobsCounter.getInProgressJobsForNode(this.localNodeId) >= 5L;
        }
        return new BatchIteratorBackpressureExecutor<ShardRequest, TAcc>(this.jobId, this.scheduler, this.executor, reqBatchIterator, this::executeBatch, this.collector.combiner(), this.collector.supplier().get(), shouldPause, BACKOFF_POLICY).consumeIteratorAndExecute().thenApply(this.collector.finisher());
    }

    @Nullable
    private static String getLocalNodeId(ClusterService clusterService) {
        String nodeId = null;
        try {
            nodeId = clusterService.localNode().getId();
        }
        catch (IllegalStateException e) {
            LOGGER.debug("Unable to get local node id", (Throwable)e);
        }
        return nodeId;
    }

    private static <A> A processResponse(ShardResponse shardResponse, Function<ShardResponse, A> f) {
        Exception failure = shardResponse.failure();
        if (failure != null) {
            throw Exceptions.toRuntimeException(failure);
        }
        return f.apply(shardResponse);
    }

    private static Long toRowCount(ShardResponse shardResponse) {
        return (long)ShardDMLExecutor.processResponse(shardResponse, ShardResponse::successRowCount);
    }

    private static List<Object[]> toResultRows(ShardResponse shardResponse) {
        List result = ShardDMLExecutor.processResponse(shardResponse, ShardResponse::getResultRows);
        return result == null ? List.of() : result;
    }
}

