/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.security.config.annotation.web.socket;

import io.micrometer.observation.ObservationRegistry;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Import;
import org.springframework.core.annotation.Order;
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
import org.springframework.messaging.simp.config.ChannelRegistration;
import org.springframework.messaging.support.ChannelInterceptor;
import org.springframework.security.authorization.AuthorizationEventPublisher;
import org.springframework.security.authorization.AuthorizationManager;
import org.springframework.security.authorization.ObservationAuthorizationManager;
import org.springframework.security.authorization.SpringAuthorizationEventPublisher;
import org.springframework.security.config.annotation.web.socket.MessageMatcherAuthorizationManagerConfiguration;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.context.SecurityContextHolderStrategy;
import org.springframework.security.messaging.access.intercept.AuthorizationChannelInterceptor;
import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager;
import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
import org.springframework.security.messaging.web.csrf.XorCsrfChannelInterceptor;
import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
import org.springframework.util.Assert;
import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
import org.springframework.web.socket.config.annotation.WebSocketMessageBrokerConfigurer;
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler;
import org.springframework.web.socket.sockjs.SockJsService;
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;

@Order(value=-2147483548)
@Import(value={MessageMatcherAuthorizationManagerConfiguration.class})
final class WebSocketMessageBrokerSecurityConfiguration
implements WebSocketMessageBrokerConfigurer,
SmartInitializingSingleton {
    private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping";
    private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor";
    private MessageMatcherDelegatingAuthorizationManager b;
    private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager.builder().anyMessage().authenticated().build();
    private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder.getContextHolderStrategy();
    private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
    private ChannelInterceptor csrfChannelInterceptor = new XorCsrfChannelInterceptor();
    private AuthorizationManager<Message<?>> authorizationManager = ANY_MESSAGE_AUTHENTICATED;
    private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
    private ApplicationContext context;

    WebSocketMessageBrokerSecurityConfiguration(ApplicationContext context) {
        this.context = context;
    }

    public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
        AuthenticationPrincipalArgumentResolver resolver = new AuthenticationPrincipalArgumentResolver();
        resolver.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
        argumentResolvers.add((HandlerMethodArgumentResolver)resolver);
    }

    public void configureClientInboundChannel(ChannelRegistration registration) {
        ChannelInterceptor csrfChannelInterceptor = this.getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME, ChannelInterceptor.class);
        if (csrfChannelInterceptor != null) {
            this.csrfChannelInterceptor = csrfChannelInterceptor;
        }
        AuthorizationManager<Message<?>> manager = this.authorizationManager;
        if (!this.observationRegistry.isNoop()) {
            manager = new ObservationAuthorizationManager(this.observationRegistry, manager);
        }
        AuthorizationChannelInterceptor interceptor = new AuthorizationChannelInterceptor(manager);
        interceptor.setAuthorizationEventPublisher((AuthorizationEventPublisher)new SpringAuthorizationEventPublisher(this.context));
        interceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
        this.securityContextChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
        registration.interceptors(new ChannelInterceptor[]{this.securityContextChannelInterceptor, this.csrfChannelInterceptor, interceptor});
    }

    @Autowired(required=false)
    void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
        Assert.notNull((Object)securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
        this.securityContextHolderStrategy = securityContextHolderStrategy;
    }

    @Autowired(required=false)
    void setAuthorizationManager(AuthorizationManager<Message<?>> authorizationManager) {
        this.authorizationManager = authorizationManager;
    }

    @Autowired(required=false)
    void setObservationRegistry(ObservationRegistry observationRegistry) {
        this.observationRegistry = observationRegistry;
    }

    @Override
    public void afterSingletonsInstantiated() {
        SimpleUrlHandlerMapping mapping = this.getBeanOrNull(SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME, SimpleUrlHandlerMapping.class);
        if (mapping == null) {
            return;
        }
        this.configureCsrf(mapping);
    }

    private <T> T getBeanOrNull(String name, Class<T> type) {
        Map<String, T> beans2 = this.context.getBeansOfType(type);
        return beans2.get(name);
    }

    private void configureCsrf(SimpleUrlHandlerMapping mapping) {
        Map mappings = mapping.getHandlerMap();
        for (Object object : mappings.values()) {
            if (object instanceof SockJsHttpRequestHandler) {
                this.setHandshakeInterceptors((SockJsHttpRequestHandler)object);
                continue;
            }
            if (object instanceof WebSocketHttpRequestHandler) {
                this.setHandshakeInterceptors((WebSocketHttpRequestHandler)object);
                continue;
            }
            throw new IllegalStateException("Bean stompWebSocketHandlerMapping is expected to contain mappings to either a SockJsHttpRequestHandler or a WebSocketHttpRequestHandler but got " + object);
        }
    }

    private void setHandshakeInterceptors(SockJsHttpRequestHandler handler) {
        SockJsService sockJsService = handler.getSockJsService();
        Assert.state(sockJsService instanceof TransportHandlingSockJsService, () -> "sockJsService must be instance of TransportHandlingSockJsService got " + sockJsService);
        TransportHandlingSockJsService transportHandlingSockJsService = (TransportHandlingSockJsService)sockJsService;
        List handshakeInterceptors = transportHandlingSockJsService.getHandshakeInterceptors();
        ArrayList<CsrfTokenHandshakeInterceptor> interceptorsToSet = new ArrayList<CsrfTokenHandshakeInterceptor>(handshakeInterceptors.size() + 1);
        interceptorsToSet.add(new CsrfTokenHandshakeInterceptor());
        interceptorsToSet.addAll(handshakeInterceptors);
        transportHandlingSockJsService.setHandshakeInterceptors(interceptorsToSet);
    }

    private void setHandshakeInterceptors(WebSocketHttpRequestHandler handler) {
        List handshakeInterceptors = handler.getHandshakeInterceptors();
        ArrayList<CsrfTokenHandshakeInterceptor> interceptorsToSet = new ArrayList<CsrfTokenHandshakeInterceptor>(handshakeInterceptors.size() + 1);
        interceptorsToSet.add(new CsrfTokenHandshakeInterceptor());
        interceptorsToSet.addAll(handshakeInterceptors);
        handler.setHandshakeInterceptors(interceptorsToSet);
    }
}

