/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.security.oauth2.server.authorization.authentication;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Predicate;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimValidator;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.authentication.CodeVerifierAuthenticator;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
import org.springframework.security.oauth2.server.authorization.context.ProviderContext;
import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;

public final class JwtClientAssertionAuthenticationProvider
implements AuthenticationProvider {
    private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1";
    private static final ClientAuthenticationMethod JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD = new ClientAuthenticationMethod("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
    private final RegisteredClientRepository registeredClientRepository;
    private final CodeVerifierAuthenticator codeVerifierAuthenticator;
    private final JwtClientAssertionDecoderFactory jwtClientAssertionDecoderFactory;

    public JwtClientAssertionAuthenticationProvider(RegisteredClientRepository registeredClientRepository, OAuth2AuthorizationService authorizationService) {
        Assert.notNull((Object)registeredClientRepository, (String)"registeredClientRepository cannot be null");
        Assert.notNull((Object)authorizationService, (String)"authorizationService cannot be null");
        this.registeredClientRepository = registeredClientRepository;
        this.codeVerifierAuthenticator = new CodeVerifierAuthenticator(authorizationService);
        this.jwtClientAssertionDecoderFactory = new JwtClientAssertionDecoderFactory();
    }

    public Authentication authenticate(Authentication authentication) throws AuthenticationException {
        OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken)authentication;
        if (!JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals((Object)clientAuthentication.getClientAuthenticationMethod())) {
            return null;
        }
        String clientId = clientAuthentication.getPrincipal().toString();
        RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
        if (registeredClient == null) {
            JwtClientAssertionAuthenticationProvider.throwInvalidClient("client_id");
        }
        if (!registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT) && !registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT)) {
            JwtClientAssertionAuthenticationProvider.throwInvalidClient("authentication_method");
        }
        if (clientAuthentication.getCredentials() == null) {
            JwtClientAssertionAuthenticationProvider.throwInvalidClient("credentials");
        }
        Jwt jwtAssertion = null;
        JwtDecoder jwtDecoder = this.jwtClientAssertionDecoderFactory.createDecoder(registeredClient);
        try {
            jwtAssertion = jwtDecoder.decode(clientAuthentication.getCredentials().toString());
        }
        catch (JwtException ex) {
            JwtClientAssertionAuthenticationProvider.throwInvalidClient("client_assertion", ex);
        }
        this.codeVerifierAuthenticator.authenticateIfAvailable(clientAuthentication, registeredClient);
        ClientAuthenticationMethod clientAuthenticationMethod = registeredClient.getClientSettings().getTokenEndpointAuthenticationSigningAlgorithm() instanceof SignatureAlgorithm ? ClientAuthenticationMethod.PRIVATE_KEY_JWT : ClientAuthenticationMethod.CLIENT_SECRET_JWT;
        return new OAuth2ClientAuthenticationToken(registeredClient, clientAuthenticationMethod, jwtAssertion);
    }

    public boolean supports(Class<?> authentication) {
        return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
    }

    private static void throwInvalidClient(String parameterName) {
        JwtClientAssertionAuthenticationProvider.throwInvalidClient(parameterName, null);
    }

    private static void throwInvalidClient(String parameterName, Throwable cause) {
        OAuth2Error error = new OAuth2Error("invalid_client", "Client authentication failed: " + parameterName, ERROR_URI);
        throw new OAuth2AuthenticationException(error, error.toString(), cause);
    }

    private static class JwtClientAssertionDecoderFactory
    implements JwtDecoderFactory<RegisteredClient> {
        private static final String JWT_CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7523#section-3";
        private static final Map<JwsAlgorithm, String> JCA_ALGORITHM_MAPPINGS;
        private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap<String, JwtDecoder>();

        private JwtClientAssertionDecoderFactory() {
        }

        public JwtDecoder createDecoder(RegisteredClient registeredClient) {
            Assert.notNull((Object)registeredClient, (String)"registeredClient cannot be null");
            return this.jwtDecoders.computeIfAbsent(registeredClient.getId(), key -> {
                NimbusJwtDecoder jwtDecoder = JwtClientAssertionDecoderFactory.buildDecoder(registeredClient);
                jwtDecoder.setJwtValidator(JwtClientAssertionDecoderFactory.createJwtValidator(registeredClient));
                return jwtDecoder;
            });
        }

        private static NimbusJwtDecoder buildDecoder(RegisteredClient registeredClient) {
            JwsAlgorithm jwsAlgorithm = registeredClient.getClientSettings().getTokenEndpointAuthenticationSigningAlgorithm();
            if (jwsAlgorithm instanceof SignatureAlgorithm) {
                String jwkSetUrl = registeredClient.getClientSettings().getJwkSetUrl();
                if (!StringUtils.hasText((String)jwkSetUrl)) {
                    OAuth2Error oauth2Error = new OAuth2Error("invalid_client", "Failed to find a Signature Verifier for Client: '" + registeredClient.getId() + "'. Check to ensure you have configured the JWK Set URL.", JWT_CLIENT_AUTHENTICATION_ERROR_URI);
                    throw new OAuth2AuthenticationException(oauth2Error);
                }
                return NimbusJwtDecoder.withJwkSetUri((String)jwkSetUrl).jwsAlgorithm((SignatureAlgorithm)jwsAlgorithm).build();
            }
            if (jwsAlgorithm instanceof MacAlgorithm) {
                String clientSecret = registeredClient.getClientSecret();
                if (!StringUtils.hasText((String)clientSecret)) {
                    OAuth2Error oauth2Error = new OAuth2Error("invalid_client", "Failed to find a Signature Verifier for Client: '" + registeredClient.getId() + "'. Check to ensure you have configured the client secret.", JWT_CLIENT_AUTHENTICATION_ERROR_URI);
                    throw new OAuth2AuthenticationException(oauth2Error);
                }
                SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8), JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
                return NimbusJwtDecoder.withSecretKey((SecretKey)secretKeySpec).macAlgorithm((MacAlgorithm)jwsAlgorithm).build();
            }
            OAuth2Error oauth2Error = new OAuth2Error("invalid_client", "Failed to find a Signature Verifier for Client: '" + registeredClient.getId() + "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'.", JWT_CLIENT_AUTHENTICATION_ERROR_URI);
            throw new OAuth2AuthenticationException(oauth2Error);
        }

        private static OAuth2TokenValidator<Jwt> createJwtValidator(RegisteredClient registeredClient) {
            String clientId = registeredClient.getClientId();
            OAuth2TokenValidator[] oAuth2TokenValidatorArray = new OAuth2TokenValidator[5];
            oAuth2TokenValidatorArray[0] = new JwtClaimValidator("iss", clientId::equals);
            oAuth2TokenValidatorArray[1] = new JwtClaimValidator("sub", clientId::equals);
            oAuth2TokenValidatorArray[2] = new JwtClaimValidator("aud", JwtClientAssertionDecoderFactory.containsProviderAudience());
            oAuth2TokenValidatorArray[3] = new JwtClaimValidator("exp", Objects::nonNull);
            oAuth2TokenValidatorArray[4] = new JwtTimestampValidator();
            return new DelegatingOAuth2TokenValidator(oAuth2TokenValidatorArray);
        }

        private static Predicate<List<String>> containsProviderAudience() {
            return audienceClaim -> {
                if (CollectionUtils.isEmpty((Collection)audienceClaim)) {
                    return false;
                }
                List<String> providerAudience = JwtClientAssertionDecoderFactory.getProviderAudience();
                for (String audience : audienceClaim) {
                    if (!providerAudience.contains(audience)) continue;
                    return true;
                }
                return false;
            };
        }

        private static List<String> getProviderAudience() {
            ProviderContext providerContext = ProviderContextHolder.getProviderContext();
            if (!StringUtils.hasText((String)providerContext.getIssuer())) {
                return Collections.emptyList();
            }
            ProviderSettings providerSettings = providerContext.getProviderSettings();
            ArrayList<String> providerAudience = new ArrayList<String>();
            providerAudience.add(providerContext.getIssuer());
            providerAudience.add(JwtClientAssertionDecoderFactory.asUrl(providerContext.getIssuer(), providerSettings.getTokenEndpoint()));
            providerAudience.add(JwtClientAssertionDecoderFactory.asUrl(providerContext.getIssuer(), providerSettings.getTokenIntrospectionEndpoint()));
            providerAudience.add(JwtClientAssertionDecoderFactory.asUrl(providerContext.getIssuer(), providerSettings.getTokenRevocationEndpoint()));
            return providerAudience;
        }

        private static String asUrl(String issuer, String endpoint) {
            return UriComponentsBuilder.fromUriString((String)issuer).path(endpoint).build().toUriString();
        }

        static {
            HashMap<MacAlgorithm, String> mappings = new HashMap<MacAlgorithm, String>();
            mappings.put(MacAlgorithm.HS256, "HmacSHA256");
            mappings.put(MacAlgorithm.HS384, "HmacSHA384");
            mappings.put(MacAlgorithm.HS512, "HmacSHA512");
            JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings);
        }
    }
}

