package io.strimzi.kafka.oauth.validator;

import com.fasterxml.jackson.databind.JsonNode;
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.SignedJWT;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import io.strimzi.kafka.oauth.common.HttpUtil;
import io.strimzi.kafka.oauth.common.JSONUtil;
import io.strimzi.kafka.oauth.common.LogUtil;
import io.strimzi.kafka.oauth.common.OAuthAuthenticator;
import io.strimzi.kafka.oauth.common.PrincipalExtractor;
import io.strimzi.kafka.oauth.common.TimeUtil;
import io.strimzi.kafka.oauth.common.TokenInfo;
import io.strimzi.kafka.oauth.common.TokenProvider;
import io.strimzi.kafka.oauth.jsonpath.JsonPathFilterQuery;
import io.strimzi.kafka.oauth.jsonpath.JsonPathQuery;
import io.strimzi.kafka.oauth.metrics.JwksHttpSensorKeyProducer;
import io.strimzi.kafka.oauth.metrics.SensorKeyProducer;
import io.strimzi.kafka.oauth.services.OAuthMetrics;
import io.strimzi.kafka.oauth.services.ServiceException;
import io.strimzi.kafka.oauth.services.Services;
import io.strimzi.kafka.oauth.validator.TokenValidationException;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.PublicKey;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocketFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/strimzi/kafka/oauth/validator/JWTSignatureValidator.class */
public class JWTSignatureValidator implements TokenValidator {
    private static final Logger log = LoggerFactory.getLogger(JWTSignatureValidator.class);
    private static final DefaultJWSVerifierFactory VERIFIER_FACTORY = new DefaultJWSVerifierFactory();
    private final String validatorId;
    private final String clientId;
    private final String clientSecret;
    private final TokenProvider bearerTokenProvider;
    private final URI keysUri;
    private final String issuerUri;
    private final int maxStaleSeconds;
    private final boolean checkAccessTokenType;
    private final String audience;
    private final JsonPathFilterQuery customClaimMatcher;
    private final JsonPathQuery groupsQuery;
    private final String groupsDelimiter;
    private final SSLSocketFactory socketFactory;
    private final HostnameVerifier hostnameVerifier;
    private final PrincipalExtractor principalExtractor;
    private final boolean ignoreKeyUse;
    private final int connectTimeout;
    private final int readTimeout;
    private long lastFetchTime;
    private Map<String, PublicKey> cache = Collections.emptyMap();
    private Map<String, PublicKey> oldCache = Collections.emptyMap();
    private BackOffTaskScheduler fastScheduler;
    private final ScheduledExecutorService executor;
    private final boolean enableMetrics;
    private final OAuthMetrics metrics;
    private final SensorKeyProducer jwksHttpSensorKeyProducer;
    private final boolean includeAcceptHeader;

    public JWTSignatureValidator(String str, String str2, String str3, TokenProvider tokenProvider, String str4, SSLSocketFactory sSLSocketFactory, HostnameVerifier hostnameVerifier, PrincipalExtractor principalExtractor, String str5, String str6, String str7, int i, int i2, int i3, boolean z, boolean z2, String str8, String str9, int i4, int i5, boolean z3, boolean z4, boolean z5) {
        OAuthMetrics metrics;
        if (str == null) {
            throw new IllegalArgumentException("validatorId == null");
        }
        this.validatorId = str;
        this.clientId = str2;
        this.clientSecret = str3;
        this.bearerTokenProvider = tokenProvider;
        checkAuthorizationOptions(str2, tokenProvider);
        this.issuerUri = checkIssuerUri(str7);
        this.keysUri = checkKeysEndpointUri(str4);
        this.socketFactory = checkSocketFactory(sSLSocketFactory);
        this.hostnameVerifier = checkHostnameVerifier(hostnameVerifier);
        this.principalExtractor = principalExtractor;
        validateRefreshConfig(i, i3);
        this.maxStaleSeconds = i3;
        this.checkAccessTokenType = z2;
        this.audience = str8;
        this.customClaimMatcher = parseCustomClaimCheck(str9);
        this.groupsQuery = parseGroupsQuery(str5);
        this.groupsDelimiter = parseGroupsDelimiter(str6);
        this.connectTimeout = i4;
        this.readTimeout = i5;
        this.enableMetrics = z3;
        this.ignoreKeyUse = z;
        this.includeAcceptHeader = z5;
        if (z3) {
            try {
                metrics = Services.getInstance().getMetrics();
            } catch (Throwable th) {
                if (log.isDebugEnabled()) {
                    log.debug("Configured JWTSignatureValidator:\n\t  validatorId: " + str + "\n\t  clientId: " + str2 + "\n\t  clientSecret: " + LogUtil.mask(str3) + "\n\t  bearerTokenProvider: " + String.valueOf(tokenProvider) + "\n\t  keysEndpointUri: " + str4 + "\n\t  sslSocketFactory: " + String.valueOf(sSLSocketFactory) + "\n\t  hostnameVerifier: " + String.valueOf(this.hostnameVerifier) + "\n\t  principalExtractor: " + String.valueOf(principalExtractor) + "\n\t  groupsClaimQuery: " + str5 + "\n\t  groupsClaimDelimiter: " + str6 + "\n\t  validIssuerUri: " + str7 + "\n\t  certsRefreshSeconds: " + i + "\n\t  certsRefreshMinPauseSeconds: " + i2 + "\n\t  certsExpirySeconds: " + i3 + "\n\t  certsIgnoreKeyUse: " + z + "\n\t  checkAccessTokenType: " + z2 + "\n\t  audience: " + str8 + "\n\t  customClaimCheck: " + str9 + "\n\t  connectTimeoutSeconds: " + i4 + "\n\t  readTimeoutSeconds: " + i5 + "\n\t  enableMetrics: " + z3 + "\n\t  failFast: " + z4 + "\n\t  includeAcceptHeader: " + z5);
                }
                throw th;
            }
        } else {
            metrics = null;
        }
        this.metrics = metrics;
        this.jwksHttpSensorKeyProducer = new JwksHttpSensorKeyProducer(str, this.keysUri);
        this.executor = setupExecutorAndFetchInitialKeys(i, i2, z4);
        setupRefreshKeysJob(this.executor, i);
        if (log.isDebugEnabled()) {
            log.debug("Configured JWTSignatureValidator:\n\t  validatorId: " + str + "\n\t  clientId: " + str2 + "\n\t  clientSecret: " + LogUtil.mask(str3) + "\n\t  bearerTokenProvider: " + String.valueOf(tokenProvider) + "\n\t  keysEndpointUri: " + str4 + "\n\t  sslSocketFactory: " + String.valueOf(sSLSocketFactory) + "\n\t  hostnameVerifier: " + String.valueOf(this.hostnameVerifier) + "\n\t  principalExtractor: " + String.valueOf(principalExtractor) + "\n\t  groupsClaimQuery: " + str5 + "\n\t  groupsClaimDelimiter: " + str6 + "\n\t  validIssuerUri: " + str7 + "\n\t  certsRefreshSeconds: " + i + "\n\t  certsRefreshMinPauseSeconds: " + i2 + "\n\t  certsExpirySeconds: " + i3 + "\n\t  certsIgnoreKeyUse: " + z + "\n\t  checkAccessTokenType: " + z2 + "\n\t  audience: " + str8 + "\n\t  customClaimCheck: " + str9 + "\n\t  connectTimeoutSeconds: " + i4 + "\n\t  readTimeoutSeconds: " + i5 + "\n\t  enableMetrics: " + z3 + "\n\t  failFast: " + z4 + "\n\t  includeAcceptHeader: " + z5);
        }
    }

    private static void checkAuthorizationOptions(String str, TokenProvider tokenProvider) {
        if (str != null && tokenProvider != null) {
            throw new IllegalArgumentException("Can't use both clientId and bearerToken");
        }
    }

    private URI checkKeysEndpointUri(String str) {
        if (str == null) {
            throw new IllegalArgumentException("keysEndpointUri == null");
        }
        try {
            return new URI(str);
        } catch (URISyntaxException e) {
            throw new IllegalArgumentException("Invalid keys endpoint uri: " + str, e);
        }
    }

    private HostnameVerifier checkHostnameVerifier(HostnameVerifier hostnameVerifier) {
        if (hostnameVerifier == null || "https".equals(this.keysUri.getScheme())) {
            return hostnameVerifier;
        }
        throw new IllegalArgumentException("Certificate hostname verifier set but keysEndpointUri not 'https'");
    }

    private static String checkIssuerUri(String str) {
        if (str != null) {
            try {
                new URI(str);
            } catch (URISyntaxException e) {
                throw new IllegalArgumentException("Value of validIssuerUri not a valid URI: " + str, e);
            }
        }
        return str;
    }

    private SSLSocketFactory checkSocketFactory(SSLSocketFactory sSLSocketFactory) {
        if (sSLSocketFactory == null || "https".equals(this.keysUri.getScheme())) {
            return sSLSocketFactory;
        }
        throw new IllegalArgumentException("SSL socket factory set but keysEndpointUri not 'https'");
    }

    private ScheduledExecutorService setupExecutorAndFetchInitialKeys(int i, int i2, boolean z) {
        boolean z2 = false;
        try {
            fetchKeys();
        } catch (Exception e) {
            if (z) {
                throw e;
            }
            z2 = true;
            log.warn("[IGNORED] Fetching JWKS keys has failed, but fail-fast is disabled: ", e);
        }
        ScheduledExecutorService newSingleThreadScheduledExecutor = Executors.newSingleThreadScheduledExecutor(new DaemonThreadFactory());
        this.fastScheduler = new BackOffTaskScheduler(newSingleThreadScheduledExecutor, i2, i, this::fetchKeys);
        if (z2) {
            this.fastScheduler.scheduleTask();
        }
        return newSingleThreadScheduledExecutor;
    }

    private JsonPathFilterQuery parseCustomClaimCheck(String str) {
        if (str == null) {
            return null;
        }
        String trim = str.trim();
        if (trim.isEmpty()) {
            throw new IllegalArgumentException("Value of customClaimCheck is empty");
        }
        return JsonPathFilterQuery.parse(trim);
    }

    private JsonPathQuery parseGroupsQuery(String str) {
        if (str == null) {
            return null;
        }
        String trim = str.trim();
        if (trim.isEmpty()) {
            throw new IllegalArgumentException("Value of groupsClaimQuery is empty");
        }
        return JsonPathQuery.parse(trim);
    }

    private String parseGroupsDelimiter(String str) {
        if (str == null || !str.isEmpty()) {
            return ",";
        }
        throw new IllegalArgumentException("Value of groupsClaimDelimiter is empty");
    }

    private void validateRefreshConfig(int i, int i2) {
        if (i <= 0) {
            throw new IllegalArgumentException("refreshSeconds has to be a positive number - (refreshSeconds=" + i + ")");
        }
        if (i2 < i + 60) {
            throw new IllegalArgumentException("expirySeconds has to be at least 60 seconds longer than refreshSeconds - (expirySeconds=" + i2 + ", refreshSeconds=" + i + ")");
        }
    }

    private void setupRefreshKeysJob(ScheduledExecutorService scheduledExecutorService, int i) {
        scheduledExecutorService.scheduleAtFixedRate(() -> {
            try {
                this.fastScheduler.scheduleTask();
            } catch (Throwable th) {
                log.error("{}", th.getMessage(), th);
            }
        }, i, i, TimeUnit.SECONDS);
    }

    private PublicKey getPublicKey(String str) {
        return getKeyUnlessStale(str);
    }

    private PublicKey getKeyUnlessStale(String str) {
        if (this.lastFetchTime + (this.maxStaleSeconds * 1000) <= System.currentTimeMillis()) {
            log.warn("The cached public key with id '{}' is expired!", str);
            return null;
        }
        PublicKey publicKey = this.cache.get(str);
        if (publicKey == null) {
            log.warn("No public key for id: {}", str);
        }
        return publicKey;
    }

    private void fetchKeys() {
        PublicKey publicKey;
        long currentTimeMillis = System.currentTimeMillis();
        try {
            String str = (String) HttpUtil.get(this.keysUri, this.socketFactory, this.hostnameVerifier, generateAuthorizationHeader(), String.class, this.connectTimeout, this.readTimeout, this.includeAcceptHeader);
            addJwksHttpMetricSuccessTime(currentTimeMillis);
            HashMap hashMap = new HashMap();
            for (RSAKey rSAKey : JWKSet.parse(str).getKeys()) {
                if (this.ignoreKeyUse || KeyUse.SIGNATURE.equals(rSAKey.getKeyUse())) {
                    if (rSAKey instanceof ECKey) {
                        publicKey = ((ECKey) rSAKey).toPublicKey();
                    } else if (rSAKey instanceof RSAKey) {
                        publicKey = rSAKey.toPublicKey();
                    } else {
                        log.warn("Unsupported JWK key type: {}", rSAKey.getKeyType());
                    }
                    hashMap.put(rSAKey.getKeyID(), publicKey);
                }
            }
            Map<String, PublicKey> unmodifiableMap = Collections.unmodifiableMap(hashMap);
            if (!this.cache.equals(unmodifiableMap)) {
                log.info("JWKS keys change detected. Keys updated.");
                this.oldCache = this.cache;
                this.cache = unmodifiableMap;
            }
            this.lastFetchTime = System.currentTimeMillis();
        } catch (Throwable th) {
            addJwksHttpMetricErrorTime(th, currentTimeMillis);
            throw new ServiceException("Failed to fetch public keys needed to validate JWT signatures: " + String.valueOf(this.keysUri), th);
        }
    }

    private String generateAuthorizationHeader() {
        String str = null;
        if (this.bearerTokenProvider != null) {
            str = "Bearer " + this.bearerTokenProvider.token();
        } else if (this.clientId != null) {
            str = "Basic " + OAuthAuthenticator.base64encode(this.clientId + ":" + this.clientSecret);
        }
        return str;
    }

    @Override // io.strimzi.kafka.oauth.validator.TokenValidator
    @SuppressFBWarnings(value = {"BC_UNCONFIRMED_CAST_OF_RETURN_VALUE"}, justification = "We tell TokenVerifier to parse AccessToken. It will return AccessToken or fail.")
    public TokenInfo validate(String str) {
        try {
            SignedJWT parse = SignedJWT.parse(str);
            String keyID = parse.getHeader().getKeyID();
            try {
                PublicKey publicKey = getPublicKey(keyID);
                if (publicKey == null) {
                    if (this.oldCache.get(keyID) != null) {
                        throw new TokenValidationException("Token validation failed: The signing key is no longer valid (kid:" + keyID + ")");
                    }
                    try {
                        this.fastScheduler.scheduleTask();
                    } catch (RuntimeException e) {
                        log.error("Failed to reschedule JWKS keys refresh: ", e);
                    }
                    throw new TokenValidationException("Token validation failed: Unknown signing key (kid:" + keyID + ")");
                }
                if (!parse.verify(VERIFIER_FACTORY.createJWSVerifier(parse.getHeader(), publicKey))) {
                    throw new TokenSignatureException("Signature check failed: Invalid token signature");
                }
                JsonNode asJson = JSONUtil.asJson(parse.getPayload().toJSONObject());
                JsonNode jsonNode = asJson.get(TokenInfo.EXP);
                if (jsonNode == null) {
                    throw new TokenValidationException("Token validation failed: Expiry not set");
                }
                long asLong = jsonNode.asLong(0L) * 1000;
                if (System.currentTimeMillis() > asLong) {
                    TimeUtil.formatIsoDateTimeUTC(asLong);
                    TokenExpiredException tokenExpiredException = new TokenExpiredException("Token expired at: " + asLong + " (" + tokenExpiredException + " UTC)");
                    throw tokenExpiredException;
                }
                validateTokenPayload(asJson);
                if (this.customClaimMatcher == null || this.customClaimMatcher.matches(asJson)) {
                    return new TokenInfo(asJson, str, extractPrincipal(asJson), extractGroups(asJson));
                }
                throw new TokenValidationException("Token validation failed: Custom claim check failed");
            } catch (TokenValidationException e2) {
                throw e2;
            } catch (Exception e3) {
                throw new TokenValidationException("Token validation failed", e3);
            }
        } catch (Exception e4) {
            throw new TokenValidationException("Token validation failed: Failed to parse JWT.", e4).status(TokenValidationException.Status.INVALID_TOKEN);
        }
    }

    private String extractPrincipal(JsonNode jsonNode) {
        String str = null;
        if (this.principalExtractor.isConfigured()) {
            str = this.principalExtractor.getPrincipal(jsonNode);
        }
        if (str == null && !this.principalExtractor.isConfigured()) {
            str = this.principalExtractor.getSub(jsonNode);
        }
        if (str == null) {
            throw new ValidationException("Failed to extract principal - check usernameClaim, fallbackUsernameClaim configuration");
        }
        return str;
    }

    private Set<String> extractGroups(JsonNode jsonNode) {
        JsonNode apply;
        if (this.groupsQuery == null || (apply = this.groupsQuery.apply(jsonNode)) == null) {
            return null;
        }
        Set<String> set = (Set) JSONUtil.asListOfString(apply, this.groupsDelimiter != null ? this.groupsDelimiter : ",").stream().map((v0) -> {
            return v0.trim();
        }).filter(str -> {
            return !str.isEmpty();
        }).collect(Collectors.toSet());
        if (set.isEmpty()) {
            return null;
        }
        return set;
    }

    private void validateTokenPayload(JsonNode jsonNode) {
        if (this.issuerUri != null) {
            JsonNode jsonNode2 = jsonNode.get(TokenInfo.ISS);
            if (jsonNode2 == null) {
                throw new TokenValidationException("Token validation failed: Issuer not set");
            }
            String asText = jsonNode2.asText();
            if (!this.issuerUri.equals(asText)) {
                throw new TokenValidationException("Token validation failed: Issuer not allowed: " + asText);
            }
        }
        if (this.checkAccessTokenType) {
            JsonNode jsonNode3 = jsonNode.get(TokenInfo.TYP);
            if (jsonNode3 == null) {
                jsonNode3 = jsonNode.get(TokenInfo.TOKEN_TYPE);
                if (jsonNode3 == null) {
                    throw new TokenValidationException("Token validation failed: Token type not set ('token_type' or 'typ' claim not present)");
                }
            }
            String asText2 = jsonNode3.asText();
            if (!"Bearer".equals(asText2)) {
                throw new TokenValidationException("Token validation failed: Token type not allowed: " + asText2);
            }
        }
        if (this.audience != null) {
            JsonNode jsonNode4 = jsonNode.get(TokenInfo.AUD);
            if (!(jsonNode4 == null ? Collections.emptyList() : JSONUtil.asListOfString(jsonNode4)).contains(this.audience)) {
                throw new TokenValidationException("Token validation failed: Expected audience not available in the token");
            }
        }
    }

    @Override // io.strimzi.kafka.oauth.validator.TokenValidator
    public String getValidatorId() {
        return this.validatorId;
    }

    @Override // io.strimzi.kafka.oauth.validator.TokenValidator
    public void close() {
        this.executor.shutdownNow();
    }

    private void addJwksHttpMetricSuccessTime(long j) {
        if (this.enableMetrics) {
            this.metrics.addTime(this.jwksHttpSensorKeyProducer.successKey(), System.currentTimeMillis() - j);
        }
    }

    private void addJwksHttpMetricErrorTime(Throwable th, long j) {
        if (this.enableMetrics) {
            this.metrics.addTime(this.jwksHttpSensorKeyProducer.errorKey(th), System.currentTimeMillis() - j);
        }
    }
}
