/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.common.security.scram.internals;

import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.HashMap;
import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslException;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.security.authenticator.CredentialCache;
import org.apache.kafka.common.security.scram.ScramCredential;
import org.apache.kafka.common.security.scram.internals.ScramFormatter;
import org.apache.kafka.common.security.scram.internals.ScramMechanism;
import org.apache.kafka.common.security.scram.internals.ScramMessages;
import org.apache.kafka.common.security.scram.internals.ScramSaslServer;
import org.apache.kafka.common.security.scram.internals.ScramServerCallbackHandler;
import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

public class ScramSaslServerTest {
    private static final String USER_A = "userA";
    private static final String USER_B = "userB";
    private ScramFormatter formatter;
    private ScramSaslServer saslServer;

    @BeforeEach
    public void setUp() throws Exception {
        ScramMechanism mechanism = ScramMechanism.SCRAM_SHA_256;
        this.formatter = new ScramFormatter(mechanism);
        CredentialCache.Cache credentialCache = new CredentialCache().createCache(mechanism.mechanismName(), ScramCredential.class);
        credentialCache.put(USER_A, (Object)this.formatter.generateCredential("passwordA", 4096));
        credentialCache.put(USER_B, (Object)this.formatter.generateCredential("passwordB", 4096));
        ScramServerCallbackHandler callbackHandler = new ScramServerCallbackHandler(credentialCache, new DelegationTokenCache(ScramMechanism.mechanismNames()));
        this.saslServer = new ScramSaslServer(mechanism, new HashMap(), (CallbackHandler)callbackHandler);
    }

    @Test
    public void noAuthorizationIdSpecified() throws Exception {
        byte[] nextChallenge = this.saslServer.evaluateResponse(this.clientFirstMessage(USER_A, null));
        Assertions.assertTrue((nextChallenge.length > 0 ? 1 : 0) != 0, (String)"Next challenge is empty");
    }

    @Test
    public void authorizationIdEqualsAuthenticationId() throws Exception {
        byte[] nextChallenge = this.saslServer.evaluateResponse(this.clientFirstMessage(USER_A, USER_A));
        Assertions.assertTrue((nextChallenge.length > 0 ? 1 : 0) != 0, (String)"Next challenge is empty");
    }

    @Test
    public void authorizationIdNotEqualsAuthenticationId() {
        Assertions.assertThrows(SaslAuthenticationException.class, () -> this.saslServer.evaluateResponse(this.clientFirstMessage(USER_A, USER_B)));
    }

    @Test
    public void validateNonceExchange() throws SaslException {
        ScramSaslServer spySaslServer = (ScramSaslServer)Mockito.spy((Object)this.saslServer);
        byte[] clientFirstMsgBytes = this.clientFirstMessage(USER_A, USER_A);
        ScramMessages.ClientFirstMessage clientFirstMessage = new ScramMessages.ClientFirstMessage(clientFirstMsgBytes);
        byte[] serverFirstMsgBytes = spySaslServer.evaluateResponse(clientFirstMsgBytes);
        ScramMessages.ServerFirstMessage serverFirstMessage = new ScramMessages.ServerFirstMessage(serverFirstMsgBytes);
        Assertions.assertTrue((boolean)serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()), (String)"Nonce in server message should start with client first message's nonce");
        byte[] clientFinalMessage = this.clientFinalMessage(serverFirstMessage.nonce());
        ((ScramSaslServer)Mockito.doNothing().when((Object)spySaslServer)).verifyClientProof((ScramMessages.ClientFinalMessage)Mockito.any(ScramMessages.ClientFinalMessage.class));
        byte[] serverFinalMsgBytes = spySaslServer.evaluateResponse(clientFinalMessage);
        ScramMessages.ServerFinalMessage serverFinalMessage = new ScramMessages.ServerFinalMessage(serverFinalMsgBytes);
        Assertions.assertNull((Object)serverFinalMessage.error(), (String)"Server final message should not contain error");
    }

    @Test
    public void validateFailedNonceExchange() throws SaslException {
        ScramSaslServer spySaslServer = (ScramSaslServer)Mockito.spy((Object)this.saslServer);
        byte[] clientFirstMsgBytes = this.clientFirstMessage(USER_A, USER_A);
        ScramMessages.ClientFirstMessage clientFirstMessage = new ScramMessages.ClientFirstMessage(clientFirstMsgBytes);
        byte[] serverFirstMsgBytes = spySaslServer.evaluateResponse(clientFirstMsgBytes);
        ScramMessages.ServerFirstMessage serverFirstMessage = new ScramMessages.ServerFirstMessage(serverFirstMsgBytes);
        Assertions.assertTrue((boolean)serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()), (String)"Nonce in server message should start with client first message's nonce");
        byte[] clientFinalMessage = this.clientFinalMessage(this.formatter.secureRandomString());
        ((ScramSaslServer)Mockito.doNothing().when((Object)spySaslServer)).verifyClientProof((ScramMessages.ClientFinalMessage)Mockito.any(ScramMessages.ClientFinalMessage.class));
        SaslException saslException = (SaslException)Assertions.assertThrows(SaslException.class, () -> spySaslServer.evaluateResponse(clientFinalMessage));
        Assertions.assertEquals((Object)"Invalid client nonce in the final client message.", (Object)saslException.getMessage(), (String)("Failure message: " + saslException.getMessage()));
    }

    private byte[] clientFirstMessage(String userName, String authorizationId) {
        String nonce = this.formatter.secureRandomString();
        Object authorizationField = authorizationId != null ? "a=" + authorizationId : "";
        String firstMessage = String.format("n,%s,n=%s,r=%s", authorizationField, userName, nonce);
        return firstMessage.getBytes(StandardCharsets.UTF_8);
    }

    private byte[] clientFinalMessage(String nonce) {
        String channelBinding = this.randomBytesAsString();
        String proof = this.randomBytesAsString();
        String message = String.format("c=%s,r=%s,p=%s", channelBinding, nonce, proof);
        return message.getBytes(StandardCharsets.UTF_8);
    }

    private String randomBytesAsString() {
        return Base64.getEncoder().encodeToString(this.formatter.secureRandomBytes());
    }
}

