package io.confluent.kafka.server.plugins.auth.token;

import io.confluent.kafka.test.utils.TokenTestUtils;
import io.confluent.security.authentication.http.HttpClient;
import jakarta.ws.rs.client.Entity;
import jakarta.ws.rs.core.Form;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import javax.security.auth.callback.Callback;
import org.apache.kafka.common.config.AbstractConfig;
import org.apache.kafka.common.config.ConfigDef;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.config.SaslConfigs;
import org.apache.kafka.common.security.authenticator.TestJaasConfig;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/confluent/kafka/server/plugins/auth/token/CompositeBearerValidatorCallbackHandlerTest.class */
public class CompositeBearerValidatorCallbackHandlerTest {
    private TokenTestUtils.JwsContainer jwsContainer;
    private final String defaultIssuer = "Confluent";
    private String defaultSubject = "Customer";
    private static final String CONFLUENT_VALIDATOR_ERROR_MESSAGE = "Authentication failed";
    private static final String IDP_VALIDATOR_ERROR_MESSAGE = "invalid_token";
    private Map<String, Object> configs;
    private static final String REALM_CONFIG_PATH = "/CompositeBearerValidatorRealm.json";
    private static final IdentityProviderService IDP_SERVICE = new IdentityProviderService(REALM_CONFIG_PATH);

    @BeforeAll
    public static void beforeAll() {
        IDP_SERVICE.setStartupTimeout(Duration.ofMinutes(20L));
        IDP_SERVICE.start();
    }

    @AfterAll
    public static void afterAll() {
        IDP_SERVICE.shutdown();
        System.clearProperty("org.apache.kafka.sasl.oauthbearer.allowed.urls");
    }

    @BeforeEach
    public void setUp() throws Exception {
        this.configs = new HashMap();
        this.configs.put("sasl.oauthbearer.jwks.endpoint.url", IDP_SERVICE.getJwksEndpoint());
        this.configs.put("sasl.oauthbearer.expected.audience", "account");
        System.setProperty("org.apache.kafka.sasl.oauthbearer.allowed.urls", IDP_SERVICE.getJwksEndpoint());
    }

    @AfterEach
    public void tearDown() {
    }

    @Test
    public void testAttachesConfluentJws() throws Exception {
        this.jwsContainer = TokenTestUtils.setUpJws(36000, "Confluent", this.defaultSubject);
        CompositeBearerValidatorCallbackHandler createCallbackHandler = createCallbackHandler(baseOptions());
        Callback oAuthBearerValidatorCallback = new OAuthBearerValidatorCallback(this.jwsContainer.getJwsToken());
        createCallbackHandler.handle(new Callback[]{oAuthBearerValidatorCallback});
        Assertions.assertNotNull(oAuthBearerValidatorCallback.token());
        Assertions.assertEquals(this.jwsContainer.getJwsToken(), oAuthBearerValidatorCallback.token().value());
        Assertions.assertNull(oAuthBearerValidatorCallback.errorStatus());
    }

    @Test
    public void testAttachesIdpJws() throws Exception {
        CompositeBearerValidatorCallbackHandler createCallbackHandler = createCallbackHandler(baseOptions());
        String jwtFromIdp = getJwtFromIdp("clusterAdmin1", "clusterAdmin1");
        Callback oAuthBearerValidatorCallback = new OAuthBearerValidatorCallback(jwtFromIdp);
        createCallbackHandler.handle(new Callback[]{oAuthBearerValidatorCallback});
        Assertions.assertNotNull(oAuthBearerValidatorCallback.token());
        Assertions.assertEquals(jwtFromIdp, oAuthBearerValidatorCallback.token().value());
        Assertions.assertNull(oAuthBearerValidatorCallback.errorStatus());
    }

    @Test
    public void testConfigureRaisesExceptionWhenInvalidKeyPath() throws Exception {
        this.jwsContainer = TokenTestUtils.setUpJws(36000, "Confluent", this.defaultSubject);
        Map<String, String> baseOptions = baseOptions();
        baseOptions.put("publicKeyPath", this.jwsContainer.getPublicKeyFile().getAbsolutePath() + "/invalid!");
        Assertions.assertThrows(ConfigException.class, () -> {
            createCallbackHandler(baseOptions);
        });
    }

    @Test
    public void testConfigureRaisesExceptionWhenInvalidJWKSUrl() throws Exception {
        Map<String, String> baseOptions = baseOptions();
        this.configs.put("sasl.oauthbearer.jwks.endpoint.url", "xyz");
        Assertions.assertThrows(ConfigException.class, () -> {
            createCallbackHandler(baseOptions);
        });
    }

    @Test
    public void testErrorWhenInvalidConfluentJws() throws Exception {
        this.jwsContainer = TokenTestUtils.setUpJws(36000, "Confluent", this.defaultSubject);
        TokenTestUtils.writePemFile(this.jwsContainer.getPublicKeyFile(), TokenTestUtils.generateKeyPair().getPublic());
        CompositeBearerValidatorCallbackHandler createCallbackHandler = createCallbackHandler(baseOptions());
        Callback oAuthBearerValidatorCallback = new OAuthBearerValidatorCallback(this.jwsContainer.getJwsToken());
        createCallbackHandler.handle(new Callback[]{oAuthBearerValidatorCallback});
        Assertions.assertNotNull(oAuthBearerValidatorCallback.errorStatus());
        Assertions.assertEquals(CONFLUENT_VALIDATOR_ERROR_MESSAGE, oAuthBearerValidatorCallback.errorStatus());
    }

    @Test
    public void testErrorWhenInvalidIdpJws() throws Exception {
        this.jwsContainer = TokenTestUtils.setUpJws(36000, "AWS", this.defaultSubject);
        CompositeBearerValidatorCallbackHandler createCallbackHandler = createCallbackHandler(baseOptions());
        Callback oAuthBearerValidatorCallback = new OAuthBearerValidatorCallback(this.jwsContainer.getJwsToken());
        createCallbackHandler.handle(new Callback[]{oAuthBearerValidatorCallback});
        Assertions.assertNotNull(oAuthBearerValidatorCallback.errorStatus());
        Assertions.assertEquals(IDP_VALIDATOR_ERROR_MESSAGE, oAuthBearerValidatorCallback.errorStatus());
    }

    @Test
    public void testErrorWhenExpiredConfluentJws() throws Exception {
        this.jwsContainer = TokenTestUtils.setUpJws(50, "Confluent", this.defaultSubject);
        Thread.sleep(100L);
        CompositeBearerValidatorCallbackHandler createCallbackHandler = createCallbackHandler(baseOptions());
        Callback oAuthBearerValidatorCallback = new OAuthBearerValidatorCallback(this.jwsContainer.getJwsToken());
        createCallbackHandler.handle(new Callback[]{oAuthBearerValidatorCallback});
        Assertions.assertNotNull(oAuthBearerValidatorCallback.errorStatus());
        Assertions.assertEquals(CONFLUENT_VALIDATOR_ERROR_MESSAGE, oAuthBearerValidatorCallback.errorStatus());
    }

    @Test
    public void testErrorWhenExpiredIdpJws() throws Exception {
        String jwtFromIdp = getJwtFromIdp("clusterAdmin1", "clusterAdmin1");
        System.out.println(jwtFromIdp);
        Thread.sleep(40000L);
        CompositeBearerValidatorCallbackHandler createCallbackHandler = createCallbackHandler(baseOptions());
        Callback oAuthBearerValidatorCallback = new OAuthBearerValidatorCallback(jwtFromIdp);
        createCallbackHandler.handle(new Callback[]{oAuthBearerValidatorCallback});
        Assertions.assertNotNull(oAuthBearerValidatorCallback.errorStatus());
        Assertions.assertEquals(IDP_VALIDATOR_ERROR_MESSAGE, oAuthBearerValidatorCallback.errorStatus());
    }

    @Test
    public void testErrorIfMissingSubject() throws Exception {
        this.jwsContainer = TokenTestUtils.setUpJws(36000, "Confluent", null);
        CompositeBearerValidatorCallbackHandler createCallbackHandler = createCallbackHandler(baseOptions());
        Callback oAuthBearerValidatorCallback = new OAuthBearerValidatorCallback(this.jwsContainer.getJwsToken());
        createCallbackHandler.handle(new Callback[]{oAuthBearerValidatorCallback});
        Assertions.assertNotNull(oAuthBearerValidatorCallback.errorStatus());
        Assertions.assertEquals(CONFLUENT_VALIDATOR_ERROR_MESSAGE, oAuthBearerValidatorCallback.errorStatus());
    }

    @Test
    public void testErrorIfNoExpirationTime() throws Exception {
        this.jwsContainer = TokenTestUtils.setUpJws(null, "Confluent", this.defaultSubject);
        CompositeBearerValidatorCallbackHandler createCallbackHandler = createCallbackHandler(baseOptions());
        Callback oAuthBearerValidatorCallback = new OAuthBearerValidatorCallback(this.jwsContainer.getJwsToken());
        createCallbackHandler.handle(new Callback[]{oAuthBearerValidatorCallback});
        Assertions.assertNotNull(oAuthBearerValidatorCallback.errorStatus());
        Assertions.assertEquals(CONFLUENT_VALIDATOR_ERROR_MESSAGE, oAuthBearerValidatorCallback.errorStatus());
    }

    private CompositeBearerValidatorCallbackHandler createCallbackHandler(Map<String, ?> map) {
        TestJaasConfig testJaasConfig = new TestJaasConfig();
        testJaasConfig.createOrUpdateEntry("Kafka", "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule", map);
        ConfigDef configDef = new ConfigDef();
        SaslConfigs.addClientSaslSupport(configDef);
        AbstractConfig abstractConfig = new AbstractConfig(configDef, this.configs);
        CompositeBearerValidatorCallbackHandler compositeBearerValidatorCallbackHandler = new CompositeBearerValidatorCallbackHandler();
        compositeBearerValidatorCallbackHandler.configure(abstractConfig.values(), "OAUTHBEARER", Collections.singletonList(testJaasConfig.getAppConfigurationEntry("Kafka")[0]));
        return compositeBearerValidatorCallbackHandler;
    }

    private Map<String, String> baseOptions() throws Exception {
        if (this.jwsContainer == null) {
            this.jwsContainer = TokenTestUtils.setUpJws(36000, "Confluent", this.defaultSubject);
        }
        HashMap hashMap = new HashMap();
        hashMap.put("publicKeyPath", this.jwsContainer.getPublicKeyFile().getAbsolutePath());
        hashMap.put("audience", String.join(",", new CharSequence[0]));
        return hashMap;
    }

    private String getJwtFromIdp(String str, String str2) throws Exception {
        return (String) HttpClient.builder().build().target(URI.create(IDP_SERVICE.getTokenEndpoint())).request().header("Authorization", "Basic " + Base64.getEncoder().encodeToString((str + ":" + str2).getBytes(StandardCharsets.UTF_8))).accept(new String[]{"application/json"}).rx().post(Entity.entity(new Form().param("grant_type", "client_credentials"), "application/x-www-form-urlencoded")).thenApply(response -> {
            return (Map) response.readEntity(Map.class);
        }).thenApply(map -> {
            return (String) map.get("access_token");
        }).toCompletableFuture().get();
    }
}
