package smile.regression;

import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.math.matrix.QRDecomposition;
import smile.math.matrix.SingularValueDecomposition;
import smile.math.special.Beta;

/* loaded from: input_file:smile/regression/OLS.class */
public class OLS implements Regression<double[]> {
    private static final Logger logger = LoggerFactory.getLogger(OLS.class);
    private int p;
    private double b;
    private double[] w;
    private double[][] coefficients;
    private double[] residuals;
    private double RSS;
    private double error;
    private int df;
    private double RSquared;
    private double adjustedRSquared;
    private double F;
    private double pvalue;

    /* loaded from: input_file:smile/regression/OLS$Trainer.class */
    public static class Trainer extends RegressionTrainer<double[]> {
        @Override // smile.regression.RegressionTrainer
        public OLS train(double[][] dArr, double[] dArr2) {
            return new OLS(dArr, dArr2);
        }
    }

    public OLS(double[][] dArr, double[] dArr2) {
        this(dArr, dArr2, false);
    }

    public OLS(double[][] dArr, double[] dArr2, boolean z) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(dArr2.length)));
        }
        int length = dArr.length;
        this.p = dArr[0].length;
        if (length <= this.p) {
            throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", Integer.valueOf(length), Integer.valueOf(this.p)));
        }
        double[] dArr3 = new double[this.p + 1];
        double[][] dArr4 = new double[length][this.p + 1];
        for (int i = 0; i < length; i++) {
            System.arraycopy(dArr[i], 0, dArr4[i], 0, this.p);
            dArr4[i][this.p] = 1.0d;
        }
        QRDecomposition qRDecomposition = null;
        SingularValueDecomposition singularValueDecomposition = null;
        if (z) {
            singularValueDecomposition = SingularValueDecomposition.decompose(dArr4);
            singularValueDecomposition.solve(dArr2, dArr3);
        } else {
            try {
                qRDecomposition = new QRDecomposition(dArr4, true);
                qRDecomposition.solve(dArr2, dArr3);
            } catch (RuntimeException e) {
                logger.warn("Matrix is not of full rank, try SVD instead");
                z = true;
                Arrays.fill(dArr3, 0.0d);
                for (int i2 = 0; i2 < length; i2++) {
                    System.arraycopy(dArr[i2], 0, dArr4[i2], 0, this.p);
                    dArr4[i2][this.p] = 1.0d;
                }
                singularValueDecomposition = SingularValueDecomposition.decompose(dArr4);
                singularValueDecomposition.solve(dArr2, dArr3);
            }
        }
        this.b = dArr3[this.p];
        this.w = new double[this.p];
        System.arraycopy(dArr3, 0, this.w, 0, this.p);
        double[] dArr5 = new double[length];
        Math.ax(dArr, this.w, dArr5);
        double d = 0.0d;
        this.RSS = 0.0d;
        double mean = Math.mean(dArr2);
        this.residuals = new double[length];
        for (int i3 = 0; i3 < length; i3++) {
            double d2 = (dArr2[i3] - dArr5[i3]) - this.b;
            this.residuals[i3] = d2;
            this.RSS += Math.sqr(d2);
            d += Math.sqr(dArr2[i3] - mean);
        }
        this.error = Math.sqrt(this.RSS / ((length - this.p) - 1));
        this.df = (length - this.p) - 1;
        this.RSquared = 1.0d - (this.RSS / d);
        this.adjustedRSquared = 1.0d - (((1.0d - this.RSquared) * (length - 1)) / ((length - this.p) - 1));
        this.F = ((d - this.RSS) * ((length - this.p) - 1)) / (this.RSS * this.p);
        int i4 = this.p;
        int i5 = (length - this.p) - 1;
        this.pvalue = Beta.regularizedIncompleteBetaFunction(0.5d * i5, 0.5d * i4, i5 / (i5 + (i4 * this.F)));
        this.coefficients = new double[this.p + 1][4];
        if (!z) {
            double[][] inverse = qRDecomposition.toCholesky().inverse();
            for (int i6 = 0; i6 <= this.p; i6++) {
                this.coefficients[i6][0] = dArr3[i6];
                double sqrt = this.error * Math.sqrt(inverse[i6][i6]);
                this.coefficients[i6][1] = sqrt;
                double d3 = dArr3[i6] / sqrt;
                this.coefficients[i6][2] = d3;
                this.coefficients[i6][3] = Beta.regularizedIncompleteBetaFunction(0.5d * this.df, 0.5d, this.df / (this.df + (d3 * d3)));
            }
            return;
        }
        for (int i7 = 0; i7 <= this.p; i7++) {
            this.coefficients[i7][0] = dArr3[i7];
            if (Math.isZero(singularValueDecomposition.getSingularValues()[i7], 1.0E-10d)) {
                this.coefficients[i7][1] = Double.NaN;
                this.coefficients[i7][2] = 0.0d;
                this.coefficients[i7][3] = 1.0d;
            } else {
                double d4 = this.error / singularValueDecomposition.getSingularValues()[i7];
                this.coefficients[i7][1] = d4;
                double d5 = dArr3[i7] / d4;
                this.coefficients[i7][2] = d5;
                this.coefficients[i7][3] = Beta.regularizedIncompleteBetaFunction(0.5d * this.df, 0.5d, this.df / (this.df + (d5 * d5)));
            }
        }
    }

    public double[][] ttest() {
        return this.coefficients;
    }

    public double[] coefficients() {
        return this.w;
    }

    public double intercept() {
        return this.b;
    }

    public double[] residuals() {
        return this.residuals;
    }

    public double RSS() {
        return this.RSS;
    }

    public double error() {
        return this.error;
    }

    public int df() {
        return this.df;
    }

    public double RSquared() {
        return this.RSquared;
    }

    public double adjustedRSquared() {
        return this.adjustedRSquared;
    }

    public double ftest() {
        return this.F;
    }

    public double pvalue() {
        return this.pvalue;
    }

    @Override // smile.regression.Regression
    public double predict(double[] dArr) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        return this.b + Math.dot(dArr, this.w);
    }

    private String significance(double d) {
        return d < 0.001d ? "***" : d < 0.01d ? "**" : d < 0.05d ? "*" : d < 0.1d ? "." : "";
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Linear Model:\n");
        double[] dArr = (double[]) this.residuals.clone();
        sb.append("\nResiduals:\n");
        sb.append("\t       Min\t        1Q\t    Median\t        3Q\t       Max\n");
        sb.append(String.format("\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.4f%n", Double.valueOf(Math.min(dArr)), Double.valueOf(Math.q1(dArr)), Double.valueOf(Math.median(dArr)), Double.valueOf(Math.q3(dArr)), Double.valueOf(Math.max(dArr))));
        sb.append("\nCoefficients:\n");
        sb.append("            Estimate        Std. Error        t value        Pr(>|t|)\n");
        sb.append(String.format("Intercept%11.4f%18.4f%15.4f%16.4f %s%n", Double.valueOf(this.coefficients[this.p][0]), Double.valueOf(this.coefficients[this.p][1]), Double.valueOf(this.coefficients[this.p][2]), Double.valueOf(this.coefficients[this.p][3]), significance(this.coefficients[this.p][3])));
        for (int i = 0; i < this.p; i++) {
            sb.append(String.format("Var %d\t %11.4f%18.4f%15.4f%16.4f %s%n", Integer.valueOf(i + 1), Double.valueOf(this.coefficients[i][0]), Double.valueOf(this.coefficients[i][1]), Double.valueOf(this.coefficients[i][2]), Double.valueOf(this.coefficients[i][3]), significance(this.coefficients[i][3])));
        }
        sb.append("---------------------------------------------------------------------\n");
        sb.append("Significance codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1\n");
        sb.append(String.format("\nResidual standard error: %.4f on %d degrees of freedom%n", Double.valueOf(this.error), Integer.valueOf(this.df)));
        sb.append(String.format("Multiple R-squared: %.4f,    Adjusted R-squared: %.4f%n", Double.valueOf(this.RSquared), Double.valueOf(this.adjustedRSquared)));
        sb.append(String.format("F-statistic: %.4f on %d and %d DF,  p-value: %.4g%n", Double.valueOf(this.F), Integer.valueOf(this.p), Integer.valueOf(this.df), Double.valueOf(this.pvalue)));
        return sb.toString();
    }
}
