package kafka.tier.store.encryption;

import com.google.crypto.tink.Aead;
import com.google.crypto.tink.JsonKeysetReader;
import com.google.crypto.tink.JsonKeysetWriter;
import com.google.crypto.tink.KeyTemplates;
import com.google.crypto.tink.KeysetHandle;
import com.google.crypto.tink.KeysetWriter;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.TemporalAmount;
import java.util.HashMap;
import java.util.Map;
import kafka.tier.exceptions.TierObjectStoreFatalException;
import kafka.tier.exceptions.TierObjectStoreRetriableException;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.utils.ByteBufferInputStream;
import org.apache.kafka.common.utils.ByteBufferOutputStream;
import org.apache.kafka.common.utils.Time;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:kafka/tier/store/encryption/EncryptionKeyManager.class */
public class EncryptionKeyManager {
    private static final Logger log = LoggerFactory.getLogger(EncryptionKeyManager.class);
    static final String METADATA_SHA_KEY = "io.confluent/key-sha-256";
    static final String METADATA_DATA_KEY = "io.confluent/base64-encrypted-data-key";
    static final String METADATA_KEY_CREATION_TIME = "io.confluent/key-creation-time";
    private static final String DATA_KEY_TEMPLATE = "AES256_GCM_RAW";
    final EncryptionKeyManagerMetrics metrics;
    private final Time time;
    private final KeyCache cache = new KeyCache();
    private final Duration keyRefreshInterval;
    private final Aead remoteKek;
    private WellKnownKeypathHook wellKnownKeypathHook;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:kafka/tier/store/encryption/EncryptionKeyManager$KeyCache.class */
    public static class KeyCache {
        private KeySha active;
        private final HashMap<KeySha, DataEncryptionKeyHolder> cache;

        private KeyCache() {
            this.cache = new HashMap<>();
        }

        synchronized void add(DataEncryptionKeyHolder dataEncryptionKeyHolder) {
            this.cache.put(dataEncryptionKeyHolder.keySha, dataEncryptionKeyHolder);
        }

        synchronized KeySha activeKeySha() {
            return this.active;
        }

        synchronized void replaceActiveKeySha(DataEncryptionKeyHolder dataEncryptionKeyHolder) {
            this.active = dataEncryptionKeyHolder.keySha;
            this.cache.put(dataEncryptionKeyHolder.keySha, dataEncryptionKeyHolder);
        }

        synchronized DataEncryptionKeyHolder get(KeySha keySha) {
            return this.cache.get(keySha);
        }

        synchronized void clear() {
            this.active = null;
            this.cache.clear();
        }
    }

    /* loaded from: input_file:kafka/tier/store/encryption/EncryptionKeyManager$WellKnownKeypathHook.class */
    public interface WellKnownKeypathHook {
        void writeWellKnownPathMetadata(Map<String, String> map);

        Map<String, String> fetchWellKnownPathMetadata();
    }

    public EncryptionKeyManager(Time time, Metrics metrics, Aead aead, Duration duration) {
        if (metrics != null) {
            this.metrics = new EncryptionKeyManagerMetrics(metrics);
            this.metrics.updateMaxKeyAge(duration);
        } else {
            this.metrics = null;
        }
        this.time = time;
        this.remoteKek = aead;
        this.keyRefreshInterval = duration;
    }

    public void bindHook(WellKnownKeypathHook wellKnownKeypathHook) {
        this.wellKnownKeypathHook = wellKnownKeypathHook;
    }

    public void close() {
        if (this.metrics != null) {
            this.metrics.close();
        }
    }

    public KeyContext keyContext(KeySha keySha) {
        DataEncryptionKeyHolder dataEncryptionKeyHolder = this.cache.get(keySha);
        if (dataEncryptionKeyHolder == null) {
            return null;
        }
        return new KeyContext(dataEncryptionKeyHolder.cleartextDataKey, keyToObjectMetadata(dataEncryptionKeyHolder), keySha);
    }

    public KeySha registerKeyFromObjectMetadata(Map<String, String> map) {
        DataEncryptionKeyHolder parseKeyFromObjectMetadata = parseKeyFromObjectMetadata(map);
        log.info("Registering key {} decoded from metadata", parseKeyFromObjectMetadata.keySha);
        this.cache.add(parseKeyFromObjectMetadata);
        return parseKeyFromObjectMetadata.keySha;
    }

    public KeySha registerKeyIfAbsentFromObjectMetadata(Map<String, String> map) {
        DataEncryptionKeyHolder parseKeyFromObjectMetadata = parseKeyFromObjectMetadata(map);
        if (this.cache.get(parseKeyFromObjectMetadata.keySha) == null) {
            this.cache.add(parseKeyFromObjectMetadata);
        }
        return parseKeyFromObjectMetadata.keySha;
    }

    public KeySha activeKeySha() {
        maybeRotate();
        return this.cache.activeKeySha();
    }

    public void clear() {
        this.cache.clear();
    }

    private static HashMap<String, String> keyToObjectMetadata(DataEncryptionKeyHolder dataEncryptionKeyHolder) {
        HashMap<String, String> hashMap = new HashMap<>();
        hashMap.put(METADATA_SHA_KEY, dataEncryptionKeyHolder.keySha.base64Encoded());
        hashMap.put(METADATA_KEY_CREATION_TIME, Long.toString(dataEncryptionKeyHolder.keyCreationTime.toEpochMilli()));
        hashMap.put(METADATA_DATA_KEY, dataEncryptionKeyHolder.encryptedDataKey.base64Encoded());
        return hashMap;
    }

    private DataEncryptionKeyHolder parseKeyFromObjectMetadata(Map<String, String> map) {
        String str = map.get(METADATA_SHA_KEY);
        if (str == null || str.isEmpty()) {
            throw new TierObjectStoreFatalException(String.format("%s metadata field not present", METADATA_SHA_KEY));
        }
        String str2 = map.get(METADATA_DATA_KEY);
        if (str2 == null || str2.isEmpty()) {
            throw new TierObjectStoreFatalException(String.format("%s metadata field not present", METADATA_DATA_KEY));
        }
        String str3 = map.get(METADATA_KEY_CREATION_TIME);
        if (str3 == null || str3.isEmpty()) {
            throw new TierObjectStoreFatalException(String.format("%s metadata field not present", METADATA_KEY_CREATION_TIME));
        }
        KeySha fromBase64Encoded = KeySha.fromBase64Encoded(str);
        try {
            DataEncryptionKeyHolder decryptDek = decryptDek(EncryptedDataKey.fromBase64Encoded(str2), Instant.ofEpochMilli(Long.parseLong(str3)));
            if (decryptDek.keySha.equals(fromBase64Encoded)) {
                return decryptDek;
            }
            throw new TierObjectStoreFatalException(String.format("KeySha parsed from object metadata '%s' does not match decoded KeySha '%s'", fromBase64Encoded, decryptDek.keySha));
        } catch (IOException | GeneralSecurityException e) {
            throw new TierObjectStoreRetriableException("Failed to decrypt data encryption key from object metadata", e);
        }
    }

    private synchronized void maybeRotate() {
        KeySha activeKeySha = this.cache.activeKeySha();
        if (activeKeySha == null) {
            log.info("No active key found, seeding key cache");
            forceRotate();
            activeKeySha = this.cache.activeKeySha();
        }
        Instant instant = this.cache.get(activeKeySha).keyCreationTime;
        if (Instant.ofEpochMilli(this.time.milliseconds()).isAfter(instant.plus((TemporalAmount) this.keyRefreshInterval))) {
            log.info("Key corresponding to {} created at {} has expired determined by the refresh interval {}, seeding key cache", new Object[]{activeKeySha, instant, this.keyRefreshInterval});
            forceRotate();
        }
    }

    private Map<String, String> fetchWellKnownPathMetadata() {
        if (this.wellKnownKeypathHook != null) {
            return this.wellKnownKeypathHook.fetchWellKnownPathMetadata();
        }
        return null;
    }

    private void writeWellKnownPathMetadata(Map<String, String> map) {
        if (this.wellKnownKeypathHook != null) {
            this.wellKnownKeypathHook.writeWellKnownPathMetadata(map);
        }
    }

    private void forceRotate() {
        Map<String, String> fetchWellKnownPathMetadata = fetchWellKnownPathMetadata();
        if (fetchWellKnownPathMetadata != null && !fetchWellKnownPathMetadata.isEmpty()) {
            DataEncryptionKeyHolder parseKeyFromObjectMetadata = parseKeyFromObjectMetadata(fetchWellKnownPathMetadata);
            Instant instant = parseKeyFromObjectMetadata.keyCreationTime;
            Instant ofEpochMilli = Instant.ofEpochMilli(this.time.milliseconds());
            Instant plus = instant.plus((TemporalAmount) this.keyRefreshInterval);
            log.info("Recovered previously written key {} created at {} from the well-known keypath", parseKeyFromObjectMetadata.keySha, instant);
            if (ofEpochMilli.isBefore(plus)) {
                log.info("Using key {} as the active key", parseKeyFromObjectMetadata.keySha);
                if (this.metrics != null) {
                    this.metrics.updateActiveKeyCreationTime(parseKeyFromObjectMetadata.keyCreationTime);
                }
                this.cache.replaceActiveKeySha(parseKeyFromObjectMetadata);
                return;
            }
            log.info("Key {} recovered from the well-known keypath is too old to use as the active key", parseKeyFromObjectMetadata.keySha);
            this.cache.add(parseKeyFromObjectMetadata);
        }
        log.info("Unable to restore a valid active key from the well-known keypath, generating a new one");
        try {
            DataEncryptionKeyHolder generateNewDek = generateNewDek();
            log.info("Using key {} as the active key", generateNewDek.keySha);
            if (this.metrics != null) {
                this.metrics.updateActiveKeyCreationTime(generateNewDek.keyCreationTime);
            }
            this.cache.replaceActiveKeySha(generateNewDek);
            log.info("Writing out newly generated key {} to the well-known key path", generateNewDek.keySha);
            writeWellKnownPathMetadata(keyToObjectMetadata(generateNewDek));
        } catch (IOException e) {
            throw new TierObjectStoreRetriableException("Failed to generate data encryption key for rotation", e);
        } catch (GeneralSecurityException e2) {
            throw new TierObjectStoreFatalException("Failed to generate data encryption key for rotation", e2);
        }
    }

    private DataEncryptionKeyHolder generateNewDek() throws GeneralSecurityException, IOException {
        KeysetHandle generateNew = KeysetHandle.generateNew(KeyTemplates.get(DATA_KEY_TEMPLATE));
        ByteBufferOutputStream byteBufferOutputStream = new ByteBufferOutputStream(256);
        try {
            KeysetWriter withOutputStream = JsonKeysetWriter.withOutputStream(byteBufferOutputStream);
            long hiResClockMs = this.time.hiResClockMs();
            generateNew.write(withOutputStream, this.remoteKek);
            if (this.metrics != null) {
                this.metrics.recordEncryptCall(this.time.hiResClockMs() - hiResClockMs);
            }
            ByteBuffer buffer = byteBufferOutputStream.buffer();
            buffer.flip();
            byte[] bArr = new byte[buffer.remaining()];
            buffer.get(bArr);
            EncryptedDataKey encryptedDataKey = new EncryptedDataKey(bArr);
            CleartextDataKey cleartextDataKey = new CleartextDataKey(generateNew);
            return new DataEncryptionKeyHolder(encryptedDataKey, cleartextDataKey, new KeySha(cleartextDataKey), Instant.ofEpochMilli(this.time.milliseconds()));
        } catch (Exception e) {
            throw new TierObjectStoreRetriableException("Exception trying to encrypt key using master key", e);
        }
    }

    private DataEncryptionKeyHolder decryptDek(EncryptedDataKey encryptedDataKey, Instant instant) throws IOException, GeneralSecurityException {
        JsonKeysetReader withInputStream = JsonKeysetReader.withInputStream(new ByteBufferInputStream(ByteBuffer.wrap(encryptedDataKey.keyMaterial())));
        long hiResClockMs = this.time.hiResClockMs();
        KeysetHandle read = KeysetHandle.read(withInputStream, this.remoteKek);
        if (this.metrics != null) {
            this.metrics.recordDecryptCall(this.time.hiResClockMs() - hiResClockMs);
        }
        CleartextDataKey cleartextDataKey = new CleartextDataKey(read);
        return new DataEncryptionKeyHolder(encryptedDataKey, cleartextDataKey, new KeySha(cleartextDataKey), instant);
    }
}
