package weka.knowledgeflow.steps;

import java.io.File;
import java.io.FileInputStream;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.List;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.UpdateableBatchProcessor;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.misc.InputMappedClassifier;
import weka.core.Drawable;
import weka.core.EnvironmentHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.LogHandler;
import weka.core.OptionHandler;
import weka.core.OptionMetadata;
import weka.core.SerializationHelper;
import weka.core.TestInstances;
import weka.core.Utils;
import weka.core.WekaException;
import weka.gui.FilePropertyMetadata;
import weka.gui.ProgrammaticProperty;
import weka.gui.knowledgeflow.KnowledgeFlowApp;
import weka.knowledgeflow.Data;
import weka.knowledgeflow.LoggingLevel;
import weka.knowledgeflow.SingleThreadedExecution;
import weka.knowledgeflow.StepManager;
import weka.knowledgeflow.steps.PairedDataHelper;

@KFStep(name = "Classifier", category = "Classifiers", toolTipText = "Weka classifier wrapper", iconPath = KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF, resourceIntensive = true)
/* loaded from: input_file:weka/knowledgeflow/steps/Classifier.class */
public class Classifier extends WekaAlgorithmWrapper implements PairedDataHelper.PairedProcessor<weka.classifiers.Classifier> {
    private static final long serialVersionUID = 8326706942962123155L;
    protected weka.classifiers.Classifier m_classifierTemplate;
    protected weka.classifiers.Classifier m_trainedClassifier;
    protected Instances m_trainedClassifierHeader;
    protected boolean m_resetIncrementalClassifier;
    protected boolean m_streaming;
    protected boolean m_classifierIsIncremental;
    protected transient PairedDataHelper<weka.classifiers.Classifier> m_trainTestHelper;
    protected boolean m_isReset;
    protected File m_loadModelFileName = new File(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF);
    protected boolean m_updateIncrementalClassifier = true;
    protected Data m_incrementalData = new Data(StepManager.CON_INCREMENTAL_CLASSIFIER);

    @Override // weka.knowledgeflow.steps.WekaAlgorithmWrapper
    public Class getWrappedAlgorithmClass() {
        return weka.classifiers.Classifier.class;
    }

    @Override // weka.knowledgeflow.steps.WekaAlgorithmWrapper
    public void setWrappedAlgorithm(Object obj) {
        super.setWrappedAlgorithm(obj);
        this.m_defaultIconPath = "weka/gui/knowledgeflow/icons/DefaultClassifier.gif";
    }

    public weka.classifiers.Classifier getClassifier() {
        return (weka.classifiers.Classifier) getWrappedAlgorithm();
    }

    @ProgrammaticProperty
    public void setClassifier(weka.classifiers.Classifier classifier) {
        setWrappedAlgorithm(classifier);
    }

    @Override // weka.knowledgeflow.steps.Step, weka.knowledgeflow.steps.BaseStepExtender
    public void stepInit() throws WekaException {
        try {
            this.m_trainedClassifier = null;
            this.m_trainedClassifierHeader = null;
            this.m_trainTestHelper = null;
            this.m_incrementalData = new Data(StepManager.CON_INCREMENTAL_CLASSIFIER);
            this.m_classifierTemplate = AbstractClassifier.makeCopy((weka.classifiers.Classifier) getWrappedAlgorithm());
            if (this.m_classifierTemplate instanceof EnvironmentHandler) {
                ((EnvironmentHandler) this.m_classifierTemplate).setEnvironment(getStepManager().getExecutionEnvironment().getEnvironmentVariables());
            }
            if (this.m_classifierTemplate.getClass().getAnnotation(SingleThreadedExecution.class) != null) {
                getStepManager().logBasic(getClassifier().getClass().getCanonicalName() + " will be executed in the single threaded executor");
                getStepManager().setStepMustRunSingleThreaded(true);
            }
            if (getStepManager().numIncomingConnectionsOfType(StepManager.CON_TRAININGSET) > 0) {
                this.m_trainTestHelper = new PairedDataHelper<>(this, this, StepManager.CON_TRAININGSET, getStepManager().numIncomingConnectionsOfType(StepManager.CON_TESTSET) > 0 ? StepManager.CON_TESTSET : null);
            }
            this.m_isReset = true;
            this.m_classifierIsIncremental = this.m_classifierTemplate instanceof UpdateableClassifier;
            if (getLoadClassifierFileName() != null && getLoadClassifierFileName().toString().length() > 0 && getStepManager().numIncomingConnectionsOfType(StepManager.CON_TRAININGSET) == 0) {
                String environmentSubstitute = getStepManager().environmentSubstitute(getLoadClassifierFileName().toString());
                try {
                    getStepManager().logBasic("Loading classifier: " + environmentSubstitute);
                    loadModel(environmentSubstitute);
                } catch (Exception e) {
                    throw new WekaException(e);
                }
            }
            if (this.m_trainedClassifier == null || getStepManager().numIncomingConnectionsOfType("instance") <= 0 || this.m_classifierIsIncremental) {
                return;
            }
            getStepManager().logWarning("Loaded classifier is not an incremental one - will only be able to evaluate, and not update, on the incoming instance stream.");
        } catch (Exception e2) {
            throw new WekaException(e2);
        }
    }

    public File getLoadClassifierFileName() {
        return this.m_loadModelFileName;
    }

    @FilePropertyMetadata(fileChooserDialogType = 0, directoriesOnly = false)
    @OptionMetadata(displayName = "Classifier model to load", description = "Optional Path to a classifier to load at execution time (only applies when using testSet or instance connections)")
    public void setLoadClassifierFileName(File file) {
        this.m_loadModelFileName = file;
    }

    public boolean getResetIncrementalClassifier() {
        return this.m_resetIncrementalClassifier;
    }

    @OptionMetadata(displayName = "Reset incremental classifier", description = "Reset classifier (if it is incremental) at the start of the incoming instance stream")
    public void setResetIncrementalClassifier(boolean z) {
        this.m_resetIncrementalClassifier = z;
    }

    public boolean getUpdateIncrementalClassifier() {
        return this.m_updateIncrementalClassifier;
    }

    @OptionMetadata(displayName = "Update incremental classifier", description = " Update an incremental classifier on incoming instance stream")
    public void setUpdateIncrementalClassifier(boolean z) {
        this.m_updateIncrementalClassifier = z;
    }

    @Override // weka.knowledgeflow.steps.BaseStep, weka.knowledgeflow.steps.Step, weka.knowledgeflow.steps.BaseStepExtender
    public void processIncoming(Data data) throws WekaException {
        try {
            getStepManager().processing();
            if (this.m_isReset) {
                this.m_isReset = false;
                Instances instances = data.getConnectionName().equals("instance") ? new Instances(((Instance) data.getPayloadElement("instance")).dataset(), 0) : (Instances) data.getPayloadElement(data.getConnectionName());
                if (instances.classAttribute() == null) {
                    getStepManager().logWarning("No class index is set in the data - using last attribute as class");
                    instances.setClassIndex(instances.numAttributes() - 1);
                }
                if (data.getConnectionName().equals("instance")) {
                    this.m_streaming = true;
                    if (this.m_trainedClassifier == null) {
                        this.m_trainedClassifier = AbstractClassifier.makeCopy(this.m_classifierTemplate);
                        getStepManager().logBasic("Initialising incremental classifier");
                        this.m_trainedClassifier.buildClassifier(instances);
                        this.m_trainedClassifierHeader = instances;
                    } else if (this.m_resetIncrementalClassifier && this.m_classifierIsIncremental) {
                        this.m_trainedClassifier = AbstractClassifier.makeCopy(this.m_classifierTemplate);
                        this.m_trainedClassifierHeader = instances;
                        getStepManager().logBasic("Resetting incremental classifier");
                        this.m_trainedClassifier.buildClassifier(this.m_trainedClassifierHeader);
                    }
                    if (this.m_trainedClassifier instanceof LogHandler) {
                        ((LogHandler) this.m_trainedClassifier).setLog(getStepManager().getLog());
                    }
                    getStepManager().logBasic((this.m_updateIncrementalClassifier && this.m_classifierIsIncremental) ? "Training incrementally" : "Predicting incrementally");
                } else if (data.getConnectionName().equals(StepManager.CON_TRAININGSET)) {
                    this.m_trainedClassifierHeader = instances;
                } else if (data.getConnectionName().equals(StepManager.CON_TESTSET) && getStepManager().numIncomingConnectionsOfType(StepManager.CON_TRAININGSET) == 0 && (this.m_classifierTemplate instanceof InputMappedClassifier)) {
                    this.m_trainedClassifier = AbstractClassifier.makeCopy(this.m_classifierTemplate);
                    ((InputMappedClassifier) this.m_trainedClassifier).getModelHeader(null);
                }
                if (this.m_trainedClassifierHeader != null && !instances.equalHeaders(this.m_trainedClassifierHeader) && !(this.m_trainedClassifier instanceof InputMappedClassifier)) {
                    throw new WekaException("Structure of incoming data does not match that of the trained classifier");
                }
            }
            if (this.m_streaming) {
                processStreaming(data);
            } else if (this.m_trainTestHelper != null) {
                this.m_trainTestHelper.process(data);
            } else {
                processOnlyTestSet(data);
            }
        } catch (Exception e) {
            throw new WekaException(e);
        }
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // weka.knowledgeflow.steps.PairedDataHelper.PairedProcessor
    public weka.classifiers.Classifier processPrimary(Integer num, Integer num2, Data data, PairedDataHelper<weka.classifiers.Classifier> pairedDataHelper) throws WekaException {
        Instances instances = (Instances) data.getPrimaryPayload();
        if (this.m_trainedClassifierHeader == null) {
            this.m_trainedClassifierHeader = new Instances(instances, 0);
        }
        try {
            weka.classifiers.Classifier makeCopy = AbstractClassifier.makeCopy(this.m_classifierTemplate);
            String canonicalName = makeCopy.getClass().getCanonicalName();
            String substring = canonicalName.substring(canonicalName.lastIndexOf(".") + 1);
            if (makeCopy instanceof OptionHandler) {
                substring = substring + TestInstances.DEFAULT_SEPARATORS + Utils.joinOptions(((OptionHandler) makeCopy).getOptions());
            }
            if (makeCopy instanceof EnvironmentHandler) {
                ((EnvironmentHandler) makeCopy).setEnvironment(getStepManager().getExecutionEnvironment().getEnvironmentVariables());
            }
            if (makeCopy instanceof LogHandler) {
                ((LogHandler) makeCopy).setLog(getStepManager().getLog());
            }
            pairedDataHelper.addIndexedValueToNamedStore("trainingSplits", num, instances);
            if (!isStopRequested()) {
                getStepManager().logBasic("Building " + substring + " on " + instances.relationName() + " for fold/set " + num + " out of " + num2);
                if (getStepManager().getLoggingLevel().ordinal() > LoggingLevel.LOW.ordinal()) {
                    getStepManager().statusMessage("Building " + substring + " on fold/set " + num);
                }
                if (num2.intValue() == 1) {
                    this.m_trainedClassifier = makeCopy;
                }
                makeCopy.buildClassifier(instances);
                getStepManager().logDetailed("Finished building " + substring + "on " + instances.relationName() + " for fold/set " + num + " out of " + num2);
                outputTextData(makeCopy, num.intValue());
                outputGraphData(makeCopy, num.intValue());
                if (getStepManager().numIncomingConnectionsOfType(StepManager.CON_TESTSET) == 0) {
                    Data data2 = new Data(StepManager.CON_BATCH_CLASSIFIER, makeCopy);
                    data2.setPayloadElement(StepManager.CON_AUX_DATA_TRAININGSET, instances);
                    data2.setPayloadElement(StepManager.CON_AUX_DATA_SET_NUM, num);
                    data2.setPayloadElement(StepManager.CON_AUX_DATA_MAX_SET_NUM, num2);
                    data2.setPayloadElement(StepManager.CON_AUX_DATA_LABEL, getName());
                    data2.setPayloadElement(StepManager.CON_AUX_DATA_PRIMARY_PAYLOAD_NOT_THREAD_SAFE, true);
                    getStepManager().outputData(data2);
                }
            }
            return makeCopy;
        } catch (Exception e) {
            throw new WekaException(e);
        }
    }

    @Override // weka.knowledgeflow.steps.PairedDataHelper.PairedProcessor
    public void processSecondary(Integer num, Integer num2, Data data, PairedDataHelper<weka.classifiers.Classifier> pairedDataHelper) throws WekaException {
        weka.classifiers.Classifier indexedPrimaryResult = pairedDataHelper.getIndexedPrimaryResult(num.intValue());
        Instances instances = (Instances) data.getPrimaryPayload();
        if (this.m_trainedClassifierHeader != null && !instances.equalHeaders(this.m_trainedClassifierHeader) && !(this.m_trainedClassifier instanceof InputMappedClassifier)) {
            throw new WekaException("Structure of incoming data does not match that of the trained classifier");
        }
        Instances instances2 = (Instances) pairedDataHelper.getIndexedValueFromNamedStore("trainingSplits", num);
        getStepManager().logBasic("Dispatching model for set " + num + " out of " + num2 + " to output");
        Data data2 = new Data(StepManager.CON_BATCH_CLASSIFIER, indexedPrimaryResult);
        data2.setPayloadElement(StepManager.CON_AUX_DATA_TRAININGSET, instances2);
        data2.setPayloadElement(StepManager.CON_AUX_DATA_TESTSET, instances);
        data2.setPayloadElement(StepManager.CON_AUX_DATA_SET_NUM, num);
        data2.setPayloadElement(StepManager.CON_AUX_DATA_MAX_SET_NUM, num2);
        data2.setPayloadElement(StepManager.CON_AUX_DATA_LABEL, getName());
        data2.setPayloadElement(StepManager.CON_AUX_DATA_PRIMARY_PAYLOAD_NOT_THREAD_SAFE, true);
        getStepManager().outputData(data2);
    }

    protected void processOnlyTestSet(Data data) throws WekaException {
        try {
            weka.classifiers.Classifier makeCopy = AbstractClassifier.makeCopy(this.m_trainedClassifier);
            Data data2 = new Data(StepManager.CON_BATCH_CLASSIFIER);
            data2.setPayloadElement(StepManager.CON_BATCH_CLASSIFIER, makeCopy);
            data2.setPayloadElement(StepManager.CON_AUX_DATA_TESTSET, data.getPayloadElement(StepManager.CON_TESTSET));
            data2.setPayloadElement(StepManager.CON_AUX_DATA_SET_NUM, data.getPayloadElement(StepManager.CON_AUX_DATA_SET_NUM, 1));
            data2.setPayloadElement(StepManager.CON_AUX_DATA_MAX_SET_NUM, data.getPayloadElement(StepManager.CON_AUX_DATA_MAX_SET_NUM, 1));
            data2.setPayloadElement(StepManager.CON_AUX_DATA_LABEL, getName());
            data2.setPayloadElement(StepManager.CON_AUX_DATA_PRIMARY_PAYLOAD_NOT_THREAD_SAFE, true);
            getStepManager().outputData(data2);
            if (isStopRequested()) {
                getStepManager().interrupted();
            } else {
                getStepManager().finished();
            }
        } catch (Exception e) {
            throw new WekaException(e);
        }
    }

    protected void processStreaming(Data data) throws WekaException {
        if (isStopRequested()) {
            return;
        }
        Instance instance = (Instance) data.getPayloadElement("instance");
        if (getStepManager().isStreamFinished(data)) {
            if (this.m_trainedClassifier instanceof UpdateableBatchProcessor) {
                try {
                    ((UpdateableBatchProcessor) this.m_trainedClassifier).batchFinished();
                } catch (Exception e) {
                    throw new WekaException(e);
                }
            }
            this.m_incrementalData.setPayloadElement(StepManager.CON_INCREMENTAL_CLASSIFIER, this.m_trainedClassifier);
            this.m_incrementalData.setPayloadElement(StepManager.CON_AUX_DATA_TEST_INSTANCE, null);
            outputTextData(this.m_trainedClassifier, -1);
            outputGraphData(this.m_trainedClassifier, 0);
            if (isStopRequested()) {
                return;
            }
            getStepManager().throughputFinished(this.m_incrementalData);
            return;
        }
        this.m_incrementalData.setPayloadElement(StepManager.CON_AUX_DATA_TEST_INSTANCE, instance);
        this.m_incrementalData.setPayloadElement(StepManager.CON_INCREMENTAL_CLASSIFIER, this.m_trainedClassifier);
        getStepManager().outputData(this.m_incrementalData.getConnectionName(), this.m_incrementalData);
        getStepManager().throughputUpdateStart();
        if (this.m_classifierIsIncremental && this.m_updateIncrementalClassifier && !instance.classIsMissing()) {
            try {
                ((UpdateableClassifier) this.m_trainedClassifier).updateClassifier(instance);
            } catch (Exception e2) {
                throw new WekaException(e2);
            }
        }
        getStepManager().throughputUpdateEnd();
    }

    protected void outputTextData(weka.classifiers.Classifier classifier, int i) throws WekaException {
        if (getStepManager().numOutgoingConnectionsOfType("text") == 0) {
            return;
        }
        Data data = new Data("text");
        String obj = classifier.toString();
        String name = classifier.getClass().getName();
        String substring = name.substring(name.lastIndexOf(46) + 1, name.length());
        String str = "=== Classifier model ===\n\nScheme:   " + substring + "\nRelation: " + this.m_trainedClassifierHeader.relationName() + "\n\n" + obj;
        data.setPayloadElement("text", str);
        data.setPayloadElement(StepManager.CON_AUX_DATA_TEXT_TITLE, "Model: " + substring);
        if (i != -1) {
            data.setPayloadElement(StepManager.CON_AUX_DATA_SET_NUM, Integer.valueOf(i));
        }
        getStepManager().outputData(data);
    }

    protected void outputGraphData(weka.classifiers.Classifier classifier, int i) throws WekaException {
        if (!(classifier instanceof Drawable) || getStepManager().numOutgoingConnectionsOfType(StepManager.CON_GRAPH) == 0) {
            return;
        }
        try {
            String graph = ((Drawable) classifier).graph();
            int graphType = ((Drawable) classifier).graphType();
            String canonicalName = classifier.getClass().getCanonicalName();
            String str = "Set " + i + " (" + this.m_trainedClassifierHeader.relationName() + ") " + canonicalName.substring(canonicalName.lastIndexOf(46) + 1, canonicalName.length());
            Data data = new Data(StepManager.CON_GRAPH);
            data.setPayloadElement(StepManager.CON_GRAPH, graph);
            data.setPayloadElement(StepManager.CON_AUX_DATA_GRAPH_TITLE, str);
            data.setPayloadElement(StepManager.CON_AUX_DATA_GRAPH_TYPE, Integer.valueOf(graphType));
            getStepManager().outputData(data);
        } catch (Exception e) {
            throw new WekaException(e);
        }
    }

    @Override // weka.knowledgeflow.steps.Step, weka.knowledgeflow.steps.BaseStepExtender
    public List<String> getIncomingConnectionTypes() {
        ArrayList arrayList = new ArrayList();
        int numIncomingConnectionsOfType = getStepManager().numIncomingConnectionsOfType(StepManager.CON_TRAININGSET);
        int numIncomingConnectionsOfType2 = getStepManager().numIncomingConnectionsOfType(StepManager.CON_TESTSET);
        int numIncomingConnectionsOfType3 = getStepManager().numIncomingConnectionsOfType("instance");
        if (numIncomingConnectionsOfType == 0 && numIncomingConnectionsOfType2 == 0) {
            arrayList.add("instance");
        }
        if (numIncomingConnectionsOfType3 == 0 && numIncomingConnectionsOfType == 0) {
            arrayList.add(StepManager.CON_TRAININGSET);
        }
        if (numIncomingConnectionsOfType3 == 0 && numIncomingConnectionsOfType2 == 0) {
            arrayList.add(StepManager.CON_TESTSET);
        }
        if (getStepManager().numIncomingConnectionsOfType(StepManager.CON_INFO) == 0) {
            arrayList.add(StepManager.CON_INFO);
        }
        return arrayList;
    }

    @Override // weka.knowledgeflow.steps.Step, weka.knowledgeflow.steps.BaseStepExtender
    public List<String> getOutgoingConnectionTypes() {
        ArrayList arrayList = new ArrayList();
        if (getStepManager().numIncomingConnections() > 0) {
            int numIncomingConnectionsOfType = getStepManager().numIncomingConnectionsOfType(StepManager.CON_TRAININGSET);
            int numIncomingConnectionsOfType2 = getStepManager().numIncomingConnectionsOfType(StepManager.CON_TESTSET);
            if (getStepManager().numIncomingConnectionsOfType("instance") > 0) {
                arrayList.add(StepManager.CON_INCREMENTAL_CLASSIFIER);
            } else if (numIncomingConnectionsOfType > 0 || numIncomingConnectionsOfType2 > 0) {
                arrayList.add(StepManager.CON_BATCH_CLASSIFIER);
            }
            arrayList.add("text");
            if ((getClassifier() instanceof Drawable) && numIncomingConnectionsOfType > 0) {
                arrayList.add(StepManager.CON_GRAPH);
            }
        }
        arrayList.add(StepManager.CON_INFO);
        return arrayList;
    }

    protected void loadModel(String str) throws Exception {
        ObjectInputStream objectInputStream = null;
        try {
            objectInputStream = SerializationHelper.getObjectInputStream(new FileInputStream(new File(str)));
            this.m_trainedClassifier = (weka.classifiers.Classifier) objectInputStream.readObject();
            if (!this.m_trainedClassifier.getClass().getCanonicalName().equals(getClassifier().getClass().getCanonicalName())) {
                throw new Exception("The loaded model '" + this.m_trainedClassifier.getClass().getCanonicalName() + "' is not a '" + getClassifier().getClass().getCanonicalName() + "'");
            }
            try {
                this.m_trainedClassifierHeader = (Instances) objectInputStream.readObject();
            } catch (Exception e) {
                getStepManager().logWarning("Model file '" + str + "' does not seem to contain an Instances header");
            }
            if (objectInputStream != null) {
                objectInputStream.close();
            }
        } catch (Throwable th) {
            if (objectInputStream != null) {
                objectInputStream.close();
            }
            throw th;
        }
    }
}
