package org.deeplearning4j.spark.parameterserver.python;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:org/deeplearning4j/spark/parameterserver/python/DataSetDescriptor.class */
public class DataSetDescriptor implements Serializable {
    private ArrayDescriptor features;
    private ArrayDescriptor labels;
    private ArrayDescriptor featuresMask;
    private ArrayDescriptor labelsMask;
    private boolean preProcessed;

    public DataSetDescriptor(ArrayDescriptor arrayDescriptor, ArrayDescriptor arrayDescriptor2, ArrayDescriptor arrayDescriptor3, ArrayDescriptor arrayDescriptor4) {
        this.features = arrayDescriptor;
        this.labels = arrayDescriptor2;
        this.featuresMask = arrayDescriptor3;
        this.labelsMask = arrayDescriptor4;
    }

    public DataSetDescriptor(DataSet dataSet) throws Exception {
        this.features = new ArrayDescriptor(dataSet.getFeatures());
        this.labels = new ArrayDescriptor(dataSet.getLabels());
        INDArray featuresMaskArray = dataSet.getFeaturesMaskArray();
        if (featuresMaskArray == null) {
            this.featuresMask = null;
        } else {
            this.featuresMask = new ArrayDescriptor(featuresMaskArray);
        }
        INDArray labelsMaskArray = dataSet.getLabelsMaskArray();
        if (labelsMaskArray == null) {
            this.labelsMask = null;
        } else {
            this.labelsMask = new ArrayDescriptor(labelsMaskArray);
        }
        this.preProcessed = dataSet.isPreProcessed();
    }

    public DataSet getDataSet() {
        DataSet dataSet = new DataSet(this.features.getArray(), this.labels.getArray(), this.featuresMask == null ? null : this.featuresMask.getArray(), this.labelsMask == null ? null : this.labelsMask.getArray());
        if (this.preProcessed) {
            dataSet.markAsPreProcessed();
        }
        return dataSet;
    }

    public ArrayDescriptor getFeatures() {
        return this.features;
    }

    public ArrayDescriptor getLabels() {
        return this.labels;
    }

    public ArrayDescriptor getFeaturesMask() {
        return this.featuresMask;
    }

    public ArrayDescriptor getLabelsMask() {
        return this.labelsMask;
    }
}
