package io.confluent.security.auth.provider.oauth;

import io.confluent.kafka.multitenant.KafkaLogicalClusterUtils;
import io.confluent.security.test.utils.JwtTestUtils;
import java.nio.charset.StandardCharsets;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslServer;
import org.jose4j.jwk.JsonWebKey;
import org.jose4j.jwk.JsonWebKeySet;
import org.jose4j.jwk.PublicJsonWebKey;
import org.jose4j.jws.JsonWebSignature;
import org.jose4j.jwt.JwtClaims;
import org.jose4j.lang.JoseException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Tag("integration")
/* loaded from: input_file:io/confluent/security/auth/provider/oauth/SaslOAuthIntegrationTest.class */
public class SaslOAuthIntegrationTest extends EnhancedOAuthBearerValidatorCallbackHandlerTest {
    private final String keyId = "key-12345";
    private final String subject = "123456789";
    private final String testIssuerA = "https://test-issuer-a.com";
    private final String testIssuerB = "https://test-issuer-b.com";
    private final String jwksEndpointA = "https://test-issuer-a.com/json.jwks";
    private final String jwksEndpointB = "https://test-issuer-b.com/json.jwks";
    private static final Logger LOG = LoggerFactory.getLogger(SaslOAuthIntegrationTest.class);
    private static final String LKC_ID = KafkaLogicalClusterUtils.LC_META_ABC.logicalClusterId();
    private static final String ORG_RESOURCE_ID = KafkaLogicalClusterUtils.LC_META_ABC.organizationId();

    @Override // io.confluent.security.auth.provider.oauth.EnhancedOAuthBearerValidatorCallbackHandlerTest
    @BeforeEach
    public void setUp() throws Exception {
        super.setUp();
    }

    @Override // io.confluent.security.auth.provider.oauth.EnhancedOAuthBearerValidatorCallbackHandlerTest
    @AfterEach
    public void tearDown() {
        super.tearDown();
    }

    @Test
    public void testUnionOfPoolsEnabled_NoProvidedPools_Passes() throws Exception {
        JsonWebKey generateKey = generateKey("key-12345");
        String generateToken = generateToken(generateKey, "https://test-issuer-a.com", "123456789");
        this.configs.put("confluent.oauth.union.of.pools.enable", true);
        OAuthBearerSaslServer oAuthBearerSaslServer = new OAuthBearerSaslServer(createCallbackHandler("AuthConfig.yaml"));
        JwtTestUtils.updateJwks(this.authCache, "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks", new JsonWebKeySet(new JsonWebKey[]{generateKey}));
        JwtTestUtils.updateIdentityProvider(this.authCache, ORG_RESOURCE_ID, "op-abc", "sub", "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks");
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-a", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-a", "true", ORG_RESOURCE_ID);
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-b", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-b", "false", ORG_RESOURCE_ID);
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-c", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-c", "true", ORG_RESOURCE_ID);
        Assertions.assertEquals(oAuthBearerSaslServer.evaluateResponse(getAuthenticateRequest(null, generateToken).getBytes(StandardCharsets.UTF_8)).length, 0);
        Assertions.assertEquals("OAuth-ClientCredentials", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId"));
        Assertions.assertEquals("123456789", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-azp"));
        Assertions.assertEquals("OAuth-ClientCredentials", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-sub"));
        Assertions.assertEquals("pool-c,pool-a", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-identityPoolId"));
    }

    @Test
    public void testUnionOfPoolsEnabled_SingleProvidedPool_Passes() throws Exception {
        JsonWebKey generateKey = generateKey("key-12345");
        String generateToken = generateToken(generateKey, "https://test-issuer-a.com", "123456789");
        this.configs.put("confluent.oauth.union.of.pools.enable", true);
        OAuthBearerSaslServer oAuthBearerSaslServer = new OAuthBearerSaslServer(createCallbackHandler("AuthConfig.yaml"));
        JwtTestUtils.updateJwks(this.authCache, "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks", new JsonWebKeySet(new JsonWebKey[]{generateKey}));
        JwtTestUtils.updateIdentityProvider(this.authCache, ORG_RESOURCE_ID, "op-abc", "sub", "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks");
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-abc", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-abc", "true", ORG_RESOURCE_ID);
        Assertions.assertEquals(oAuthBearerSaslServer.evaluateResponse(getAuthenticateRequest("pool-abc", generateToken).getBytes(StandardCharsets.UTF_8)).length, 0);
        Assertions.assertEquals("pool-abc", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId"));
        Assertions.assertEquals("123456789", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-azp"));
        Assertions.assertEquals("pool-abc", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-sub"));
        Assertions.assertNull(oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-identityPoolId"));
    }

    @Test
    public void testUnionOfPoolsEnabled_ExplicitProvidedPools_Passes() throws Exception {
        JsonWebKey generateKey = generateKey("key-12345");
        String generateToken = generateToken(generateKey, "https://test-issuer-a.com", "123456789");
        this.configs.put("confluent.oauth.union.of.pools.enable", true);
        OAuthBearerSaslServer oAuthBearerSaslServer = new OAuthBearerSaslServer(createCallbackHandler("AuthConfig.yaml"));
        JwtTestUtils.updateJwks(this.authCache, "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks", new JsonWebKeySet(new JsonWebKey[]{generateKey}));
        JwtTestUtils.updateIdentityProvider(this.authCache, ORG_RESOURCE_ID, "op-abc", "sub", "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks");
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-a", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-a", "true", ORG_RESOURCE_ID);
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-b", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-b", "false", ORG_RESOURCE_ID);
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-c", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-c", "true", ORG_RESOURCE_ID);
        String authenticateRequest = getAuthenticateRequest(String.format("%s,%s", "pool-b", "pool-c"), generateToken);
        Assertions.assertEquals(oAuthBearerSaslServer.evaluateResponse(authenticateRequest.getBytes(StandardCharsets.UTF_8)).length, 0);
        Assertions.assertEquals("OAuth-ClientCredentials", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId"));
        Assertions.assertEquals("123456789", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-azp"));
        Assertions.assertEquals("OAuth-ClientCredentials", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-sub"));
        Assertions.assertEquals("pool-c", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-identityPoolId"));
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-b", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-b", "true", ORG_RESOURCE_ID);
        Assertions.assertEquals(oAuthBearerSaslServer.evaluateResponse(authenticateRequest.getBytes(StandardCharsets.UTF_8)).length, 0);
        Assertions.assertEquals("OAuth-ClientCredentials", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId"));
        Assertions.assertEquals("123456789", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-azp"));
        Assertions.assertEquals("OAuth-ClientCredentials", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-sub"));
        Assertions.assertEquals("pool-b,pool-c", oAuthBearerSaslServer.getNegotiatedProperty("identityPoolId-identityPoolId"));
    }

    @Test
    public void testUnionOfPoolsEnabled_ExplicitProvidedPoolsFromDifferentProviders_Fails() throws Exception {
        JsonWebKey generateKey = generateKey("key-12345");
        String generateToken = generateToken(generateKey, "https://test-issuer-a.com", "123456789");
        this.configs.put("confluent.oauth.union.of.pools.enable", true);
        OAuthBearerSaslServer oAuthBearerSaslServer = new OAuthBearerSaslServer(createCallbackHandler("AuthConfig.yaml"));
        JwtTestUtils.updateJwks(this.authCache, "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks", new JsonWebKeySet(new JsonWebKey[]{generateKey}));
        JwtTestUtils.updateIdentityProvider(this.authCache, ORG_RESOURCE_ID, "op-a", "sub", "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks");
        JwtTestUtils.updateIdentityProvider(this.authCache, ORG_RESOURCE_ID, "op-b", "sub", "https://test-issuer-b.com", "https://test-issuer-b.com/json.jwks");
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-a", 1, "https://test-issuer-a.com", "op-a", "https://test-issuer-a.com/json.jwks", "sub", "pool-a", "true", ORG_RESOURCE_ID);
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-b", 1, "https://test-issuer-a.com", "op-a", "https://test-issuer-a.com/json.jwks", "sub", "pool-b", "false", ORG_RESOURCE_ID);
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-c", 1, "https://test-issuer-a.com", "op-b", "https://test-issuer-a.com/json.jwks", "sub", "pool-c", "true", ORG_RESOURCE_ID);
        String authenticateRequest = getAuthenticateRequest(String.format("%s,%s,%s", "pool-a", "pool-b", "pool-c"), generateToken);
        Assertions.assertEquals(Assertions.assertThrows(SaslAuthenticationException.class, () -> {
            oAuthBearerSaslServer.evaluateResponse(authenticateRequest.getBytes(StandardCharsets.UTF_8));
        }).errorInfo().errorMessage(), "List of pools provided in sasl extension contains multiple providers");
    }

    @Test
    public void testUnionOfPoolsDisabled_ExplicitProvidedPools_Fails() throws Exception {
        JsonWebKey generateKey = generateKey("key-12345");
        String generateToken = generateToken(generateKey, "https://test-issuer-a.com", "123456789");
        OAuthBearerSaslServer oAuthBearerSaslServer = new OAuthBearerSaslServer(createCallbackHandler("AuthConfig.yaml"));
        JwtTestUtils.updateJwks(this.authCache, "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks", new JsonWebKeySet(new JsonWebKey[]{generateKey}));
        JwtTestUtils.updateIdentityProvider(this.authCache, ORG_RESOURCE_ID, "op-abc", "sub", "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks");
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-a", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-a", "true", ORG_RESOURCE_ID);
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-b", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-b", "false", ORG_RESOURCE_ID);
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-c", 1, "https://test-issuer-a.com", "op-abc", "https://test-issuer-a.com/json.jwks", "sub", "pool-c", "true", ORG_RESOURCE_ID);
        String authenticateRequest = getAuthenticateRequest(String.format("%s,%s", "pool-b", "pool-c"), generateToken);
        Assertions.assertEquals(Assertions.assertThrows(SaslAuthenticationException.class, () -> {
            oAuthBearerSaslServer.evaluateResponse(authenticateRequest.getBytes(StandardCharsets.UTF_8));
        }).errorInfo().errorMessage(), "Invalid format found for pool id extension: pool-b,pool-c");
    }

    @Test
    public void testUnionOfPoolsDisabled_ExplicitProvidedPoolsFromDifferentProviders_Fails() throws Exception {
        JsonWebKey generateKey = generateKey("key-12345");
        String generateToken = generateToken(generateKey, "https://test-issuer-a.com", "123456789");
        OAuthBearerSaslServer oAuthBearerSaslServer = new OAuthBearerSaslServer(createCallbackHandler("AuthConfig.yaml"));
        JwtTestUtils.updateJwks(this.authCache, "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks", new JsonWebKeySet(new JsonWebKey[]{generateKey}));
        JwtTestUtils.updateIdentityProvider(this.authCache, ORG_RESOURCE_ID, "op-a", "sub", "https://test-issuer-a.com", "https://test-issuer-a.com/json.jwks");
        JwtTestUtils.updateIdentityProvider(this.authCache, ORG_RESOURCE_ID, "op-b", "sub", "https://test-issuer-b.com", "https://test-issuer-b.com/json.jwks");
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-a", 1, "https://test-issuer-a.com", "op-a", "https://test-issuer-a.com/json.jwks", "sub", "pool-a", "true", ORG_RESOURCE_ID);
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-b", 1, "https://test-issuer-a.com", "op-a", "https://test-issuer-a.com/json.jwks", "sub", "pool-b", "false", ORG_RESOURCE_ID);
        JwtTestUtils.updateIdentityPool(this.authCache, "pool-c", 1, "https://test-issuer-a.com", "op-b", "https://test-issuer-a.com/json.jwks", "sub", "pool-c", "true", ORG_RESOURCE_ID);
        String authenticateRequest = getAuthenticateRequest(String.format("%s,%s,%s", "pool-a", "pool-b", "pool-c"), generateToken);
        Assertions.assertEquals(Assertions.assertThrows(SaslAuthenticationException.class, () -> {
            oAuthBearerSaslServer.evaluateResponse(authenticateRequest.getBytes(StandardCharsets.UTF_8));
        }).errorInfo().errorMessage(), "Invalid format found for pool id extension: pool-a,pool-b,pool-c");
    }

    private String getAuthenticateRequest(String str, String str2) {
        return str != null ? String.format("n,,\u0001%s=%s\u0001%s=%s\u0001auth=Bearer %s\u0001\u0001", "logicalCluster", LKC_ID, "identityPoolId", str, str2) : String.format("n,,\u0001%s=%s\u0001auth=Bearer %s\u0001\u0001", "logicalCluster", LKC_ID, str2);
    }

    private PublicJsonWebKey generateKey(String str) throws NoSuchAlgorithmException, JoseException {
        KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
        keyPairGenerator.initialize(2048);
        KeyPair genKeyPair = keyPairGenerator.genKeyPair();
        PublicJsonWebKey newPublicJwk = PublicJsonWebKey.Factory.newPublicJwk(genKeyPair.getPublic());
        newPublicJwk.setPrivateKey(genKeyPair.getPrivate());
        newPublicJwk.setKeyId(str);
        return newPublicJwk;
    }

    private String generateToken(PublicJsonWebKey publicJsonWebKey, String str, String str2) throws JoseException {
        JwtClaims jwtClaims = new JwtClaims();
        jwtClaims.setIssuer(str);
        jwtClaims.setExpirationTimeMinutesInTheFuture(60.0f);
        jwtClaims.setGeneratedJwtId();
        jwtClaims.setIssuedAtToNow();
        jwtClaims.setSubject(str2);
        JsonWebSignature jsonWebSignature = new JsonWebSignature();
        jsonWebSignature.setPayload(jwtClaims.toJson());
        jsonWebSignature.setKey(publicJsonWebKey.getPrivateKey());
        jsonWebSignature.setKeyIdHeaderValue(publicJsonWebKey.getKeyId());
        jsonWebSignature.setAlgorithmHeaderValue("RS256");
        return jsonWebSignature.getCompactSerialization();
    }
}
