/*
 * Decompiled with CFR 0.152.
 */
package io.inverno.mod.security.http.cors;

import io.inverno.mod.http.base.BadRequestException;
import io.inverno.mod.http.base.ExchangeContext;
import io.inverno.mod.http.base.ForbiddenException;
import io.inverno.mod.http.base.Method;
import io.inverno.mod.http.base.OutboundResponseHeaders;
import io.inverno.mod.http.base.Status;
import io.inverno.mod.http.server.Exchange;
import io.inverno.mod.http.server.ExchangeInterceptor;
import java.net.URI;
import java.util.Collections;
import java.util.HashSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import reactor.core.publisher.Mono;

public class CORSInterceptor<A extends ExchangeContext, B extends Exchange<A>>
implements ExchangeInterceptor<A, B> {
    private static final int DEFAULT_HTTP_PORT = 80;
    private static final int DEFAULT_HTTPS_PORT = 443;
    private static final int DEFAULT_FTP_PORT = 21;
    protected final Set<Origin> allowedOrigins;
    protected final Set<Pattern> allowedOriginsPattern;
    protected final boolean allowCredentials;
    protected final String allowedHeaders;
    protected final String allowedMethods;
    protected final String exposedHeaders;
    protected final Integer maxAge;
    protected final boolean allowPrivateNetwork;
    protected final boolean isWildcard;
    protected final boolean isStatic;

    protected CORSInterceptor(Set<Origin> allowedOrigins, Set<Pattern> allowedOriginsPattern, boolean allowCredentials, Set<String> allowedHeaders, Set<Method> allowedMethods, Set<String> exposedHeaders, Integer maxAge, boolean allowPrivateNetwork) {
        this.allowedOrigins = allowedOrigins != null ? Collections.unmodifiableSet(allowedOrigins) : Set.of();
        this.allowedOriginsPattern = allowedOriginsPattern != null ? Collections.unmodifiableSet(allowedOriginsPattern) : Set.of();
        this.allowCredentials = allowCredentials;
        this.allowedHeaders = allowedHeaders != null ? allowedHeaders.stream().collect(Collectors.joining(",")) : null;
        this.allowedMethods = allowedMethods != null ? allowedMethods.stream().map(Enum::toString).collect(Collectors.joining(",")) : null;
        this.exposedHeaders = exposedHeaders != null ? exposedHeaders.stream().collect(Collectors.joining(",")) : null;
        this.maxAge = maxAge;
        this.allowPrivateNetwork = allowPrivateNetwork;
        this.isWildcard = this.allowedOrigins.isEmpty() && this.allowedOriginsPattern.isEmpty();
        this.isStatic = this.allowedOriginsPattern.isEmpty() && this.allowedOrigins.size() == 1;
    }

    public static Builder builder(String ... allowedOrigins) {
        Builder builder = new Builder();
        if (allowedOrigins != null) {
            for (String allowedOrigin : allowedOrigins) {
                builder.allowOrigin(allowedOrigin);
            }
        }
        return builder;
    }

    public Mono<? extends B> intercept(B exchange) {
        Origin origin;
        Optional originOpt = exchange.request().headers().get((CharSequence)"origin");
        if (!originOpt.isPresent()) {
            if (!this.isWildcard && !this.isStatic) {
                exchange.response().headers(headers -> CORSInterceptor.addVaryHeader(headers, "origin"));
            } else if (this.isWildcard) {
                exchange.response().headers(headers -> headers.set((CharSequence)"access-control-allow-origin", (CharSequence)"*"));
            } else if (this.isStatic) {
                exchange.response().headers(headers -> headers.set((CharSequence)"access-control-allow-origin", (CharSequence)this.allowedOrigins.iterator().next().toString()));
            }
            return Mono.just(exchange);
        }
        String originValue = (String)originOpt.get();
        try {
            origin = new Origin(originValue);
        }
        catch (IllegalArgumentException e) {
            throw new BadRequestException("Invalid origin header: " + (String)originOpt.get(), (Throwable)e);
        }
        if (this.isSameOrigin(exchange, origin)) {
            return Mono.just(exchange);
        }
        this.checkOrigin(origin);
        boolean preflight = exchange.request().getMethod().equals((Object)Method.OPTIONS) && exchange.request().headers().contains((CharSequence)"access-control-request-method");
        exchange.response().headers(headers -> {
            if (this.allowCredentials) {
                headers.set((CharSequence)"access-control-allow-credentials", (CharSequence)"true");
                headers.set((CharSequence)"access-control-allow-origin", (CharSequence)originValue);
            } else if (this.isWildcard) {
                headers.set((CharSequence)"access-control-allow-origin", (CharSequence)"*");
            } else {
                headers.set((CharSequence)"access-control-allow-origin", (CharSequence)originValue);
            }
            if (preflight) {
                headers.status(Status.NO_CONTENT).contentLength(0L);
                if (this.allowedMethods != null) {
                    headers.set((CharSequence)"access-control-allow-methods", (CharSequence)this.allowedMethods);
                }
                if (this.allowedHeaders != null) {
                    headers.set((CharSequence)"access-control-allow-headers", (CharSequence)this.allowedHeaders);
                } else {
                    exchange.request().headers().get((CharSequence)"access-control-request-headers").ifPresent(value -> {
                        CORSInterceptor.addVaryHeader(headers, "access-control-request-headers");
                        headers.set((CharSequence)"access-control-allow-headers", (CharSequence)value);
                    });
                }
                if (this.maxAge != null) {
                    headers.set((CharSequence)"access-control-max-age", (CharSequence)this.maxAge.toString());
                }
                if (this.allowPrivateNetwork) {
                    exchange.request().headers().get((CharSequence)"access-control-request-private-network").ifPresent(value -> {
                        if (Boolean.valueOf(value).booleanValue()) {
                            headers.set((CharSequence)"access-control-allow-private-network", (CharSequence)"true");
                        }
                    });
                }
            } else {
                CORSInterceptor.addVaryHeader(headers, "origin");
                if (this.exposedHeaders != null) {
                    headers.set((CharSequence)"access-control-expose-headers", (CharSequence)this.exposedHeaders);
                }
            }
        });
        if (preflight) {
            exchange.response().body().empty();
            return Mono.empty();
        }
        return Mono.just(exchange);
    }

    private static void addVaryHeader(OutboundResponseHeaders responseHeaders, String headerName) {
        responseHeaders.get((CharSequence)"www-vary").ifPresentOrElse(vary -> responseHeaders.set((CharSequence)"www-vary", (CharSequence)(vary + "," + headerName)), () -> responseHeaders.set((CharSequence)"www-vary", (CharSequence)headerName));
    }

    protected boolean isSameOrigin(B exchange, Origin origin) {
        String scheme = exchange.request().getScheme();
        String authority = exchange.request().getAuthority();
        if (scheme != null && authority != null) {
            String[] splitAuthority = authority.split(":");
            return (switch (splitAuthority.length) {
                case 1 -> new Origin(scheme, splitAuthority[0], null);
                case 2 -> new Origin(scheme, splitAuthority[0], Integer.parseInt(splitAuthority[1]));
                default -> throw new BadRequestException("Invalid authority");
            }).equals(origin);
        }
        return false;
    }

    protected void checkOrigin(Origin origin) throws BadRequestException, ForbiddenException {
        if (this.isWildcard) {
            return;
        }
        for (Pattern matcher : this.allowedOriginsPattern) {
            if (!matcher.matcher(origin.toString()).matches()) continue;
            return;
        }
        for (Origin allowedOrigin : this.allowedOrigins) {
            if (!allowedOrigin.equals(origin)) continue;
            return;
        }
        throw new ForbiddenException("Rejected CORS: " + origin.toString() + " is not auhorized to access resources");
    }

    public static class Builder {
        protected Set<Origin> allowedOrigins;
        protected Set<Pattern> allowedOriginsPattern;
        protected boolean allowCredentials;
        protected Set<String> allowedHeaders;
        protected Set<Method> allowedMethods;
        protected Set<String> exposedHeaders;
        protected Integer maxAge;
        protected boolean allowPrivateNetwork;

        protected Builder() {
        }

        public Builder allowOrigin(String allowedOrigin) throws IllegalArgumentException {
            if (StringUtils.isNotBlank((CharSequence)allowedOrigin)) {
                if (allowedOrigin.equals("*")) {
                    this.allowedOrigins = Set.of();
                    this.allowedOriginsPattern = Set.of();
                    return this;
                }
                if (this.allowedOrigins == null) {
                    this.allowedOrigins = new HashSet<Origin>();
                } else if (this.allowedOrigins.isEmpty()) {
                    return this;
                }
                this.allowedOrigins.add(new Origin(allowedOrigin));
            }
            return this;
        }

        public Builder allowOriginPattern(String allowedOriginRegex) {
            if (StringUtils.isNotBlank((CharSequence)allowedOriginRegex)) {
                if (allowedOriginRegex.equals(".*")) {
                    this.allowedOrigins = Set.of();
                    this.allowedOriginsPattern = Set.of();
                    return this;
                }
                if (this.allowedOriginsPattern == null) {
                    this.allowedOriginsPattern = new HashSet<Pattern>();
                } else if (this.allowedOriginsPattern.isEmpty()) {
                    return this;
                }
                this.allowedOriginsPattern.add(Pattern.compile(allowedOriginRegex));
            }
            return this;
        }

        public Builder allowCredentials(boolean allowCredentials) {
            this.allowCredentials = allowCredentials;
            return this;
        }

        public Builder allowHeader(String allowedHeader) {
            if (StringUtils.isNotBlank((CharSequence)allowedHeader)) {
                if (this.allowedHeaders == null) {
                    this.allowedHeaders = new HashSet<String>();
                }
                this.allowedHeaders.add(allowedHeader);
            }
            return this;
        }

        public Builder allowMethod(Method allowedMethod) {
            if (allowedMethod != null) {
                if (this.allowedMethods == null) {
                    this.allowedMethods = new HashSet<Method>();
                }
                this.allowedMethods.add(allowedMethod);
            }
            return this;
        }

        public Builder exposeHeader(String exposedHeader) {
            if (StringUtils.isNotBlank((CharSequence)exposedHeader)) {
                if (this.exposedHeaders == null) {
                    this.exposedHeaders = new HashSet<String>();
                }
                this.exposedHeaders.add(exposedHeader);
            }
            return this;
        }

        public Builder maxAge(int maxAge) {
            this.maxAge = maxAge;
            return this;
        }

        public Builder allowPrivateNetwork(boolean allowPrivateNetwork) {
            this.allowPrivateNetwork = allowPrivateNetwork;
            return this;
        }

        public <A extends ExchangeContext, B extends Exchange<A>> CORSInterceptor<A, B> build() {
            return new CORSInterceptor(this.allowedOrigins, this.allowedOriginsPattern, this.allowCredentials, this.allowedHeaders, this.allowedMethods, this.exposedHeaders, this.maxAge, this.allowPrivateNetwork);
        }
    }

    protected static class Origin {
        protected final String scheme;
        protected final String host;
        protected final int port;
        protected final String value;

        protected Origin(String origin) throws IllegalArgumentException {
            URI originURI = URI.create(origin);
            String originURIScheme = originURI.getScheme();
            String originURIHost = originURI.getHost();
            if (originURIScheme == null || originURIHost == null) {
                throw new IllegalArgumentException("Invalid origin: " + origin);
            }
            this.scheme = originURIScheme.toLowerCase();
            this.host = originURIHost.toLowerCase();
            StringBuilder valueBuilder = new StringBuilder();
            valueBuilder.append(this.scheme).append("://").append(this.host);
            if (originURI.getPort() == -1) {
                this.port = Origin.getDefaultPort(this.scheme);
            } else {
                this.port = originURI.getPort();
                valueBuilder.append(":").append(this.port);
            }
            this.value = valueBuilder.toString();
        }

        protected Origin(String scheme, String host, Integer port) throws IllegalArgumentException {
            this.scheme = scheme;
            this.host = host;
            StringBuilder valueBuilder = new StringBuilder();
            valueBuilder.append(this.scheme).append("://").append(this.host);
            if (port == null || port < 0) {
                this.port = Origin.getDefaultPort(scheme);
            } else {
                this.port = port;
                valueBuilder.append(":").append(this.port);
            }
            this.value = valueBuilder.toString();
        }

        public static Integer getDefaultPort(String scheme) {
            switch (scheme.toLowerCase()) {
                case "http": {
                    return 80;
                }
                case "https": {
                    return 443;
                }
                case "ftp": {
                    return 21;
                }
            }
            return -1;
        }

        public String toString() {
            return this.value;
        }

        public int hashCode() {
            int hash = 5;
            hash = 59 * hash + Objects.hashCode(this.scheme);
            hash = 59 * hash + Objects.hashCode(this.host);
            hash = 59 * hash + this.port;
            return hash;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null) {
                return false;
            }
            if (this.getClass() != obj.getClass()) {
                return false;
            }
            Origin other = (Origin)obj;
            if (this.port != other.port) {
                return false;
            }
            if (!Objects.equals(this.scheme, other.scheme)) {
                return false;
            }
            return Objects.equals(this.host, other.host);
        }
    }
}

