/*
 * Decompiled with CFR 0.152.
 */
package io.confluent.kafka.clients.plugins.auth.oauth;

import io.confluent.kafka.clients.plugins.auth.oauth.SpireJwtLoginCallbackHandler;
import io.confluent.security.authentication.credential.BearerCredential;
import io.confluent.security.authentication.oauthbearer.MockJwtSource;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.config.types.Password;
import org.apache.kafka.common.security.JaasContext;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.oauthbearer.JwtRetriever;
import org.apache.kafka.common.security.oauthbearer.JwtValidatorException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.jwt.NumericDate;
import org.jose4j.lang.JoseException;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class SpireJwtLoginCallbackHandlerTest {
    private static final String SPIFFE_ID_1 = "spiffe://" + String.valueOf(MockJwtSource.SPIRE_TRUST_DOMAIN_1) + "/test-workload";
    private static final String SPIFFE_ID_INVALID = "spife://" + String.valueOf(MockJwtSource.SPIRE_TRUST_DOMAIN_1) + "/test-workload";
    private MockSpireJwtLoginCallbackHandler spireJwtLoginCallbackHandler;
    private JwtClaims claims;
    private BearerCredential bearerCredential;

    @BeforeEach
    public void setUp() {
        this.spireJwtLoginCallbackHandler = new MockSpireJwtLoginCallbackHandler();
        this.claims = new JwtClaims();
        this.claims.setIssuer("test.prefix.spire.internal.confluent.cloud");
        this.claims.setSubject(SPIFFE_ID_1);
        this.claims.setExpirationTimeMinutesInTheFuture(60.0f);
        this.claims.setIssuedAt(NumericDate.now());
        try {
            this.bearerCredential = MockJwtSource.createEncodedJws((MockJwtSource.Kid)MockJwtSource.Kid.RSA_SPIRE_1, (JwtClaims)this.claims);
        }
        catch (JoseException e) {
            throw new RuntimeException(e);
        }
    }

    @Test
    public void testConfigurationRaisesExceptionOnWrongMechanism() {
        Map<String, Object> jaasConfig = this.buildClientJassConfigText("lkc-test");
        Exception exception = (Exception)Assertions.assertThrows(IllegalArgumentException.class, () -> this.spireJwtLoginCallbackHandler.configure(jaasConfig, "PLAINTEXT", JaasContext.loadClientContext((Map)jaasConfig).configurationEntries()));
        Assertions.assertTrue((boolean)exception.getMessage().contains("Unexpected SASL mechanism:"));
    }

    @Test
    public void testConfigurationRaisesExceptionOnMissingJaasConfig() {
        String jaasConfigText = "";
        HashMap<String, Password> configs = new HashMap<String, Password>();
        configs.put("sasl.jaas.config", new Password(jaasConfigText));
        Exception exception = (Exception)Assertions.assertThrows(IllegalArgumentException.class, () -> this.spireJwtLoginCallbackHandler.configure(configs, "OAUTHBEARER", Collections.emptyList()));
        Assertions.assertTrue((boolean)exception.getMessage().contains("Must supply exactly 1 non-null JAAS mechanism configuration"));
    }

    @Test
    public void testConfigureRaisesExceptionOnMissingLogicalCluster() {
        Map<String, Object> jaasConfig = this.buildClientJassConfigText(null);
        Exception exception = (Exception)Assertions.assertThrows(ConfigException.class, () -> this.spireJwtLoginCallbackHandler.configure(jaasConfig, "OAUTHBEARER", JaasContext.loadClientContext((Map)jaasConfig).configurationEntries()));
        Assertions.assertEquals((Object)"Logical cluster for must be set in the JAAS config.", (Object)exception.getMessage());
    }

    @Test
    public void testConfigureRaisesExceptionOnMissingSpireAgentEndPoint() {
        String jaasConfigText = "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule Required logicalCluster=\"lkc-test\";";
        HashMap<String, Password> configs = new HashMap<String, Password>();
        configs.put("sasl.jaas.config", new Password(jaasConfigText));
        Exception exception = (Exception)Assertions.assertThrows(ConfigException.class, () -> this.spireJwtLoginCallbackHandler.configure(configs, "OAUTHBEARER", JaasContext.loadClientContext((Map)configs).configurationEntries()));
        Assertions.assertEquals((Object)"The OAuth configuration option sasl.oauthbearer.token.spire.agent.endpoint value must be non-null", (Object)exception.getMessage());
    }

    @Test
    public void testSuccessfulConfiguration() {
        Map<String, Object> jaasConfig = this.buildClientJassConfigText("lkc-test");
        Assertions.assertDoesNotThrow(() -> this.spireJwtLoginCallbackHandler.configure(jaasConfig, "OAUTHBEARER", JaasContext.loadClientContext((Map)jaasConfig).configurationEntries()));
    }

    @Test
    public void testHandleExtensionsCallback() throws IOException, UnsupportedCallbackException {
        Map<String, Object> jaasConfig = this.buildClientJassConfigText("lkc-test");
        this.spireJwtLoginCallbackHandler.configure(jaasConfig, "OAUTHBEARER", JaasContext.loadClientContext(jaasConfig).configurationEntries());
        SaslExtensionsCallback saslExtensionsCallback = new SaslExtensionsCallback();
        this.spireJwtLoginCallbackHandler.handle(new Callback[]{saslExtensionsCallback});
        Assertions.assertEquals((Object)"lkc-test", saslExtensionsCallback.extensions().map().get("logicalCluster"));
    }

    @Test
    public void testTokenLoginValidationFailure() throws InterruptedException, JoseException {
        Map<String, Object> jaasConfig = this.buildClientJassConfigText("lkc-test");
        this.spireJwtLoginCallbackHandler.configure(jaasConfig, "OAUTHBEARER", JaasContext.loadClientContext(jaasConfig).configurationEntries());
        OAuthBearerTokenCallback oAuthBearerTokenCallback = new OAuthBearerTokenCallback();
        this.claims.setSubject(SPIFFE_ID_INVALID);
        this.bearerCredential = MockJwtSource.createEncodedJws((MockJwtSource.Kid)MockJwtSource.Kid.RSA_SPIRE_1, (JwtClaims)this.claims);
        Exception exception = (Exception)Assertions.assertThrows(JwtValidatorException.class, () -> this.spireJwtLoginCallbackHandler.handle(new Callback[]{oAuthBearerTokenCallback}));
        Assertions.assertEquals((Object)"sub value must be a spiffe id", (Object)exception.getMessage());
        this.claims.setSubject(SPIFFE_ID_1);
        this.claims.setExpirationTime(NumericDate.fromMilliseconds((long)System.currentTimeMillis()));
        Thread.sleep(1000L);
        this.bearerCredential = MockJwtSource.createEncodedJws((MockJwtSource.Kid)MockJwtSource.Kid.RSA_SPIRE_1, (JwtClaims)this.claims);
        exception = (Exception)Assertions.assertThrows(JwtValidatorException.class, () -> this.spireJwtLoginCallbackHandler.handle(new Callback[]{oAuthBearerTokenCallback}));
        Assertions.assertEquals((Object)"Token has expired", (Object)exception.getMessage());
        this.claims.setExpirationTimeMinutesInTheFuture(60.0f);
        this.claims.setIssuedAt(NumericDate.fromMilliseconds((long)(System.currentTimeMillis() + 1800000L)));
        this.bearerCredential = MockJwtSource.createEncodedJws((MockJwtSource.Kid)MockJwtSource.Kid.RSA_SPIRE_1, (JwtClaims)this.claims);
        exception = (Exception)Assertions.assertThrows(JwtValidatorException.class, () -> this.spireJwtLoginCallbackHandler.handle(new Callback[]{oAuthBearerTokenCallback}));
        Assertions.assertEquals((Object)"iat has future value", (Object)exception.getMessage());
        this.claims.setIssuedAt(NumericDate.now());
    }

    @Test
    public void testSuccessfulHandleTokenCallback() throws IOException, UnsupportedCallbackException {
        Map<String, Object> jaasConfig = this.buildClientJassConfigText("lkc-test");
        this.spireJwtLoginCallbackHandler.configure(jaasConfig, "OAUTHBEARER", JaasContext.loadClientContext(jaasConfig).configurationEntries());
        OAuthBearerTokenCallback oAuthBearerTokenCallback = new OAuthBearerTokenCallback();
        this.spireJwtLoginCallbackHandler.handle(new Callback[]{oAuthBearerTokenCallback});
        Assertions.assertEquals((Object)this.bearerCredential.bearerToken(), (Object)oAuthBearerTokenCallback.token().value());
    }

    private Map<String, Object> buildClientJassConfigText(String logicalCluster) {
        Object jaasConfigText = "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule Required";
        if (logicalCluster != null && !logicalCluster.isEmpty()) {
            jaasConfigText = (String)jaasConfigText + " logicalCluster=\"" + logicalCluster + "\"";
        }
        jaasConfigText = (String)jaasConfigText + ";";
        HashMap<String, Object> configs = new HashMap<String, Object>();
        configs.put("sasl.jaas.config", new Password((String)jaasConfigText));
        configs.put("sasl.oauthbearer.token.spire.agent.endpoint", "tcp://0.0.0.0:31523");
        return Collections.unmodifiableMap(configs);
    }

    private class MockSpireJwtLoginCallbackHandler
    extends SpireJwtLoginCallbackHandler {
        private MockSpireJwtLoginCallbackHandler() {
        }

        protected void initAccessTokenRetriever(String spireAgentEndpoint) {
            this.accessTokenRetriever = new MockSpireJwtTokenRetriever(spireAgentEndpoint);
        }
    }

    private class MockSpireJwtTokenRetriever
    implements JwtRetriever {
        public MockSpireJwtTokenRetriever(String spireAgentEndpoint) {
        }

        public String retrieve() {
            return SpireJwtLoginCallbackHandlerTest.this.bearerCredential.bearerToken();
        }
    }
}

