package org.springframework.security.oauth2.jwt;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.proc.DefaultJOSEObjectTypeVerifier;
import com.nimbusds.jose.proc.JOSEObjectTypeVerifier;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.time.temporal.TemporalUnit;
import java.util.Base64;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidatorResult;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/* loaded from: input_file:org/springframework/security/oauth2/jwt/DPoPProofJwtDecoderFactory.class */
public final class DPoPProofJwtDecoderFactory implements JwtDecoderFactory<DPoPProofContext> {
    public static final Function<DPoPProofContext, OAuth2TokenValidator<Jwt>> DEFAULT_JWT_VALIDATOR_FACTORY = defaultJwtValidatorFactory();
    private static final JOSEObjectTypeVerifier<SecurityContext> DPOP_TYPE_VERIFIER = new DefaultJOSEObjectTypeVerifier(new JOSEObjectType[]{new JOSEObjectType("dpop+jwt")});
    private Function<DPoPProofContext, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = DEFAULT_JWT_VALIDATOR_FACTORY;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/security/oauth2/jwt/DPoPProofJwtDecoderFactory$JtiClaimValidator.class */
    public static final class JtiClaimValidator implements OAuth2TokenValidator<Jwt> {
        private static final Map<String, Long> JTI_CACHE = Collections.synchronizedMap(new JtiCache());

        /* loaded from: input_file:org/springframework/security/oauth2/jwt/DPoPProofJwtDecoderFactory$JtiClaimValidator$JtiCache.class */
        private static final class JtiCache extends LinkedHashMap<String, Long> {
            private static final int MAX_SIZE = 1000;

            private JtiCache() {
            }

            @Override // java.util.LinkedHashMap
            protected boolean removeEldestEntry(Map.Entry<String, Long> entry) {
                if (size() > MAX_SIZE) {
                    return true;
                }
                return Instant.now().isAfter(Instant.ofEpochMilli(entry.getValue().longValue()));
            }
        }

        private JtiClaimValidator() {
        }

        public OAuth2TokenValidatorResult validate(Jwt jwt) {
            Assert.notNull(jwt, "DPoP proof jwt cannot be null");
            String id = jwt.getId();
            if (!StringUtils.hasText(id)) {
                return OAuth2TokenValidatorResult.failure(new OAuth2Error[]{createOAuth2Error("jti claim is required.")});
            }
            try {
                return JTI_CACHE.putIfAbsent(computeSHA256(id), Long.valueOf(Instant.now().plus(1L, (TemporalUnit) ChronoUnit.HOURS).toEpochMilli())) != null ? OAuth2TokenValidatorResult.failure(new OAuth2Error[]{createOAuth2Error("jti claim is invalid.")}) : OAuth2TokenValidatorResult.success();
            } catch (Exception e) {
                return OAuth2TokenValidatorResult.failure(new OAuth2Error[]{createOAuth2Error("jti claim is invalid.")});
            }
        }

        private static OAuth2Error createOAuth2Error(String str) {
            return new OAuth2Error("invalid_dpop_proof", str, (String) null);
        }

        private static String computeSHA256(String str) throws Exception {
            return Base64.getUrlEncoder().withoutPadding().encodeToString(MessageDigest.getInstance("SHA-256").digest(str.getBytes(StandardCharsets.UTF_8)));
        }
    }

    @Override // org.springframework.security.oauth2.jwt.JwtDecoderFactory
    public JwtDecoder createDecoder(DPoPProofContext dPoPProofContext) {
        Assert.notNull(dPoPProofContext, "dPoPProofContext cannot be null");
        NimbusJwtDecoder buildDecoder = buildDecoder();
        buildDecoder.setJwtValidator(this.jwtValidatorFactory.apply(dPoPProofContext));
        return buildDecoder;
    }

    public void setJwtValidatorFactory(Function<DPoPProofContext, OAuth2TokenValidator<Jwt>> function) {
        Assert.notNull(function, "jwtValidatorFactory cannot be null");
        this.jwtValidatorFactory = function;
    }

    private static NimbusJwtDecoder buildDecoder() {
        DefaultJWTProcessor defaultJWTProcessor = new DefaultJWTProcessor();
        defaultJWTProcessor.setJWSTypeVerifier(DPOP_TYPE_VERIFIER);
        defaultJWTProcessor.setJWSKeySelector(jwsKeySelector());
        defaultJWTProcessor.setJWTClaimsSetVerifier((jWTClaimsSet, securityContext) -> {
        });
        return new NimbusJwtDecoder(defaultJWTProcessor);
    }

    private static JWSKeySelector<SecurityContext> jwsKeySelector() {
        return (jWSHeader, securityContext) -> {
            JWSAlgorithm algorithm = jWSHeader.getAlgorithm();
            if (!JWSAlgorithm.Family.RSA.contains(algorithm) && !JWSAlgorithm.Family.EC.contains(algorithm)) {
                throw new BadJwtException("Unsupported alg parameter in JWS Header: " + algorithm.getName());
            }
            RSAKey jwk = jWSHeader.getJWK();
            if (jwk == null) {
                throw new BadJwtException("Missing jwk parameter in JWS Header.");
            }
            if (jwk.isPrivate()) {
                throw new BadJwtException("Invalid jwk parameter in JWS Header.");
            }
            try {
                if (JWSAlgorithm.Family.RSA.contains(algorithm) && (jwk instanceof RSAKey)) {
                    return Collections.singletonList(jwk.toRSAPublicKey());
                }
                if (JWSAlgorithm.Family.EC.contains(algorithm) && (jwk instanceof ECKey)) {
                    return Collections.singletonList(((ECKey) jwk).toECPublicKey());
                }
                throw new BadJwtException("Invalid alg / jwk parameter in JWS Header: alg=" + algorithm.getName() + ", jwk.kty=" + jwk.getKeyType().getValue());
            } catch (JOSEException e) {
                throw new BadJwtException("Invalid jwk parameter in JWS Header.");
            }
        };
    }

    private static Function<DPoPProofContext, OAuth2TokenValidator<Jwt>> defaultJwtValidatorFactory() {
        return dPoPProofContext -> {
            String method = dPoPProofContext.getMethod();
            Objects.requireNonNull(method);
            String targetUri = dPoPProofContext.getTargetUri();
            Objects.requireNonNull(targetUri);
            return new DelegatingOAuth2TokenValidator(new OAuth2TokenValidator[]{new JwtClaimValidator("htm", method::equals), new JwtClaimValidator("htu", targetUri::equals), new JtiClaimValidator(), new JwtIssuedAtValidator(true)});
        };
    }
}
