/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.network.client;

import io.netty.channel.Channel;
import java.io.IOException;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.network.client.BaseResponseCallback;
import org.apache.spark.network.client.ChunkFetchFailureException;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.MergedBlockMetaResponseCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallback;
import org.apache.spark.network.client.StreamInterceptor;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.MergedBlockMetaSuccess;
import org.apache.spark.network.protocol.ResponseMessage;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.StreamFailure;
import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.server.MessageHandler;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportFrameDecoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.sparkproject.guava.annotations.VisibleForTesting;

public class TransportResponseHandler
extends MessageHandler<ResponseMessage> {
    private static final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);
    private final Channel channel;
    private final Map<StreamChunkId, ChunkReceivedCallback> outstandingFetches;
    private final Map<Long, BaseResponseCallback> outstandingRpcs;
    private final Queue<Pair<String, StreamCallback>> streamCallbacks;
    private volatile boolean streamActive;
    private final AtomicLong timeOfLastRequestNs;

    public TransportResponseHandler(Channel channel) {
        this.channel = channel;
        this.outstandingFetches = new ConcurrentHashMap<StreamChunkId, ChunkReceivedCallback>();
        this.outstandingRpcs = new ConcurrentHashMap<Long, BaseResponseCallback>();
        this.streamCallbacks = new ConcurrentLinkedQueue<Pair<String, StreamCallback>>();
        this.timeOfLastRequestNs = new AtomicLong(0L);
    }

    public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
        this.updateTimeOfLastRequest();
        this.outstandingFetches.put(streamChunkId, callback);
    }

    public void removeFetchRequest(StreamChunkId streamChunkId) {
        this.outstandingFetches.remove(streamChunkId);
    }

    public void addRpcRequest(long requestId, BaseResponseCallback callback) {
        this.updateTimeOfLastRequest();
        this.outstandingRpcs.put(requestId, callback);
    }

    public void removeRpcRequest(long requestId) {
        this.outstandingRpcs.remove(requestId);
    }

    public void addStreamCallback(String streamId, StreamCallback callback) {
        this.updateTimeOfLastRequest();
        this.streamCallbacks.offer((Pair<String, StreamCallback>)ImmutablePair.of((Object)streamId, (Object)callback));
    }

    @VisibleForTesting
    public void deactivateStream() {
        this.streamActive = false;
    }

    private void failOutstandingRequests(Throwable cause) {
        for (Map.Entry<StreamChunkId, ChunkReceivedCallback> entry : this.outstandingFetches.entrySet()) {
            try {
                entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
            }
            catch (Exception e) {
                logger.warn("ChunkReceivedCallback.onFailure throws exception", (Throwable)e);
            }
        }
        for (Map.Entry<Object, Object> entry : this.outstandingRpcs.entrySet()) {
            try {
                ((BaseResponseCallback)entry.getValue()).onFailure(cause);
            }
            catch (Exception e) {
                logger.warn("RpcResponseCallback.onFailure throws exception", (Throwable)e);
            }
        }
        for (Pair pair : this.streamCallbacks) {
            try {
                ((StreamCallback)pair.getValue()).onFailure((String)pair.getKey(), cause);
            }
            catch (Exception e) {
                logger.warn("StreamCallback.onFailure throws exception", (Throwable)e);
            }
        }
        this.outstandingFetches.clear();
        this.outstandingRpcs.clear();
        this.streamCallbacks.clear();
    }

    @Override
    public void channelActive() {
    }

    @Override
    public void channelInactive() {
        if (this.numOutstandingRequests() > 0) {
            String remoteAddress = NettyUtils.getRemoteAddress(this.channel);
            logger.error("Still have {} requests outstanding when connection from {} is closed", (Object)this.numOutstandingRequests(), (Object)remoteAddress);
            this.failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
        }
    }

    @Override
    public void exceptionCaught(Throwable cause) {
        if (this.numOutstandingRequests() > 0) {
            String remoteAddress = NettyUtils.getRemoteAddress(this.channel);
            logger.error("Still have {} requests outstanding when connection from {} is closed", (Object)this.numOutstandingRequests(), (Object)remoteAddress);
            this.failOutstandingRequests(cause);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void handle(ResponseMessage message) throws Exception {
        if (message instanceof ChunkFetchSuccess) {
            ChunkFetchSuccess resp = (ChunkFetchSuccess)message;
            ChunkReceivedCallback listener = this.outstandingFetches.get(resp.streamChunkId);
            if (listener == null) {
                logger.warn("Ignoring response for block {} from {} since it is not outstanding", (Object)resp.streamChunkId, (Object)NettyUtils.getRemoteAddress(this.channel));
                resp.body().release();
            } else {
                this.outstandingFetches.remove(resp.streamChunkId);
                listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
                resp.body().release();
            }
        } else if (message instanceof ChunkFetchFailure) {
            ChunkFetchFailure resp = (ChunkFetchFailure)message;
            ChunkReceivedCallback listener = this.outstandingFetches.get(resp.streamChunkId);
            if (listener == null) {
                logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", new Object[]{resp.streamChunkId, NettyUtils.getRemoteAddress(this.channel), resp.errorString});
            } else {
                this.outstandingFetches.remove(resp.streamChunkId);
                listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException("Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
            }
        } else if (message instanceof RpcResponse) {
            RpcResponse resp = (RpcResponse)message;
            RpcResponseCallback listener = (RpcResponseCallback)this.outstandingRpcs.get(resp.requestId);
            if (listener == null) {
                logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", new Object[]{resp.requestId, NettyUtils.getRemoteAddress(this.channel), resp.body().size()});
                resp.body().release();
            } else {
                this.outstandingRpcs.remove(resp.requestId);
                try {
                    listener.onSuccess(resp.body().nioByteBuffer());
                }
                finally {
                    resp.body().release();
                }
            }
        } else if (message instanceof RpcFailure) {
            RpcFailure resp = (RpcFailure)message;
            BaseResponseCallback listener = this.outstandingRpcs.get(resp.requestId);
            if (listener == null) {
                logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", new Object[]{resp.requestId, NettyUtils.getRemoteAddress(this.channel), resp.errorString});
            } else {
                this.outstandingRpcs.remove(resp.requestId);
                listener.onFailure(new RuntimeException(resp.errorString));
            }
        } else if (message instanceof MergedBlockMetaSuccess) {
            MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess)message;
            try {
                MergedBlockMetaResponseCallback listener = (MergedBlockMetaResponseCallback)this.outstandingRpcs.get(resp.requestId);
                if (listener == null) {
                    logger.warn("Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) since it is not outstanding", new Object[]{resp.requestId, NettyUtils.getRemoteAddress(this.channel), resp.body().size()});
                }
                this.outstandingRpcs.remove(resp.requestId);
                listener.onSuccess(resp.getNumChunks(), resp.body());
            }
            finally {
                resp.body().release();
            }
        } else if (message instanceof StreamResponse) {
            StreamResponse resp = (StreamResponse)message;
            Pair<String, StreamCallback> entry = this.streamCallbacks.poll();
            if (entry != null) {
                StreamCallback callback = (StreamCallback)entry.getValue();
                if (resp.byteCount > 0L) {
                    StreamInterceptor<ResponseMessage> interceptor = new StreamInterceptor<ResponseMessage>(this, resp.streamId, resp.byteCount, callback);
                    try {
                        TransportFrameDecoder frameDecoder = (TransportFrameDecoder)this.channel.pipeline().get("frameDecoder");
                        frameDecoder.setInterceptor(interceptor);
                        this.streamActive = true;
                    }
                    catch (Exception e) {
                        logger.error("Error installing stream handler.", (Throwable)e);
                        this.deactivateStream();
                    }
                } else {
                    try {
                        callback.onComplete(resp.streamId);
                    }
                    catch (Exception e) {
                        logger.warn("Error in stream handler onComplete().", (Throwable)e);
                    }
                }
            } else {
                logger.error("Could not find callback for StreamResponse.");
            }
        } else if (message instanceof StreamFailure) {
            StreamFailure resp = (StreamFailure)message;
            Pair<String, StreamCallback> entry = this.streamCallbacks.poll();
            if (entry != null) {
                StreamCallback callback = (StreamCallback)entry.getValue();
                try {
                    callback.onFailure(resp.streamId, new RuntimeException(resp.error));
                }
                catch (IOException ioe) {
                    logger.warn("Error in stream failure handler.", (Throwable)ioe);
                }
            } else {
                logger.warn("Stream failure with unknown callback: {}", (Object)resp.error);
            }
        } else {
            throw new IllegalStateException("Unknown response type: " + message.type());
        }
    }

    public int numOutstandingRequests() {
        return this.outstandingFetches.size() + this.outstandingRpcs.size() + this.streamCallbacks.size() + (this.streamActive ? 1 : 0);
    }

    public long getTimeOfLastRequestNs() {
        return this.timeOfLastRequestNs.get();
    }

    public void updateTimeOfLastRequest() {
        this.timeOfLastRequestNs.set(System.nanoTime());
    }
}

