/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ui

import java.net.{URI, URL}
import javax.servlet.DispatcherType
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}

import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.xml.Node

import org.eclipse.jetty.client.api.Response
import org.eclipse.jetty.proxy.ProxyServlet
import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler}
import org.eclipse.jetty.security.authentication.BasicAuthenticator
import org.eclipse.jetty.server.{HttpConnectionFactory, Request, Server, ServerConnector}
import org.eclipse.jetty.server.handler._
import org.eclipse.jetty.servlet._
import org.eclipse.jetty.servlets.gzip.GzipHandler
import org.eclipse.jetty.util.component.LifeCycle
import org.eclipse.jetty.util.security.{Constraint, Credential}
import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler}
import org.json4s.JValue
import org.json4s.jackson.JsonMethods.{pretty, render}

import org.apache.spark.{SecurityManager, SparkConf, SSLOptions}
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils

/**
 * Utilities for launching a web server using Jetty's HTTP Server class
 */
private[spark] object JettyUtils extends Logging {

  val SPARK_CONNECTOR_NAME = "Spark"
  val REDIRECT_CONNECTOR_NAME = "HttpsRedirect"

  val snappyDataRealm = "SnappyDataPulse"
  val snappyDataRoles = Array("user")
  var customAuthenticator: Option[BasicAuthenticator] = None

  lazy val constraintMapping = {
    val constraint = new Constraint()
    constraint.setName(Constraint.__BASIC_AUTH);
    constraint.setRoles(snappyDataRoles);
    constraint.setAuthenticate(true);

    val cm = new ConstraintMapping();
    cm.setConstraint(constraint);
    cm.setPathSpec("/*")
    cm
  }

  lazy val snappyHashLoginService = {
    val userName = "snappyuser"
    val password = "snappyuser"
    val ls = new HashLoginService()
    ls.putUser(userName, Credential.getCredential(password), snappyDataRoles)
    ls.setName(snappyDataRealm)
    ls
  }

  // Base type for a function that returns something based on an HTTP request. Allows for
  // implicit conversion from many types of functions to jetty Handlers.
  type Responder[T] = HttpServletRequest => T

  class ServletParams[T <% AnyRef](val responder: Responder[T],
    val contentType: String,
    val extractFn: T => String = (in: Any) => in.toString) {}

  // Conversions from various types of Responder's to appropriate servlet parameters
  implicit def jsonResponderToServlet(responder: Responder[JValue]): ServletParams[JValue] =
    new ServletParams(responder, "text/json", (in: JValue) => pretty(render(in)))

  implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): ServletParams[Seq[Node]] =
    new ServletParams(responder, "text/html", (in: Seq[Node]) => "<!DOCTYPE html>" + in.toString)

  implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] =
    new ServletParams(responder, "text/plain")

  def createServlet[T <% AnyRef](
      servletParams: ServletParams[T],
      securityMgr: SecurityManager,
      conf: SparkConf): HttpServlet = {

    // SPARK-10589 avoid frame-related click-jacking vulnerability, using X-Frame-Options
    // (see http://tools.ietf.org/html/rfc7034). By default allow framing only from the
    // same origin, but allow framing for a specific named URI.
    // Example: spark.ui.allowFramingFrom = https://example.com/
    val allowFramingFrom = conf.getOption("spark.ui.allowFramingFrom")
    val xFrameOptionsValue =
      allowFramingFrom.map(uri => s"ALLOW-FROM $uri").getOrElse("SAMEORIGIN")

    new HttpServlet {
      override def doGet(request: HttpServletRequest, response: HttpServletResponse) {
        try {
          if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) {
            response.setContentType("%s;charset=utf-8".format(servletParams.contentType))
            response.setStatus(HttpServletResponse.SC_OK)
            val result = servletParams.responder(request)
            response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
            response.setHeader("X-Frame-Options", xFrameOptionsValue)
            response.getWriter.print(servletParams.extractFn(result))
          } else {
            response.setStatus(HttpServletResponse.SC_FORBIDDEN)
            response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
            response.sendError(HttpServletResponse.SC_FORBIDDEN,
              "User is not authorized to access this page.")
          }
        } catch {
          case e: IllegalArgumentException =>
            response.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage)
          case e: Exception =>
            logWarning(s"GET ${request.getRequestURI} failed: $e", e)
            throw e
        }
      }
      // SPARK-5983 ensure TRACE is not supported
      protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = {
        res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
      }
    }
  }

  /** Create a context handler that responds to a request with the given path prefix */
  def createServletHandler[T <% AnyRef](
      path: String,
      servletParams: ServletParams[T],
      securityMgr: SecurityManager,
      conf: SparkConf,
      basePath: String = ""): ServletContextHandler = {
    createServletHandler(path, createServlet(servletParams, securityMgr, conf), basePath)
  }

  /** Create a context handler that responds to a request with the given path prefix */
  def createServletHandler(
      path: String,
      servlet: HttpServlet,
      basePath: String): ServletContextHandler = {
    val prefixedPath = if (basePath == "" && path == "/") {
      path
    } else {
      (basePath + path).stripSuffix("/")
    }
    val contextHandler = new ServletContextHandler
    val holder = new ServletHolder(servlet)
    contextHandler.setContextPath(prefixedPath)
    contextHandler.addServlet(holder, "/")
    contextHandler
  }

  /** Create a handler that always redirects the user to the given path */
  def createRedirectHandler(
      srcPath: String,
      destPath: String,
      beforeRedirect: HttpServletRequest => Unit = x => (),
      basePath: String = "",
      httpMethods: Set[String] = Set("GET")): ServletContextHandler = {
    val prefixedDestPath = basePath + destPath
    val servlet = new HttpServlet {
      override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = {
        if (httpMethods.contains("GET")) {
          doRequest(request, response)
        } else {
          response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
        }
      }
      override def doPost(request: HttpServletRequest, response: HttpServletResponse): Unit = {
        if (httpMethods.contains("POST")) {
          doRequest(request, response)
        } else {
          response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
        }
      }
      private def doRequest(request: HttpServletRequest, response: HttpServletResponse): Unit = {
        beforeRedirect(request)
        // Make sure we don't end up with "//" in the middle
        val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString
        response.sendRedirect(newUrl)
      }
      // SPARK-5983 ensure TRACE is not supported
      protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = {
        res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
      }
    }
    createServletHandler(srcPath, servlet, basePath)
  }

  /** Create a handler for serving files from a static directory */
  def createStaticHandler(resourceBase: String, path: String): ServletContextHandler = {
    val contextHandler = new ServletContextHandler
    contextHandler.setInitParameter("org.eclipse.jetty.servlet.Default.gzip", "false")
    val staticHandler = new DefaultServlet
    val holder = new ServletHolder(staticHandler)
    Option(Utils.getSparkClassLoader.getResource(resourceBase)) match {
      case Some(res) =>
        holder.setInitParameter("resourceBase", res.toString)
      case None =>
        throw new Exception("Could not find resource path for Web UI: " + resourceBase)
    }
    contextHandler.setContextPath(path)
    contextHandler.addServlet(holder, "/")
    contextHandler
  }

  /** Create a handler for proxying request to Workers and Application Drivers */
  def createProxyHandler(
      prefix: String,
      target: String): ServletContextHandler = {
    val servlet = new ProxyServlet {
      override def rewriteTarget(request: HttpServletRequest): String = {
        val rewrittenURI = createProxyURI(
          prefix, target, request.getRequestURI(), request.getQueryString())
        if (rewrittenURI == null) {
          return null
        }
        if (!validateDestination(rewrittenURI.getHost(), rewrittenURI.getPort())) {
          return null
        }
        rewrittenURI.toString()
      }

      override def filterServerResponseHeader(
          clientRequest: HttpServletRequest,
          serverResponse: Response,
          headerName: String,
          headerValue: String): String = {
        if (headerName.equalsIgnoreCase("location")) {
          val newHeader = createProxyLocationHeader(
            prefix, headerValue, clientRequest, serverResponse.getRequest().getURI())
          if (newHeader != null) {
            return newHeader
          }
        }
        super.filterServerResponseHeader(
          clientRequest, serverResponse, headerName, headerValue)
      }
    }

    val contextHandler = new ServletContextHandler
    val holder = new ServletHolder(servlet)
    contextHandler.setContextPath(prefix)
    contextHandler.addServlet(holder, "/")
    contextHandler
  }

  /** Add filters, if any, to the given list of ServletContextHandlers */
  def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) {
    val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim())
    filters.foreach {
      case filter : String =>
        if (!filter.isEmpty) {
          logInfo("Adding filter: " + filter)
          val holder : FilterHolder = new FilterHolder()
          holder.setClassName(filter)
          // Get any parameters for each filter
          conf.get("spark." + filter + ".params", "").split(',').map(_.trim()).toSet.foreach {
            param: String =>
              if (!param.isEmpty) {
                val parts = param.split("=")
                if (parts.length == 2) holder.setInitParameter(parts(0), parts(1))
             }
          }

          val prefix = s"spark.$filter.param."
          conf.getAll
            .filter { case (k, v) => k.length() > prefix.length() && k.startsWith(prefix) }
            .foreach { case (k, v) => holder.setInitParameter(k.substring(prefix.length()), v) }

          val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR,
            DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST)
          handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) }
        }
    }
  }

  /**
   * Attempt to start a Jetty server bound to the supplied hostName:port using the given
   * context handlers.
   *
   * If the desired port number is contended, continues incrementing ports until a free port is
   * found. Return the jetty Server object, the chosen port, and a mutable collection of handlers.
   */
  def startJettyServer(
      hostName: String,
      port: Int,
      sslOptions: SSLOptions,
      handlers: Seq[ServletContextHandler],
      conf: SparkConf,
      serverName: String = ""): ServerInfo = {

    addFilters(handlers, conf)

    val gzipHandlers = handlers.map { h =>
      h.setVirtualHosts(Array("@" + SPARK_CONNECTOR_NAME))
      val gzipHandler = new GzipHandler
      gzipHandler.setHandler(h)
      gzipHandler
    }

    // Bind to the given port, or throw a java.net.BindException if the port is occupied
    def connect(currentPort: Int): ((Server, Option[Int]), Int) = {
      val pool = new QueuedThreadPool
      if (serverName.nonEmpty) {
        pool.setName(serverName)
      }
      pool.setDaemon(true)

      // Set SnappyData authenticator into the SecurityHandler.
      // Has to be done inside connect because a failure to bind to port will
      // clear the handler so auth will fail even if bind on next port succeeds.
      customAuthenticator match {
        case Some(_) =>
          gzipHandlers.foreach { gh =>
            gh.getHandler.asInstanceOf[ServletContextHandler]
                .setSecurityHandler(basicAuthenticationHandler())
          }
        case None => logDebug("Not setting auth handler")
      }

      val server = new Server(pool)
      val connectors = new ArrayBuffer[ServerConnector]()
      val collection = new ContextHandlerCollection

      // Create a connector on port currentPort to listen for HTTP requests
      val httpConnector = new ServerConnector(
        server,
        null,
        // Call this full constructor to set this, which forces daemon threads:
        new ScheduledExecutorScheduler(s"$serverName-JettyScheduler", true),
        null,
        -1,
        -1,
        new HttpConnectionFactory())
      httpConnector.setPort(currentPort)
      connectors += httpConnector

      val httpsConnector = sslOptions.createJettySslContextFactory() match {
        case Some(factory) =>
          // If the new port wraps around, do not try a privileged port.
          val securePort =
            if (currentPort != 0) {
              (currentPort + 400 - 1024) % (65536 - 1024) + 1024
            } else {
              0
            }
          val scheme = "https"
          // Create a connector on port securePort to listen for HTTPS requests
          val connector = new ServerConnector(server, factory)
          connector.setPort(securePort)
          connector.setName(SPARK_CONNECTOR_NAME)
          connectors += connector

          // redirect the HTTP requests to HTTPS port
          httpConnector.setName(REDIRECT_CONNECTOR_NAME)
          collection.addHandler(createRedirectHttpsHandler(connector, scheme))
          Some(connector)

        case None =>
          // No SSL, so the HTTP connector becomes the official one where all contexts bind.
          httpConnector.setName(SPARK_CONNECTOR_NAME)
          None
      }

      // As each acceptor and each selector will use one thread, the number of threads should at
      // least be the number of acceptors and selectors plus 1. (See SPARK-13776)
      var minThreads = 1
      connectors.foreach { connector =>
        // Currently we only use "SelectChannelConnector"
        // Limit the max acceptor number to 8 so that we don't waste a lot of threads
        connector.setAcceptQueueSize(math.min(connector.getAcceptors, 8))
        connector.setHost(hostName)
        // The number of selectors always equals to the number of acceptors
        minThreads += connector.getAcceptors * 2
      }
      pool.setMaxThreads(math.max(pool.getMaxThreads, minThreads))

      val errorHandler = new ErrorHandler()
      errorHandler.setShowStacks(true)
      errorHandler.setServer(server)
      server.addBean(errorHandler)

      gzipHandlers.foreach(collection.addHandler)
      server.setHandler(collection)

      server.setConnectors(connectors.toArray)
      try {
        server.start()
        ((server, httpsConnector.map(_.getLocalPort())), httpConnector.getLocalPort)
      } catch {
        case e: Exception =>
          server.stop()
          pool.stop()
          throw e
      }
    }

    val ((server, securePort), boundPort) = Utils.startServiceOnPort(port, connect, conf,
      serverName)
    ServerInfo(server, boundPort, securePort,
      server.getHandler().asInstanceOf[ContextHandlerCollection])
  }
  /* Basic Authentication Handler */
  private def basicAuthenticationHandler(): SecurityHandler = {
    val csh = new ConstraintSecurityHandler();
    csh.setAuthenticator(customAuthenticator.get);
    csh.setRealmName(snappyDataRealm);
    csh.addConstraintMapping(constraintMapping);
    csh.setLoginService(snappyHashLoginService);

    csh
  }

  private def createRedirectHttpsHandler(
      httpsConnector: ServerConnector,
      scheme: String): ContextHandler = {
    val redirectHandler: ContextHandler = new ContextHandler
    redirectHandler.setContextPath("/")
    redirectHandler.setVirtualHosts(Array("@" + REDIRECT_CONNECTOR_NAME))
    redirectHandler.setHandler(new AbstractHandler {
      override def handle(
          target: String,
          baseRequest: Request,
          request: HttpServletRequest,
          response: HttpServletResponse): Unit = {
        if (baseRequest.isSecure) {
          return
        }
        val httpsURI = createRedirectURI(scheme, baseRequest.getServerName,
          httpsConnector.getLocalPort, baseRequest.getRequestURI, baseRequest.getQueryString)
        response.setContentLength(0)
        response.encodeRedirectURL(httpsURI)
        response.sendRedirect(httpsURI)
        baseRequest.setHandled(true)
      }
    })
    redirectHandler
  }

  def createProxyURI(prefix: String, target: String, path: String, query: String): URI = {
    if (!path.startsWith(prefix)) {
      return null
    }

    val uri = new StringBuilder(target)
    val rest = path.substring(prefix.length())

    if (!rest.isEmpty()) {
      if (!rest.startsWith("/")) {
        uri.append("/")
      }
      uri.append(rest)
    }

    val rewrittenURI = URI.create(uri.toString())
    if (query != null) {
      return new URI(
          rewrittenURI.getScheme(),
          rewrittenURI.getAuthority(),
          rewrittenURI.getPath(),
          query,
          rewrittenURI.getFragment()
        ).normalize()
    }
    rewrittenURI.normalize()
  }

  def createProxyLocationHeader(
      prefix: String,
      headerValue: String,
      clientRequest: HttpServletRequest,
      targetUri: URI): String = {
    val toReplace = targetUri.getScheme() + "://" + targetUri.getAuthority()
    if (headerValue.startsWith(toReplace)) {
      clientRequest.getScheme() + "://" + clientRequest.getHeader("host") +
          prefix + headerValue.substring(toReplace.length())
    } else {
      null
    }
  }

  // Create a new URI from the arguments, handling IPv6 host encoding and default ports.
  private def createRedirectURI(
      scheme: String, server: String, port: Int, path: String, query: String) = {
    val redirectServer = if (server.contains(":") && !server.startsWith("[")) {
      s"[${server}]"
    } else {
      server
    }
    val authority = s"$redirectServer:$port"
    new URI(scheme, authority, path, query, null).toString
  }

}

private[spark] case class ServerInfo(
    server: Server,
    boundPort: Int,
    securePort: Option[Int],
    private val rootHandler: ContextHandlerCollection) {

  def addHandler(handler: ContextHandler): Unit = {
    handler.setVirtualHosts(Array("@" + JettyUtils.SPARK_CONNECTOR_NAME))
    rootHandler.addHandler(handler)
    if (!handler.isStarted()) {
      handler.start()
    }
  }

  def removeHandler(handler: ContextHandler): Unit = {
    rootHandler.removeHandler(handler)
    if (handler.isStarted) {
      handler.stop()
    }
  }

  def stop(): Unit = {
    server.stop()
    // Stop the ThreadPool if it supports stop() method (through LifeCycle).
    // It is needed because stopping the Server won't stop the ThreadPool it uses.
    val threadPool = server.getThreadPool
    if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) {
      threadPool.asInstanceOf[LifeCycle].stop
    }
  }
}
