package io.helidon.security.jwt;

import io.helidon.common.Errors;
import io.helidon.security.jwt.JwtHeaders;
import io.helidon.security.jwt.jwk.Jwk;
import io.helidon.security.jwt.jwk.JwkEC;
import io.helidon.security.jwt.jwk.JwkKeys;
import io.helidon.security.jwt.jwk.JwkRSA;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import java.util.Arrays;
import java.util.Base64;
import java.util.Map;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.Mac;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;

/* loaded from: input_file:io/helidon/security/jwt/EncryptedJwt.class */
public final class EncryptedJwt {
    private final String token;
    private final JwtHeaders header;
    private final byte[] iv;
    private final byte[] encryptedKey;
    private final byte[] authTag;
    private final byte[] encryptedPayload;
    private static final Pattern JWE_PATTERN = Pattern.compile("(^[\\S]+)\\.([\\S]+)\\.([\\S]+)\\.([\\S]+)\\.([\\S]+$)");
    private static final Base64.Decoder URL_DECODER = Base64.getUrlDecoder();
    private static final Base64.Encoder URL_ENCODER = Base64.getUrlEncoder().withoutPadding();
    private static final Map<SupportedAlgorithm, String> RSA_ALGORITHMS = Map.of(SupportedAlgorithm.RSA_OAEP, "RSA/ECB/OAEPWithSHA-1AndMGF1Padding", SupportedAlgorithm.RSA_OAEP_256, "RSA/ECB/OAEPWithSHA-256AndMGF1Padding", SupportedAlgorithm.RSA1_5, "RSA/ECB/PKCS1Padding");
    private static final Map<SupportedEncryption, AesAlgorithm> CONTENT_ENCRYPTION = Map.of(SupportedEncryption.A128GCM, new AesGcmAlgorithm(128), SupportedEncryption.A192GCM, new AesGcmAlgorithm(192), SupportedEncryption.A256GCM, new AesGcmAlgorithm(256), SupportedEncryption.A128CBC_HS256, new AesAlgorithmWithHmac("AES/CBC/PKCS5Padding", 128, 16, "HmacSHA256"), SupportedEncryption.A192CBC_HS384, new AesAlgorithmWithHmac("AES/CBC/PKCS5Padding", 192, 16, "HmacSHA384"), SupportedEncryption.A256CBC_HS512, new AesAlgorithmWithHmac("AES/CBC/PKCS5Padding", 256, 16, "HmacSHA512"));

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/helidon/security/jwt/EncryptedJwt$AesAlgorithm.class */
    public static class AesAlgorithm {
        private static final SecureRandom RANDOM = new SecureRandom();
        private final String cipher;
        private final int keySize;
        private final int ivSize;

        private AesAlgorithm(String str, int i, int i2) {
            this.cipher = str;
            this.keySize = i;
            this.ivSize = i2;
        }

        EncryptionParts encrypt(byte[] bArr, byte[] bArr2) {
            try {
                KeyGenerator keyGenerator = KeyGenerator.getInstance("AES");
                keyGenerator.init(this.keySize, RANDOM);
                SecretKey generateKey = keyGenerator.generateKey();
                byte[] bArr3 = new byte[this.ivSize];
                RANDOM.nextBytes(bArr3);
                EncryptionParts encryptionParts = new EncryptionParts(generateKey.getEncoded(), bArr3, bArr2, null, null);
                Cipher cipher = Cipher.getInstance(this.cipher);
                cipher.init(1, generateKey, createParameterSpec(encryptionParts));
                postCipherConstruct(cipher, encryptionParts);
                return new EncryptionParts(generateKey.getEncoded(), bArr3, bArr2, cipher.doFinal(bArr), null);
            } catch (Exception e) {
                throw new JwtException("Exception during content encryption", e);
            }
        }

        byte[] decrypt(EncryptionParts encryptionParts) {
            try {
                byte[] key = encryptionParts.key();
                Cipher cipher = Cipher.getInstance(this.cipher);
                cipher.init(2, new SecretKeySpec(key, "AES"), createParameterSpec(encryptionParts));
                postCipherConstruct(cipher, encryptionParts);
                return cipher.doFinal(encryptionParts.encryptedContent());
            } catch (Exception e) {
                throw new JwtException("Exception during content decryption.", e);
            }
        }

        protected void postCipherConstruct(Cipher cipher, EncryptionParts encryptionParts) {
        }

        protected AlgorithmParameterSpec createParameterSpec(EncryptionParts encryptionParts) {
            return new IvParameterSpec(encryptionParts.iv());
        }
    }

    /* loaded from: input_file:io/helidon/security/jwt/EncryptedJwt$AesAlgorithmWithHmac.class */
    private static class AesAlgorithmWithHmac extends AesAlgorithm {
        private final String hmac;

        private AesAlgorithmWithHmac(String str, int i, int i2, String str2) {
            super(str, i, i2);
            this.hmac = str2;
        }

        @Override // io.helidon.security.jwt.EncryptedJwt.AesAlgorithm
        public EncryptionParts encrypt(byte[] bArr, byte[] bArr2) {
            EncryptionParts encrypt = super.encrypt(bArr, bArr2);
            return new EncryptionParts(encrypt.key(), encrypt.iv(), encrypt.aad(), encrypt.encryptedContent(), sign(encrypt));
        }

        private byte[] sign(EncryptionParts encryptionParts) {
            try {
                Mac macInstance = macInstance();
                macInstance.init(new SecretKeySpec(encryptionParts.key(), "AES"));
                macInstance.update(encryptionParts.aad());
                macInstance.update(encryptionParts.encryptedContent());
                return macInstance.doFinal();
            } catch (InvalidKeyException e) {
                throw new JwtException("Exception occurred while HMAC signature");
            }
        }

        @Override // io.helidon.security.jwt.EncryptedJwt.AesAlgorithm
        public byte[] decrypt(EncryptionParts encryptionParts) {
            if (verifySignature(encryptionParts)) {
                return super.decrypt(encryptionParts);
            }
            throw new JwtException("HMAC signature does not match");
        }

        private boolean verifySignature(EncryptionParts encryptionParts) {
            try {
                Mac macInstance = macInstance();
                macInstance.init(new SecretKeySpec(encryptionParts.key(), "AES"));
                macInstance.update(encryptionParts.aad());
                macInstance.update(encryptionParts.encryptedContent());
                return Arrays.equals(macInstance.doFinal(), encryptionParts.authTag());
            } catch (InvalidKeyException e) {
                throw new JwtException("Exception occurred while HMAC signature.");
            }
        }

        private Mac macInstance() {
            try {
                return Mac.getInstance(this.hmac);
            } catch (NoSuchAlgorithmException e) {
                throw new JwtException("Could not find MAC instance: " + this.hmac);
            }
        }
    }

    /* loaded from: input_file:io/helidon/security/jwt/EncryptedJwt$AesGcmAlgorithm.class */
    private static class AesGcmAlgorithm extends AesAlgorithm {
        private AesGcmAlgorithm(int i) {
            super("AES/GCM/NoPadding", i, 12);
        }

        @Override // io.helidon.security.jwt.EncryptedJwt.AesAlgorithm
        public EncryptionParts encrypt(byte[] bArr, byte[] bArr2) {
            EncryptionParts encrypt = super.encrypt(bArr, bArr2);
            byte[] encryptedContent = encrypt.encryptedContent();
            int length = encryptedContent.length - 16;
            byte[] bArr3 = new byte[length];
            byte[] bArr4 = new byte[16];
            System.arraycopy(encryptedContent, 0, bArr3, 0, bArr3.length);
            System.arraycopy(encryptedContent, length, bArr4, 0, bArr4.length);
            return new EncryptionParts(encrypt.key(), encrypt.iv(), encrypt.aad(), bArr3, bArr4);
        }

        @Override // io.helidon.security.jwt.EncryptedJwt.AesAlgorithm
        byte[] decrypt(EncryptionParts encryptionParts) {
            byte[] encryptedContent = encryptionParts.encryptedContent();
            byte[] authTag = encryptionParts.authTag();
            int length = encryptedContent.length;
            int length2 = authTag.length;
            byte[] bArr = new byte[length + length2];
            System.arraycopy(encryptedContent, 0, bArr, 0, length);
            System.arraycopy(authTag, 0, bArr, length, length2);
            return super.decrypt(new EncryptionParts(encryptionParts.key(), encryptionParts.iv(), encryptionParts.aad(), bArr, authTag));
        }

        @Override // io.helidon.security.jwt.EncryptedJwt.AesAlgorithm
        protected AlgorithmParameterSpec createParameterSpec(EncryptionParts encryptionParts) {
            return new GCMParameterSpec(128, encryptionParts.iv());
        }

        @Override // io.helidon.security.jwt.EncryptedJwt.AesAlgorithm
        protected void postCipherConstruct(Cipher cipher, EncryptionParts encryptionParts) {
            cipher.updateAAD(encryptionParts.aad());
        }
    }

    /* loaded from: input_file:io/helidon/security/jwt/EncryptedJwt$Builder.class */
    public static class Builder implements io.helidon.common.Builder<EncryptedJwt> {
        private final SignedJwt jwt;
        private Jwk jwk;
        private JwkKeys jwks;
        private String kid;
        private final JwtHeaders.Builder headersBuilder = JwtHeaders.builder();
        private SupportedAlgorithm algorithm = SupportedAlgorithm.RSA_OAEP;
        private SupportedEncryption encryption = SupportedEncryption.A256GCM;

        private Builder(SignedJwt signedJwt) {
            this.jwt = (SignedJwt) Objects.requireNonNull(signedJwt);
        }

        public Builder jwks(JwkKeys jwkKeys, String str) {
            this.jwks = (JwkKeys) Objects.requireNonNull(jwkKeys);
            this.kid = (String) Objects.requireNonNull(str);
            return this;
        }

        public Builder jwk(Jwk jwk) {
            this.jwk = (Jwk) Objects.requireNonNull(jwk);
            return this;
        }

        public Builder algorithm(SupportedAlgorithm supportedAlgorithm) {
            this.algorithm = (SupportedAlgorithm) Objects.requireNonNull(supportedAlgorithm);
            return this;
        }

        public Builder encryption(SupportedEncryption supportedEncryption) {
            this.encryption = (SupportedEncryption) Objects.requireNonNull(supportedEncryption);
            return this;
        }

        /* renamed from: build, reason: merged with bridge method [inline-methods] */
        public EncryptedJwt m2build() {
            PublicKey publicKey;
            this.headersBuilder.algorithm(this.algorithm.toString());
            this.headersBuilder.encryption(this.encryption.toString());
            this.headersBuilder.contentType("JWT");
            if (this.jwk == null && this.jwks != null) {
                this.jwk = this.jwks.forKeyId(this.kid).orElseThrow(() -> {
                    return new JwtException("Could not determine which JWK should be used for encryption.");
                });
                this.headersBuilder.keyId(this.kid);
            }
            if (this.jwk == null) {
                throw new JwtException("No JWK specified for encrypted JWT creation.");
            }
            if (this.jwk instanceof JwkRSA) {
                publicKey = ((JwkRSA) this.jwk).publicKey();
            } else {
                if (!(this.jwk instanceof JwkEC)) {
                    throw new JwtException("Unsupported JWK type: " + this.jwk.keyType());
                }
                publicKey = ((JwkEC) this.jwk).publicKey();
            }
            JwtHeaders m10build = this.headersBuilder.m10build();
            StringBuilder sb = new StringBuilder();
            String encode = EncryptedJwt.encode(m10build.headerJson().toString());
            String str = EncryptedJwt.RSA_ALGORITHMS.get(this.algorithm);
            EncryptionParts encrypt = EncryptedJwt.CONTENT_ENCRYPTION.get(this.encryption).encrypt(this.jwt.tokenContent().getBytes(StandardCharsets.UTF_8), encode.getBytes(StandardCharsets.US_ASCII));
            byte[] encryptRsa = EncryptedJwt.encryptRsa(str, publicKey, encrypt.key());
            return new EncryptedJwt(sb.append(encode).append(".").append(EncryptedJwt.encode(encryptRsa)).append(".").append(EncryptedJwt.encode(encrypt.iv())).append(".").append(EncryptedJwt.encode(encrypt.encryptedContent())).append(".").append(EncryptedJwt.encode(encrypt.authTag())).toString(), m10build, encrypt.iv, encryptRsa, encrypt.authTag(), encrypt.encryptedContent());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/helidon/security/jwt/EncryptedJwt$EncryptionParts.class */
    public static final class EncryptionParts {
        private final byte[] key;
        private final byte[] iv;
        private final byte[] aad;
        private final byte[] encryptedContent;
        private final byte[] authTag;

        private EncryptionParts(byte[] bArr, byte[] bArr2, byte[] bArr3, byte[] bArr4, byte[] bArr5) {
            this.key = bArr;
            this.iv = bArr2;
            this.aad = bArr3;
            this.encryptedContent = bArr4;
            this.authTag = bArr5;
        }

        public byte[] key() {
            return this.key;
        }

        public byte[] iv() {
            return this.iv;
        }

        public byte[] aad() {
            return this.aad;
        }

        public byte[] encryptedContent() {
            return this.encryptedContent;
        }

        public byte[] authTag() {
            return this.authTag;
        }
    }

    /* loaded from: input_file:io/helidon/security/jwt/EncryptedJwt$SupportedAlgorithm.class */
    public enum SupportedAlgorithm {
        RSA_OAEP("RSA-OAEP"),
        RSA_OAEP_256("RSA-OAEP-256"),
        RSA1_5("RSA1_5");

        private final String algorithmName;

        SupportedAlgorithm(String str) {
            this.algorithmName = str;
        }

        @Override // java.lang.Enum
        public String toString() {
            return this.algorithmName;
        }

        static SupportedAlgorithm getValue(String str) {
            for (SupportedAlgorithm supportedAlgorithm : values()) {
                if (supportedAlgorithm.algorithmName.equalsIgnoreCase(str)) {
                    return supportedAlgorithm;
                }
            }
            throw new IllegalArgumentException();
        }
    }

    /* loaded from: input_file:io/helidon/security/jwt/EncryptedJwt$SupportedEncryption.class */
    public enum SupportedEncryption {
        A128GCM("A128GCM"),
        A192GCM("A192GCM"),
        A256GCM("A256GCM"),
        A128CBC_HS256("A128CBC-HS256"),
        A192CBC_HS384("A192CBC-HS384"),
        A256CBC_HS512("A256CBC-HS512");

        private final String encryptionName;

        SupportedEncryption(String str) {
            this.encryptionName = str;
        }

        @Override // java.lang.Enum
        public String toString() {
            return this.encryptionName;
        }

        static SupportedEncryption getValue(String str) {
            for (SupportedEncryption supportedEncryption : values()) {
                if (supportedEncryption.encryptionName.equalsIgnoreCase(str)) {
                    return supportedEncryption;
                }
            }
            throw new IllegalArgumentException();
        }
    }

    private EncryptedJwt(String str, JwtHeaders jwtHeaders, byte[] bArr, byte[] bArr2, byte[] bArr3, byte[] bArr4) {
        this.token = str;
        this.header = jwtHeaders;
        this.iv = bArr;
        this.encryptedKey = bArr2;
        this.authTag = bArr3;
        this.encryptedPayload = bArr4;
    }

    public static Builder builder(SignedJwt signedJwt) {
        return new Builder(signedJwt);
    }

    public static EncryptedJwt create(SignedJwt signedJwt, Jwk jwk) {
        return builder(signedJwt).jwk(jwk).m2build();
    }

    public static EncryptedJwt parseToken(String str) {
        Errors.Collector collector = Errors.collector();
        Matcher matcher = JWE_PATTERN.matcher(str);
        if (!matcher.matches()) {
            throw new JwtException("Not a JWE token: " + str);
        }
        String group = matcher.group(1);
        return parse(str, collector, JwtHeaders.parseBase64(group, collector), matcher.group(2), matcher.group(3), matcher.group(4), matcher.group(5));
    }

    public static EncryptedJwt parseToken(JwtHeaders jwtHeaders, String str) {
        Errors.Collector collector = Errors.collector();
        Matcher matcher = JWE_PATTERN.matcher(str);
        if (matcher.matches()) {
            return parse(str, collector, jwtHeaders, matcher.group(2), matcher.group(3), matcher.group(4), matcher.group(5));
        }
        throw new JwtException("Not a JWE token: " + str);
    }

    private static EncryptedJwt parse(String str, Errors.Collector collector, JwtHeaders jwtHeaders, String str2, String str3, String str4, String str5) {
        byte[] decodeBytes = decodeBytes(str2, collector, "JWE encrypted key");
        byte[] decodeBytes2 = decodeBytes(str3, collector, "JWE initialization vector");
        byte[] decodeBytes3 = decodeBytes(str4, collector, "JWE payload");
        byte[] decodeBytes4 = decodeBytes(str5, collector, "JWE authentication tag");
        collector.collect().checkValid();
        return new EncryptedJwt(str, jwtHeaders, decodeBytes2, decodeBytes, decodeBytes4, decodeBytes3);
    }

    private static byte[] encryptRsa(String str, PublicKey publicKey, byte[] bArr) {
        try {
            Cipher cipher = Cipher.getInstance(str);
            cipher.init(1, publicKey);
            return cipher.doFinal(bArr);
        } catch (Exception e) {
            throw new JwtException("Exception during aes key decryption occurred.", e);
        }
    }

    private static byte[] decryptRsa(String str, PrivateKey privateKey, byte[] bArr) {
        try {
            Cipher cipher = Cipher.getInstance(str);
            cipher.init(2, privateKey);
            return cipher.doFinal(bArr);
        } catch (Exception e) {
            throw new JwtException("Exception during aes key decryption occurred.", e);
        }
    }

    private static String encode(String str) {
        return encode(str.getBytes(StandardCharsets.UTF_8));
    }

    private static String encode(byte[] bArr) {
        return URL_ENCODER.encodeToString(bArr);
    }

    private static byte[] decodeBytes(String str, Errors.Collector collector, String str2) {
        try {
            return URL_DECODER.decode(str);
        } catch (Exception e) {
            collector.fatal(str, str2 + " is not a base64 encoded string.");
            return null;
        }
    }

    public SignedJwt decrypt(JwkKeys jwkKeys) {
        return decrypt(jwkKeys, null);
    }

    public SignedJwt decrypt(Jwk jwk) {
        return decrypt(null, jwk);
    }

    public SignedJwt decrypt(JwkKeys jwkKeys, Jwk jwk) {
        Errors.Collector collector = Errors.collector();
        String encode = encode(this.header.headerJson().toString().getBytes(StandardCharsets.UTF_8));
        String orElse = this.header.algorithm().orElse(null);
        String orElse2 = this.header.keyId().orElse(null);
        String orElse3 = this.header.encryption().orElse(null);
        Jwk jwk2 = null;
        String str = null;
        if (orElse2 == null) {
            jwk2 = jwk;
            if (jwk2 == null) {
                collector.fatal("Could not find any suitable JWK.");
            }
        } else if (jwkKeys != null) {
            jwk2 = jwkKeys.forKeyId(orElse2).orElse(null);
        } else if (orElse2.equals(jwk.keyId())) {
            jwk2 = jwk;
        } else {
            collector.fatal("Could not find JWK for kid: " + orElse2);
        }
        if (orElse3 == null) {
            collector.fatal("Content encryption algorithm not set.");
        }
        if (orElse != null) {
            try {
                str = RSA_ALGORITHMS.get(SupportedAlgorithm.getValue(orElse));
            } catch (IllegalArgumentException e) {
                collector.fatal("Value of the claim alg not supported. alg: " + orElse);
            }
        } else {
            collector.fatal("No alg header was present among JWE headers");
        }
        PrivateKey privateKey = null;
        Jwk jwk3 = jwk2;
        if (jwk2 instanceof JwkRSA) {
            privateKey = (PrivateKey) ((JwkRSA) jwk2).privateKey().orElseGet(() -> {
                collector.fatal("No private key present in RSA JWK kid: " + jwk3.keyId());
                return null;
            });
        } else if (jwk2 instanceof JwkEC) {
            privateKey = (PrivateKey) ((JwkEC) jwk2).privateKey().orElseGet(() -> {
                collector.fatal("No private key present in EC JWK kid: " + jwk3.keyId());
                return null;
            });
        } else if (jwk2 != null) {
            collector.fatal("Not supported JWK type: " + jwk2.keyType() + ", JWK class: " + jwk2.getClass().getName());
        }
        collector.collect().checkValid();
        try {
            return SignedJwt.parseToken(new String(CONTENT_ENCRYPTION.get(SupportedEncryption.getValue(orElse3)).decrypt(new EncryptionParts(decryptRsa(str, privateKey, this.encryptedKey), this.iv, encode.getBytes(StandardCharsets.US_ASCII), this.encryptedPayload, this.authTag)), StandardCharsets.UTF_8));
        } catch (IllegalArgumentException e2) {
            throw new JwtException("Unsupported content encryption: " + orElse3);
        }
    }

    public JwtHeaders headers() {
        return this.header;
    }

    public String token() {
        return this.token;
    }

    public byte[] iv() {
        return Arrays.copyOf(this.iv, this.iv.length);
    }

    public byte[] encryptedKey() {
        return Arrays.copyOf(this.encryptedKey, this.encryptedKey.length);
    }

    public byte[] authTag() {
        return Arrays.copyOf(this.authTag, this.authTag.length);
    }

    public byte[] encryptedPayload() {
        return Arrays.copyOf(this.encryptedPayload, this.encryptedPayload.length);
    }
}
