package name.remal.gradle_plugins.dsl.utils

import name.remal.ASM_API
import org.objectweb.asm.ClassVisitor
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.INVOKESTATIC
import org.objectweb.asm.Type
import org.objectweb.asm.Type.OBJECT
import org.objectweb.asm.Type.getArgumentTypes
import org.objectweb.asm.Type.getInternalName
import org.objectweb.asm.Type.getMethodDescriptor
import org.objectweb.asm.Type.getType
import org.objectweb.asm.tree.AbstractInsnNode
import org.objectweb.asm.tree.LdcInsnNode
import org.objectweb.asm.tree.MethodInsnNode
import org.objectweb.asm.tree.MethodNode
import org.objectweb.asm.tree.analysis.Analyzer
import org.objectweb.asm.tree.analysis.BasicInterpreter
import org.objectweb.asm.tree.analysis.BasicValue
import org.objectweb.asm.tree.analysis.Interpreter
import java.util.ServiceLoader

class UsedServicesCollectorClassVisitor(api: Int, delegate: ClassVisitor?) : ClassVisitor(api, delegate) {

    companion object {
        private val logger = getGradleLogger(UsedServicesCollectorClassVisitor::class.java)
    }

    constructor(delegate: ClassVisitor?) : this(ASM_API, delegate)

    private lateinit var internalClassName: String
    override fun visit(version: Int, access: Int, name: String, signature: String?, superName: String?, interfaces: Array<String>?) {
        internalClassName = name
        super.visit(version, access, name, signature, superName, interfaces)
    }


    val usedServiceInternalClassNames: Set<String> get() = _usedServiceInternalClassNames.toSortedSet()
    private val _usedServiceInternalClassNames = hashSetOf<String>()

    override fun visitMethod(access: Int, name: String, descriptor: String, signature: String?, exceptions: Array<String>?): MethodVisitor? {
        val delegate = super.visitMethod(access, name, descriptor, signature, exceptions)

        return object : MethodNode(api, access, name, descriptor, signature, exceptions) {
            override fun visitEnd() {
                super.visitEnd()

                val interpreter = UsedServicesInterpreter(api, internalClassName, this)
                val analyzer = Analyzer(interpreter)
                analyzer.analyze(internalClassName, this)

                delegate?.let(this::accept)
            }
        }
    }

    @Suppress("NULLABILITY_MISMATCH_BASED_ON_JAVA_ANNOTATIONS", "HasPlatformType")
    private inner class UsedServicesInterpreter(api: Int, private val owner: String, private val methodNode: MethodNode) : Interpreter<BasicValue>(api) {
        private val basicInterpreter = BasicInterpreter()

        override fun newOperation(insn: AbstractInsnNode): BasicValue? {
            if (insn is LdcInsnNode) {
                val value = insn.cst
                if (value is Type && value.sort == OBJECT) {
                    return ClassValue(value)
                }
            }

            return basicInterpreter.newOperation(insn)
        }

        override fun naryOperation(insn: AbstractInsnNode, values: List<BasicValue?>): BasicValue? {
            if (insn is MethodInsnNode && insn.opcode == INVOKESTATIC) {
                if (insn.owner == getInternalName(ServiceLoader::class.java) && insn.name.startsWith("load")) {
                    val method = ServiceLoader::class.java.methods.firstOrNull { it.name == insn.name && getMethodDescriptor(it) == insn.desc }
                    if (method == null) {
                        logger.warn(
                            "Method not found in current JRE: {}.{}{}",
                            getInternalName(ServiceLoader::class.java),
                            insn.name,
                            insn.desc
                        )
                        return basicInterpreter.naryOperation(insn, values)
                    }

                    val valuesIndex = method.parameterTypes.indexOfFirst { it == Class::class.java }
                    val value = values[valuesIndex]
                    if (value is ClassValue) {
                        _usedServiceInternalClassNames.add(value.clazz.internalName)

                    } else {
                        logger.warn(
                            "Not a constant class value passed to ServiceLoader.{} method: {}.{}({})",
                            insn.name,
                            owner.replace('/', '.'),
                            methodNode.name,
                            getArgumentTypes(methodNode.desc).map(Type::getClassName).joinToString(", ")
                        )
                    }
                }
            }

            return basicInterpreter.naryOperation(insn, values)
        }

        override fun newValue(type: Type?) = basicInterpreter.newValue(type)
        override fun ternaryOperation(insn: AbstractInsnNode, value1: BasicValue?, value2: BasicValue?, value3: BasicValue?) = basicInterpreter.ternaryOperation(insn, value1, value2, value3)
        override fun merge(value1: BasicValue?, value2: BasicValue?) = basicInterpreter.merge(value1, value2)
        override fun returnOperation(insn: AbstractInsnNode, value: BasicValue?, expected: BasicValue?) = basicInterpreter.returnOperation(insn, value, expected)
        override fun unaryOperation(insn: AbstractInsnNode, value: BasicValue?) = basicInterpreter.unaryOperation(insn, value)
        override fun binaryOperation(insn: AbstractInsnNode, value1: BasicValue?, value2: BasicValue?) = basicInterpreter.binaryOperation(insn, value1, value2)
        override fun copyOperation(insn: AbstractInsnNode, value: BasicValue?) = basicInterpreter.copyOperation(insn, value)
    }

    private class ClassValue(val clazz: Type) : BasicValue(getType(Class::class.java))

}
