/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.kafka.server.plugins.auth.oauth;

import io.confluent.kafka.clients.plugins.auth.jwt.JwtAuthenticator;
import io.confluent.kafka.clients.plugins.auth.jwt.JwtAuthenticatorConfig;
import io.confluent.kafka.clients.plugins.auth.jwt.JwtVerificationException;
import io.confluent.kafka.common.multitenant.oauth.OAuthBearerJwsToken;
import io.confluent.kafka.multitenant.BasePhysicalClusterMetadata;
import io.confluent.kafka.multitenant.KafkaLogicalClusterMetadata;
import io.confluent.kafka.multitenant.PhysicalClusterMetadata;
import io.confluent.kafka.server.plugins.auth.SniValidationMode;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.auth.login.AppConfigurationEntry;
import kafka.server.KafkaConfig;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.authenticator.PathAwareSniHostName;
import org.apache.kafka.common.security.oauthbearer.CommonExtensionsValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OAuthBearerValidatorCallbackHandler
implements AuthenticateCallbackHandler {
    private static final Logger log = LoggerFactory.getLogger(OAuthBearerValidatorCallbackHandler.class);
    private static final String DEFAULT_SCOPE_CLAIM = "orgResourceId";
    private static final String AUTH_ERROR_MESSAGE = "Authentication failed";
    private JwtAuthenticator jwtAuthenticator;
    private BasePhysicalClusterMetadata clusterMetadata;
    private SniValidationMode mode;
    private boolean configured = false;
    private boolean enableOrgIdCheck = true;

    public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
        JaasOptionsUtils.validateOAuthMechanismAndNonNullJaasConfig((String)saslMechanism, jaasConfigEntries);
        HashMap moduleOptions = new HashMap(jaasConfigEntries.get(0).getOptions());
        if (moduleOptions.containsKey("publicKeyPath")) {
            moduleOptions.put("jwksLocation", (String)moduleOptions.remove("publicKeyPath"));
        }
        JwtAuthenticatorConfig authenticatorConfig = new JwtAuthenticatorConfig(moduleOptions);
        this.jwtAuthenticator = new JwtAuthenticator(authenticatorConfig);
        Object uuid = configs.get(KafkaConfig.BrokerSessionUuidProp());
        if (uuid == null || uuid.toString().isEmpty()) {
            throw new ConfigException("Broker session UUID must be set in the Kafka config!");
        }
        this.clusterMetadata = BasePhysicalClusterMetadata.getInstance((String)uuid.toString());
        if (this.clusterMetadata == null) {
            throw new ConfigException("Could not get a PhysicalClusterMetadata instance with broker session UUID " + uuid.toString());
        }
        if (this.clusterMetadata instanceof PhysicalClusterMetadata) {
            this.enableOrgIdCheck = false;
        }
        this.mode = SniValidationMode.fromString((String)moduleOptions.get("sni_host_name_validation_mode"));
        this.configured = true;
    }

    public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
        if (!this.configured) {
            throw new IllegalStateException("Callback handler not configured");
        }
        for (Callback callback : callbacks) {
            if (callback instanceof OAuthBearerValidatorCallback) {
                this.handleCallback((OAuthBearerValidatorCallback)callback);
                continue;
            }
            if (callback instanceof OAuthBearerExtensionsValidatorCallback) {
                this.handleExtensionsCallback((OAuthBearerExtensionsValidatorCallback)callback);
                continue;
            }
            throw new UnsupportedCallbackException(callback);
        }
    }

    private void handleCallback(OAuthBearerValidatorCallback callback) {
        try {
            this.handleValidatorCallback(callback);
        }
        catch (JwtVerificationException e) {
            log.info("Failed to verify OAuth JWT token", (Throwable)e);
            callback.error(AUTH_ERROR_MESSAGE, null, null);
        }
    }

    public void close() {
        if (this.jwtAuthenticator != null) {
            try {
                this.jwtAuthenticator.close();
            }
            catch (IOException e) {
                log.error("Failed to close Authenticator", (Throwable)e);
            }
        }
    }

    private void handleValidatorCallback(OAuthBearerValidatorCallback callback) throws JwtVerificationException {
        String tokenValue = callback.tokenValue();
        if (tokenValue == null) {
            throw new IllegalArgumentException("Callback missing required token value");
        }
        OAuthBearerToken token = this.processToken(tokenValue);
        callback.token(token);
        log.debug("Successfully validated token");
    }

    private void handleExtensionsCallback(OAuthBearerExtensionsValidatorCallback callback) {
        OAuthBearerJwsToken token = (OAuthBearerJwsToken)callback.token();
        String logicalCluster = (String)callback.inputExtensions().map().get("logicalCluster");
        String sniHostName = (String)callback.inputExtensions().map().get("__confluent_sni_broker_host_name");
        KafkaLogicalClusterMetadata metadata = null;
        try {
            metadata = this.clusterMetadataMatched(callback, token, logicalCluster);
            if (Objects.isNull(metadata)) {
                return;
            }
        }
        catch (IllegalStateException e) {
            this.reportErrorGettingMetadata(callback, e);
            return;
        }
        if (!(this.doesClusterExtensionExist((CommonExtensionsValidatorCallback)callback, logicalCluster) && this.isSniHostNameMatched(callback, logicalCluster, sniHostName, this.mode) && this.isLogicalClusterBelongToOrg(callback, token, metadata))) {
            return;
        }
        callback.valid("logicalCluster");
        log.debug("Successfully authenticated for user: {} (cluster: {})", (Object)token.principalName(), (Object)logicalCluster);
    }

    private boolean isLogicalClusterBelongToOrg(OAuthBearerExtensionsValidatorCallback callback, OAuthBearerJwsToken token, KafkaLogicalClusterMetadata metadata) {
        if (!this.enableOrgIdCheck) {
            return true;
        }
        String orgResourceId = (String)token.jwtClaims().get(DEFAULT_SCOPE_CLAIM);
        if (orgResourceId != null && orgResourceId.equals(metadata.organizationId())) {
            return true;
        }
        String errorMessage = String.format("The principal %s's logical cluster %s is not belong to the org in this token (%s).", token.principalName(), metadata.logicalClusterId(), orgResourceId);
        this.handleExtensionError((CommonExtensionsValidatorCallback)callback, errorMessage, "logicalCluster");
        return false;
    }

    private void reportErrorGettingMetadata(OAuthBearerExtensionsValidatorCallback callback, IllegalStateException e) {
        log.error("Could not get physical cluster metadata to validate the token. ", (Throwable)e);
        callback.errorMessage("Could not get cluster metadata to validate the token");
        callback.error("logicalCluster", AUTH_ERROR_MESSAGE);
    }

    private KafkaLogicalClusterMetadata clusterMetadataMatched(OAuthBearerExtensionsValidatorCallback callback, OAuthBearerJwsToken token, String logicalCluster) {
        if (!this.clusterMetadata.logicalClusterIds().contains(logicalCluster)) {
            if (this.clusterMetadata.logicalClusterIdsIncludingStale().contains(logicalCluster)) {
                log.info("Failing OAuth authentication because the metadata for the logical cluster {} is stale.", (Object)logicalCluster);
            }
            String errorMessage = String.format("The principal %s's logical cluster %s is not hosted on this broker.", token.principalName(), logicalCluster);
            this.handleExtensionError((CommonExtensionsValidatorCallback)callback, errorMessage, "logicalCluster");
            return null;
        }
        return (KafkaLogicalClusterMetadata)this.clusterMetadata.metadata(logicalCluster);
    }

    private boolean doesClusterExtensionExist(CommonExtensionsValidatorCallback callback, String logicalCluster) {
        if (logicalCluster == null || logicalCluster.isEmpty()) {
            String errorMessage = "The logical cluster extension is missing or is empty";
            this.handleExtensionError(callback, errorMessage, "logicalCluster");
            return false;
        }
        return true;
    }

    protected boolean isSniHostNameMatched(OAuthBearerExtensionsValidatorCallback callback, String logicalClusterId, String sniHostName, SniValidationMode sniValidationMode) {
        Optional<PathAwareSniHostName> sniHostNameOptional = sniHostName == null ? Optional.empty() : Optional.of(new PathAwareSniHostName(sniHostName));
        Optional<String> sniClusterId = sniHostNameOptional.map(PathAwareSniHostName::logicalClusterId);
        if (sniValidationMode.sniHostNameMatches(logicalClusterId, sniClusterId, sniHostNameOptional)) {
            return true;
        }
        String errorMessage = String.format("The SNI cluster Id: %s doesn't match with logical cluster extension: %s.", sniClusterId.orElse("<empty>"), logicalClusterId);
        this.handleExtensionError((CommonExtensionsValidatorCallback)callback, errorMessage, "__confluent_sni_broker_host_name");
        return false;
    }

    private void handleExtensionError(CommonExtensionsValidatorCallback callback, String errorMessage, String invalidExtensionName) {
        log.info(errorMessage);
        callback.errorMessage(errorMessage);
        callback.error(invalidExtensionName, AUTH_ERROR_MESSAGE);
    }

    OAuthBearerToken processToken(String jws) throws JwtVerificationException {
        return this.jwtAuthenticator.login(jws, DEFAULT_SCOPE_CLAIM);
    }
}

