/*
 * Decompiled with CFR 0.152.
 */
package tuwien.auto.calimero.secure;

import java.io.ByteArrayOutputStream;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.IntStream;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import tuwien.auto.calimero.GroupAddress;
import tuwien.auto.calimero.IndividualAddress;
import tuwien.auto.calimero.KNXAddress;
import tuwien.auto.calimero.KNXFormatException;
import tuwien.auto.calimero.KNXIllegalArgumentException;
import tuwien.auto.calimero.secure.KnxSecureException;
import tuwien.auto.calimero.xml.KNXMLException;
import tuwien.auto.calimero.xml.XmlInputFactory;
import tuwien.auto.calimero.xml.XmlReader;

public final class Keyring {
    private static final String keyringNamespace = "http://knx.org/xml/keyring/1";
    private static final byte[] keyringSalt = Keyring.utf8Bytes("1.keyring.ets.knx.org");
    private static final byte[] emptyPwd = new byte[0];
    private static final Logger logger = LoggerFactory.getLogger((String)"calimero.keyring");
    private final String keyringUri;
    private final char[] keyringPassword;
    private byte[] passwordHash = new byte[0];
    private byte[] createdHash = new byte[0];
    private byte[] signature;
    private volatile Backbone backbone;
    private volatile Map<IndividualAddress, List<Interface>> interfaces = Map.of();
    private volatile Map<GroupAddress, byte[]> groups = Map.of();
    private volatile Map<IndividualAddress, Device> devices = Map.of();
    private static final long DefaultMulticast = Keyring.unsigned(new byte[]{-32, 0, 23, 12});

    private static Optional<byte[]> optional(byte[] ba) {
        return Optional.ofNullable(ba).map(rec$ -> (byte[])((byte[])rec$).clone());
    }

    public static Keyring load(String keyringUri) {
        Keyring keyring = new Keyring(keyringUri, new char[0]);
        keyring.load();
        return keyring;
    }

    Keyring(String keyringUri, char[] keyringPassword) {
        if (!keyringUri.endsWith(".knxkeys")) {
            throw new KNXIllegalArgumentException("'" + keyringUri + "' is not a keyring file");
        }
        this.keyringUri = keyringUri;
        this.keyringPassword = keyringPassword;
    }

    void load() {
        int line = 0;
        try (XmlReader reader = XmlInputFactory.newInstance().createXMLReader(this.keyringUri);){
            reader.nextTag();
            String namespace = reader.getNamespaceURI();
            if (!keyringNamespace.equals(namespace)) {
                throw new KNXMLException("keyring '" + this.keyringUri + "' with unsupported namespace '" + namespace + "'");
            }
            if (!"Keyring".equals(reader.getLocalName())) {
                throw new KNXMLException("keyring '" + this.keyringUri + "' requires 'Keyring' element");
            }
            String project = reader.getAttributeValue(null, "Project");
            String createdBy = reader.getAttributeValue(null, "CreatedBy");
            String created = reader.getAttributeValue(null, "Created");
            logger.debug("read keyring for project '{}', created by {} on {}", new Object[]{project, createdBy, created});
            this.passwordHash = Keyring.hashKeyringPwd(this.keyringPassword);
            this.createdHash = Keyring.sha256(Keyring.utf8Bytes(created));
            this.signature = Keyring.decode(reader.getAttributeValue(null, "Signature"));
            if (this.keyringPassword.length > 0 && !this.verifySignature(this.passwordHash)) {
                String msg = "signature verification failed for keyring '" + this.keyringUri + "'";
                boolean strictVerification = true;
                throw new KnxSecureException(msg);
            }
            Interface iface = null;
            boolean inDevices = false;
            boolean inGroupAddresses = false;
            HashMap<IndividualAddress, List> interfaces = new HashMap<IndividualAddress, List>();
            HashMap<GroupAddress, byte[]> groups = new HashMap<GroupAddress, byte[]>();
            HashMap<GroupAddress, Device> devices = new HashMap<GroupAddress, Device>();
            reader.next();
            while (reader.getEventType() != 8) {
                int event = reader.getEventType();
                if (reader.getEventType() != 1) {
                    if (event == 2 && "Interface".equals(reader.getLocalName()) && iface != null) {
                        iface.groups = Map.copyOf(iface.groups);
                        logger.trace("add {}", iface);
                        iface = null;
                    }
                } else {
                    KNXAddress addr;
                    String name = reader.getLocalName();
                    line = reader.getLocation().getLineNumber();
                    if ("Backbone".equals(name)) {
                        InetAddress mcastGroup = InetAddress.getByName(reader.getAttributeValue(null, "MulticastAddress"));
                        if (!Keyring.validRoutingMulticast(mcastGroup)) {
                            throw new KNXMLException("loading keyring '" + this.keyringUri + "': " + mcastGroup.getHostAddress() + " is not a valid KNX multicast address");
                        }
                        byte[] groupKey = Keyring.decode(reader.getAttributeValue(null, "Key"));
                        Duration latency = Duration.ofMillis(Integer.parseInt(reader.getAttributeValue(null, "Latency")));
                        this.backbone = new Backbone(mcastGroup, groupKey, latency);
                    } else if ("Interface".equals(name)) {
                        inGroupAddresses = false;
                        String type = reader.getAttributeValue(null, "Type");
                        String attr = reader.getAttributeValue(null, "Host");
                        IndividualAddress host = attr != null ? new IndividualAddress(attr) : new IndividualAddress(0);
                        attr = reader.getAttributeValue(null, "IndividualAddress");
                        IndividualAddress addr2 = attr != null ? new IndividualAddress(attr) : new IndividualAddress(0);
                        Integer user = Keyring.readAttribute(reader, "UserID", Integer::parseInt, 0);
                        byte[] pwd = Keyring.readAttribute(reader, "Password", Keyring::decode, null);
                        byte[] auth = Keyring.readAttribute(reader, "Authentication", Keyring::decode, null);
                        iface = new Interface(type, addr2, user, pwd, auth);
                        interfaces.computeIfAbsent(host, key -> new ArrayList()).add(iface);
                    } else if (iface != null && "Group".equals(name)) {
                        addr = new GroupAddress(reader.getAttributeValue(null, "Address"));
                        String senders = reader.getAttributeValue(null, "Senders");
                        ArrayList<IndividualAddress> list = new ArrayList<IndividualAddress>();
                        Matcher matcher = Pattern.compile("[^\\s]+").matcher(senders);
                        while (matcher.find()) {
                            list.add(new IndividualAddress(matcher.group()));
                        }
                        if (iface.groups.isEmpty()) {
                            iface.groups = new HashMap<GroupAddress, Set<IndividualAddress>>();
                        }
                        iface.groups.put((GroupAddress)addr, Set.of(list.toArray(new IndividualAddress[0])));
                    } else if ("Devices".equals(name)) {
                        inDevices = true;
                    } else if (inDevices && "Device".equals(name)) {
                        addr = new IndividualAddress(reader.getAttributeValue(null, "IndividualAddress"));
                        byte[] toolkey = Keyring.readAttribute(reader, "ToolKey", Keyring::decode, null);
                        Long seq = Keyring.readAttribute(reader, "SequenceNumber", Long::parseLong, 0L);
                        byte[] pwd = Keyring.readAttribute(reader, "ManagementPassword", Keyring::decode, null);
                        byte[] auth = Keyring.readAttribute(reader, "Authentication", Keyring::decode, null);
                        Device device = new Device((IndividualAddress)addr, toolkey, pwd, auth, seq);
                        devices.put((GroupAddress)addr, device);
                        logger.trace("add {}", (Object)device);
                    } else if ("GroupAddresses".equals(name)) {
                        inGroupAddresses = true;
                    } else if (inGroupAddresses && "Group".equals(name)) {
                        addr = new GroupAddress(reader.getAttributeValue(null, "Address"));
                        byte[] key2 = Keyring.decode(reader.getAttributeValue(null, "Key"));
                        groups.put((GroupAddress)addr, key2);
                    } else {
                        logger.warn("keyring '" + this.keyringUri + "': skip unknown element '{}'", (Object)name);
                    }
                }
                reader.next();
            }
            this.interfaces = Map.copyOf(interfaces);
            this.groups = Map.copyOf(groups);
            this.devices = Map.copyOf(devices);
        }
        catch (UnknownHostException | KNXFormatException e) {
            Object location = line != 0 ? " [line " + line + "]" : "";
            throw new KNXMLException("loading keyring '" + this.keyringUri + "'" + (String)location + " address element with " + e.getMessage());
        }
        catch (GeneralSecurityException e) {
            throw new KnxSecureException("crypto error", e);
        }
        finally {
            Arrays.fill(this.passwordHash, (byte)0);
        }
    }

    public boolean verifySignature(char[] keyringPassword) {
        try {
            return this.verifySignature(Keyring.hashKeyringPwd(keyringPassword));
        }
        catch (GeneralSecurityException e) {
            return false;
        }
    }

    public Optional<Backbone> backbone() {
        return Optional.ofNullable(this.backbone);
    }

    public Map<IndividualAddress, List<Interface>> interfaces() {
        return this.interfaces;
    }

    public Map<GroupAddress, byte[]> groups() {
        return this.groups;
    }

    public Map<IndividualAddress, Device> devices() {
        return this.devices;
    }

    private boolean verifySignature(byte[] passwordHash) throws GeneralSecurityException {
        ByteArrayOutputStream output = new ByteArrayOutputStream();
        try (XmlReader reader = XmlInputFactory.newInstance().createXMLReader(this.keyringUri);){
            while (reader.next() != 8) {
                if (reader.getEventType() == 1) {
                    Keyring.appendElement(reader, output);
                    continue;
                }
                if (reader.getEventType() != 2) continue;
                output.write(2);
            }
        }
        Keyring.appendString(Base64.getEncoder().encode(passwordHash), output);
        byte[] outputHash = Keyring.sha256(output.toByteArray());
        return Arrays.equals(outputHash, this.signature);
    }

    public byte[] decryptKey(byte[] input, char[] keyringPassword) {
        byte[] pwdHash = Keyring.hashKeyringPwd(keyringPassword);
        try {
            byte[] byArray = Keyring.aes128Cbc(input, pwdHash, this.createdHash);
            return byArray;
        }
        catch (RuntimeException | GeneralSecurityException e) {
            throw new KnxSecureException("decrypting key data", e);
        }
        finally {
            Arrays.fill(pwdHash, (byte)0);
        }
    }

    public char[] decryptPassword(byte[] input, char[] keyringPassword) {
        byte[] keyringPwdHash = Keyring.hashKeyringPwd(keyringPassword);
        try {
            byte[] pwdData = Keyring.extractPassword(Keyring.aes128Cbc(input, keyringPwdHash, this.createdHash));
            char[] chars = new char[pwdData.length];
            for (int i = 0; i < pwdData.length; ++i) {
                chars[i] = (char)(pwdData[i] & 0xFF);
            }
            Arrays.fill(pwdData, (byte)0);
            char[] cArray = chars;
            return cArray;
        }
        catch (RuntimeException | GeneralSecurityException e) {
            throw new KnxSecureException("decrypting password data", e);
        }
        finally {
            Arrays.fill(keyringPwdHash, (byte)0);
        }
    }

    private static boolean validRoutingMulticast(InetAddress address) {
        return address != null && address.isMulticastAddress() && Keyring.unsigned(address.getAddress()) >= DefaultMulticast;
    }

    private static long unsigned(byte[] data) {
        long l = 0L;
        for (byte b : data) {
            l = l << 8 | (long)(b & 0xFF);
        }
        return l;
    }

    private static <R> R readAttribute(XmlReader reader, String attribute, Function<String, R> parser, R defaultValue) {
        String attr = reader.getAttributeValue(null, attribute);
        if (attr == null) {
            return defaultValue;
        }
        return parser.apply(attr);
    }

    private static void appendElement(XmlReader reader, ByteArrayOutputStream output) {
        output.write(1);
        Keyring.appendString(Keyring.utf8Bytes(reader.getLocalName()), output);
        IntStream.range(0, reader.getAttributeCount()).mapToObj(reader::getAttributeLocalName).filter(Predicate.not(Predicate.isEqual("xmlns").or(Predicate.isEqual("Signature")))).sorted().forEach(attr -> Keyring.appendAttribute(attr, reader, output));
    }

    private static void appendAttribute(String attr, XmlReader reader, ByteArrayOutputStream output) {
        Keyring.appendString(Keyring.utf8Bytes(attr), output);
        Keyring.appendString(Keyring.utf8Bytes(reader.getAttributeValue(null, attr)), output);
    }

    private static void appendString(byte[] str, ByteArrayOutputStream output) {
        output.write(str.length);
        output.write(str, 0, str.length);
    }

    private static byte[] decode(String base64) {
        return Base64.getDecoder().decode(base64);
    }

    private static byte[] extractPassword(byte[] data) {
        if (data.length == 0) {
            return emptyPwd;
        }
        int b = data[data.length - 1] & 0xFF;
        byte[] range = Arrays.copyOfRange(data, 8, data.length - b);
        return range;
    }

    private static byte[] hashKeyringPwd(char[] keyringPwd) {
        try {
            return Keyring.pbkdf2WithHmacSha256(keyringPwd, keyringSalt);
        }
        catch (GeneralSecurityException e) {
            throw new KnxSecureException("hashing keyring password", e);
        }
    }

    private static byte[] aes128Cbc(byte[] input, byte[] key, byte[] iv) throws GeneralSecurityException {
        Cipher cipher = Cipher.getInstance("AES/CBC/NoPadding");
        SecretKeySpec keySpec = new SecretKeySpec(key, "AES");
        IvParameterSpec params = new IvParameterSpec(iv);
        cipher.init(2, (Key)keySpec, params);
        return cipher.doFinal(input);
    }

    private static byte[] sha256(byte[] input) throws NoSuchAlgorithmException {
        MessageDigest digest = MessageDigest.getInstance("SHA-256");
        digest.update(input);
        return Arrays.copyOf(digest.digest(), 16);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static byte[] pbkdf2WithHmacSha256(char[] password, byte[] salt) throws GeneralSecurityException {
        int iterations = 65536;
        int keyLength = 128;
        PBEKeySpec keySpec = new PBEKeySpec(password, salt, 65536, 128);
        try {
            SecretKeyFactory secretKeyFactory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256");
            SecretKey secretKey = secretKeyFactory.generateSecret(keySpec);
            byte[] byArray = secretKey.getEncoded();
            return byArray;
        }
        finally {
            keySpec.clearPassword();
        }
    }

    private static byte[] utf8Bytes(String s) {
        return s.getBytes(StandardCharsets.UTF_8);
    }

    public static final class Interface {
        private final String type;
        private final IndividualAddress addr;
        private final int user;
        private final byte[] pwd;
        private final byte[] auth;
        private volatile Map<GroupAddress, Set<IndividualAddress>> groups = Map.of();

        Interface(String type, IndividualAddress addr, int user, byte[] pwd, byte[] auth) {
            this.type = type;
            this.addr = addr;
            this.user = user;
            this.pwd = pwd;
            this.auth = auth;
        }

        public IndividualAddress address() {
            return this.addr;
        }

        public int user() {
            return this.user;
        }

        public Optional<byte[]> password() {
            return Keyring.optional(this.pwd);
        }

        public Optional<byte[]> authentication() {
            return Keyring.optional(this.auth);
        }

        public Map<GroupAddress, Set<IndividualAddress>> groups() {
            return this.groups;
        }

        public String toString() {
            return this.type + " interface " + this.addr + ", user " + this.user + ", groups " + this.groups.keySet();
        }
    }

    public static final class Backbone {
        private final InetAddress mcGroup;
        private final byte[] groupKey;
        private final Duration latency;

        Backbone(InetAddress multicastGroup, byte[] groupKey, Duration latency) {
            this.mcGroup = multicastGroup;
            this.groupKey = (byte[])groupKey.clone();
            this.latency = latency;
        }

        public InetAddress multicastGroup() {
            return this.mcGroup;
        }

        public byte[] groupKey() {
            return (byte[])this.groupKey.clone();
        }

        public Duration latencyTolerance() {
            return this.latency;
        }

        public int hashCode() {
            return Objects.hash(this.mcGroup, Arrays.hashCode(this.groupKey), this.latency);
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof Backbone)) {
                return false;
            }
            Backbone other = (Backbone)obj;
            return Objects.equals(this.mcGroup, other.mcGroup) && Objects.equals(this.latency, other.latency) && Arrays.equals(this.groupKey, other.groupKey);
        }

        public String toString() {
            return this.multicastGroup().getHostAddress() + " (latency tolerance " + this.latency.toMillis() + " ms)";
        }
    }

    public static final class Device {
        private final IndividualAddress addr;
        private final byte[] toolkey;
        private final byte[] pwd;
        private final byte[] auth;
        private final long sequence;

        Device(IndividualAddress addr, byte[] toolkey, byte[] pwd, byte[] auth, long sequence) {
            this.addr = addr;
            this.toolkey = toolkey;
            this.pwd = pwd;
            this.auth = auth;
            this.sequence = sequence;
        }

        public Optional<byte[]> toolKey() {
            return Keyring.optional(this.toolkey);
        }

        public Optional<byte[]> password() {
            return Keyring.optional(this.pwd);
        }

        public Optional<byte[]> authentication() {
            return Keyring.optional(this.auth);
        }

        public long sequenceNumber() {
            return this.sequence;
        }

        public String toString() {
            return "device " + this.addr + " (seq " + this.sequence + ")";
        }
    }
}

