/*
 * Decompiled with CFR 0.152.
 */
package io.crate.protocols.postgres;

import io.crate.action.sql.BaseResultReceiver;
import io.crate.action.sql.DescribeResult;
import io.crate.action.sql.SQLOperations;
import io.crate.action.sql.Session;
import io.crate.action.sql.SessionContext;
import io.crate.auth.Authentication;
import io.crate.auth.AuthenticationMethod;
import io.crate.auth.Protocol;
import io.crate.auth.user.AccessControl;
import io.crate.auth.user.User;
import io.crate.common.annotations.VisibleForTesting;
import io.crate.common.collections.Lists2;
import io.crate.expression.symbol.Symbol;
import io.crate.protocols.SSL;
import io.crate.protocols.postgres.AuthenticationContext;
import io.crate.protocols.postgres.ClientInterrupted;
import io.crate.protocols.postgres.ConnectionProperties;
import io.crate.protocols.postgres.DelayableWriteChannel;
import io.crate.protocols.postgres.FormatCodes;
import io.crate.protocols.postgres.Messages;
import io.crate.protocols.postgres.QueryStringSplitter;
import io.crate.protocols.postgres.ResultSetReceiver;
import io.crate.protocols.postgres.RowCountReceiver;
import io.crate.protocols.postgres.SslReqHandler;
import io.crate.protocols.postgres.TransactionState;
import io.crate.protocols.postgres.types.PGType;
import io.crate.protocols.postgres.types.PGTypes;
import io.crate.protocols.ssl.SslContextProvider;
import io.crate.types.DataType;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.ByteToMessageDecoder;
import java.net.InetAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.BiConsumer;
import java.util.function.Function;
import javax.annotation.Nullable;
import javax.net.ssl.SSLSession;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.Version;
import org.elasticsearch.http.netty4.Netty4HttpServerTransport;

public class PostgresWireProtocol {
    private static final Logger LOGGER = LogManager.getLogger(PostgresWireProtocol.class);
    private static final String PASSWORD_AUTH_NAME = "password";
    public static int SERVER_VERSION_NUM = 100500;
    public static String PG_SERVER_VERSION = "10.5";
    final MessageDecoder decoder;
    final MessageHandler handler;
    private final SQLOperations sqlOperations;
    private final Function<SessionContext, AccessControl> getAccessControl;
    private final Authentication authService;
    private final SslReqHandler sslReqHandler;
    private DelayableWriteChannel channel;
    private int msgLength;
    private byte msgType;
    private Session session;
    private boolean ignoreTillSync = false;
    private AuthenticationContext authContext;
    private Properties properties;
    private State state = State.PRE_STARTUP;

    PostgresWireProtocol(SQLOperations sqlOperations, Function<SessionContext, AccessControl> getAcessControl, Authentication authService, @Nullable SslContextProvider sslContextProvider) {
        this.sqlOperations = sqlOperations;
        this.getAccessControl = getAcessControl;
        this.authService = authService;
        this.sslReqHandler = new SslReqHandler(sslContextProvider);
        this.decoder = new MessageDecoder();
        this.handler = new MessageHandler();
    }

    private static void traceLogProtocol(int protocol) {
        if (LOGGER.isTraceEnabled()) {
            int major = protocol >> 16;
            int minor = protocol & 0xFFFF;
            LOGGER.trace("protocol {}.{}", (Object)major, (Object)minor);
        }
    }

    @Nullable
    static String readCString(ByteBuf buffer) {
        byte[] bytes = new byte[buffer.bytesBefore((byte)0) + 1];
        if (bytes.length == 0) {
            return null;
        }
        buffer.readBytes(bytes);
        return new String(bytes, 0, bytes.length - 1, StandardCharsets.UTF_8);
    }

    @Nullable
    private static char[] readCharArray(ByteBuf buffer) {
        byte[] bytes = new byte[buffer.bytesBefore((byte)0) + 1];
        if (bytes.length == 0) {
            return null;
        }
        buffer.readBytes(bytes);
        return StandardCharsets.UTF_8.decode(ByteBuffer.wrap(bytes)).array();
    }

    private Properties readStartupMessage(ByteBuf buffer) {
        String key;
        Properties properties = new Properties();
        ByteBuf byteBuf = buffer.readSlice(this.msgLength);
        while ((key = PostgresWireProtocol.readCString(byteBuf)) != null) {
            String value = PostgresWireProtocol.readCString(byteBuf);
            LOGGER.trace("payload: key={} value={}", (Object)key, (Object)value);
            if ("".equals(key) || "".equals(value)) continue;
            properties.setProperty(key, value);
        }
        return properties;
    }

    private void handleStartupBody(ByteBuf buffer, Channel channel) {
        this.properties = this.readStartupMessage(buffer);
        this.initAuthentication(channel);
    }

    private void initAuthentication(Channel channel) {
        SSLSession sslSession;
        InetAddress address;
        ConnectionProperties connProperties;
        String userName = this.properties.getProperty("user");
        AuthenticationMethod authMethod = this.authService.resolveAuthenticationType(userName, connProperties = new ConnectionProperties(address = Netty4HttpServerTransport.getRemoteAddress(channel), Protocol.POSTGRES, sslSession = SSL.getSession(channel)));
        if (authMethod == null) {
            String errorMessage = String.format(Locale.ENGLISH, "No valid auth.host_based entry found for host \"%s\", user \"%s\"", address.getHostAddress(), userName);
            Messages.sendAuthenticationError(channel, errorMessage);
        } else {
            this.authContext = new AuthenticationContext(authMethod, connProperties, userName, LOGGER);
            if (PASSWORD_AUTH_NAME.equals(authMethod.name())) {
                Messages.sendAuthenticationCleartextPassword(channel);
                return;
            }
            this.finishAuthentication(channel);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void finishAuthentication(Channel channel) {
        assert (this.authContext != null) : "finishAuthentication() requires an authContext instance";
        try {
            User authenticatedUser = this.authContext.authenticate();
            String database = this.properties.getProperty("database");
            this.session = this.sqlOperations.createSession(database, authenticatedUser);
            Messages.sendAuthenticationOK(channel).addListener(f -> this.sendParamsAndRdyForQuery(channel));
        }
        catch (Exception e) {
            Messages.sendAuthenticationError(channel, e.getMessage());
        }
        finally {
            this.authContext.close();
            this.authContext = null;
        }
    }

    private void sendParamsAndRdyForQuery(Channel channel) {
        Messages.sendParameterStatus(channel, "crate_version", Version.CURRENT.externalNumber());
        Messages.sendParameterStatus(channel, "server_version", PG_SERVER_VERSION);
        Messages.sendParameterStatus(channel, "server_encoding", "UTF8");
        Messages.sendParameterStatus(channel, "client_encoding", "UTF8");
        Messages.sendParameterStatus(channel, "datestyle", "ISO");
        Messages.sendParameterStatus(channel, "TimeZone", "UTC");
        Messages.sendParameterStatus(channel, "integer_datetimes", "on");
        Messages.sendReadyForQuery(channel, TransactionState.IDLE);
    }

    private void handleFlush(Channel channel) {
        try {
            if (this.session.hasDeferredExecutions()) {
                this.session.flush();
            } else {
                channel.flush();
            }
        }
        catch (Throwable t) {
            Messages.sendErrorResponse(channel, this.getAccessControl.apply(this.session.sessionContext()), t);
        }
    }

    private void handleParseMessage(ByteBuf buffer, Channel channel) {
        String statementName = PostgresWireProtocol.readCString(buffer);
        String query = PostgresWireProtocol.readCString(buffer);
        int numParams = buffer.readShort();
        ArrayList<DataType> paramTypes = new ArrayList<DataType>(numParams);
        for (int i = 0; i < numParams; ++i) {
            int oid = buffer.readInt();
            DataType<?> dataType = PGTypes.fromOID(oid);
            if (dataType == null) {
                throw new IllegalArgumentException(String.format(Locale.ENGLISH, "Can't map PGType with oid=%d to Crate type", oid));
            }
            paramTypes.add(dataType);
        }
        this.session.parse(statementName, query, paramTypes);
        Messages.sendParseComplete(channel);
    }

    private void handlePassword(ByteBuf buffer, Channel channel) {
        char[] passwd = PostgresWireProtocol.readCharArray(buffer);
        if (passwd != null) {
            this.authContext.setSecurePassword(passwd);
        }
        this.finishAuthentication(channel);
    }

    private void handleBindMessage(ByteBuf buffer, Channel channel) {
        String portalName = PostgresWireProtocol.readCString(buffer);
        String statementName = PostgresWireProtocol.readCString(buffer);
        FormatCodes.FormatCode[] formatCodes = FormatCodes.fromBuffer(buffer);
        int numParams = buffer.readShort();
        List<Object> params = this.createList((short)numParams);
        block4: for (int i = 0; i < numParams; ++i) {
            int valueLength = buffer.readInt();
            if (valueLength == -1) {
                params.add(null);
                continue;
            }
            DataType<?> paramType = this.session.getParamType(statementName, i);
            PGType pgType = PGTypes.get(paramType);
            FormatCodes.FormatCode formatCode = FormatCodes.getFormatCode(formatCodes, i);
            switch (formatCode) {
                case TEXT: {
                    params.add(pgType.readTextValue(buffer, valueLength));
                    continue block4;
                }
                case BINARY: {
                    params.add(pgType.readBinaryValue(buffer, valueLength));
                    continue block4;
                }
                default: {
                    Messages.sendErrorResponse(channel, this.getAccessControl.apply(this.session.sessionContext()), new UnsupportedOperationException(String.format(Locale.ENGLISH, "Unsupported format code '%d' for param '%s'", formatCode.ordinal(), paramType.getName())));
                    return;
                }
            }
        }
        FormatCodes.FormatCode[] resultFormatCodes = FormatCodes.fromBuffer(buffer);
        this.session.bind(portalName, statementName, params, resultFormatCodes);
        Messages.sendBindComplete(channel);
    }

    private <T> List<T> createList(short size) {
        return size == 0 ? Collections.emptyList() : new ArrayList(size);
    }

    private void handleDescribeMessage(ByteBuf buffer, Channel channel) {
        byte type = buffer.readByte();
        String portalOrStatement = PostgresWireProtocol.readCString(buffer);
        DescribeResult describeResult = this.session.describe((char)type, portalOrStatement);
        List<Symbol> fields = describeResult.getFields();
        if (type == 83) {
            Messages.sendParameterDescription(channel, describeResult.getParameters());
        }
        if (fields == null) {
            Messages.sendNoData(channel);
        } else {
            FormatCodes.FormatCode[] resultFormatCodes = type == 80 ? this.session.getResultFormatCodes(portalOrStatement) : null;
            Messages.sendRowDescription(channel, fields, resultFormatCodes, describeResult.relation());
        }
    }

    private void handleExecute(ByteBuf buffer, DelayableWriteChannel channel) {
        BaseResultReceiver resultReceiver;
        String portalName = PostgresWireProtocol.readCString(buffer);
        int maxRows = buffer.readInt();
        String query = this.session.getQuery(portalName);
        if (query.isEmpty()) {
            this.session.close((byte)80, portalName);
            Messages.sendEmptyQueryResponse(channel);
            return;
        }
        List<? extends DataType> outputTypes = this.session.getOutputTypes(portalName);
        if (outputTypes == null) {
            maxRows = 0;
            resultReceiver = new RowCountReceiver(query, channel.bypassDelay(), this.getAccessControl.apply(this.session.sessionContext()));
        } else {
            resultReceiver = new ResultSetReceiver(query, channel.bypassDelay(), this.session.transactionState(), this.getAccessControl.apply(this.session.sessionContext()), Lists2.map(outputTypes, PGTypes::get), this.session.getResultFormatCodes(portalName));
        }
        CompletableFuture<?> execute = this.session.execute(portalName, maxRows, resultReceiver);
        if (execute != null) {
            channel.delayWritesUntil(execute);
        }
    }

    private void handleSync(Channel channel) {
        if (this.ignoreTillSync) {
            this.ignoreTillSync = false;
            this.session.resetDeferredExecutions();
            Messages.sendReadyForQuery(channel, TransactionState.FAILED_TRANSACTION);
            return;
        }
        try {
            ReadyForQueryCallback readyForQueryCallback = new ReadyForQueryCallback(channel, this.session.transactionState());
            this.session.sync().whenComplete((BiConsumer)readyForQueryCallback);
        }
        catch (Throwable t) {
            Messages.sendErrorResponse(channel, this.getAccessControl.apply(this.session.sessionContext()), t);
            Messages.sendReadyForQuery(channel, TransactionState.FAILED_TRANSACTION);
        }
    }

    private void handleClose(ByteBuf buffer, Channel channel) {
        byte b = buffer.readByte();
        String portalOrStatementName = PostgresWireProtocol.readCString(buffer);
        this.session.close(b, portalOrStatementName);
        Messages.sendCloseComplete(channel);
    }

    @VisibleForTesting
    void handleSimpleQuery(ByteBuf buffer, DelayableWriteChannel channel) {
        String queryString = PostgresWireProtocol.readCString(buffer);
        assert (queryString != null) : "query must not be nulL";
        List<String> queries = QueryStringSplitter.splitQuery(queryString);
        CompletionStage<Object> composedFuture = CompletableFuture.completedFuture(null);
        for (String query : queries) {
            composedFuture = composedFuture.thenCompose(result -> this.handleSingleQuery(query, channel));
        }
        composedFuture.whenComplete((BiConsumer)new ReadyForQueryCallback(channel, TransactionState.IDLE));
    }

    private CompletableFuture<?> handleSingleQuery(String query, DelayableWriteChannel channel) {
        CompletableFuture result = new CompletableFuture();
        if (query.isEmpty() || ";".equals(query)) {
            Messages.sendEmptyQueryResponse(channel);
            result.complete(null);
            return result;
        }
        AccessControl accessControl = this.getAccessControl.apply(this.session.sessionContext());
        try {
            CompletableFuture<?> execute;
            this.session.parse("", query, Collections.emptyList());
            this.session.bind("", "", Collections.emptyList(), null);
            DescribeResult describeResult = this.session.describe('P', "");
            List<Symbol> fields = describeResult.getFields();
            if (fields == null) {
                RowCountReceiver rowCountReceiver = new RowCountReceiver(query, channel.bypassDelay(), accessControl);
                execute = this.session.execute("", 0, rowCountReceiver);
            } else {
                Messages.sendRowDescription(channel, fields, null, describeResult.relation());
                ResultSetReceiver resultSetReceiver = new ResultSetReceiver(query, channel.bypassDelay(), TransactionState.IDLE, accessControl, Lists2.map(fields, x -> PGTypes.get(x.valueType())), null);
                execute = this.session.execute("", 0, resultSetReceiver);
            }
            if (execute != null) {
                channel.delayWritesUntil(execute);
            }
            return this.session.sync();
        }
        catch (Throwable t) {
            Messages.sendErrorResponse(channel, accessControl, t);
            result.completeExceptionally(t);
            return result;
        }
    }

    static enum State {
        PRE_STARTUP,
        STARTUP_HEADER,
        STARTUP_BODY,
        MSG_HEADER,
        MSG_BODY;

    }

    private class MessageDecoder
    extends ByteToMessageDecoder {
        private MessageDecoder() {
            this.setCumulator(COMPOSITE_CUMULATOR);
        }

        protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
            ByteBuf decode = this.decode(in, ctx.pipeline());
            if (decode != null) {
                out.add(decode);
            }
        }

        private ByteBuf decode(ByteBuf buffer, ChannelPipeline pipeline) {
            switch (PostgresWireProtocol.this.state) {
                case PRE_STARTUP: {
                    if (PostgresWireProtocol.this.sslReqHandler.process(buffer, pipeline) == SslReqHandler.State.DONE) {
                        PostgresWireProtocol.this.state = State.STARTUP_HEADER;
                        return this.decode(buffer, pipeline);
                    }
                    return null;
                }
                case STARTUP_HEADER: {
                    if (buffer.readableBytes() < 8) {
                        return null;
                    }
                    PostgresWireProtocol.this.msgLength = buffer.readInt() - 8;
                    LOGGER.trace("Header pkgLength: {}", (Object)PostgresWireProtocol.this.msgLength);
                    int protocol = buffer.readInt();
                    PostgresWireProtocol.traceLogProtocol(protocol);
                    return this.nullOrBuffer(buffer, State.STARTUP_BODY);
                }
                case MSG_HEADER: {
                    if (buffer.readableBytes() < 5) {
                        return null;
                    }
                    buffer.markReaderIndex();
                    PostgresWireProtocol.this.msgType = buffer.readByte();
                    PostgresWireProtocol.this.msgLength = buffer.readInt() - 4;
                    return this.nullOrBuffer(buffer, State.MSG_BODY);
                }
                case STARTUP_BODY: 
                case MSG_BODY: {
                    return this.nullOrBuffer(buffer, PostgresWireProtocol.this.state);
                }
            }
            throw new IllegalStateException("Invalid state " + PostgresWireProtocol.this.state);
        }

        private ByteBuf nullOrBuffer(ByteBuf buffer, State nextState) {
            if (buffer.readableBytes() < PostgresWireProtocol.this.msgLength) {
                buffer.resetReaderIndex();
                return null;
            }
            PostgresWireProtocol.this.state = nextState;
            return buffer.readBytes(PostgresWireProtocol.this.msgLength);
        }
    }

    private class MessageHandler
    extends SimpleChannelInboundHandler<ByteBuf> {
        private MessageHandler() {
        }

        public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
            PostgresWireProtocol.this.channel = new DelayableWriteChannel(ctx.channel());
        }

        public void channelRead0(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception {
            assert (PostgresWireProtocol.this.channel != null) : "Channel must be initialized";
            try {
                this.dispatchState(buffer, PostgresWireProtocol.this.channel);
            }
            catch (Throwable t) {
                PostgresWireProtocol.this.ignoreTillSync = true;
                try {
                    AccessControl accessControl = PostgresWireProtocol.this.session == null ? AccessControl.DISABLED : PostgresWireProtocol.this.getAccessControl.apply(PostgresWireProtocol.this.session.sessionContext());
                    Messages.sendErrorResponse(PostgresWireProtocol.this.channel, accessControl, t);
                }
                catch (Throwable ti) {
                    LOGGER.error("Error trying to send error to client: {}", (Object)t, (Object)ti);
                }
            }
        }

        private void dispatchState(ByteBuf buffer, DelayableWriteChannel channel) {
            switch (PostgresWireProtocol.this.state) {
                case STARTUP_HEADER: 
                case MSG_HEADER: {
                    throw new IllegalStateException("Decoder should've processed the headers");
                }
                case STARTUP_BODY: {
                    PostgresWireProtocol.this.state = State.MSG_HEADER;
                    PostgresWireProtocol.this.handleStartupBody(buffer, channel);
                    return;
                }
                case MSG_BODY: {
                    PostgresWireProtocol.this.state = State.MSG_HEADER;
                    LOGGER.trace("msg={} msgLength={} readableBytes={}", (Object)Character.valueOf((char)PostgresWireProtocol.this.msgType), (Object)PostgresWireProtocol.this.msgLength, (Object)buffer.readableBytes());
                    if (PostgresWireProtocol.this.ignoreTillSync && PostgresWireProtocol.this.msgType != 83) {
                        buffer.skipBytes(PostgresWireProtocol.this.msgLength);
                        return;
                    }
                    this.dispatchMessage(buffer, channel);
                    return;
                }
            }
            throw new IllegalStateException("Illegal state: " + PostgresWireProtocol.this.state);
        }

        private void dispatchMessage(ByteBuf buffer, DelayableWriteChannel channel) {
            switch (PostgresWireProtocol.this.msgType) {
                case 81: {
                    PostgresWireProtocol.this.handleSimpleQuery(buffer, channel);
                    return;
                }
                case 80: {
                    PostgresWireProtocol.this.handleParseMessage(buffer, channel);
                    return;
                }
                case 112: {
                    PostgresWireProtocol.this.handlePassword(buffer, channel);
                    return;
                }
                case 66: {
                    PostgresWireProtocol.this.handleBindMessage(buffer, channel);
                    return;
                }
                case 68: {
                    PostgresWireProtocol.this.handleDescribeMessage(buffer, channel);
                    return;
                }
                case 69: {
                    PostgresWireProtocol.this.handleExecute(buffer, channel);
                    return;
                }
                case 72: {
                    PostgresWireProtocol.this.handleFlush(channel);
                    return;
                }
                case 83: {
                    PostgresWireProtocol.this.handleSync(channel);
                    return;
                }
                case 67: {
                    PostgresWireProtocol.this.handleClose(buffer, channel);
                    return;
                }
                case 88: {
                    this.closeSession();
                    channel.close();
                    return;
                }
            }
            Messages.sendErrorResponse(channel, PostgresWireProtocol.this.session == null ? AccessControl.DISABLED : PostgresWireProtocol.this.getAccessControl.apply(PostgresWireProtocol.this.session.sessionContext()), new UnsupportedOperationException("Unsupported messageType: " + PostgresWireProtocol.this.msgType));
        }

        private void closeSession() {
            if (PostgresWireProtocol.this.session != null) {
                PostgresWireProtocol.this.session.close();
                PostgresWireProtocol.this.session = null;
            }
        }

        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            if (cause instanceof SocketException && cause.getMessage().equals("Connection reset")) {
                LOGGER.info("Connection reset. Client likely terminated connection");
                this.closeSession();
            } else {
                LOGGER.error("Uncaught exception: ", cause);
            }
        }

        public void channelUnregistered(ChannelHandlerContext ctx) throws Exception {
            LOGGER.trace("channelDisconnected");
            PostgresWireProtocol.this.channel = null;
            this.closeSession();
            super.channelUnregistered(ctx);
        }
    }

    private static class ReadyForQueryCallback
    implements BiConsumer<Object, Throwable> {
        private final Channel channel;
        private final TransactionState transactionState;

        private ReadyForQueryCallback(Channel channel, TransactionState transactionState) {
            this.channel = channel;
            this.transactionState = transactionState;
        }

        @Override
        public void accept(Object result, Throwable t) {
            boolean clientInterrupted;
            boolean bl = clientInterrupted = t instanceof ClientInterrupted || t != null && t.getCause() instanceof ClientInterrupted;
            if (!clientInterrupted) {
                Messages.sendReadyForQuery(this.channel, this.transactionState);
            }
        }
    }
}

