/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.wildfly.security.asn1;

import static org.wildfly.security._private.ElytronMessages.log;
import static org.wildfly.security.asn1.ASN1.*;

import java.io.UnsupportedEncodingException;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.ArrayDeque;
import java.util.NoSuchElementException;

import org.wildfly.security.util.ByteIterator;

/**
 * A class used to decode ASN.1 values that have been encoded using the Distinguished Encoding Rules (DER).
 *
 * @author <a href="mailto:fjuma@redhat.com">Farah Juma</a>
 */
public class DERDecoder implements ASN1Decoder {

    private final ByteIterator bi;
    private final ArrayDeque<DecoderState> states = new ArrayDeque<DecoderState>();
    private int implicitTag = -1;

    /**
     * Create a DER decoder that will decode values from the given byte array.
     *
     * @param buf the byte array to decode
     */
    public DERDecoder(byte[] buf) {
        this.bi = ByteIterator.ofBytes(buf);
    }

    /**
     * Create a DER decoder that will decode values from the given byte array.
     *
     * @param buf the byte array to decode
     * @param offset the offset in the byte array of the first byte to read
     * @param length the maximum number of bytes to read from the byte array
     */
    public DERDecoder(byte[] buf, int offset, int length) {
        this.bi = ByteIterator.ofBytes(buf, offset, length);
    }

    /**
     * Create a DER decoder that will decode values from the given {@code ByteIterator}.
     *
     * @param bi the {@code ByteIterator} from which DER encoded values will be decoded
     */
    public DERDecoder(ByteIterator bi) {
        this.bi = bi;
    }

    @Override
    public void startSequence() throws ASN1Exception {
        readTag(SEQUENCE_TYPE);
        int length = readLength();
        states.add(new DecoderState(SEQUENCE_TYPE, bi.offset() + length));
    }

    @Override
    public void endSequence() throws ASN1Exception {
        DecoderState lastState = states.peekLast();
        if ((lastState == null) || (lastState.getTag() != SEQUENCE_TYPE)) {
            throw log.noSequenceToEnd();
        }
        endConstructedElement(lastState.getNextElementIndex());
        states.removeLast();
    }

    @Override
    public void startSet() throws ASN1Exception {
        readTag(SET_TYPE);
        int length = readLength();
        states.add(new DecoderState(SET_TYPE, bi.offset() + length));
    }

    @Override
    public void endSet() throws ASN1Exception {
        DecoderState lastState = states.peekLast();
        if ((lastState == null) || (lastState.getTag() != SET_TYPE)) {
            throw log.noSetToEnd();
        }
        endConstructedElement(lastState.getNextElementIndex());
        states.removeLast();
    }

    @Override
    public void startSetOf() throws ASN1Exception {
        startSet();
    }

    @Override
    public void endSetOf() throws ASN1Exception {
        endSet();
    }

    @Override
    public void startExplicit(int number) throws ASN1Exception {
        startExplicit(CONTEXT_SPECIFIC_MASK, number);
    }

    @Override
    public void startExplicit(int clazz, int number) throws ASN1Exception {
        int explicitTag = clazz | CONSTRUCTED_MASK | number;
        readTag(explicitTag);
        int length = readLength();
        states.add(new DecoderState(explicitTag, bi.offset() + length));
    }

    @Override
    public void endExplicit() throws ASN1Exception {
        DecoderState lastState = states.peekLast();
        if ((lastState == null) || (lastState.getTag() == SEQUENCE_TYPE)
                || (lastState.getTag() == SET_TYPE) || ((lastState.getTag() & CONSTRUCTED_MASK) == 0)) {
            throw log.noExplicitlyTaggedElementToEnd();
        }
        endConstructedElement(lastState.getNextElementIndex());
        states.removeLast();
    }

    private void endConstructedElement(int nextElementIndex) throws ASN1Exception {
        int pos = bi.offset();
        if (pos < nextElementIndex) {
            // Any elements in this constructed element that have not yet been read will be discarded
            int i;
            for (i = 0; i < (nextElementIndex - pos) && bi.hasNext(); i++) {
                bi.next();
            }
            if (i != (nextElementIndex - pos)) {
                throw log.asnUnexpectedEndOfInput();
            }
        } else if (pos > nextElementIndex) {
            // Shouldn't happen
            throw new IllegalStateException();
        }
    }

    @Override
    public byte[] decodeOctetString() throws ASN1Exception {
        readTag(OCTET_STRING_TYPE);
        int length = readLength();
        byte[] result = new byte[length];
        if ((length != 0) && (bi.drain(result, 0, length) != length)) {
            throw log.asnUnexpectedEndOfInput();
        }
        return result;
    }

    @Override
    public String decodeOctetStringAsString() throws ASN1Exception {
        return decodeOctetStringAsString(StandardCharsets.UTF_8.name());
    }

    @Override
    public String decodeOctetStringAsString(String charSet) throws ASN1Exception {
        readTag(OCTET_STRING_TYPE);
        int length = readLength();
        byte[] octets = new byte[length];
        if ((length != 0) && (bi.drain(octets, 0, length) != length)) {
            throw log.asnUnexpectedEndOfInput();
        }
        try {
            return new String(octets, charSet);
        } catch (UnsupportedEncodingException e) {
            throw new ASN1Exception(e);
        }
    }

    @Override
    public String decodeIA5String() throws ASN1Exception {
        byte[] octets = decodeIA5StringAsBytes();
        return new String(octets, StandardCharsets.US_ASCII);
    }

    @Override
    public byte[] decodeIA5StringAsBytes() throws ASN1Exception {
        readTag(IA5_STRING_TYPE);
        int length = readLength();
        byte[] result = new byte[length];
        if ((length != 0) && (bi.drain(result, 0, length) != length)) {
            throw log.asnUnexpectedEndOfInput();
        }
        return result;
    }

    @Override
    public byte[] decodeBitString() throws ASN1Exception {
        readTag(BIT_STRING_TYPE);
        int length = readLength();
        byte[] result = new byte[length - 1];

        int numUnusedBits = bi.next();
        if (numUnusedBits < 0 || numUnusedBits > 7) {
            throw log.asnInvalidNumberOfUnusedBits();
        }

        if (numUnusedBits == 0) {
            for (int i = 0; i < (length -1); i++) {
                result[i] = (byte) bi.next();
            }
        } else {
            // Any unused bits will be removed
            int leftShift = 8 - numUnusedBits;
            int previous = 0;
            int next;
            for (int i = 0; i < (length -1); i++) {
                next = bi.next();
                if (i == 0) {
                    result[i] = (byte) (next >>> numUnusedBits);
                } else {
                    result[i] = (byte) ((next >>> numUnusedBits) | (previous << leftShift));
                }
                previous = next;
            }
        }
        return result;
    }

    @Override
    public BigInteger decodeBitStringAsInteger() {
        DERDecoder decoder = new DERDecoder(decodeBitString());

        if (decoder.peekType() != INTEGER_TYPE) {
            throw log.asnUnexpectedTag();
        }

        return decoder.decodeInteger();
    }

    @Override
    public String decodeBitStringAsString() throws ASN1Exception {
        readTag(BIT_STRING_TYPE);
        int length = readLength();
        int numUnusedBits = bi.next();
        if (numUnusedBits < 0 || numUnusedBits > 7) {
            throw log.asnInvalidNumberOfUnusedBits();
        }

        int k = 0, next;
        int numBits = (length - 1) * 8 - numUnusedBits;
        StringBuilder result = new StringBuilder(numBits);
        for (int i = 0; i < (length - 1); i++) {
            next = bi.next();
            for (int j = 7; j >= 0 && k < numBits; j--) {
                if ((next & (1 << j)) != 0) {
                    result.append("1");
                } else {
                    result.append("0");
                }
                k += 1;
            }
        }
        return result.toString();
    }

    @Override
    public String decodePrintableString() throws ASN1Exception {
        return new String(decodePrintableStringAsBytes(), StandardCharsets.US_ASCII);
    }

    @Override
    public byte[] decodePrintableStringAsBytes() throws ASN1Exception {
        readTag(PRINTABLE_STRING_TYPE);
        final int length = readLength();
        int c = 0;
        byte[] result = new byte[length];
        while (bi.hasNext() && c < length) {
            final int b = bi.next();
            validatePrintableByte(b);
            result[c++] = (byte) b;
        }
        if (c < length) {
            throw log.asnUnexpectedEndOfInput();
        }
        return result;
    }

    @Override
    public String decodeObjectIdentifier() throws ASN1Exception {
        readTag(OBJECT_IDENTIFIER_TYPE);
        int length = readLength();
        int octet;
        long value = 0;
        BigInteger bigInt = null;
        boolean processedFirst = false;
        StringBuilder objectIdentifierStr = new StringBuilder();

        for (int i = 0; i < length; i++) {
            octet = bi.next();
            if (value < 0x80000000000000L) {
                value = (value << 7) + (octet & 0x7f);
                if ((octet & 0x80) == 0) {
                    // Reached the end of a component value
                    if (!processedFirst) {
                        int first = ((int) value / 40);
                        if (first == 0) {
                            objectIdentifierStr.append("0");
                        } else if (first == 1) {
                            value = value - 40;
                            objectIdentifierStr.append("1");
                        } else if (first == 2) {
                            value = value - 80;
                            objectIdentifierStr.append("2");
                        }
                        processedFirst = true;
                    }
                    objectIdentifierStr.append('.');
                    objectIdentifierStr.append(value);

                    // Reset for the next component value
                    value = 0;
                }
            } else {
                if (bigInt == null) {
                    bigInt = BigInteger.valueOf(value);
                }
                bigInt = bigInt.shiftLeft(7).add(BigInteger.valueOf(octet & 0x7f));
                if ((octet & 0x80) == 0) {
                    // Reached the end of a component value
                    objectIdentifierStr.append('.');
                    objectIdentifierStr.append(bigInt);

                    // Reset for the next component value
                    bigInt = null;
                    value = 0;
                }
            }
        }
        return objectIdentifierStr.toString();
    }

    @Override
    public BigInteger decodeInteger() throws ASN1Exception {
        if (INTEGER_TYPE != peekType()) {
            throw log.asnUnexpectedTag();
        }

        return new BigInteger(drainElementValue());
    }

    @Override
    public void decodeNull() throws ASN1Exception {
        readTag(NULL_TYPE);
        int length = readLength();
        if (length != 0) {
            throw log.asnNonZeroLengthForNullTypeTag();
        }
    }

    @Override
    public void decodeImplicit(int number) {
        decodeImplicit(CONTEXT_SPECIFIC_MASK, number);
    }

    @Override
    public void decodeImplicit(int clazz, int number) {
        if (implicitTag == -1) {
            implicitTag = clazz | number;
        }
    }

    @Override
    public boolean isNextType(int clazz, int number, boolean isConstructed) {
        try {
            return peekType() == (clazz | (isConstructed ? CONSTRUCTED_MASK : 0x00) | number);
        } catch (ASN1Exception e) {
            return false;
        }
    }

    @Override
    public int peekType() throws ASN1Exception {
        int currOffset = bi.offset();
        int tag = readTag();
        while ((bi.offset() != currOffset) && bi.hasPrev()) {
            bi.prev();
        }
        return tag;
    }

    @Override
    public void skipElement() throws ASN1Exception {
        readTag();
        int length = readLength();
        int i;
        for (i = 0; i < length && bi.hasNext(); i++) {
            bi.next();
        }
        if (i != length) {
            throw log.asnUnexpectedEndOfInput();
        }
    }

    @Override
    public boolean hasNextElement() {
        DecoderState lastState = states.peekLast();
        boolean hasNext;
        if (lastState != null) {
            hasNext = ((bi.offset() < lastState.getNextElementIndex()) && hasCompleteElement());
        } else {
            hasNext = hasCompleteElement();
        }
        return hasNext;
    }

    private  boolean hasCompleteElement() {
        boolean hasNext;
        int currOffset = bi.offset();
        try {
            readTag();
            int length = readLength();
            int i;
            for (i = 0; (i < length) && bi.hasNext(); i++) {
                bi.next();
            }
            hasNext = (i == length);
        } catch (ASN1Exception e) {
            hasNext = false;
        }
        while ((bi.offset() != currOffset) && bi.hasPrev()) {
            bi.prev();
        }
        return hasNext;
    }

    @Override
    public byte[] drainElementValue() throws ASN1Exception {
        if (implicitTag != -1) {
            implicitTag = -1;
        }
        readTag();
        int length = readLength();
        byte[] value = new byte[length];
        if ((length != 0) && (bi.drain(value) != length)) {
            throw log.asnUnexpectedEndOfInput();
        }
        return value;
    }

    @Override
    public byte[] drainElement() throws ASN1Exception {
        if (implicitTag != -1) {
            implicitTag = -1;
        }
        int currOffset = bi.offset();
        readTag();
        int valueLength = readLength();
        int length = (bi.offset() - currOffset) + valueLength;
        while ((bi.offset() != currOffset) && bi.hasPrev()) {
            bi.prev();
        }
        byte[] element = new byte[length];
        if ((length != 0) && (bi.drain(element) != length)) {
            throw log.asnUnexpectedEndOfInput();
        }
        return element;
    }

    private int readTag() throws ASN1Exception {
        try {
            int tag = bi.next();
            int tagClass = tag & CLASS_MASK;
            int constructed = tag & CONSTRUCTED_MASK;
            int tagNumber = tag & TAG_NUMBER_MASK;
            if (tagNumber == 0x1f) {
                // High-tag-number form
                tagNumber = 0;
                int octet = bi.next();
                if ((octet & 0x7f) == 0) {
                    // Bits 7 to 1 of the first subsequent octet cannot be 0
                    throw log.asnInvalidHighTagNumberForm();
                }
                while ((octet >= 0) && ((octet & 0x80) != 0)) {
                    tagNumber |= (octet & 0x7f);
                    tagNumber <<= 7;
                    octet = bi.next();
                }
                tagNumber |= (octet & 0x7f);
            }
            return (tagClass | constructed | tagNumber);
        } catch (NoSuchElementException e) {
            throw log.asnUnexpectedEndOfInput();
        }
    }

    private void readTag(int expectedTag) throws ASN1Exception {
        if (implicitTag != -1) {
            expectedTag = implicitTag | (expectedTag & CONSTRUCTED_MASK);
            implicitTag = -1;
        }
        int currOffset = bi.offset();
        int actualTag = readTag();
        if (actualTag != expectedTag) {
            while ((bi.offset() != currOffset) && bi.hasPrev()) {
                bi.prev();
            }
            throw log.asnUnexpectedTag();
        }
    }

    private int readLength() throws ASN1Exception {
        try {
            int length = bi.next();
            if (length > 127) {
                // Long form
                int numOctets = length & 0x7f;
                if (numOctets > 4) {
                    throw log.asnLengthEncodingExceeds4bytes();
                }
                length = 0;
                int nextOctet;
                for (int i = 0; i < numOctets; i++) {
                    nextOctet = bi.next();
                    length = (length << 8) + nextOctet;
                }
            }
            return length;
        } catch (NoSuchElementException e) {
            throw log.asnUnexpectedEndOfInput();
        }
    }

    /**
     * Decodes an OID and resolve its corresponding key algorithm.
     *
     * @return the key algorithm associated with the OID or null if no algorithm could be resolved
     */
    public String decodeObjectIdentifierAsKeyAlgorithm() {
        return keyAlgorithmFromOid(decodeObjectIdentifier());
    }

    /**
     * A class used to maintain state information during DER decoding.
     */
    private class DecoderState {
        // Tag number for a constructed element
        private final int tag;

        // The position of the first character in the encoded buffer that occurs after
        // the encoding of the constructed element
        private final int nextElementIndex;

        public DecoderState(int tag, int nextElementIndex) {
            this.tag = tag;
            this.nextElementIndex = nextElementIndex;
        }

        public int getTag() {
            return tag;
        }

        public int getNextElementIndex() {
            return nextElementIndex;
        }
    }
}
