/*
 * Decompiled with CFR 0.152.
 */
package ai.djl;

import ai.djl.engine.Engine;
import ai.djl.util.cuda.CudaUtils;
import java.util.Objects;

public class Device {
    private static final Device CPU = new Device("cpu");
    private static final Device GPU = new Device("gpu", 0);
    private String deviceType;
    private int deviceId;

    public Device(String deviceType) {
        this.deviceType = deviceType;
        this.deviceId = "cpu".equals(deviceType) ? -1 : 0;
    }

    public Device(String deviceType, int deviceId) {
        if ("cpu".equals(deviceType)) {
            throw new IllegalArgumentException("CPU doesn't have device id, please use new Device(\"cpu\") instead");
        }
        this.deviceType = deviceType;
        this.deviceId = deviceId;
    }

    public String getDeviceType() {
        return this.deviceType;
    }

    public int getDeviceId() {
        if ("cpu".equals(this.deviceType)) {
            throw new IllegalStateException("CPU doesn't have device id");
        }
        return this.deviceId;
    }

    public String toString() {
        if ("cpu".equals(this.deviceType)) {
            return this.deviceType + "()";
        }
        return this.deviceType + '(' + this.deviceId + ')';
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Device device = (Device)o;
        if ("cpu".equals(this.deviceType)) {
            return Objects.equals(this.deviceType, device.deviceType);
        }
        return this.deviceId == device.deviceId && Objects.equals(this.deviceType, device.deviceType);
    }

    public int hashCode() {
        return Objects.hash(this.deviceType, this.deviceId);
    }

    public static Device cpu() {
        return CPU;
    }

    public static Device gpu() {
        return GPU;
    }

    public static Device gpu(int deviceId) {
        return new Device("gpu", deviceId);
    }

    public static Device[] getDevices(int maxGpus) {
        int count = Device.getGpuCount();
        if (maxGpus <= 0 || count <= 0) {
            return new Device[]{CPU};
        }
        count = Math.min(maxGpus, count);
        Device[] devices = new Device[count];
        for (int i = 0; i < count; ++i) {
            devices[i] = new Device("gpu", i);
        }
        return devices;
    }

    public static int getGpuCount() {
        if (Engine.getInstance().hasCapability("CUDA")) {
            return CudaUtils.getGpuCount();
        }
        return 0;
    }

    public static int getGpuCount(String engineName) {
        if (Engine.getEngine(engineName).hasCapability("CUDA")) {
            return CudaUtils.getGpuCount();
        }
        return 0;
    }

    public static Device defaultDevice() {
        if (Device.getGpuCount() > 0) {
            return Device.gpu();
        }
        return Device.cpu();
    }

    public static Device defaultIfNull(Device device) {
        if (device != null) {
            return device;
        }
        return Device.defaultDevice();
    }

    public static Device defaultIfNull(Device device, Device def) {
        if (device != null) {
            return device;
        }
        return Device.defaultIfNull(def);
    }

    public static interface Type {
        public static final String CPU = "cpu";
        public static final String GPU = "gpu";
    }
}

