package weka.classifiers.meta;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.IteratedSingleClassifierEnhancer;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.trees.DecisionStump;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.UnassignedClassException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

/* loaded from: classes2.dex */
public class AdditiveRegression extends IteratedSingleClassifierEnhancer implements OptionHandler, AdditionalMeasureProducer, WeightedInstancesHandler, TechnicalInformationHandler, IterativeClassifier {
    static final long serialVersionUID = -2368937577670527151L;
    protected ArrayList<Classifier> m_Classifiers;
    protected Instances m_Data;
    protected double m_Diff;
    protected double m_Error;
    protected double m_InitialPrediction;
    protected boolean m_MinimizeAbsoluteError;
    protected boolean m_SuitableData;
    protected double m_shrinkage;

    public AdditiveRegression() {
        this(new DecisionStump());
    }

    public AdditiveRegression(Classifier classifier) {
        this.m_shrinkage = 1.0d;
        this.m_SuitableData = true;
        this.m_MinimizeAbsoluteError = false;
        this.m_Classifier = classifier;
    }

    public static void main(String[] strArr) {
        runClassifier(new AdditiveRegression(), strArr);
    }

    private Instances residualReplace(Instances instances, double d) throws Exception {
        Instances instances2 = new Instances(instances);
        for (int i = 0; i < instances2.numInstances(); i++) {
            instances2.instance(i).setClassValue(instances2.instance(i).classValue() - d);
        }
        return instances2;
    }

    private Instances residualReplace(Instances instances, Classifier classifier) throws Exception {
        Instances instances2 = new Instances(instances);
        for (int i = 0; i < instances2.numInstances(); i++) {
            double classifyInstance = classifier.classifyInstance(instances2.instance(i));
            if (Utils.isMissingValue(classifyInstance)) {
                throw new UnassignedClassException("AdditiveRegression: base learner predicted missing value.");
            }
            instances2.instance(i).setClassValue(instances2.instance(i).classValue() - (classifyInstance * getShrinkage()));
        }
        return instances2;
    }

    @Override // weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        initializeClassifier(instances);
        do {
        } while (next());
        done();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double classifyInstance(Instance instance) throws Exception {
        double d = this.m_InitialPrediction;
        if (!this.m_SuitableData) {
            return d;
        }
        Iterator<Classifier> it = this.m_Classifiers.iterator();
        while (it.hasNext()) {
            double classifyInstance = it.next().classifyInstance(instance);
            if (Utils.isMissingValue(classifyInstance)) {
                throw new UnassignedClassException("AdditiveRegression: base learner predicted missing value.");
            }
            d += classifyInstance * getShrinkage();
        }
        return d;
    }

    @Override // weka.classifiers.SingleClassifierEnhancer
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    @Override // weka.classifiers.IterativeClassifier
    public void done() {
        this.m_Data = null;
    }

    @Override // weka.core.AdditionalMeasureProducer
    public Enumeration<String> enumerateMeasures() {
        Vector vector = new Vector(1);
        vector.addElement("measureNumIterations");
        return vector.elements();
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        return capabilities;
    }

    @Override // weka.core.AdditionalMeasureProducer
    public double getMeasure(String str) {
        if (str.compareToIgnoreCase("measureNumIterations") == 0) {
            return measureNumIterations();
        }
        throw new IllegalArgumentException(str + " not supported (AdditiveRegression)");
    }

    public boolean getMinimizeAbsoluteError() {
        return this.m_MinimizeAbsoluteError;
    }

    @Override // weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-S");
        vector.add("" + getShrinkage());
        if (getMinimizeAbsoluteError()) {
            vector.add("-A");
        }
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[0]);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 12091 $");
    }

    public double getShrinkage() {
        return this.m_shrinkage;
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.TECHREPORT);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "J.H. Friedman");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1999");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Stochastic Gradient Boosting");
        technicalInformation.setValue(TechnicalInformation.Field.INSTITUTION, "Stanford University");
        technicalInformation.setValue(TechnicalInformation.Field.PS, "http://www-stat.stanford.edu/~jhf/ftp/stobst.ps");
        return technicalInformation;
    }

    public String globalInfo() {
        return " Meta classifier that enhances the performance of a regression base classifier. Each iteration fits a model to the residuals left by the classifier on the previous iteration. Prediction is accomplished by adding the predictions of each classifier. Reducing the shrinkage (learning rate) parameter helps prevent overfitting and has a smoothing effect but increases the learning time.\n\nFor more information see:\n\n" + getTechnicalInformation().toString();
    }

    @Override // weka.classifiers.IterativeClassifier
    public void initializeClassifier(Instances instances) throws Exception {
        getCapabilities().testWithFail(instances);
        this.m_Data = new Instances(instances);
        this.m_Data.deleteWithMissingClass();
        if (getMinimizeAbsoluteError()) {
            this.m_InitialPrediction = this.m_Data.kthSmallestValue(this.m_Data.classIndex(), this.m_Data.numInstances() / 2);
        } else {
            this.m_InitialPrediction = this.m_Data.meanOrMode(this.m_Data.classIndex());
        }
        if (this.m_Data.numAttributes() == 1) {
            System.err.println("Cannot build non-trivial model (only class attribute present in data!).");
            this.m_SuitableData = false;
            return;
        }
        this.m_SuitableData = true;
        this.m_Classifiers = new ArrayList<>(this.m_NumIterations);
        this.m_Data = residualReplace(this.m_Data, this.m_InitialPrediction);
        this.m_Error = 0.0d;
        this.m_Diff = Double.MAX_VALUE;
        for (int i = 0; i < this.m_Data.numInstances(); i++) {
            if (getMinimizeAbsoluteError()) {
                this.m_Error += this.m_Data.instance(i).weight() * Math.abs(this.m_Data.instance(i).classValue());
            } else {
                this.m_Error += this.m_Data.instance(i).weight() * this.m_Data.instance(i).classValue() * this.m_Data.instance(i).classValue();
            }
        }
        if (this.m_Debug) {
            if (getMinimizeAbsoluteError()) {
                System.err.println("Sum of absolute residuals (predicting the median) : " + this.m_Error);
                return;
            }
            System.err.println("Sum of squared residuals (predicting the mean) : " + this.m_Error);
        }
    }

    @Override // weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector(2);
        vector.addElement(new Option("\tSpecify shrinkage rate. (default = 1.0, i.e., no shrinkage)", "S", 1, "-S"));
        vector.addElement(new Option("\tMinimize absolute error instead of squared error (assumes that base learner minimizes absolute error).", "A", 0, "-A"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    public double measureNumIterations() {
        return this.m_Classifiers.size();
    }

    public String minimizeAbsoluteErrorTipText() {
        return "Minimize absolute error instead of squared error (assume base learner minimizes absolute error)";
    }

    @Override // weka.classifiers.IterativeClassifier
    public boolean next() throws Exception {
        double weight;
        double classValue;
        if (!this.m_SuitableData || this.m_Classifiers.size() >= this.m_NumIterations || this.m_Diff <= Utils.SMALL) {
            return false;
        }
        this.m_Classifiers.add(AbstractClassifier.makeCopy(this.m_Classifier));
        this.m_Classifiers.get(this.m_Classifiers.size() - 1).buildClassifier(this.m_Data);
        this.m_Data = residualReplace(this.m_Data, this.m_Classifiers.get(this.m_Classifiers.size() - 1));
        double d = 0.0d;
        for (int i = 0; i < this.m_Data.numInstances(); i++) {
            if (getMinimizeAbsoluteError()) {
                weight = this.m_Data.instance(i).weight();
                classValue = Math.abs(this.m_Data.instance(i).classValue());
            } else {
                weight = this.m_Data.instance(i).weight() * this.m_Data.instance(i).classValue();
                classValue = this.m_Data.instance(i).classValue();
            }
            d += weight * classValue;
        }
        if (this.m_Debug) {
            if (getMinimizeAbsoluteError()) {
                System.err.println("Sum of absolute residuals: " + d);
            } else {
                System.err.println("Sum of squared residuals: " + d);
            }
        }
        this.m_Diff = this.m_Error - d;
        this.m_Error = d;
        return true;
    }

    public void setMinimizeAbsoluteError(boolean z) {
        this.m_MinimizeAbsoluteError = z;
    }

    @Override // weka.classifiers.IteratedSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('S', strArr);
        if (option.length() != 0) {
            setShrinkage(Double.valueOf(option).doubleValue());
        }
        setMinimizeAbsoluteError(Utils.getFlag('A', strArr));
        super.setOptions(strArr);
        Utils.checkForRemainingOptions(strArr);
    }

    public void setShrinkage(double d) {
        this.m_shrinkage = d;
    }

    public String shrinkageTipText() {
        return "Shrinkage rate. Smaller values help prevent overfitting and have a smoothing effect (but increase learning time). Default = 1.0, ie. no shrinkage.";
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        if (this.m_SuitableData && this.m_Classifiers == null) {
            return "Classifier hasn't been built yet!";
        }
        if (!this.m_SuitableData) {
            StringBuffer stringBuffer2 = new StringBuffer();
            stringBuffer2.append(getClass().getName().replaceAll(".*\\.", "") + "\n");
            stringBuffer2.append(getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
            stringBuffer2.append("Warning: Non-trivial model could not be built, initial prediction is: ");
            stringBuffer2.append(this.m_InitialPrediction);
            return stringBuffer2.toString();
        }
        stringBuffer.append("Additive Regression\n\n");
        stringBuffer.append("Initial prediction: " + this.m_InitialPrediction + "\n\n");
        stringBuffer.append("Base classifier " + getClassifier().getClass().getName() + "\n\n");
        stringBuffer.append("" + this.m_Classifiers.size() + " models generated.\n");
        for (int i = 0; i < this.m_Classifiers.size(); i++) {
            stringBuffer.append("\nModel number " + i + "\n\n" + this.m_Classifiers.get(i) + "\n");
        }
        return stringBuffer.toString();
    }
}
