/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.convolutional;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public abstract class Convolution
extends ParameterBlock {
    private static final byte VERSION = 2;
    protected Shape kernel;
    protected Shape stride;
    protected Shape pad;
    protected Shape dilate;
    protected int numFilters;
    protected int numGroups;
    protected boolean includeBias;
    protected Parameter weight;
    protected Parameter bias;

    public Convolution(ConvolutionBuilder<?> builder) {
        this.kernel = builder.kernel;
        this.stride = builder.stride;
        this.pad = builder.pad;
        this.dilate = builder.dilate;
        this.numFilters = builder.numFilters;
        this.numGroups = builder.numGroups;
        this.includeBias = builder.includeBias;
        this.weight = new Parameter("weight", this, ParameterType.WEIGHT);
        if (this.includeBias) {
            this.bias = new Parameter("bias", this, ParameterType.BIAS);
        }
    }

    protected abstract LayoutType[] getExpectedLayout();

    protected abstract String getStringLayout();

    protected abstract int numDimensions();

    @Override
    public NDList forward(ParameterStore parameterStore, NDList inputs, PairList<String, Object> params) {
        inputs = this.opInputs(parameterStore, inputs);
        NDArrayEx ex = inputs.head().getNDArrayInternal();
        return ex.convolution(inputs, this.kernel, this.stride, this.pad, this.dilate, this.numFilters, this.numGroups, this.getStringLayout(), !this.includeBias, params);
    }

    @Override
    protected void beforeInitialize(Shape[] inputs) {
        this.inputShapes = inputs;
        Shape inputShape = inputs[0];
        Block.validateLayout(this.getExpectedLayout(), inputShape.getLayout());
    }

    @Override
    public Shape[] getOutputShapes(NDManager manager, Shape[] inputs) {
        long[] shape = new long[this.numDimensions()];
        shape[0] = inputs[0].get(0);
        shape[1] = this.numFilters;
        for (int i = 0; i < this.numDimensions() - 2; ++i) {
            shape[2 + i] = (inputs[0].get(2 + i) + 2L * this.pad.get(i) - this.dilate.get(0) * (this.kernel.get(i) - 1L) - 1L) / this.stride.get(0) + 1L;
        }
        return new Shape[]{new Shape(shape)};
    }

    @Override
    public Shape getParameterShape(String name, Shape[] inputShapes) {
        Shape shape = inputShapes[0];
        switch (name) {
            case "weight": {
                return new Shape(this.numFilters, shape.get(1)).addAll(this.kernel);
            }
            case "bias": {
                return new Shape(this.numFilters);
            }
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    @Override
    public List<Parameter> getDirectParameters() {
        ArrayList<Parameter> parameters = new ArrayList<Parameter>();
        parameters.add(this.weight);
        if (this.includeBias) {
            parameters.add(this.bias);
        }
        return parameters;
    }

    @Override
    public void saveParameters(DataOutputStream os) throws IOException {
        os.writeByte(2);
        this.saveInputShapes(os);
        this.weight.save(os);
        if (this.bias != null) {
            this.bias.save(os);
        }
    }

    @Override
    public void loadParameters(NDManager manager, DataInputStream is) throws IOException, MalformedModelException {
        byte version = is.readByte();
        if (version == 2) {
            this.readInputShapes(is);
        } else if (version != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + version);
        }
        this.weight.load(manager, is);
        if (this.bias != null) {
            this.bias.load(manager, is);
        }
    }

    private NDList opInputs(ParameterStore parameterStore, NDList inputs) {
        NDArray data = inputs.singletonOrThrow();
        Device device = data.getDevice();
        NDList ret = new NDList(3);
        ret.add(data);
        ret.add(parameterStore.getValue(this.weight, device));
        if (this.bias != null) {
            ret.add(parameterStore.getValue(this.bias, device));
        }
        return ret;
    }

    public static abstract class ConvolutionBuilder<T extends ConvolutionBuilder> {
        protected Shape kernel;
        protected Shape stride;
        protected Shape pad;
        protected Shape dilate;
        protected int numFilters;
        protected int numGroups = 1;
        protected boolean includeBias = true;

        public T setKernel(Shape kernel) {
            this.kernel = kernel;
            return this.self();
        }

        public T optStride(Shape stride) {
            this.stride = stride;
            return this.self();
        }

        public T optPad(Shape pad) {
            this.pad = pad;
            return this.self();
        }

        public T optDilate(Shape dilate) {
            this.dilate = dilate;
            return this.self();
        }

        public T setNumFilters(int numFilters) {
            this.numFilters = numFilters;
            return this.self();
        }

        public T optNumGroups(int numGroups) {
            this.numGroups = numGroups;
            return this.self();
        }

        public T optBias(boolean includeBias) {
            this.includeBias = includeBias;
            return this.self();
        }

        protected void validate() {
            if (this.kernel == null || this.numFilters == 0) {
                throw new IllegalArgumentException("Kernel and numFilters must be set");
            }
        }

        protected abstract T self();
    }
}

