package com.vaadin.copilot.javarewriter;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import com.vaadin.copilot.IdentityHashSet;

import com.github.javaparser.Range;
import com.github.javaparser.ast.CompilationUnit;
import com.github.javaparser.ast.ImportDeclaration;
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.BodyDeclaration;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.TypeDeclaration;
import com.github.javaparser.ast.comments.Comment;
import com.github.javaparser.ast.nodeTypes.NodeWithStatements;
import com.github.javaparser.ast.stmt.BlockStmt;
import com.github.javaparser.ast.stmt.Statement;

/**
 * Handles merging of the original source file with changes perform to the AST.
 * Outputs the original parts of the source file as they were and tries it best
 * to make the new changes fit in.
 */
public class JavaRewriterMerger {

    private record AddChange(LineColumn lineColumn, String text) implements Comparable<AddChange> {

        @Override
        public int compareTo(AddChange o) {
            return Integer.compare(lineColumn.line, o.lineColumn.line);
        }

    }

    private record LineColumn(int line, int column) {

        public boolean before(LineColumn endOfFile) {
            if (this.line < endOfFile.line) {
                return true;
            }
            return this.line == endOfFile.line && this.column < endOfFile.column;
        }

    }

    /**
     * Applies the given changes to the given source and returns the result.
     *
     * @param addedOrModifiedNodes
     *            the nodes that have been added or modified and need to be written
     *            to the file
     * @param removedRanges
     *            the ranges that have been removed from the original source
     * @param source
     *            the original source
     */
    public static String apply(Set<Node> addedOrModifiedNodes, Set<Range> removedRanges, String source) {
        Set<Statement> addedStatements = new IdentityHashSet<>();
        Set<BodyDeclaration<?>> addedDeclarations = new HashSet<>();
        Set<ImportDeclaration> addedImports = new HashSet<>();
        Set<Comment> addedComments = new HashSet<>();

        List<Node> addedOrModified = new ArrayList<>(addedOrModifiedNodes);
        addedOrModified.removeIf(node -> {
            if (node instanceof Comment comment) {
                addedComments.add(comment);
                return true;
            }

            if (JavaRewriterUtil.findAncestor(node, CompilationUnit.class) == null) {
                // This probably means code is inefficient as it adds nodes that
                // are not used in the end
                return true;
            }
            Statement statement = JavaRewriterUtil.findAncestor(node, Statement.class);
            if (statement != null) {
                addedStatements.add(statement);
                return true;
            }
            BodyDeclaration<?> recordDeclaration = JavaRewriterUtil.findAncestor(node, BodyDeclaration.class);
            if (recordDeclaration != null) {
                addedDeclarations.add(recordDeclaration);
                return true;
            }

            ImportDeclaration importDeclaration = JavaRewriterUtil.findAncestor(node, ImportDeclaration.class);
            if (importDeclaration != null) {
                addedImports.add(importDeclaration);
                return true;
            }

            return false;
        });

        Set<Integer> removedLines = removedRanges.stream().flatMap(JavaRewriterMerger::rangeToLines)
                .collect(Collectors.toSet());

        List<AddChange> addChanges = collectChanges(addedStatements, addedDeclarations, addedImports, addedComments)
                .sorted().toList();

        return applyChanges(source, addChanges, removedLines);
    }

    @SafeVarargs
    private static Stream<AddChange> collectChanges(Set<? extends Node>... collections) {
        Stream<AddChange> streams = Stream.empty();

        for (Set<? extends Node> collection : collections) {
            Set<List<Node>> mergedNodes = mergeConsecutiveNodes(collection);
            Stream<AddChange> stream = (mergedNodes.stream().map(addedNode -> {
                LineColumn insertAt = findInsertPosition(addedNode.get(0));
                return new AddChange(insertAt,
                        addedNode.stream().map(Object::toString).collect(Collectors.joining("\n")));
            }));
            streams = Stream.concat(streams, stream);
        }
        return streams;
    }

    private static Stream<Integer> rangeToLines(Range range) {
        return IntStream.range(range.begin.line - 1, range.end.line - 1 + 1).boxed();
    }

    private static Set<List<Node>> mergeConsecutiveNodes(Set<? extends Node> nodes) {
        Set<List<Node>> merged = new IdentityHashSet<>();
        while (!nodes.isEmpty()) {
            Node statement = nodes.iterator().next();
            List<Node> list = getConsecutiveNodes(statement, nodes);
            merged.add(list);
        }

        return merged;
    }

    private static List<Node> getConsecutiveNodes(Node statement, Set<? extends Node> nodes) {
        List<Node> list = new ArrayList<>();
        list.add(statement);
        nodes.remove(statement);

        Optional<Node> next = getNextSibling(statement);
        Optional<Node> prev = getPreviousSibling(statement);
        if (prev.isPresent() && nodes.contains(prev.get())) {
            list.addAll(0, getConsecutiveNodes(prev.get(), nodes));
        }
        if (next.isPresent() && nodes.contains(next.get())) {
            list.addAll(getConsecutiveNodes(next.get(), nodes));
        }
        return list;
    }

    private static String applyChanges(String source, List<AddChange> addChanges, Set<Integer> removedLines) {
        StringBuilder result = new StringBuilder();

        String[] sourceLines = source.split("\n");

        LineColumn nextFromOriginal = new LineColumn(0, 0);
        for (AddChange addChange : addChanges) {
            LineColumn insertAt = addChange.lineColumn();
            writeLines(nextFromOriginal, insertAt, sourceLines, result, removedLines);
            result.append(addChange.text.replace("DELETE_THIS", "")).append("\n");
            nextFromOriginal = insertAt;
        }

        LineColumn endOfFile = new LineColumn(sourceLines.length - 1, sourceLines[sourceLines.length - 1].length());
        if (nextFromOriginal.before(endOfFile)) {
            writeLines(nextFromOriginal, endOfFile, sourceLines, result, removedLines);
        }
        return result.toString();
    }

    private static void writeLines(LineColumn startInclusive, LineColumn endInclusive, String[] from, StringBuilder to,
            Set<Integer> removedLines) {
        int firstFullLine = startInclusive.line();
        if (startInclusive.column() != 0 && !removedLines.contains(startInclusive.line())) {
            firstFullLine++;
            // Write end of the first line
            to.append(from[startInclusive.line()].substring(startInclusive.column()));
        }
        // Full lines
        for (int i = firstFullLine; i < endInclusive.line(); i++) {
            if (!removedLines.contains(i)) {
                to.append(from[i]).append("\n");
            }
        }
        // Maybe partial last line
        if (from.length > endInclusive.line() && endInclusive.column() > 0
                && !removedLines.contains(endInclusive.line())) {
            to.append(from[endInclusive.line()], 0, endInclusive.column()).append("\n");
        }
    }

    /**
     * Gets the zero-based line number to insert the node at.
     *
     * @param nodeToInsert
     *            the node to insert
     */
    private static LineColumn findInsertPosition(Node nodeToInsert) {
        Optional<Node> maybePrev = getPreviousSibling(nodeToInsert);
        Optional<Node> maybeParent = nodeToInsert.getParentNode();

        if (maybeParent.isPresent()) {
            Node parent = maybeParent.get();
            if (maybePrev.isEmpty() || nodeToInsert instanceof ImportDeclaration) {
                // Add where the first child is
                Optional<Range> firstChildFromSource = getOrderedChildCollection(nodeToInsert, parent).stream()
                        .filter(child -> child.getRange().isPresent()).map(child -> child.getRange().get()).findFirst();
                if (firstChildFromSource.isPresent()) {
                    // Before the first child (source lines are 1-based)
                    return new LineColumn(firstChildFromSource.get().begin.line - 1, 0);
                } else {
                    // No children in the source
                    if (parent instanceof ClassOrInterfaceDeclaration classOrInterfaceDeclaration) {
                        Optional<Range> range = classOrInterfaceDeclaration.getRange();
                        if (range.isPresent()) {
                            // Before the ending }
                            return new LineColumn(range.get().end.line - 1, range.get().end.column - 1);
                        }
                    } else if (parent instanceof BlockStmt blockStmt) {
                        // Empty constructor or method
                        Optional<Range> range = blockStmt.getRange();
                        if (range.isPresent()) {
                            // Before the ending }
                            return new LineColumn(range.get().end.line - 1, range.get().end.column - 1);
                        }
                    } else if (nodeToInsert instanceof ImportDeclaration
                            && parent instanceof CompilationUnit compilationUnit) {
                        // No existing imports
                        NodeList<TypeDeclaration<?>> types = compilationUnit.getTypes();
                        if (!types.isEmpty()) {
                            Optional<Range> range = types.get(0).getRange();
                            if (range.isPresent()) {
                                // Before the first type declaration
                                return new LineColumn(range.get().begin.line - 1, 0);
                            }
                        } else {
                            // Assume the first row in the file, assumes this is after the package
                            // declaration
                            return new LineColumn(1, 0);
                        }
                    }
                }
                throw new IllegalArgumentException("Unclear where to add code for " + nodeToInsert);
            }
        }

        if (maybePrev.isPresent()) {
            Node prev = maybePrev.get();
            Optional<Range> prevRange = prev.getRange();
            if (prevRange.isPresent()) {
                // If the previous statement ends on line 12 (1-based) we should
                // insert the new statement on line 13 (1-based) or line 12 (0-based)
                return new LineColumn(prevRange.get().end.line, 0);
            }
        }
        if (nodeToInsert instanceof Comment comment) {
            Optional<Node> commentedNode = comment.getCommentedNode();
            if (commentedNode.isPresent()) {
                return findInsertPosition(commentedNode.get());
            }
        }

        // At the beginning of the file, most often wrong
        return new LineColumn(0, 0);
    }

    private static Optional<Node> getPreviousSibling(Node n) {
        return getSibling(n, -1);
    }

    private static Optional<Node> getNextSibling(Node n) {
        return getSibling(n, 1);
    }

    private static Optional<Node> getSibling(Node n, int i) {
        return n.getParentNode().flatMap(parent -> getSibling(getOrderedChildCollection(n, parent), n, i));
    }

    private static List<? extends Node> getOrderedChildCollection(Node n, Node parent) {
        if (n instanceof Statement && parent instanceof NodeWithStatements<?> nodeWithStatements) {
            return nodeWithStatements.getStatements();
        }
        if (parent instanceof ClassOrInterfaceDeclaration classOrInterfaceDeclaration) {
            return classOrInterfaceDeclaration.getMembers();
        }
        if (parent instanceof CompilationUnit compilationUnit && n instanceof ImportDeclaration) {
            return compilationUnit.getImports();
        }

        return parent.getChildNodes();
    }

    private static Optional<Node> getSibling(List<? extends Node> nodes, Node n, int i) {
        int index = -1;
        for (int k = 0; k < nodes.size(); k++) {
            if (nodes.get(k) == n) {
                index = k;
                break;
            }
        }
        if (index == -1) {
            throw new IllegalArgumentException("Node not found in parent");
        }
        int siblingIndex = index + i;
        if (siblingIndex >= 0 && siblingIndex < nodes.size()) {
            return Optional.of(nodes.get(siblingIndex));
        }

        return Optional.empty();
    }
}
