/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.impurity.mse;

import org.apache.ignite.ml.tree.TreeFilter;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.data.TreeDataIndex;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.ImpurityMeasureCalculator;
import org.apache.ignite.ml.tree.impurity.mse.MSEImpurityMeasure;
import org.apache.ignite.ml.tree.impurity.util.StepFunction;

public class MSEImpurityMeasureCalculator
extends ImpurityMeasureCalculator<MSEImpurityMeasure> {
    private static final long serialVersionUID = 288747414953756824L;

    public MSEImpurityMeasureCalculator(boolean useIdx) {
        super(useIdx);
    }

    @Override
    public StepFunction<MSEImpurityMeasure>[] calculate(DecisionTreeData data, TreeFilter filter, int depth) {
        boolean canCalculate;
        TreeDataIndex idx = null;
        if (this.useIdx) {
            idx = data.createIndexByFilter(depth, filter);
            canCalculate = idx.rowsCount() > 0;
        } else {
            boolean bl = canCalculate = (data = data.filter(filter)).getFeatures().length > 0;
        }
        if (canCalculate) {
            int rowsCnt = this.rowsCount(data, idx);
            int colsCnt = this.columnsCount(data, idx);
            StepFunction[] res = new StepFunction[colsCnt];
            double rightYOriginal = 0.0;
            double rightY2Original = 0.0;
            for (int i = 0; i < rowsCnt; ++i) {
                double lbVal = this.getLabelValue(data, idx, 0, i);
                rightYOriginal += lbVal;
                rightY2Original += Math.pow(lbVal, 2.0);
            }
            for (int col = 0; col < res.length; ++col) {
                if (!this.useIdx) {
                    data.sort(col);
                }
                double[] x = new double[rowsCnt + 1];
                ImpurityMeasure[] y = new MSEImpurityMeasure[rowsCnt + 1];
                x[0] = Double.NEGATIVE_INFINITY;
                double leftY = 0.0;
                double leftY2 = 0.0;
                double rightY = rightYOriginal;
                double rightY2 = rightY2Original;
                int leftSize = 0;
                for (int i = 0; i <= rowsCnt; ++i) {
                    if (leftSize > 0) {
                        double lblVal = this.getLabelValue(data, idx, col, i - 1);
                        leftY += lblVal;
                        leftY2 += Math.pow(lblVal, 2.0);
                        rightY -= lblVal;
                        rightY2 -= Math.pow(lblVal, 2.0);
                    }
                    if (leftSize < rowsCnt) {
                        x[leftSize + 1] = this.getFeatureValue(data, idx, col, i);
                    }
                    y[leftSize] = new MSEImpurityMeasure(leftY, leftY2, leftSize, rightY, rightY2, rowsCnt - leftSize);
                    ++leftSize;
                }
                res[col] = new StepFunction(x, y);
            }
            return res;
        }
        return null;
    }
}

