package org.keycloak.crypto.fips;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.math.BigInteger;
import java.nio.charset.Charset;
import java.security.InvalidAlgorithmParameterException;
import java.security.Key;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPoint;
import java.security.spec.ECPublicKeySpec;
import java.security.spec.InvalidKeySpecException;
import org.bouncycastle.asn1.nist.NISTNamedCurves;
import org.bouncycastle.asn1.x9.X9ECParameters;
import org.bouncycastle.crypto.SymmetricSecretKey;
import org.bouncycastle.crypto.asymmetric.AsymmetricECPrivateKey;
import org.bouncycastle.crypto.asymmetric.AsymmetricECPublicKey;
import org.bouncycastle.crypto.asymmetric.ECDomainParameters;
import org.bouncycastle.crypto.fips.FipsAES;
import org.bouncycastle.crypto.fips.FipsEC;
import org.bouncycastle.crypto.fips.FipsKDF;
import org.bouncycastle.jcajce.spec.ECDomainParameterSpec;
import org.keycloak.common.util.Base64Url;
import org.keycloak.jose.jwe.JWEHeader;
import org.keycloak.jose.jwe.JWEKeyStorage;
import org.keycloak.jose.jwe.alg.JWEAlgorithmProvider;
import org.keycloak.jose.jwe.enc.JWEEncryptionProvider;
import org.keycloak.jose.jwk.ECPublicJWK;
import org.keycloak.jose.jwk.JWKUtil;

/* loaded from: input_file:org/keycloak/crypto/fips/BCFIPSEcdhEsAlgorithmProvider.class */
public class BCFIPSEcdhEsAlgorithmProvider implements JWEAlgorithmProvider {
    public byte[] decodeCek(byte[] bArr, Key key, JWEHeader jWEHeader, JWEEncryptionProvider jWEEncryptionProvider) throws Exception {
        byte[] deriveKey = deriveKey(toPublicKey(jWEHeader.getEphemeralPublicKey()), key, getKeyDataLength(jWEHeader.getAlgorithm(), jWEEncryptionProvider), getAlgorithmID(jWEHeader.getAlgorithm(), jWEHeader.getEncryptionAlgorithm()), base64UrlDecode(jWEHeader.getAgreementPartyUInfo()), base64UrlDecode(jWEHeader.getAgreementPartyVInfo()));
        if ("ECDH-ES".equals(jWEHeader.getAlgorithm())) {
            return deriveKey;
        }
        return new FipsAES.KeyWrapOperatorFactory().createKeyUnwrapper(new SymmetricSecretKey(FipsAES.KW, deriveKey), FipsAES.KW).unwrap(bArr, 0, bArr.length);
    }

    public byte[] encodeCek(JWEEncryptionProvider jWEEncryptionProvider, JWEKeyStorage jWEKeyStorage, Key key, JWEHeader.JWEHeaderBuilder jWEHeaderBuilder) throws Exception {
        JWEHeader build = jWEHeaderBuilder.build();
        int keyDataLength = getKeyDataLength(build.getAlgorithm(), jWEEncryptionProvider);
        KeyPair generateEcKeyPair = generateEcKeyPair(((ECPublicKey) key).getParams());
        ECPublicKey eCPublicKey = (ECPublicKey) generateEcKeyPair.getPublic();
        ECPrivateKey eCPrivateKey = (ECPrivateKey) generateEcKeyPair.getPrivate();
        byte[] base64UrlDecode = build.getAgreementPartyUInfo() != null ? base64UrlDecode(build.getAgreementPartyUInfo()) : new byte[0];
        byte[] base64UrlDecode2 = build.getAgreementPartyVInfo() != null ? base64UrlDecode(build.getAgreementPartyVInfo()) : new byte[0];
        jWEHeaderBuilder.ephemeralPublicKey(toECPublicJWK(eCPublicKey));
        byte[] deriveKey = deriveKey(key, eCPrivateKey, keyDataLength, getAlgorithmID(build.getAlgorithm(), build.getEncryptionAlgorithm()), base64UrlDecode, base64UrlDecode2);
        if (!"ECDH-ES".equals(build.getAlgorithm())) {
            byte[] cekBytes = jWEKeyStorage.getCekBytes();
            return new FipsAES.KeyWrapOperatorFactory().createKeyWrapper(new SymmetricSecretKey(FipsAES.KW, deriveKey), FipsAES.KW).wrap(cekBytes, 0, cekBytes.length);
        }
        jWEKeyStorage.setCEKBytes(deriveKey);
        jWEEncryptionProvider.deserializeCEK(jWEKeyStorage);
        return new byte[0];
    }

    private byte[] base64UrlDecode(String str) {
        return Base64Url.decode(str == null ? "" : str);
    }

    private static KeyPair generateEcKeyPair(ECParameterSpec eCParameterSpec) {
        try {
            KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC", "BCFIPS");
            keyPairGenerator.initialize(eCParameterSpec, SecureRandom.getInstance("DEFAULT", "BCFIPS"));
            return keyPairGenerator.generateKeyPair();
        } catch (InvalidAlgorithmParameterException | NoSuchAlgorithmException | NoSuchProviderException e) {
            throw new IllegalArgumentException(e);
        }
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [byte[], byte[][]] */
    private static byte[] deriveOtherInfo(int i, String str, byte[] bArr, byte[] bArr2) {
        return concat(new byte[]{encodeDataLengthData(str.getBytes(Charset.forName("ASCII"))), encodeDataLengthData(bArr), encodeDataLengthData(bArr2), toByteArray(i), emptyBytes()});
    }

    public static byte[] deriveKey(Key key, Key key2, int i, String str, byte[] bArr, byte[] bArr2) {
        return new FipsEC.DHAgreementFactory().createAgreement(new AsymmetricECPrivateKey(FipsEC.ALGORITHM, key2.getEncoded()), FipsEC.DH.withKDF(FipsKDF.CONCATENATION.withPRF(FipsKDF.AgreementKDFPRF.SHA256), deriveOtherInfo(i, str, bArr, bArr2), i / 8)).calculate(new AsymmetricECPublicKey(FipsEC.ALGORITHM, key.getEncoded()));
    }

    private static ECPublicJWK toECPublicJWK(ECPublicKey eCPublicKey) {
        ECPublicJWK eCPublicJWK = new ECPublicJWK();
        int fieldSize = eCPublicKey.getParams().getCurve().getField().getFieldSize();
        eCPublicJWK.setCrv("P-" + fieldSize);
        eCPublicJWK.setKeyType("EC");
        eCPublicJWK.setX(Base64Url.encode(JWKUtil.toIntegerBytes(eCPublicKey.getW().getAffineX(), fieldSize)));
        eCPublicJWK.setY(Base64Url.encode(JWKUtil.toIntegerBytes(eCPublicKey.getW().getAffineY(), fieldSize)));
        return eCPublicJWK;
    }

    private static PublicKey toPublicKey(ECPublicJWK eCPublicJWK) {
        String crv = eCPublicJWK.getCrv();
        String x = eCPublicJWK.getX();
        String y = eCPublicJWK.getY();
        if (crv == null) {
            throw new IllegalArgumentException("JWK crv must be set");
        }
        if (x == null) {
            throw new IllegalArgumentException("JWK x must be set");
        }
        if (y == null) {
            throw new IllegalArgumentException("JWK y must be set");
        }
        try {
            ECPoint eCPoint = new ECPoint(new BigInteger(1, Base64Url.decode(x)), new BigInteger(1, Base64Url.decode(y)));
            X9ECParameters byName = NISTNamedCurves.getByName(crv);
            return KeyFactory.getInstance("EC", "BCFIPS").generatePublic(new ECPublicKeySpec(eCPoint, new ECDomainParameterSpec(new ECDomainParameters(byName.getCurve(), byName.getG(), byName.getN(), byName.getH()))));
        } catch (NoSuchAlgorithmException | NoSuchProviderException | InvalidKeySpecException e) {
            throw new IllegalArgumentException(e);
        }
    }

    private static String getAlgorithmID(String str, String str2) {
        if ("ECDH-ES+A128KW".equals(str) || "ECDH-ES+A192KW".equals(str) || "ECDH-ES+A256KW".equals(str)) {
            return str;
        }
        if ("ECDH-ES".equals(str)) {
            return str2;
        }
        throw new IllegalArgumentException("Unsupported algorithm");
    }

    private static int getKeyDataLength(String str, JWEEncryptionProvider jWEEncryptionProvider) {
        if ("ECDH-ES+A128KW".equals(str)) {
            return 128;
        }
        if ("ECDH-ES+A192KW".equals(str)) {
            return 192;
        }
        if ("ECDH-ES+A256KW".equals(str)) {
            return 256;
        }
        if ("ECDH-ES".equals(str)) {
            return jWEEncryptionProvider.getExpectedCEKLength() * 8;
        }
        throw new IllegalArgumentException("Unsupported algorithm");
    }

    /* JADX WARN: Type inference failed for: r0v8, types: [byte[], byte[][]] */
    private static byte[] encodeDataLengthData(byte[] bArr) {
        byte[] bArr2 = bArr != null ? bArr : new byte[0];
        return concat(new byte[]{toByteArray(bArr2.length), bArr2});
    }

    private static byte[] emptyBytes() {
        return new byte[0];
    }

    private static byte[] toByteArray(int i) {
        return new byte[]{(byte) (i >> 24), (byte) (i >> 16), (byte) (i >> 8), (byte) i};
    }

    private static byte[] concat(byte[]... bArr) {
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            try {
                for (byte[] bArr2 : bArr) {
                    if (bArr2 != null) {
                        byteArrayOutputStream.write(bArr2);
                    }
                }
                byte[] byteArray = byteArrayOutputStream.toByteArray();
                byteArrayOutputStream.close();
                return byteArray;
            } finally {
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }
}
