/*
 * Decompiled with CFR 0.152.
 */
package org.apache.activemq.transport.ws.jetty9;

import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.activemq.broker.BrokerService;
import org.apache.activemq.broker.BrokerServiceAware;
import org.apache.activemq.transport.Transport;
import org.apache.activemq.transport.TransportAcceptListener;
import org.apache.activemq.transport.TransportFactory;
import org.apache.activemq.transport.util.HttpTransportUtils;
import org.apache.activemq.transport.ws.WSTransportProxy;
import org.apache.activemq.transport.ws.jetty9.MQTTSocket;
import org.apache.activemq.transport.ws.jetty9.StompSocket;
import org.eclipse.jetty.websocket.api.WebSocketListener;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;

public class WSServlet
extends WebSocketServlet
implements BrokerServiceAware {
    private static final long serialVersionUID = -4716657876092884139L;
    private TransportAcceptListener listener;
    private static final Map<String, Integer> stompProtocols = new ConcurrentHashMap<String, Integer>();
    private static final Map<String, Integer> mqttProtocols = new ConcurrentHashMap<String, Integer>();
    private Map<String, Object> transportOptions;
    private BrokerService brokerService;

    public void init() throws ServletException {
        super.init();
        this.listener = (TransportAcceptListener)this.getServletContext().getAttribute("acceptListener");
        if (this.listener == null) {
            throw new ServletException("No such attribute 'acceptListener' available in the ServletContext");
        }
    }

    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
    }

    public void configure(WebSocketServletFactory factory) {
        factory.setCreator(new WebSocketCreator(){

            public Object createWebSocket(ServletUpgradeRequest req, ServletUpgradeResponse resp) {
                WebSocketListener socket;
                Protocol requestedProtocol = Protocol.UNKNOWN;
                if (!req.getSubProtocols().isEmpty()) {
                    for (String subProtocol : req.getSubProtocols()) {
                        if (subProtocol.startsWith("mqtt")) {
                            requestedProtocol = Protocol.MQTT;
                            continue;
                        }
                        if (!subProtocol.contains("stomp")) continue;
                        requestedProtocol = Protocol.STOMP;
                    }
                } else {
                    requestedProtocol = Protocol.STOMP;
                }
                switch (requestedProtocol) {
                    case MQTT: {
                        socket = new MQTTSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
                        ((MQTTSocket)socket).setTransportOptions(new HashMap<String, Object>(WSServlet.this.transportOptions));
                        ((MQTTSocket)socket).setPeerCertificates(req.getCertificates());
                        resp.setAcceptedSubProtocol(WSServlet.this.getAcceptedSubProtocol(mqttProtocols, req.getSubProtocols(), "mqtt"));
                        break;
                    }
                    case UNKNOWN: {
                        socket = WSServlet.this.findWSTransport(req, resp);
                        if (socket != null) break;
                    }
                    case STOMP: {
                        socket = new StompSocket(HttpTransportUtils.generateWsRemoteAddress(req.getHttpServletRequest()));
                        ((StompSocket)socket).setPeerCertificates(req.getCertificates());
                        resp.setAcceptedSubProtocol(WSServlet.this.getAcceptedSubProtocol(stompProtocols, req.getSubProtocols(), "stomp"));
                        break;
                    }
                    default: {
                        socket = null;
                        WSServlet.this.listener.onAcceptError(new IOException("Unknown protocol requested"));
                    }
                }
                if (socket != null) {
                    WSServlet.this.listener.onAccept((Transport)socket);
                }
                return socket;
            }
        });
    }

    private WebSocketListener findWSTransport(ServletUpgradeRequest request, ServletUpgradeResponse response) {
        WSTransportProxy proxy = null;
        for (String subProtocol : request.getSubProtocols()) {
            try {
                String remoteAddress = HttpTransportUtils.generateWsRemoteAddress(request.getHttpServletRequest(), subProtocol);
                URI remoteURI = new URI(remoteAddress);
                TransportFactory factory = TransportFactory.findTransportFactory(remoteURI);
                if (factory instanceof BrokerServiceAware) {
                    ((BrokerServiceAware)((Object)factory)).setBrokerService(this.brokerService);
                }
                Transport transport = factory.doConnect(remoteURI);
                proxy = new WSTransportProxy(remoteAddress, transport);
                proxy.setPeerCertificates(request.getCertificates());
                proxy.setTransportOptions(this.transportOptions);
                response.setAcceptedSubProtocol(proxy.getSubProtocol());
            }
            catch (Exception e) {
                proxy = null;
            }
        }
        return proxy;
    }

    private String getAcceptedSubProtocol(Map<String, Integer> protocols, List<String> subProtocols, String defaultProtocol) {
        ArrayList<SubProtocol> matchedProtocols = new ArrayList<SubProtocol>();
        if (subProtocols != null && subProtocols.size() > 0) {
            for (String subProtocol : subProtocols) {
                Integer priority = protocols.get(subProtocol);
                if (subProtocol == null || priority == null) continue;
                matchedProtocols.add(new SubProtocol(subProtocol, priority));
            }
            if (matchedProtocols.size() > 0) {
                Collections.sort(matchedProtocols, new Comparator<SubProtocol>(){

                    @Override
                    public int compare(SubProtocol s1, SubProtocol s2) {
                        return s2.priority.compareTo(s1.priority);
                    }
                });
                return ((SubProtocol)matchedProtocols.get(0)).protocol;
            }
        }
        return defaultProtocol;
    }

    public void setTransportOptions(Map<String, Object> transportOptions) {
        this.transportOptions = transportOptions;
    }

    @Override
    public void setBrokerService(BrokerService brokerService) {
        this.brokerService = brokerService;
    }

    static {
        stompProtocols.put("v12.stomp", 3);
        stompProtocols.put("v11.stomp", 2);
        stompProtocols.put("v10.stomp", 1);
        stompProtocols.put("stomp", 0);
        mqttProtocols.put("mqttv3.1", 1);
        mqttProtocols.put("mqtt", 0);
    }

    private class SubProtocol {
        private String protocol;
        private Integer priority;

        public SubProtocol(String protocol, Integer priority) {
            this.protocol = protocol;
            this.priority = priority;
        }
    }

    private static enum Protocol {
        MQTT,
        STOMP,
        UNKNOWN;

    }
}

