package org.apache.kafka.common.security.oauthbearer;

import java.io.File;
import java.io.IOException;
import java.util.Base64;
import java.util.Calendar;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.TimeZone;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.FileTokenRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.HttpAccessTokenRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest;
import org.apache.kafka.common.utils.Utils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginCallbackHandlerTest.class */
public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
    @Test
    public void testHandleTokenCallback() throws Exception {
        Map<String, ?> saslConfigs = getSaslConfigs();
        AccessTokenBuilder alg = new AccessTokenBuilder().jwk(createRsaJwk()).alg("RS256");
        String build = alg.build();
        OAuthBearerLoginCallbackHandler createHandler = createHandler(() -> {
            return build;
        }, saslConfigs);
        try {
            Callback oAuthBearerTokenCallback = new OAuthBearerTokenCallback();
            createHandler.handle(new Callback[]{oAuthBearerTokenCallback});
            Assertions.assertNotNull(oAuthBearerTokenCallback.token());
            OAuthBearerToken oAuthBearerToken = oAuthBearerTokenCallback.token();
            Assertions.assertEquals(build, oAuthBearerToken.value());
            Assertions.assertEquals(alg.subject(), oAuthBearerToken.principalName());
            Assertions.assertEquals(alg.expirationSeconds().longValue() * 1000, oAuthBearerToken.lifetimeMs());
            Assertions.assertEquals(alg.issuedAtSeconds().longValue() * 1000, oAuthBearerToken.startTimeMs());
            createHandler.close();
        } catch (Throwable th) {
            createHandler.close();
            throw th;
        }
    }

    @Test
    public void testHandleSaslExtensionsCallback() throws Exception {
        OAuthBearerLoginCallbackHandler oAuthBearerLoginCallbackHandler = new OAuthBearerLoginCallbackHandler();
        Map<String, ?> saslConfigs = getSaslConfigs("sasl.oauthbearer.token.endpoint.url", "http://www.example.com");
        HashMap hashMap = new HashMap();
        hashMap.put("clientId", "an ID");
        hashMap.put("clientSecret", "a secret");
        hashMap.put("extension_foo", "1");
        hashMap.put("extension_bar", 2);
        hashMap.put("EXTENSION_baz", "3");
        configureHandler(oAuthBearerLoginCallbackHandler, saslConfigs, hashMap);
        try {
            Callback saslExtensionsCallback = new SaslExtensionsCallback();
            oAuthBearerLoginCallbackHandler.handle(new Callback[]{saslExtensionsCallback});
            Assertions.assertNotNull(saslExtensionsCallback.extensions());
            Map map = saslExtensionsCallback.extensions().map();
            Assertions.assertEquals("1", map.get("foo"));
            Assertions.assertEquals("2", map.get("bar"));
            Assertions.assertNull(map.get("baz"));
            Assertions.assertEquals(2, map.size());
            oAuthBearerLoginCallbackHandler.close();
        } catch (Throwable th) {
            oAuthBearerLoginCallbackHandler.close();
            throw th;
        }
    }

    @Test
    public void testHandleSaslExtensionsCallbackWithInvalidExtension() {
        OAuthBearerLoginCallbackHandler oAuthBearerLoginCallbackHandler = new OAuthBearerLoginCallbackHandler();
        Map<String, ?> saslConfigs = getSaslConfigs("sasl.oauthbearer.token.endpoint.url", "http://www.example.com");
        HashMap hashMap = new HashMap();
        hashMap.put("clientId", "an ID");
        hashMap.put("clientSecret", "a secret");
        hashMap.put("extension_auth", "this key isn't allowed per OAuthBearerClientInitialResponse.validateExtensions");
        configureHandler(oAuthBearerLoginCallbackHandler, saslConfigs, hashMap);
        try {
            SaslExtensionsCallback saslExtensionsCallback = new SaslExtensionsCallback();
            assertThrowsWithMessage(ConfigException.class, () -> {
                oAuthBearerLoginCallbackHandler.handle(new Callback[]{saslExtensionsCallback});
            }, "Extension name auth is invalid");
            oAuthBearerLoginCallbackHandler.close();
        } catch (Throwable th) {
            oAuthBearerLoginCallbackHandler.close();
            throw th;
        }
    }

    @Test
    public void testInvalidCallbackGeneratesUnsupportedCallbackException() {
        Map<String, ?> saslConfigs = getSaslConfigs();
        OAuthBearerLoginCallbackHandler oAuthBearerLoginCallbackHandler = new OAuthBearerLoginCallbackHandler();
        oAuthBearerLoginCallbackHandler.init(() -> {
            return "foo";
        }, AccessTokenValidatorFactory.create(saslConfigs));
        try {
            Callback callback = new Callback() { // from class: org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandlerTest.1
            };
            Assertions.assertThrows(UnsupportedCallbackException.class, () -> {
                oAuthBearerLoginCallbackHandler.handle(new Callback[]{callback});
            });
            oAuthBearerLoginCallbackHandler.close();
        } catch (Throwable th) {
            oAuthBearerLoginCallbackHandler.close();
            throw th;
        }
    }

    @Test
    public void testInvalidAccessToken() throws Exception {
        testInvalidAccessToken("this isn't valid", "Malformed JWT provided");
        testInvalidAccessToken("this.isn't.valid", "malformed Base64 URL encoded value");
        testInvalidAccessToken(createAccessKey("this", "isn't", "valid"), "malformed JSON");
        testInvalidAccessToken(createAccessKey("{}", "{}", "{}"), "exp value must be non-null");
    }

    @Test
    public void testMissingAccessToken() {
        OAuthBearerLoginCallbackHandler createHandler = createHandler(() -> {
            throw new IOException("The token endpoint response access_token value must be non-null");
        }, getSaslConfigs());
        try {
            OAuthBearerTokenCallback oAuthBearerTokenCallback = new OAuthBearerTokenCallback();
            assertThrowsWithMessage(IOException.class, () -> {
                createHandler.handle(new Callback[]{oAuthBearerTokenCallback});
            }, "token endpoint response access_token value must be non-null");
            createHandler.close();
        } catch (Throwable th) {
            createHandler.close();
            throw th;
        }
    }

    @Test
    public void testFileTokenRetrieverHandlesNewline() throws IOException {
        long timeInMillis = Calendar.getInstance(TimeZone.getTimeZone("UTC")).getTimeInMillis() / 1000;
        String createAccessKey = createAccessKey("{}", String.format("{\"exp\":%s, \"iat\":%s, \"sub\":\"subj\"}", "" + (timeInMillis + 3600), "" + timeInMillis), "sign");
        OAuthBearerLoginCallbackHandler createHandler = createHandler(new FileTokenRetriever(createTempFile(createTempDir("access-token"), "access-token-", ".json", createAccessKey + "\n").toPath()), getSaslConfigs());
        Callback oAuthBearerTokenCallback = new OAuthBearerTokenCallback();
        try {
            try {
                createHandler.handle(new Callback[]{oAuthBearerTokenCallback});
                Assertions.assertEquals(oAuthBearerTokenCallback.token().value(), createAccessKey);
                createHandler.close();
            } catch (Exception e) {
                Assertions.fail(e);
                createHandler.close();
            }
        } catch (Throwable th) {
            createHandler.close();
            throw th;
        }
    }

    @Test
    public void testNotConfigured() {
        OAuthBearerLoginCallbackHandler oAuthBearerLoginCallbackHandler = new OAuthBearerLoginCallbackHandler();
        assertThrowsWithMessage(IllegalStateException.class, () -> {
            oAuthBearerLoginCallbackHandler.handle(new Callback[0]);
        }, "first call the configure or init method");
    }

    @Test
    public void testConfigureWithAccessTokenFile() throws Exception {
        File createTempFile = createTempFile(createTempDir("access-token"), "access-token-", ".json", "{}");
        OAuthBearerLoginCallbackHandler oAuthBearerLoginCallbackHandler = new OAuthBearerLoginCallbackHandler();
        configureHandler(oAuthBearerLoginCallbackHandler, getSaslConfigs("sasl.oauthbearer.token.endpoint.url", createTempFile.toURI().toString()), Collections.emptyMap());
        Assertions.assertTrue(oAuthBearerLoginCallbackHandler.getAccessTokenRetriever() instanceof FileTokenRetriever);
    }

    @Test
    public void testConfigureWithAccessClientCredentials() {
        OAuthBearerLoginCallbackHandler oAuthBearerLoginCallbackHandler = new OAuthBearerLoginCallbackHandler();
        Map<String, ?> saslConfigs = getSaslConfigs("sasl.oauthbearer.token.endpoint.url", "http://www.example.com");
        HashMap hashMap = new HashMap();
        hashMap.put("clientId", "an ID");
        hashMap.put("clientSecret", "a secret");
        configureHandler(oAuthBearerLoginCallbackHandler, saslConfigs, hashMap);
        Assertions.assertTrue(oAuthBearerLoginCallbackHandler.getAccessTokenRetriever() instanceof HttpAccessTokenRetriever);
    }

    private void testInvalidAccessToken(String str, String str2) throws Exception {
        OAuthBearerLoginCallbackHandler createHandler = createHandler(() -> {
            return str;
        }, getSaslConfigs());
        try {
            Callback oAuthBearerTokenCallback = new OAuthBearerTokenCallback();
            createHandler.handle(new Callback[]{oAuthBearerTokenCallback});
            Assertions.assertNull(oAuthBearerTokenCallback.token());
            String errorDescription = oAuthBearerTokenCallback.errorDescription();
            Assertions.assertNotNull(errorDescription);
            Assertions.assertTrue(errorDescription.contains(str2), String.format("The error message \"%s\" didn't contain the expected substring \"%s\"", errorDescription, str2));
            createHandler.close();
        } catch (Throwable th) {
            createHandler.close();
            throw th;
        }
    }

    private String createAccessKey(String str, String str2, String str3) {
        Base64.Encoder encoder = Base64.getEncoder();
        return String.format("%s.%s.%s", encoder.encodeToString(Utils.utf8(str)), encoder.encodeToString(Utils.utf8(str2)), encoder.encodeToString(Utils.utf8(str3)));
    }

    private OAuthBearerLoginCallbackHandler createHandler(AccessTokenRetriever accessTokenRetriever, Map<String, ?> map) {
        OAuthBearerLoginCallbackHandler oAuthBearerLoginCallbackHandler = new OAuthBearerLoginCallbackHandler();
        oAuthBearerLoginCallbackHandler.init(accessTokenRetriever, AccessTokenValidatorFactory.create(map));
        return oAuthBearerLoginCallbackHandler;
    }
}
