/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.common.security.oauthbearer.internals;

import io.confluent.kafka.server.plugins.auth.stats.AuthenticationStats;
import io.confluent.kafka.util.ClientContext;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import javax.security.sasl.SaslServerFactory;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.auth.SaslExtensions;
import org.apache.kafka.common.security.authenticator.PathAwareSniHostName;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.PreTokenValidationExtensionsValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.server.audit.AuditEventStatus;
import org.apache.kafka.server.audit.AuthenticationErrorInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OAuthBearerSaslServer
implements SaslServer {
    private static final Logger log = LoggerFactory.getLogger(OAuthBearerSaslServer.class);
    private static final String NEGOTIATED_PROPERTY_KEY_TOKEN = "OAUTHBEARER.token";
    private static final String INTERNAL_ERROR_ON_SERVER = "Authentication could not be performed due to an internal error on the server";
    private static final AuthenticationStats STATS = AuthenticationStats.getInstance();
    private final AuthenticateCallbackHandler callbackHandler;
    private final PathAwareSniHostName sniHostName;
    private boolean complete;
    private OAuthBearerToken tokenForNegotiatedProperty = null;
    private String errorMessage = null;
    private SaslExtensions extensions;

    public OAuthBearerSaslServer(CallbackHandler callbackHandler) {
        this(callbackHandler, null);
    }

    public OAuthBearerSaslServer(CallbackHandler callbackHandler, PathAwareSniHostName sniHostName) {
        if (!(Objects.requireNonNull(callbackHandler) instanceof AuthenticateCallbackHandler)) {
            throw new IllegalArgumentException(String.format("Callback handler must be castable to %s: %s", AuthenticateCallbackHandler.class.getName(), callbackHandler.getClass().getName()));
        }
        this.callbackHandler = (AuthenticateCallbackHandler)callbackHandler;
        this.sniHostName = sniHostName;
    }

    @Override
    public byte[] evaluateResponse(byte[] response) throws SaslException, SaslAuthenticationException {
        OAuthBearerClientInitialResponse clientResponse;
        if (response.length == 1 && response[0] == 1 && this.errorMessage != null) {
            STATS.incrFailed();
            log.debug("Received %x01 response from client after it received our error");
            throw new SaslAuthenticationException(this.errorMessage, AuthenticationErrorInfo.UNKNOWN_USER_ERROR);
        }
        this.errorMessage = null;
        try {
            clientResponse = new OAuthBearerClientInitialResponse(response);
        }
        catch (SaslException e) {
            STATS.incrFailed();
            log.debug(e.getMessage());
            throw e;
        }
        HashMap<String, String> extensions = new HashMap<String, String>(clientResponse.extensions().map());
        if (this.sniHostName != null) {
            extensions.put("__confluent_sni_broker_host_name", this.sniHostName.strippedHostname());
        }
        return this.process(clientResponse.tokenValue(), clientResponse.authorizationId(), new SaslExtensions(extensions));
    }

    @Override
    public String getAuthorizationID() {
        if (!this.complete) {
            throw new IllegalStateException("Authentication exchange has not completed");
        }
        return this.tokenForNegotiatedProperty.principalName();
    }

    @Override
    public String getMechanismName() {
        return "OAUTHBEARER";
    }

    @Override
    public Object getNegotiatedProperty(String propName) {
        if (!this.complete) {
            throw new IllegalStateException("Authentication exchange has not completed");
        }
        if (NEGOTIATED_PROPERTY_KEY_TOKEN.equals(propName)) {
            return this.tokenForNegotiatedProperty;
        }
        if ("CREDENTIAL.LIFETIME.MS".equals(propName)) {
            return this.tokenForNegotiatedProperty.lifetimeMs();
        }
        return this.extensions.map().get(propName);
    }

    @Override
    public boolean isComplete() {
        return this.complete;
    }

    @Override
    public byte[] unwrap(byte[] incoming, int offset, int len) {
        if (!this.complete) {
            throw new IllegalStateException("Authentication exchange has not completed");
        }
        return Arrays.copyOfRange(incoming, offset, offset + len);
    }

    @Override
    public byte[] wrap(byte[] outgoing, int offset, int len) {
        if (!this.complete) {
            throw new IllegalStateException("Authentication exchange has not completed");
        }
        return Arrays.copyOfRange(outgoing, offset, offset + len);
    }

    @Override
    public void dispose() {
        this.complete = false;
        this.tokenForNegotiatedProperty = null;
        this.extensions = null;
    }

    private byte[] process(String tokenValue, String authorizationId, SaslExtensions extensions) throws SaslException {
        ClientContext context = new ClientContext();
        this.preProcessExtensions(extensions, context);
        OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(tokenValue, context);
        try {
            this.callbackHandler.handle(new Callback[]{callback});
        }
        catch (IOException | UnsupportedCallbackException e) {
            this.handleCallbackError(e);
        }
        OAuthBearerToken token = callback.token();
        if (token == null) {
            STATS.incrFailed();
            this.errorMessage = OAuthBearerSaslServer.jsonErrorResponse(callback.errorStatus(), callback.errorScope(), callback.errorOpenIDConfiguration());
            log.debug(this.errorMessage);
            return this.errorMessage.getBytes(StandardCharsets.UTF_8);
        }
        if (!authorizationId.isEmpty() && !authorizationId.equals(token.principalName())) {
            STATS.incrFailed();
            AuthenticationErrorInfo errorInfo = new AuthenticationErrorInfo(AuditEventStatus.UNAUTHENTICATED, "", token.principalName(), "");
            throw new SaslAuthenticationException(String.format("Authentication failed: Client requested an authorization id (%s) that is different from the token's principal name (%s)", authorizationId, token.principalName()), errorInfo);
        }
        Map<String, String> validExtensions = this.processExtensions(token, extensions);
        this.tokenForNegotiatedProperty = token;
        this.extensions = new SaslExtensions(validExtensions);
        this.complete = true;
        log.debug("Successfully authenticate User={}", (Object)token.principalName());
        STATS.incrSucceeded();
        return new byte[0];
    }

    private void preProcessExtensions(SaslExtensions extensions, ClientContext context) throws SaslException {
        PreTokenValidationExtensionsValidatorCallback callback = new PreTokenValidationExtensionsValidatorCallback(extensions, context);
        try {
            this.callbackHandler.handle(new Callback[]{callback});
        }
        catch (UnsupportedCallbackException unsupportedCallbackException) {
        }
        catch (IOException e) {
            this.handleCallbackError(e);
        }
        if (!callback.invalidExtensions().isEmpty()) {
            STATS.incrFailed();
            String errorMessage = String.format("Authentication failed: %d extensions are invalid! They are: %s", callback.invalidExtensions().size(), Utils.mkString(callback.invalidExtensions(), "", "", ": ", "; "));
            log.debug(errorMessage);
            AuthenticationErrorInfo errorInfo = new AuthenticationErrorInfo(AuditEventStatus.UNAUTHENTICATED, callback.errorMessage(), null, "");
            errorInfo.saslExtensions(callback.inputExtensions().map());
            throw new SaslAuthenticationException(errorMessage, errorInfo);
        }
    }

    private Map<String, String> processExtensions(OAuthBearerToken token, SaslExtensions extensions) throws SaslException {
        OAuthBearerExtensionsValidatorCallback extensionsCallback = new OAuthBearerExtensionsValidatorCallback(token, extensions);
        try {
            this.callbackHandler.handle(new Callback[]{extensionsCallback});
        }
        catch (UnsupportedCallbackException unsupportedCallbackException) {
        }
        catch (IOException e) {
            this.handleCallbackError(e);
        }
        if (!extensionsCallback.invalidExtensions().isEmpty()) {
            STATS.incrFailed();
            Set<String> scope = OAuthBearerSaslServer.oauthScope(extensionsCallback.token());
            String errorMessage = String.format("Authentication failed: %d extensions are invalid! They are: %s", extensionsCallback.invalidExtensions().size(), Utils.mkString(extensionsCallback.invalidExtensions(), "", "", ": ", "; "));
            log.debug(errorMessage);
            AuthenticationErrorInfo errorInfo = new AuthenticationErrorInfo(scope.isEmpty() ? AuditEventStatus.UNKNOWN_USER_DENIED : AuditEventStatus.UNAUTHENTICATED, extensionsCallback.errorMessage(), token.principalName(), "");
            errorInfo.saslExtensions(extensionsCallback.inputExtensions().map());
            for (Map.Entry<String, String> entry : extensionsCallback.data().entrySet()) {
                errorInfo.data(entry.getKey(), entry.getValue());
            }
            throw new SaslAuthenticationException(errorMessage, errorInfo);
        }
        return extensionsCallback.validatedExtensions();
    }

    private static Set<String> oauthScope(OAuthBearerToken jwtToken) {
        return jwtToken.scope() == null ? Collections.emptySet() : jwtToken.scope();
    }

    private static String jsonErrorResponse(String errorStatus, String errorScope, String errorOpenIDConfiguration) {
        String jsonErrorResponse = String.format("{\"status\":\"%s\"", errorStatus);
        if (errorScope != null) {
            jsonErrorResponse = String.format("%s, \"scope\":\"%s\"", jsonErrorResponse, errorScope);
        }
        if (errorOpenIDConfiguration != null) {
            jsonErrorResponse = String.format("%s, \"openid-configuration\":\"%s\"", jsonErrorResponse, errorOpenIDConfiguration);
        }
        jsonErrorResponse = String.format("%s}", jsonErrorResponse);
        return jsonErrorResponse;
    }

    private void handleCallbackError(Exception e) throws SaslException {
        STATS.incrFailed();
        String msg = String.format("%s: %s", INTERNAL_ERROR_ON_SERVER, e.getMessage());
        log.debug(msg, (Throwable)e);
        throw new SaslException(msg);
    }

    public static String[] mechanismNamesCompatibleWithPolicy(Map<String, ?> props) {
        String[] stringArray;
        if (props != null && "true".equals(String.valueOf(props.get("javax.security.sasl.policy.noplaintext")))) {
            stringArray = new String[]{};
        } else {
            String[] stringArray2 = new String[1];
            stringArray = stringArray2;
            stringArray2[0] = "OAUTHBEARER";
        }
        return stringArray;
    }

    public static class OAuthBearerSaslServerFactory
    implements SaslServerFactory {
        @Override
        public SaslServer createSaslServer(String mechanism, String protocol, String serverName, Map<String, ?> props, CallbackHandler callbackHandler) {
            String[] mechanismNamesCompatibleWithPolicy = this.getMechanismNames(props);
            for (int i = 0; i < mechanismNamesCompatibleWithPolicy.length; ++i) {
                if (!mechanismNamesCompatibleWithPolicy[i].equals(mechanism)) continue;
                return new OAuthBearerSaslServer(callbackHandler, (PathAwareSniHostName)props.get("__confluent_sni_broker_host_name"));
            }
            return null;
        }

        @Override
        public String[] getMechanismNames(Map<String, ?> props) {
            return OAuthBearerSaslServer.mechanismNamesCompatibleWithPolicy(props);
        }
    }
}

