/*
* Copyright (c) 2013, Oracle and/or its affiliates. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
*
*   - Redistributions of source code must retain the above copyright
*     notice, this list of conditions and the following disclaimer.
*
*   - Redistributions in binary form must reproduce the above copyright
*     notice, this list of conditions and the following disclaimer in the
*     documentation and/or other materials provided with the distribution.
*
*   - Neither the name of Oracle or the names of its
*     contributors may be used to endorse or promote products derived
*     from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
* IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
* THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package org.wildfly.security.ssl;

import java.nio.ByteBuffer;
import java.nio.BufferUnderflowException;
import java.io.IOException;

import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLException;
import javax.net.ssl.StandardConstants;

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.wildfly.security._private.ElytronMessages;

/**
 * Instances of this class acts as an explorer of the network data of an
 * SSL/TLS connection.
 */
final class SSLExplorer {

    private static final MechanismDatabase database = MechanismDatabase.getInstance();

    // Private constructor prevents construction outside this class.
    private SSLExplorer() {
    }

    /**
     * The header size of TLS/SSL records.
     * <P>
     * The value of this constant is {@value}.
     */
    public static final int RECORD_HEADER_SIZE = 0x05;

    /**
     * Returns the required number of bytes in the {@code source}
     * {@link ByteBuffer} necessary to explore SSL/TLS connection.
     * <P>
     * This method tries to parse as few bytes as possible from
     * {@code source} byte buffer to get the length of an
     * SSL/TLS record.
     * <P>
     * This method accesses the {@code source} parameter in read-only
     * mode, and does not update the buffer's properties such as capacity,
     * limit, position, and mark values.
     *
     * @param  source
     *         a {@link ByteBuffer} containing
     *         inbound or outbound network data for an SSL/TLS connection.
     * @throws BufferUnderflowException if less than {@code RECORD_HEADER_SIZE}
     *         bytes remaining in {@code source}
     * @return the required size in byte to explore an SSL/TLS connection
     */
    public static int getRequiredSize(ByteBuffer source) {

        ByteBuffer input = source.duplicate();

        // Do we have a complete header?
        if (input.remaining() < RECORD_HEADER_SIZE) {
            throw new BufferUnderflowException();
        }

        // Is it a handshake message?
        byte firstByte = input.get();
        byte secondByte = input.get();
        byte thirdByte = input.get();
        if ((firstByte & 0x80) != 0 && thirdByte == 0x01) {
            // looks like a V2ClientHello
            // return (((firstByte & 0x7F) << 8) | (secondByte & 0xFF)) + 2;
            return RECORD_HEADER_SIZE;   // Only need the header fields
        } else {
            return ((input.get() & 0xFF) << 8 | input.get() & 0xFF) + 5;
        }
    }

    /**
     * Returns the required number of bytes in the {@code source} byte array
     * necessary to explore SSL/TLS connection.
     * <P>
     * This method tries to parse as few bytes as possible from
     * {@code source} byte array to get the length of an
     * SSL/TLS record.
     *
     * @param  source
     *         a byte array containing inbound or outbound network data for
     *         an SSL/TLS connection.
     * @param  offset
     *         the start offset in array {@code source} at which the
     *         network data is read from.
     * @param  length
     *         the maximum number of bytes to read.
     *
     * @throws BufferUnderflowException if less than {@code RECORD_HEADER_SIZE}
     *         bytes remaining in {@code source}
     * @return the required size in byte to explore an SSL/TLS connection
     */
    public static int getRequiredSize(byte[] source,
            int offset, int length) throws IOException {

        ByteBuffer byteBuffer =
            ByteBuffer.wrap(source, offset, length).asReadOnlyBuffer();
        return getRequiredSize(byteBuffer);
    }

    /**
     * Launch and explore the security capabilities from byte buffer.
     * <P>
     * This method tries to parse as few records as possible from
     * {@code source} byte buffer to get the capabilities
     * of an SSL/TLS connection.
     * <P>
     * Please NOTE that this method must be called before any handshaking
     * occurs.  The behavior of this method is not defined in this release
     * if the handshake has begun, or has completed.
     * <P>
     * This method accesses the {@code source} parameter in read-only
     * mode, and does not update the buffer's properties such as capacity,
     * limit, position, and mark values.
     *
     * @param  source
     *         a {@link ByteBuffer} containing
     *         inbound or outbound network data for an SSL/TLS connection.
     *
     * @throws IOException on network data error
     * @throws BufferUnderflowException if not enough source bytes available
     *         to make a complete exploration.
     *
     * @return the explored capabilities of the SSL/TLS
     *         connection
     */
    public static SSLConnectionInformationImpl explore(ByteBuffer source)
            throws SSLException {

        ByteBuffer input = source.duplicate();

        // Do we have a complete header?
        if (input.remaining() < RECORD_HEADER_SIZE) {
            throw new BufferUnderflowException();
        }

        // Is it a handshake message?
        byte firstByte = input.get();
        byte secondByte = input.get();
        byte thirdByte = input.get();
        if ((firstByte & 0x80) != 0 && thirdByte == 0x01) {
            // looks like a V2ClientHello
            return exploreV2HelloRecord(input,
                                    firstByte, secondByte, thirdByte);
        } else if (firstByte == 22) {   // 22: handshake record
            return exploreTLSRecord(input,
                                    firstByte, secondByte, thirdByte);
        } else {
            throw ElytronMessages.log.notHandshakeRecord();
        }
    }

    /**
     * Launch and explore the security capabilities from byte array.
     * <P>
     * Please NOTE that this method must be called before any handshaking
     * occurs.  The behavior of this method is not defined in this release
     * if the handshake has begun, or has completed.  Once handshake has
     * begun, or has completed, the security capabilities can not and
     * should not be launched with this method.
     *
     * @param  source
     *         a byte array containing inbound or outbound network data for
     *         an SSL/TLS connection.
     * @param  offset
     *         the start offset in array {@code source} at which the
     *         network data is read from.
     * @param  length
     *         the maximum number of bytes to read.
     *
     * @throws IOException on network data error
     * @throws BufferUnderflowException if not enough source bytes available
     *         to make a complete exploration.
     * @return the explored capabilities of the SSL/TLS
     *         connection
     *
     * @see #explore(ByteBuffer)
     */
    public static SSLConnectionInformationImpl explore(byte[] source,
            int offset, int length) throws IOException {
        ByteBuffer byteBuffer =
            ByteBuffer.wrap(source, offset, length).asReadOnlyBuffer();
        return explore(byteBuffer);
    }

    /*
     * uint8 V2CipherSpec[3];
     * struct {
     *     uint16 msg_length;         // The highest bit MUST be 1;
     *                                // the remaining bits contain the length
     *                                // of the following data in bytes.
     *     uint8 msg_type;            // MUST be 1
     *     Version version;
     *     uint16 cipher_spec_length; // It cannot be zero and MUST be a
     *                                // multiple of the V2CipherSpec length.
     *     uint16 session_id_length;  // This field MUST be empty.
     *     uint16 challenge_length;   // SHOULD use a 32-byte challenge
     *     V2CipherSpec cipher_specs[V2ClientHello.cipher_spec_length];
     *     opaque session_id[V2ClientHello.session_id_length];
     *     opaque challenge[V2ClientHello.challenge_length;
     * } V2ClientHello;
     */
    private static SSLConnectionInformationImpl exploreV2HelloRecord(
            ByteBuffer input, byte firstByte, byte secondByte,
            byte thirdByte) throws SSLException {

        // We only need the header. We have already had enough source bytes.
        // int recordLength = (firstByte & 0x7F) << 8) | (secondByte & 0xFF);
        try {
            // Is it a V2ClientHello?
            if (thirdByte != 0x01) {
                throw ElytronMessages.log.unsupportedSslRecord();
            }

            // What's the hello version?
            byte helloVersionMajor = input.get();
            byte helloVersionMinor = input.get();

            int csLen = getInt16(input); // in units of 3 bytes
            input.getShort(); // session_id_length
            input.getShort(); // challenge_length

            List<String> ciphers = new ArrayList<>();

            while (csLen >= 3) {
                int lead = getInt8(input);
                int byte1 = getInt8(input);
                int byte2 = getInt8(input);
                if (lead == 0) {
                    final MechanismDatabase.Entry entry = database.getCipherSuiteById(byte1, byte2);
                    if (entry != null) ciphers.add(entry.getName());
                }
                // skip any non-TLS cipher suites
                csLen -= 3;
            }

            // 0x00: major version of SSLv20
            // 0x02: minor version of SSLv20
            //
            // SNIServerName is an extension, SSLv20 doesn't support extension.
            return new SSLConnectionInformationImpl((byte)0x00, (byte)0x02,
                        helloVersionMajor, helloVersionMinor,
                        Collections.emptyList(), Collections.emptyList(),
                        ciphers.isEmpty() ? Collections.emptyList() : ciphers);
        } catch (BufferUnderflowException ignored) {
            throw ElytronMessages.log.invalidHandshakeRecord();
        }
    }

    /*
     * struct {
     *     uint8 major;
     *     uint8 minor;
     * } ProtocolVersion;
     *
     * enum {
     *     change_cipher_spec(20), alert(21), handshake(22),
     *     application_data(23), (255)
     * } ContentType;
     *
     * struct {
     *     ContentType type;
     *     ProtocolVersion version;
     *     uint16 length;
     *     opaque fragment[TLSPlaintext.length];
     * } TLSPlaintext;
     */
    private static SSLConnectionInformationImpl exploreTLSRecord(
            ByteBuffer input, byte firstByte, byte secondByte,
            byte thirdByte) throws SSLException {

        // Is it a handshake message?
        if (firstByte != 22) {        // 22: handshake record
            throw ElytronMessages.log.notHandshakeRecord();
        }

        // Is there enough data for a full record?
        int recordLength = getInt16(input);
        if (recordLength > input.remaining()) {
            throw new BufferUnderflowException();
        }

        // We have already had enough source bytes.
        try {
            return exploreHandshake(input,
                secondByte, thirdByte, recordLength);
        } catch (BufferUnderflowException ignored) {
            throw ElytronMessages.log.invalidHandshakeRecord();
        }
    }

    /*
     * enum {
     *     hello_request(0), client_hello(1), server_hello(2),
     *     certificate(11), server_key_exchange (12),
     *     certificate_request(13), server_hello_done(14),
     *     certificate_verify(15), client_key_exchange(16),
     *     finished(20)
     *     (255)
     * } HandshakeType;
     *
     * struct {
     *     HandshakeType msg_type;
     *     uint24 length;
     *     select (HandshakeType) {
     *         case hello_request:       HelloRequest;
     *         case client_hello:        ClientHello;
     *         case server_hello:        ServerHello;
     *         case certificate:         Certificate;
     *         case server_key_exchange: ServerKeyExchange;
     *         case certificate_request: CertificateRequest;
     *         case server_hello_done:   ServerHelloDone;
     *         case certificate_verify:  CertificateVerify;
     *         case client_key_exchange: ClientKeyExchange;
     *         case finished:            Finished;
     *     } body;
     * } Handshake;
     */
    private static SSLConnectionInformationImpl exploreHandshake(
            ByteBuffer input, byte recordMajorVersion,
            byte recordMinorVersion, int recordLength) throws SSLException {

        // What is the handshake type?
        byte handshakeType = input.get();
        if (handshakeType != 0x01) {   // 0x01: client_hello message
            throw ElytronMessages.log.expectedClientHello();
        }

        // What is the handshake body length?
        int handshakeLength = getInt24(input);

        // Theoretically, a single handshake message might span multiple
        // records, but in practice this does not occur.
        if (handshakeLength > recordLength - 4) { // 4: handshake header size
            throw ElytronMessages.log.multiRecordSSLHandshake();
        }

        input = input.duplicate();
        input.limit(handshakeLength + input.position());
        return exploreClientHello(input,
                                    recordMajorVersion, recordMinorVersion);
    }

    /*
     * struct {
     *     uint32 gmt_unix_time;
     *     opaque random_bytes[28];
     * } Random;
     *
     * opaque SessionID<0..32>;
     *
     * uint8 CipherSuite[2];
     *
     * enum { null(0), (255) } CompressionMethod;
     *
     * struct {
     *     ProtocolVersion client_version;
     *     Random random;
     *     SessionID session_id;
     *     CipherSuite cipher_suites<2..2^16-2>;
     *     CompressionMethod compression_methods<1..2^8-1>;
     *     select (extensions_present) {
     *         case false:
     *             struct {};
     *         case true:
     *             Extension extensions<0..2^16-1>;
     *     };
     * } ClientHello;
     */
    private static SSLConnectionInformationImpl exploreClientHello(
            ByteBuffer input,
            byte recordMajorVersion,
            byte recordMinorVersion) throws SSLException {

        ExtensionInfo info = null;

        // client version
        byte helloMajorVersion = input.get();
        byte helloMinorVersion = input.get();

        // ignore random
        int position = input.position();
        input.position(position + 32);  // 32: the length of Random

        // ignore session id
        ignoreByteVector8(input);

        ArrayList<String> ciphers = new ArrayList<>();

        // ignore cipher_suites
        int csLen = getInt16(input);
        while (csLen > 0) {
            int byte1 = getInt8(input);
            int byte2 = getInt8(input);
            final MechanismDatabase.Entry entry = database.getCipherSuiteById(byte1, byte2);
            if (entry != null) ciphers.add(entry.getName());
            csLen -= 2;
        }

        // ignore compression methods
        ignoreByteVector8(input);

        if (input.remaining() > 0) {
            info = exploreExtensions(input);
        }

        final List<SNIServerName> snList = info != null ? info.sni : Collections.emptyList();
        final List<String> alpnProtocols = info != null ? info.alpn : Collections.emptyList();

        return new SSLConnectionInformationImpl(
                recordMajorVersion, recordMinorVersion,
                helloMajorVersion, helloMinorVersion, snList, alpnProtocols,
                ciphers.isEmpty() ? Collections.emptyList() : ciphers);
    }

    /*
     * struct {
     *     ExtensionType extension_type;
     *     opaque extension_data<0..2^16-1>;
     * } Extension;
     *
     * enum {
     *     server_name(0), max_fragment_length(1),
     *     client_certificate_url(2), trusted_ca_keys(3),
     *     truncated_hmac(4), status_request(5), (65535)
     * } ExtensionType;
     */
    private static ExtensionInfo exploreExtensions(ByteBuffer input)
            throws SSLException {

        List<SNIServerName> sni = Collections.emptyList();
        List<String> alpn = Collections.emptyList();

        int length = getInt16(input);           // length of extensions
        while (length > 0) {
            int extType = getInt16(input);      // extension type
            int extLen = getInt16(input);       // length of extension data

            if (extType == 0x00) {      // 0x00: type of server name indication
                sni = exploreSNIExt(input, extLen);
            } else if (extType == 0x10) { // 0x10: type of alpn
                alpn = exploreALPN(input, extLen);
            } else {                    // ignore other extensions
                ignoreByteVector(input, extLen);
            }

            length -= extLen + 4;
        }

        return new ExtensionInfo(sni, alpn);
    }

    /*
     * opaque ProtocolName<1..2^8-1>;
     *
     * struct {
     *     ProtocolName protocol_name_list<2..2^16-1>
     * } ProtocolNameList;
     *
     */
    private static List<String> exploreALPN(ByteBuffer input,
            int extLen) throws SSLException {
        final ArrayList<String> strings = new ArrayList<>();

        int rem = extLen;
        if (extLen >= 2) {
            int listLen = getInt16(input);
            if (listLen == 0 || listLen + 2 != extLen) {
                throw ElytronMessages.log.invalidTlsExt();
            }

            rem -= 2;
            while (rem > 0) {
                int len = getInt8(input);
                if (len > rem) {
                    throw ElytronMessages.log.notEnoughData();
                }
                byte[] b = new byte[len];
                input.get(b);
                strings.add(new String(b, StandardCharsets.UTF_8));

                rem -= len + 1;
            }
        }
        return strings.isEmpty() ? Collections.emptyList() : strings;
    }

    /*
     * struct {
     *     NameType name_type;
     *     select (name_type) {
     *         case host_name: HostName;
     *     } name;
     * } ServerName;
     *
     * enum {
     *     host_name(0), (255)
     * } NameType;
     *
     * opaque HostName<1..2^16-1>;
     *
     * struct {
     *     ServerName server_name_list<1..2^16-1>
     * } ServerNameList;
     */
    private static List<SNIServerName> exploreSNIExt(ByteBuffer input,
            int extLen) throws SSLException {

        Map<Integer, SNIServerName> sniMap = new LinkedHashMap<>();

        int remains = extLen;
        if (extLen >= 2) {     // "server_name" extension in ClientHello
            int listLen = getInt16(input);     // length of server_name_list
            if (listLen == 0 || listLen + 2 != extLen) {
                throw ElytronMessages.log.invalidTlsExt();
            }

            remains -= 2;     // 0x02: the length field of server_name_list
            while (remains > 0) {
                int code = getInt8(input);      // name_type
                int snLen = getInt16(input);    // length field of server name
                if (snLen > remains) {
                    throw ElytronMessages.log.notEnoughData();
                }
                byte[] encoded = new byte[snLen];
                input.get(encoded);

                SNIServerName serverName;
                switch (code) {
                    case StandardConstants.SNI_HOST_NAME:
                        if (encoded.length == 0) {
                            throw ElytronMessages.log.emptyHostNameSni();
                        }
                        serverName = new SNIHostName(encoded);
                        break;
                    default:
                        serverName = new UnknownServerName(code, encoded);
                }
                // check for duplicated server name type
                if (sniMap.put(serverName.getType(), serverName) != null) {
                    throw ElytronMessages.log.duplicatedSniServerName(serverName.getType());
                }

                remains -= encoded.length + 3;  // NameType: 1 byte
                                                // HostName length: 2 bytes
            }
        } else if (extLen == 0) {     // "server_name" extension in ServerHello
            throw ElytronMessages.log.invalidTlsExt();
        }

        if (remains != 0) {
            throw ElytronMessages.log.invalidTlsExt();
        }

        return Collections.unmodifiableList(new ArrayList<>(sniMap.values()));
    }

    private static int getInt8(ByteBuffer input) {
        return input.get();
    }

    private static int getInt16(ByteBuffer input) {
        return (input.get() & 0xFF) << 8 | input.get() & 0xFF;
    }

    private static int getInt24(ByteBuffer input) {
        return (input.get() & 0xFF) << 16 | (input.get() & 0xFF) << 8 |
            input.get() & 0xFF;
    }

    private static void ignoreByteVector8(ByteBuffer input) {
        ignoreByteVector(input, getInt8(input));
    }

    private static void ignoreByteVector16(ByteBuffer input) {
        ignoreByteVector(input, getInt16(input));
    }

    private static void ignoreByteVector24(ByteBuffer input) {
        ignoreByteVector(input, getInt24(input));
    }

    private static void ignoreByteVector(ByteBuffer input, int length) {
        if (length != 0) {
            int position = input.position();
            input.position(position + length);
        }
    }

    static final class UnknownServerName extends SNIServerName {
        UnknownServerName(int code, byte[] encoded) {
            super(code, encoded);
        }
    }

    static final class ExtensionInfo {
        final List<SNIServerName> sni;
        final List<String> alpn;

        ExtensionInfo(final List<SNIServerName> sni, final List<String> alpn) {
            this.sni = sni;
            this.alpn = alpn;
        }
    }

    static final class SSLConnectionInformationImpl implements SSLConnectionInformation {

        private final String recordVersion;
        private final String helloVersion;
        private final List<SNIServerName> sniNames;
        private final List<String> alpnProtocols;
        private final List<String> ciphers;

        SSLConnectionInformationImpl(byte recordMajorVersion, byte recordMinorVersion,
            byte helloMajorVersion, byte helloMinorVersion,
            List<SNIServerName> sniNames, final List<String> alpnProtocols, final List<String> ciphers) {

            this.recordVersion = getVersionString(recordMajorVersion, recordMinorVersion);
            this.helloVersion = getVersionString(helloMajorVersion, helloMinorVersion);
            this.sniNames = sniNames;
            this.alpnProtocols = alpnProtocols;
            this.ciphers = ciphers;
        }

        private static String getVersionString(final byte helloMajorVersion, final byte helloMinorVersion) {
            switch (helloMajorVersion) {
                case 0x00: {
                    switch (helloMinorVersion) {
                        case 0x02: return "SSLv2Hello";
                        default: return unknownVersion(helloMajorVersion, helloMinorVersion);
                    }
                }
                case 0x03: {
                    switch (helloMinorVersion) {
                        case 0x00: return "SSLv3";
                        case 0x01: return "TLSv1";
                        case 0x02: return "TLSv1.1";
                        case 0x03: return "TLSv1.2";
                        case 0x04: return "TLSv1.3";
                        default: return unknownVersion(helloMajorVersion, helloMinorVersion);
                    }
                }
                default: return unknownVersion(helloMajorVersion, helloMinorVersion);
            }
        }

        @Override
        public String getRecordVersion() {
            return recordVersion;
        }

        @Override
        public String getHelloVersion() {
            return helloVersion;
        }

        @Override
        public List<SNIServerName> getSNIServerNames() {
            return Collections.unmodifiableList(sniNames);
        }

        @Override
        public List<String> getProtocols() {
            return Collections.unmodifiableList(alpnProtocols);
        }

        @Override
        public List<String> getCipherSuites() {
            return Collections.unmodifiableList(ciphers);
        }

        private static String unknownVersion(byte major, byte minor) {
            return "Unknown-" + (major & 0xff) + "." + (minor & 0xff);
        }
    }
}

