// Copyright 2022-2023 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package build.buf.connect

import build.buf.connect.compression.CompressionPool
import build.buf.connect.compression.IdentityCompressionPool
import build.buf.connect.http.HTTPClientInterface
import build.buf.connect.protocolclientoptions.SerializationStrategy

/**
 *  Set of configuration (usually modified through `ClientOption` types) used to set up clients.
 */
class ProtocolClientConfig constructor(
    // The host (e.g., https://buf.build).
    val host: String,
    // The client to use for performing requests.
    val httpClient: HTTPClientInterface,
    // The serialization strategy for decoding messages.
    val serializationStrategy: SerializationStrategy,
    // / The compression type that should be used (e.g., "gzip").
    // / Requires a matching `compressionPools` entry.
    val requestCompressionName: String = IdentityCompressionPool.name(),
    // The minimum number of bytes that a request message should be for compression to be used.
    val compressionMinBytes: Int? = null,
    // Set of interceptors that should be invoked with requests/responses.
    val interceptors: List<(ProtocolClientConfig) -> Interceptor> = emptyList(),
    // Compression pools that provide support for the provided `compressionName`, as well as any
    // other compression methods that need to be supported for inbound responses.
    val compressionPools: Map<String, CompressionPool> = emptyMap(),
) {
    /**
     * Helps create a new configuration with modifications.
     */
    fun clone(
        host: String = this.host,
        httpClient: HTTPClientInterface = this.httpClient,
        serializationStrategy: SerializationStrategy = this.serializationStrategy,
        requestCompressionName: String = this.requestCompressionName,
        compressionMinBytes: Int? = this.compressionMinBytes,
        interceptors: List<(ProtocolClientConfig) -> Interceptor> = this.interceptors,
        compressionPools: Map<String, CompressionPool> = this.compressionPools,
    ): ProtocolClientConfig {
        return ProtocolClientConfig(
            host,
            httpClient,
            serializationStrategy,
            requestCompressionName,
            compressionMinBytes,
            interceptors,
            compressionPools,
        )
    }

    /**
     * Get the compression pool by name.
     *
     * @param name The name of the compression pool.
     */
    fun compressionPool(name: String?): CompressionPool? {
        if (compressionPools.containsKey(name)) {
            return compressionPools[name]!!
        }
        return null
    }

    /**
     * Creates an interceptor chain from the list of interceptors for unary based requests.
     */
    fun createInterceptorChain(): UnaryFunction {
        if (interceptors.isEmpty()) {
            return UnaryFunction()
        }
        val finalInterceptor = chain(interceptors)
        return finalInterceptor.unaryFunction()
    }

    /**
     * Creates an interceptor chain from the list of interceptors for streaming based requests.
     */
    fun createStreamingInterceptorChain(): StreamFunction {
        if (interceptors.isEmpty()) {
            return StreamFunction()
        }
        val finalInterceptor = chain(interceptors)
        return finalInterceptor.streamFunction()
    }

    private fun chain(
        interceptorFactories: List<(ProtocolClientConfig) -> Interceptor>
    ): Interceptor {
        val interceptors = interceptorFactories.map { factory -> factory(this) }
        return object : Interceptor {
            override fun unaryFunction(): UnaryFunction {
                val unaryFunctions = interceptors.map { interceptor -> interceptor.unaryFunction() }
                return UnaryFunction(
                    requestFunction = { httpRequest ->
                        var request = httpRequest
                        for (unaryFunction in unaryFunctions) {
                            request = unaryFunction.requestFunction(request)
                        }
                        request
                    },
                    responseFunction = { httpResponse ->
                        var response = httpResponse
                        for (unaryFunction in unaryFunctions.reversed()) {
                            response = unaryFunction.responseFunction(response)
                        }
                        response
                    }
                )
            }

            override fun streamFunction(): StreamFunction {
                val streamFunctions = interceptors.map { interceptor -> interceptor.streamFunction() }
                return StreamFunction(
                    requestFunction = { httpRequest ->
                        var request = httpRequest
                        for (streamFunction in streamFunctions) {
                            request = streamFunction.requestFunction(request)
                        }
                        request
                    },
                    requestBodyFunction = { requestBody ->
                        var body = requestBody
                        for (streamFunction in streamFunctions) {
                            body = streamFunction.requestBodyFunction(body)
                        }
                        body
                    },
                    streamResultFunction = { streamResult ->
                        var result = streamResult
                        for (streamFunction in streamFunctions.reversed()) {
                            result = streamFunction.streamResultFunction(result)
                        }
                        result
                    },
                )
            }
        }
    }
}
