/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.pooling;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.LambdaBlock;
import ai.djl.nn.pooling.PoolingConvention;

public final class Pool {
    private Pool() {
    }

    private static NDArray maxPool(NDArray data, Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        return data.getNDArrayInternal().maxPool(kernel, stride, pad, poolingConvention);
    }

    private static NDList maxPool(NDList list, Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        return new NDList(Pool.maxPool(list.singletonOrThrow(), kernel, stride, pad, poolingConvention));
    }

    private static NDArray globalMaxPool(NDArray data) {
        return data.getNDArrayInternal().globalMaxPool();
    }

    private static NDArray avgPool(NDArray data, Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, boolean countIncludePad) {
        return data.getNDArrayInternal().avgPool(kernel, stride, pad, poolingConvention, countIncludePad);
    }

    public static NDArray avgPool(NDArray data, Shape kernel, Shape stride, Shape pad) {
        return Pool.avgPool(data, kernel, stride, pad, PoolingConvention.VALID, true);
    }

    private static NDList avgPool(NDList list, Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, boolean countIncludePad) {
        return new NDList(Pool.avgPool(list.singletonOrThrow(), kernel, stride, pad, poolingConvention, countIncludePad));
    }

    private static NDArray globalAvgPool(NDArray data) {
        return data.getNDArrayInternal().globalAvgPool();
    }

    private static NDArray lpPool(NDArray data, Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, int pValue) {
        return data.getNDArrayInternal().lpPool(kernel, stride, pad, poolingConvention, pValue);
    }

    private static NDList lpPool(NDList list, Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, int pValue) {
        return new NDList(Pool.lpPool(list.singletonOrThrow(), kernel, stride, pad, poolingConvention, pValue));
    }

    private static NDArray globalLpPool(NDArray data, int pValue) {
        return data.getNDArrayInternal().globalLpPool(pValue);
    }

    public static Block maxPool1DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        if (kernel == null) {
            throw new IllegalArgumentException("Kernel cannot be null for maxPool1DBlock Block");
        }
        if (kernel.dimension() != 1 || stride.dimension() != 1 || pad.dimension() != 1) {
            throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for maxPool1DBlock layer should be 1");
        }
        return new LambdaBlock(ndList -> Pool.maxPool(ndList, kernel, stride, pad, poolingConvention));
    }

    public static Block maxPool1DBlock(Shape kernel, Shape stride, Shape pad) {
        return Pool.maxPool1DBlock(kernel, stride, pad, PoolingConvention.VALID);
    }

    public static Block maxPool1DBlock(Shape kernel, Shape stride) {
        return Pool.maxPool1DBlock(kernel, stride, new Shape(0L), PoolingConvention.VALID);
    }

    public static Block maxPool1DBlock(Shape kernel) {
        return Pool.maxPool1DBlock(kernel, kernel, new Shape(0L), PoolingConvention.VALID);
    }

    public static Block maxPool2DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        if (kernel == null) {
            throw new IllegalArgumentException("Kernel cannot be null for maxPool2DBlock Block");
        }
        if (kernel.dimension() != 2 || stride.dimension() != 2 || pad.dimension() != 2) {
            throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for maxPool2DBlock layer should be 2");
        }
        return new LambdaBlock(ndList -> Pool.maxPool(ndList, kernel, stride, pad, poolingConvention));
    }

    public static Block maxPool2DBlock(Shape kernel, Shape stride, Shape pad) {
        return Pool.maxPool2DBlock(kernel, stride, pad, PoolingConvention.VALID);
    }

    public static Block maxPool2DBlock(Shape kernel, Shape stride) {
        return Pool.maxPool2DBlock(kernel, stride, new Shape(0L, 0L), PoolingConvention.VALID);
    }

    public static Block maxPool2DBlock(Shape kernel) {
        return Pool.maxPool2DBlock(kernel, kernel, new Shape(0L, 0L), PoolingConvention.VALID);
    }

    public static Block maxPool3DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        if (kernel == null) {
            throw new IllegalArgumentException("Kernel cannot be null for maxPool3DBlock Block");
        }
        if (kernel.dimension() != 3 || stride.dimension() != 3 || pad.dimension() != 3) {
            throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for maxPool3DBlock layer should be 3");
        }
        return new LambdaBlock(ndList -> Pool.maxPool(ndList, kernel, stride, pad, poolingConvention));
    }

    public static Block maxPool3DBlock(Shape kernel, Shape stride, Shape pad) {
        return Pool.maxPool3DBlock(kernel, stride, pad, PoolingConvention.VALID);
    }

    public static Block maxPool3DBlock(Shape kernel, Shape stride) {
        return Pool.maxPool3DBlock(kernel, stride, new Shape(0L, 0L, 0L), PoolingConvention.VALID);
    }

    public static Block maxPool3DBlock(Shape kernel) {
        return Pool.maxPool3DBlock(kernel, kernel, new Shape(0L, 0L, 0L), PoolingConvention.VALID);
    }

    public static Block globalMaxPool1DBlock() {
        return new LambdaBlock(ndList -> new NDList(Pool.globalMaxPool(ndList.singletonOrThrow())));
    }

    public static Block globalMaxPool2DBlock() {
        return new LambdaBlock(ndList -> new NDList(Pool.globalMaxPool(ndList.singletonOrThrow())));
    }

    public static Block globalMaxPool3DBlock() {
        return new LambdaBlock(ndList -> new NDList(Pool.globalMaxPool(ndList.singletonOrThrow())));
    }

    public static Block avgPool1DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, boolean countIncludePad) {
        if (kernel == null) {
            throw new IllegalArgumentException("Kernel cannot be null for avgPool1DBlock Block");
        }
        if (kernel.dimension() != 1 || stride.dimension() != 1 || pad.dimension() != 1) {
            throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for avgPool1DBlock layer should be 1");
        }
        return new LambdaBlock(ndList -> Pool.avgPool(ndList, kernel, stride, pad, poolingConvention, countIncludePad));
    }

    public static Block avgPool1DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        return Pool.avgPool1DBlock(kernel, stride, pad, poolingConvention, true);
    }

    public static Block avgPool1DBlock(Shape kernel, Shape stride, Shape pad) {
        return Pool.avgPool1DBlock(kernel, stride, pad, PoolingConvention.VALID, true);
    }

    public static Block avgPool1DBlock(Shape kernel, Shape stride) {
        return Pool.avgPool1DBlock(kernel, stride, new Shape(0L), PoolingConvention.VALID, true);
    }

    public static Block avgPool1DBlock(Shape kernel) {
        return Pool.avgPool1DBlock(kernel, kernel, new Shape(0L), PoolingConvention.VALID, true);
    }

    public static Block avgPool2DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, boolean countIncludePad) {
        if (kernel == null) {
            throw new IllegalArgumentException("Kernel cannot be null for avgPool3DBlock Block");
        }
        if (kernel.dimension() != 2 || stride.dimension() != 2 || pad.dimension() != 2) {
            throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for avgPool2DBlock layer should be 2");
        }
        return new LambdaBlock(ndList -> Pool.avgPool(ndList, kernel, stride, pad, poolingConvention, countIncludePad));
    }

    public static Block avgPool2DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        return Pool.avgPool2DBlock(kernel, stride, pad, poolingConvention, true);
    }

    public static Block avgPool2DBlock(Shape kernel, Shape stride, Shape pad) {
        return Pool.avgPool2DBlock(kernel, stride, pad, PoolingConvention.VALID, true);
    }

    public static Block avgPool2DBlock(Shape kernel, Shape stride) {
        return Pool.avgPool2DBlock(kernel, stride, new Shape(0L, 0L), PoolingConvention.VALID, true);
    }

    public static Block avgPool2DBlock(Shape kernel) {
        return Pool.avgPool2DBlock(kernel, kernel, new Shape(0L, 0L), PoolingConvention.VALID, true);
    }

    public static Block avgPool3DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, boolean countIncludePad) {
        if (kernel == null) {
            throw new IllegalArgumentException("Kernel cannot be null for avgPool3DBlock Block");
        }
        if (kernel.dimension() != 3 || stride.dimension() != 3 || pad.dimension() != 3) {
            throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for avgPool3DBlock layer should be 3");
        }
        return new LambdaBlock(ndList -> Pool.avgPool(ndList, kernel, stride, pad, poolingConvention, countIncludePad));
    }

    public static Block avgPool3DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention) {
        return Pool.avgPool3DBlock(kernel, stride, pad, poolingConvention, true);
    }

    public static Block avgPool3DBlock(Shape kernel, Shape stride, Shape pad) {
        return Pool.avgPool3DBlock(kernel, stride, pad, PoolingConvention.VALID, true);
    }

    public static Block avgPool3DBlock(Shape kernel, Shape stride) {
        return Pool.avgPool3DBlock(kernel, stride, new Shape(0L, 0L, 0L), PoolingConvention.VALID, true);
    }

    public static Block avgPool3DBlock(Shape kernel) {
        return Pool.avgPool3DBlock(kernel, kernel, new Shape(0L, 0L, 0L), PoolingConvention.VALID, true);
    }

    public static Block globalAvgPool1DBlock() {
        return new LambdaBlock(ndList -> new NDList(Pool.globalAvgPool(ndList.singletonOrThrow())));
    }

    public static Block globalAvgPool2DBlock() {
        return new LambdaBlock(ndList -> new NDList(Pool.globalAvgPool(ndList.singletonOrThrow())));
    }

    public static Block globalAvgPool3DBlock() {
        return new LambdaBlock(ndList -> new NDList(Pool.globalAvgPool(ndList.singletonOrThrow())));
    }

    public static Block lpPool1DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, int pValue) {
        if (kernel == null) {
            throw new IllegalArgumentException("Kernel cannot be null for lpPool1D Block");
        }
        if (kernel.dimension() != 1 || stride.dimension() != 1 || pad.dimension() != 1) {
            throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for lpPool1D layer should be 1");
        }
        return new LambdaBlock(ndList -> Pool.lpPool(ndList, kernel, stride, pad, poolingConvention, pValue));
    }

    public static Block lpPool1DBlock(Shape kernel, Shape stride, Shape pad, int pValue) {
        return Pool.lpPool1DBlock(kernel, stride, pad, PoolingConvention.VALID, pValue);
    }

    public static Block lpPool1DBlock(Shape kernel, Shape stride, int pValue) {
        return Pool.lpPool1DBlock(kernel, stride, new Shape(0L), PoolingConvention.VALID, pValue);
    }

    public static Block lpPool1DBlock(Shape kernel, int pValue) {
        return Pool.lpPool1DBlock(kernel, kernel, new Shape(0L), PoolingConvention.VALID, pValue);
    }

    public static Block lpPool2DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, int pValue) {
        if (kernel == null) {
            throw new IllegalArgumentException("Kernel cannot be null for lpPool2D Block");
        }
        if (kernel.dimension() != 2 || stride.dimension() != 2 || pad.dimension() != 2) {
            throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for lpPool2D layer should be 2");
        }
        return new LambdaBlock(ndList -> Pool.lpPool(ndList, kernel, stride, pad, poolingConvention, pValue));
    }

    public static Block lpPool2DBlock(Shape kernel, Shape stride, Shape pad, int pValue) {
        return Pool.lpPool2DBlock(kernel, stride, pad, PoolingConvention.VALID, pValue);
    }

    public static Block lpPool2DBlock(Shape kernel, Shape stride, int pValue) {
        return Pool.lpPool2DBlock(kernel, stride, new Shape(0L, 0L), PoolingConvention.VALID, pValue);
    }

    public static Block lpPool2DBlock(Shape kernel, int pValue) {
        return Pool.lpPool2DBlock(kernel, kernel, new Shape(0L, 0L), PoolingConvention.VALID, pValue);
    }

    public static Block lpPool3DBlock(Shape kernel, Shape stride, Shape pad, PoolingConvention poolingConvention, int pValue) {
        if (kernel == null) {
            throw new IllegalArgumentException("Kernel cannot be null for lpPool3D Block");
        }
        if (kernel.dimension() != 3 || stride.dimension() != 3 || pad.dimension() != 3) {
            throw new IllegalArgumentException("Kernel , Stride and Pad dimensions for lpPool3D layer should be 3");
        }
        return new LambdaBlock(ndList -> Pool.lpPool(ndList, kernel, stride, pad, poolingConvention, pValue));
    }

    public static Block lpPool3DBlock(Shape kernel, Shape stride, Shape pad, int pValue) {
        return Pool.lpPool3DBlock(kernel, stride, pad, PoolingConvention.VALID, pValue);
    }

    public static Block lpPool3DBlock(Shape kernel, Shape stride, int pValue) {
        return Pool.lpPool3DBlock(kernel, stride, new Shape(0L, 0L, 0L), PoolingConvention.VALID, pValue);
    }

    public static Block lpPool3DBlock(Shape kernel, int pValue) {
        return Pool.lpPool3DBlock(kernel, kernel, new Shape(0L, 0L, 0L), PoolingConvention.VALID, pValue);
    }

    public static Block globalLpPool1DBlock(int pValue) {
        return new LambdaBlock(ndList -> new NDList(Pool.globalLpPool(ndList.singletonOrThrow(), pValue)));
    }

    public static Block globalLpPool2DBlock(int pValue) {
        return new LambdaBlock(ndList -> new NDList(Pool.globalLpPool(ndList.singletonOrThrow(), pValue)));
    }

    public static Block globalLpPool3DBlock(int pValue) {
        return new LambdaBlock(ndList -> new NDList(Pool.globalLpPool(ndList.singletonOrThrow(), pValue)));
    }
}

