/*
 * Decompiled with CFR 0.152.
 */
package eu.clarussecure.proxy.protocol.plugins.pgsql.message;

import eu.clarussecure.proxy.protocol.plugins.pgsql.PgsqlConfiguration;
import eu.clarussecure.proxy.protocol.plugins.pgsql.PgsqlConstants;
import eu.clarussecure.proxy.protocol.plugins.pgsql.PgsqlSession;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.PgsqlErrorMessage;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.PgsqlMessage;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.parser.PgsqlMessageParser;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.sql.EventProcessor;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.sql.SQLSession;
import eu.clarussecure.proxy.protocol.plugins.pgsql.message.writer.PgsqlMessageWriter;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.DefaultFullPgsqlRawMessage;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.DefaultLastPgsqlRawContent;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.FullPgsqlRawMessage;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.PgsqlRawMessage;
import eu.clarussecure.proxy.protocol.plugins.pgsql.raw.handler.codec.MutablePgsqlRawMessage;
import eu.clarussecure.proxy.protocol.plugins.tcp.TCPConstants;
import eu.clarussecure.proxy.protocol.plugins.tcp.handler.forwarder.DirectedMessage;
import eu.clarussecure.proxy.spi.CString;
import eu.clarussecure.proxy.spi.buffer.MutableByteBufInputStream;
import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.util.ReferenceCountUtil;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class PgsqlMessageHandler<T extends PgsqlMessage>
extends MessageToMessageDecoder<PgsqlRawMessage> {
    private static final Logger LOGGER = LoggerFactory.getLogger(PgsqlMessageHandler.class);
    protected final Map<Byte, Class<? extends T>> msgTypes;
    protected int numberOfPeerChannels = 0;
    protected int preferredPeerChannel = Integer.MIN_VALUE;

    @SafeVarargs
    protected PgsqlMessageHandler(Class<? extends T> ... msgTypes) {
        this.msgTypes = Arrays.stream(msgTypes).collect(Collectors.toMap(msgType -> {
            try {
                return msgType.getField("TYPE").getByte(null);
            }
            catch (IllegalAccessException | IllegalArgumentException | NoSuchFieldException | SecurityException e) {
                LOGGER.error("Cannot read TYPE field of message class {}: ", (Object)msgType.getSimpleName(), (Object)e);
                throw new IllegalArgumentException(String.format("Cannot read TYPE field of message class %s: ", msgType.getSimpleName(), e));
            }
        }, msgType -> msgType));
    }

    public boolean acceptInboundMessage(Object msg) throws Exception {
        if (!super.acceptInboundMessage(msg)) {
            return false;
        }
        if (msg instanceof FullPgsqlRawMessage || msg instanceof MutablePgsqlRawMessage) {
            return this.msgTypes.keySet().stream().anyMatch(type -> ((PgsqlRawMessage)msg).getType() == type.byteValue());
        }
        return false;
    }

    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        Object msg2;
        if (msg instanceof DirectedMessage && this.acceptInboundMessage(msg2 = ((DirectedMessage)msg).getMsg())) {
            msg = msg2;
        }
        super.channelRead(ctx, msg);
    }

    protected void decode(ChannelHandlerContext ctx, PgsqlRawMessage rawMsg, List<Object> out) throws Exception {
        if (this.isStreamingSupported(rawMsg.getType()) && rawMsg instanceof MutablePgsqlRawMessage && !((MutablePgsqlRawMessage)((Object)rawMsg)).isComplete()) {
            LOGGER.trace("Decoding raw message in streaming mode: {}...", (Object)rawMsg);
            this.decodeStream(ctx, rawMsg);
            LOGGER.trace("Full raw message decoded: {}", (Object)rawMsg);
        } else {
            LOGGER.trace("Decoding full raw message: {}...", (Object)rawMsg);
            T pgsqlMessage = this.decode(ctx, rawMsg.getType(), rawMsg.getContent());
            LOGGER.trace("PGSQL message decoded: {}", pgsqlMessage);
            if (this.getNumberOfPeerChannels(ctx) > 1) {
                List<DirectedMessage<T>> directedMsgs = this.directedProcess(ctx, pgsqlMessage);
                if (directedMsgs != null) {
                    ByteBuf bufferToRecycle = rawMsg.getBytes();
                    for (DirectedMessage<T> directedMsg : directedMsgs) {
                        int to = directedMsg.getTo();
                        PgsqlMessage newPgsqlMessage = (PgsqlMessage)directedMsg.getMsg();
                        if (newPgsqlMessage != pgsqlMessage) {
                            LOGGER.trace("Encoding modified PGSQL message {}...", (Object)newPgsqlMessage);
                            ByteBuf buffer = this.allocate(ctx, newPgsqlMessage, bufferToRecycle);
                            buffer = this.encode(ctx, newPgsqlMessage, buffer);
                            rawMsg = new DefaultFullPgsqlRawMessage(buffer, newPgsqlMessage.getType(), buffer.capacity());
                            LOGGER.trace("Full raw message encoded: {}", (Object)rawMsg);
                            bufferToRecycle = null;
                        } else {
                            ReferenceCountUtil.retain((Object)rawMsg);
                        }
                        out.add(new DirectedMessage(to, (Object)rawMsg));
                        LOGGER.trace("Full raw message retained in the pipeline : {}", (Object)rawMsg);
                    }
                } else {
                    LOGGER.trace("Full raw message consumed {}...", (Object)rawMsg);
                }
            } else {
                T newPgsqlMessage = this.process(ctx, pgsqlMessage);
                if (newPgsqlMessage != null) {
                    if (newPgsqlMessage != pgsqlMessage) {
                        LOGGER.trace("Encoding modified PGSQL message {}...", newPgsqlMessage);
                        ByteBuf buffer = this.allocate(ctx, newPgsqlMessage, rawMsg.getBytes());
                        buffer = this.encode(ctx, newPgsqlMessage, buffer);
                        rawMsg = new DefaultFullPgsqlRawMessage(buffer, newPgsqlMessage.getType(), buffer.capacity());
                        LOGGER.trace("Full raw message encoded: {}", (Object)rawMsg);
                    } else {
                        ReferenceCountUtil.retain((Object)rawMsg);
                    }
                    rawMsg.filter(false);
                    out.add(rawMsg);
                    LOGGER.trace("Full raw message retained in the pipeline : {}", (Object)rawMsg);
                } else {
                    LOGGER.trace("Full raw message consumed {}...", (Object)rawMsg);
                }
            }
        }
    }

    protected void decodeStream(ChannelHandlerContext ctx, PgsqlRawMessage rawMsg) throws IOException {
        LOGGER.trace("Creating input stream...");
        try (MutableByteBufInputStream in = new MutableByteBufInputStream(rawMsg.getBytes(), rawMsg.getTotalLength());){
            LOGGER.trace("Input stream created to read from {}", (Object)rawMsg);
            this.decodeStream(ctx, rawMsg.getType(), in);
        }
    }

    protected boolean isStreamingSupported(byte type) {
        return false;
    }

    protected void decodeStream(ChannelHandlerContext ctx, byte type, MutableByteBufInputStream in) throws IOException {
        throw new UnsupportedOperationException("Unsupported decoding from input stream");
    }

    protected T decode(ChannelHandlerContext ctx, byte type, ByteBuf content) throws IOException {
        Class<? extends T> msgType = this.msgTypes.get(type);
        if (msgType == null) {
            LOGGER.error("Unsupported decoding of full raw message for type {}", (Object)type);
            throw new UnsupportedOperationException(String.format("Unsupported decoding of full raw message for type %d", type));
        }
        PgsqlMessageParser parser = this.getParser(ctx, msgType);
        content.markReaderIndex();
        Object msg = parser.parse(content);
        content.resetReaderIndex();
        return (T)msg;
    }

    protected List<DirectedMessage<T>> directedProcess(ChannelHandlerContext ctx, T msg) throws IOException {
        T newMsg = this.process(ctx, msg);
        if (newMsg == null) {
            return null;
        }
        int preferredPeerChannel = this.getPreferredPeerChannel(ctx);
        DirectedMessage directedMsg = new DirectedMessage(preferredPeerChannel, newMsg);
        List<DirectedMessage<T>> directedMsgs = Collections.singletonList(directedMsg);
        return directedMsgs;
    }

    protected T process(ChannelHandlerContext ctx, T msg) throws IOException {
        List<DirectedMessage<T>> directedMsgs = this.directedProcess(ctx, msg);
        if (directedMsgs != null && directedMsgs.size() > 1) {
            throw new IllegalStateException(String.format("%d new messages, 1 expected", directedMsgs.size()));
        }
        return (T)(directedMsgs != null ? (PgsqlMessage)directedMsgs.get(0).getMsg() : null);
    }

    protected ByteBuf allocate(ChannelHandlerContext ctx, T msg, ByteBuf buffer) {
        PgsqlMessageWriter writer = this.getWriter(ctx, msg.getClass());
        if (writer == null) {
            LOGGER.error("Unsupported allocating buffer for {} message", (Object)msg.getClass().getSimpleName());
            throw new UnsupportedOperationException(String.format("Unsupported allocating buffer for %s message", msg.getClass().getSimpleName()));
        }
        return writer.allocate(msg, buffer);
    }

    protected ByteBuf encode(ChannelHandlerContext ctx, T msg) throws IOException {
        return this.encode(ctx, msg, null);
    }

    protected ByteBuf encode(ChannelHandlerContext ctx, T msg, ByteBuf buffer) throws IOException {
        PgsqlMessageWriter writer = this.getWriter(ctx, msg.getClass());
        if (writer == null) {
            LOGGER.error("Unsupported encoding of {} message", (Object)msg.getClass().getSimpleName());
            throw new UnsupportedOperationException(String.format("Unsupported encoding of %s message", msg.getClass().getSimpleName()));
        }
        return writer.write(msg, buffer);
    }

    protected <M extends T> PgsqlMessageParser<M> getParser(ChannelHandlerContext ctx, Class<? extends T> msgType) {
        PgsqlMessageParser parser;
        HashMap<Class<? extends T>, PgsqlMessageParser> map = (HashMap<Class<? extends T>, PgsqlMessageParser>)ctx.channel().attr(PgsqlConstants.MSG_PARSERS_KEY).get();
        if (map == null) {
            map = new HashMap<Class<? extends T>, PgsqlMessageParser>();
            ctx.channel().attr(PgsqlConstants.MSG_PARSERS_KEY).set(map);
        }
        if ((parser = (PgsqlMessageParser)map.get(msgType)) == null) {
            parser = (PgsqlMessageParser)PgsqlMessageHandler.buildParserWriter(msgType, true);
            map.put(msgType, parser);
        }
        return parser;
    }

    protected <M extends PgsqlMessage> PgsqlMessageWriter<M> getWriter(ChannelHandlerContext ctx, Class<? extends PgsqlMessage> msgType) {
        PgsqlMessageWriter writer;
        HashMap<Class<? extends PgsqlMessage>, PgsqlMessageWriter> map = (HashMap<Class<? extends PgsqlMessage>, PgsqlMessageWriter>)ctx.channel().attr(PgsqlConstants.MSG_WRITERS_KEY).get();
        if (map == null) {
            map = new HashMap<Class<? extends PgsqlMessage>, PgsqlMessageWriter>();
            ctx.channel().attr(PgsqlConstants.MSG_WRITERS_KEY).set(map);
        }
        if ((writer = (PgsqlMessageWriter)map.get(msgType)) == null) {
            writer = (PgsqlMessageWriter)PgsqlMessageHandler.buildParserWriter(msgType, false);
            map.put(msgType, writer);
        }
        return writer;
    }

    private static <WP> WP buildParserWriter(Class<? extends PgsqlMessage> msgType, boolean parser) {
        String msgTypeName = msgType.getSimpleName();
        String pkgName = PgsqlMessage.class.getPackage().getName();
        String suffix = parser ? "Parser" : "Writer";
        String className = pkgName + "." + suffix.toLowerCase() + "." + msgTypeName + suffix;
        try {
            Class<?> loadClass = PgsqlMessage.class.getClassLoader().loadClass(className);
            return (WP)loadClass.newInstance();
        }
        catch (ClassNotFoundException | IllegalAccessException | InstantiationException e) {
            throw new IllegalArgumentException(e);
        }
    }

    protected int getNumberOfPeerChannels(ChannelHandlerContext ctx) {
        if (this.numberOfPeerChannels == 0) {
            PgsqlSession psqlSession = this.getPgsqlSession(ctx);
            if (ctx.channel() == psqlSession.getClientSideChannel()) {
                PgsqlConfiguration psqlConfiguration = this.getPsqlConfiguration(ctx);
                this.numberOfPeerChannels = psqlConfiguration.getServerEndpoints().size();
            } else {
                this.numberOfPeerChannels = 1;
            }
        }
        return this.numberOfPeerChannels;
    }

    protected int getPreferredPeerChannel(ChannelHandlerContext ctx) {
        if (this.preferredPeerChannel == Integer.MIN_VALUE) {
            PgsqlSession pgsqlSession = this.getPgsqlSession(ctx);
            if (ctx.channel() == pgsqlSession.getClientSideChannel()) {
                Integer preferredServerEndpoint = (Integer)ctx.channel().attr(TCPConstants.PREFERRED_SERVER_ENDPOINT_KEY).get();
                if (preferredServerEndpoint == null) {
                    throw new NullPointerException(TCPConstants.PREFERRED_SERVER_ENDPOINT_KEY.name() + " is not set");
                }
                if (preferredServerEndpoint < 0 || preferredServerEndpoint >= pgsqlSession.getServerSideChannels().size()) {
                    throw new IndexOutOfBoundsException(String.format("invalid %s: value: %d, number of server endpoints: %d ", TCPConstants.PREFERRED_SERVER_ENDPOINT_KEY.name(), preferredServerEndpoint, pgsqlSession.getServerSideChannels().size()));
                }
                this.preferredPeerChannel = preferredServerEndpoint;
            } else {
                this.preferredPeerChannel = 0;
            }
        }
        return this.preferredPeerChannel;
    }

    protected PgsqlConfiguration getPsqlConfiguration(ChannelHandlerContext ctx) {
        PgsqlConfiguration pgsqlConfiguration = (PgsqlConfiguration)((Object)ctx.channel().attr(PgsqlConstants.CONFIGURATION_KEY).get());
        return pgsqlConfiguration;
    }

    protected PgsqlSession getPgsqlSession(ChannelHandlerContext ctx) {
        PgsqlSession pgsqlSession = (PgsqlSession)((Object)ctx.channel().attr(PgsqlConstants.SESSION_KEY).get());
        return pgsqlSession;
    }

    protected SQLSession getSqlSession(ChannelHandlerContext ctx) {
        return this.getPgsqlSession(ctx).getSqlSession();
    }

    protected EventProcessor getEventProcessor(ChannelHandlerContext ctx) {
        return this.getPgsqlSession(ctx).getEventProcessor();
    }

    protected void sendErrorResponse(ChannelHandlerContext ctx, Map<Byte, CString> errorDetails) throws IOException {
        PgsqlErrorMessage msg = new PgsqlErrorMessage(errorDetails);
        this.sendResponse(ctx, msg);
    }

    protected <M extends PgsqlMessage> void sendResponse(ChannelHandlerContext ctx, M msg) throws IOException {
        PgsqlMessageWriter<M> writer = this.getWriter(ctx, msg.getClass());
        ByteBuf buffer = writer.allocate(msg);
        buffer = writer.write(msg, buffer);
        DefaultLastPgsqlRawContent content = new DefaultLastPgsqlRawContent(buffer);
        ctx.channel().writeAndFlush((Object)content);
    }

    protected <M extends PgsqlMessage> void sendRequest(ChannelHandlerContext ctx, M msg, int backend) throws IOException {
        PgsqlMessageWriter<M> writer = this.getWriter(ctx, msg.getClass());
        ByteBuf buffer = writer.allocate(msg);
        buffer = writer.write(msg, buffer);
        DefaultLastPgsqlRawContent content = new DefaultLastPgsqlRawContent(buffer);
        if (backend == -1) {
            for (int i = 0; i < this.getPgsqlSession(ctx).getServerSideChannels().size(); ++i) {
                Channel sinkChannel = this.getPgsqlSession(ctx).getServerSideChannel(i);
                if (i < this.getPgsqlSession(ctx).getServerSideChannels().size() - 1) {
                    sinkChannel.writeAndFlush((Object)content.retainedDuplicate());
                    continue;
                }
                sinkChannel.writeAndFlush((Object)content);
            }
        } else {
            this.getPgsqlSession(ctx).getServerSideChannel(backend).writeAndFlush((Object)content);
        }
    }
}

