package org.springframework.security.oauth2.server.authorization.authentication;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Predicate;
import javax.crypto.spec.SecretKeySpec;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimValidator;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContext;
import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder;
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenClaimNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;

/* loaded from: input_file:org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionDecoderFactory.class */
public final class JwtClientAssertionDecoderFactory implements JwtDecoderFactory<RegisteredClient> {
    public static final Function<RegisteredClient, OAuth2TokenValidator<Jwt>> DEFAULT_JWT_VALIDATOR_FACTORY = defaultJwtValidatorFactory();
    private static final String JWT_CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7523#section-3";
    private static final Map<JwsAlgorithm, String> JCA_ALGORITHM_MAPPINGS;
    private static final RestTemplate restTemplate;
    private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap();
    private Function<RegisteredClient, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = DEFAULT_JWT_VALIDATOR_FACTORY;

    public JwtDecoder createDecoder(RegisteredClient registeredClient) {
        Assert.notNull(registeredClient, "registeredClient cannot be null");
        return this.jwtDecoders.computeIfAbsent(registeredClient.getId(), str -> {
            NimbusJwtDecoder buildDecoder = buildDecoder(registeredClient);
            buildDecoder.setJwtValidator(this.jwtValidatorFactory.apply(registeredClient));
            return buildDecoder;
        });
    }

    public void setJwtValidatorFactory(Function<RegisteredClient, OAuth2TokenValidator<Jwt>> function) {
        Assert.notNull(function, "jwtValidatorFactory cannot be null");
        this.jwtValidatorFactory = function;
    }

    private static NimbusJwtDecoder buildDecoder(RegisteredClient registeredClient) {
        SignatureAlgorithm tokenEndpointAuthenticationSigningAlgorithm = registeredClient.getClientSettings().getTokenEndpointAuthenticationSigningAlgorithm();
        if (tokenEndpointAuthenticationSigningAlgorithm instanceof SignatureAlgorithm) {
            String jwkSetUrl = registeredClient.getClientSettings().getJwkSetUrl();
            if (StringUtils.hasText(jwkSetUrl)) {
                return NimbusJwtDecoder.withJwkSetUri(jwkSetUrl).jwsAlgorithm(tokenEndpointAuthenticationSigningAlgorithm).restOperations(restTemplate).build();
            }
            throw new OAuth2AuthenticationException(new OAuth2Error("invalid_client", "Failed to find a Signature Verifier for Client: '" + registeredClient.getId() + "'. Check to ensure you have configured the JWK Set URL.", JWT_CLIENT_AUTHENTICATION_ERROR_URI));
        }
        if (!(tokenEndpointAuthenticationSigningAlgorithm instanceof MacAlgorithm)) {
            throw new OAuth2AuthenticationException(new OAuth2Error("invalid_client", "Failed to find a Signature Verifier for Client: '" + registeredClient.getId() + "'. Check to ensure you have configured a valid JWS Algorithm: '" + tokenEndpointAuthenticationSigningAlgorithm + "'.", JWT_CLIENT_AUTHENTICATION_ERROR_URI));
        }
        String clientSecret = registeredClient.getClientSecret();
        if (StringUtils.hasText(clientSecret)) {
            return NimbusJwtDecoder.withSecretKey(new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8), JCA_ALGORITHM_MAPPINGS.get(tokenEndpointAuthenticationSigningAlgorithm))).macAlgorithm((MacAlgorithm) tokenEndpointAuthenticationSigningAlgorithm).build();
        }
        throw new OAuth2AuthenticationException(new OAuth2Error("invalid_client", "Failed to find a Signature Verifier for Client: '" + registeredClient.getId() + "'. Check to ensure you have configured the client secret.", JWT_CLIENT_AUTHENTICATION_ERROR_URI));
    }

    private static Function<RegisteredClient, OAuth2TokenValidator<Jwt>> defaultJwtValidatorFactory() {
        return registeredClient -> {
            String clientId = registeredClient.getClientId();
            clientId.getClass();
            clientId.getClass();
            return new DelegatingOAuth2TokenValidator(new OAuth2TokenValidator[]{new JwtClaimValidator(OAuth2TokenClaimNames.ISS, clientId::equals), new JwtClaimValidator(OAuth2TokenClaimNames.SUB, clientId::equals), new JwtClaimValidator(OAuth2TokenClaimNames.AUD, containsAudience()), new JwtClaimValidator(OAuth2TokenClaimNames.EXP, Objects::nonNull), new JwtTimestampValidator()});
        };
    }

    private static Predicate<List<String>> containsAudience() {
        return list -> {
            if (CollectionUtils.isEmpty(list)) {
                return false;
            }
            List<String> audience = getAudience();
            Iterator it = list.iterator();
            while (it.hasNext()) {
                if (audience.contains((String) it.next())) {
                    return true;
                }
            }
            return false;
        };
    }

    private static List<String> getAudience() {
        AuthorizationServerContext context = AuthorizationServerContextHolder.getContext();
        if (!StringUtils.hasText(context.getIssuer())) {
            return Collections.emptyList();
        }
        AuthorizationServerSettings authorizationServerSettings = context.getAuthorizationServerSettings();
        ArrayList arrayList = new ArrayList();
        arrayList.add(context.getIssuer());
        arrayList.add(asUrl(context.getIssuer(), authorizationServerSettings.getTokenEndpoint()));
        arrayList.add(asUrl(context.getIssuer(), authorizationServerSettings.getTokenIntrospectionEndpoint()));
        arrayList.add(asUrl(context.getIssuer(), authorizationServerSettings.getTokenRevocationEndpoint()));
        return arrayList;
    }

    private static String asUrl(String str, String str2) {
        return UriComponentsBuilder.fromUriString(str).path(str2).build().toUriString();
    }

    static {
        HashMap hashMap = new HashMap();
        hashMap.put(MacAlgorithm.HS256, "HmacSHA256");
        hashMap.put(MacAlgorithm.HS384, "HmacSHA384");
        hashMap.put(MacAlgorithm.HS512, "HmacSHA512");
        JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(hashMap);
        restTemplate = new RestTemplate();
        SimpleClientHttpRequestFactory simpleClientHttpRequestFactory = new SimpleClientHttpRequestFactory();
        simpleClientHttpRequestFactory.setConnectTimeout(15000);
        simpleClientHttpRequestFactory.setReadTimeout(15000);
        restTemplate.setRequestFactory(simpleClientHttpRequestFactory);
    }
}
