/*
 * Decompiled with CFR 0.152.
 */
package sfa.classification;

import com.carrotsearch.hppc.DoubleArrayList;
import com.carrotsearch.hppc.DoubleDoubleHashMap;
import com.carrotsearch.hppc.DoubleIntHashMap;
import com.carrotsearch.hppc.cursors.DoubleDoubleCursor;
import de.bwaldvogel.liblinear.SolverType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Random;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_print_interface;
import libsvm.svm_problem;
import sfa.classification.Classifier;
import sfa.classification.WEASELClassifier;
import sfa.timeseries.TimeSeries;

public class TEASERClassifier
extends Classifier {
    public static int SVM_KERNEL = 2;
    public static double[] SVM_GAMMAS = new double[]{100.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.5, 1.0};
    public static double SVM_NU = 0.05;
    public static double S = 20.0;
    public static boolean PRINT_EARLINESS = false;
    public static int MIN_WINDOW_LENGTH = 2;
    public static int MAX_WINDOW_LENGTH = 250;
    public EarlyClassificationModel model;
    public WEASELClassifier slaveClassifier = new WEASELClassifier();

    public TEASERClassifier() {
        WEASELClassifier.lowerBounding = true;
        WEASELClassifier.solverType = SolverType.L2R_LR;
        WEASELClassifier.MAX_WINDOW_LENGTH = 250;
    }

    @Override
    public Classifier.Score eval(TimeSeries[] trainSamples, TimeSeries[] testSamples) {
        long startTime = System.currentTimeMillis();
        Classifier.Score score = this.fit(trainSamples);
        if (DEBUG) {
            System.out.println("TEASER Training:\t");
            TEASERClassifier.outputResult(score.training, startTime, trainSamples.length);
        }
        OffsetPrediction pred = this.predict(testSamples, true);
        int correctTesting = pred.getCorrect();
        if (DEBUG) {
            System.out.println("TEASER Testing:\t");
            TEASERClassifier.outputResult(correctTesting, startTime, testSamples.length);
            System.out.println("");
        }
        score.avgOffset = pred.offset / (double)testSamples.length;
        score.testing = correctTesting;
        score.testSize = testSamples.length;
        score.trainSize = trainSamples.length;
        return score;
    }

    @Override
    public Classifier.Score fit(TimeSeries[] trainSamples) {
        this.model = this.fitTeaser(trainSamples);
        return this.model.score;
    }

    public EarlyClassificationModel fitTeaser(TimeSeries[] samples) {
        try {
            int min = Math.max(3, MIN_WINDOW_LENGTH);
            int max = this.getMax(samples, MAX_WINDOW_LENGTH);
            double step = (double)max / S;
            this.model = new EarlyClassificationModel();
            int s = 2;
            while ((double)s <= S) {
                this.model.offsets[s] = (int)Math.round(step * (double)s);
                TimeSeries[] data = this.extractUntilOffset(samples, this.model.offsets[s], true);
                if (this.model.offsets[s] >= min) {
                    Classifier.Score score = this.slaveClassifier.fit(data);
                    Classifier.Predictions result = this.slaveClassifier.predictProbabilities(data);
                    this.model.slaveModels[s] = this.slaveClassifier.getModel();
                    this.model.masterModels[s] = this.fitSVM(samples, result.labels, result.probabilities, result.realLabels);
                }
                ++s;
            }
            double bestF1 = -1.0;
            int bestCount = 1;
            for (int i = 2; i <= 5; ++i) {
                this.model.threshold = i;
                OffsetPrediction off = this.predict(samples, false);
                double correct = (double)off.getCorrect() / (double)off.N;
                double earliness = 1.0 - off.offset / (double)off.N;
                double harmonic_mean = 2.0 * correct * earliness / (correct + earliness);
                System.out.println("Prediction:\t" + this.model.threshold + "\t" + off + "\t" + harmonic_mean);
                if (!(bestF1 < harmonic_mean)) continue;
                bestF1 = harmonic_mean;
                bestCount = i;
                this.model.score.training = off.getCorrect();
                this.model.score.trainSize = samples.length;
            }
            System.out.println("Best Repetition: " + bestCount);
            this.model.threshold = bestCount;
            return this.model;
        }
        catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public svm_model fitSVM(TimeSeries[] samples, Double[] predictedLabels, double[][] probs, int[] probsLabels) {
        ArrayList<double[]> probabilities = new ArrayList<double[]>();
        ArrayList<int[]> labels = new ArrayList<int[]>();
        DoubleArrayList correct = new DoubleArrayList();
        for (int ind = 0; ind < samples.length; ++ind) {
            double is_corr;
            double d = is_corr = this.compareLabels(samples[ind].getLabel(), predictedLabels[ind]) ? 1.0 : 0.0;
            if (is_corr != 1.0) continue;
            labels.add(probsLabels);
            probabilities.add(probs[ind]);
            correct.add(1.0);
        }
        svm_problem problem_one_class = TEASERClassifier.initProblem((double[][])probabilities.toArray((T[])new double[0][]), (int[][])labels.toArray((T[])new int[0][]), correct.toArray());
        svm_parameter best_parameter = null;
        double bestCorrect = -1.0;
        for (double gamma : SVM_GAMMAS) {
            svm_parameter parameter = this.initSVMParameters(gamma);
            if (svm.svm_check_parameter(problem_one_class, parameter) != null) {
                System.out.println(svm.svm_check_parameter(problem_one_class, parameter));
            }
            Double[] predictions = new Double[problem_one_class.l];
            TEASERClassifier.trainSVMOneClass(problem_one_class, parameter, 10, predictions, new Random(1L));
            double correct2 = (double)this.evalLabels((double[])problem_one_class.y, (Double[])predictions).correct.get() / (double)problem_one_class.l;
            if (!(correct2 > bestCorrect)) continue;
            best_parameter = parameter;
            bestCorrect = correct2;
        }
        return svm.svm_train(problem_one_class, best_parameter);
    }

    public TimeSeries[] extractUntilOffset(TimeSeries[] samples, int offset, boolean testing) {
        ArrayList<TimeSeries> offsetSamples = new ArrayList<TimeSeries>();
        for (TimeSeries sample : samples) {
            if (testing) {
                offsetSamples.add(sample.getSubsequence(0, offset));
                continue;
            }
            offsetSamples.add(sample);
        }
        return offsetSamples.toArray(new TimeSeries[0]);
    }

    public int getCount(DoubleIntHashMap counts, double prediction) {
        int count = counts.get(prediction);
        if (count == 0) {
            counts.clear();
        }
        return counts.addTo(prediction, 1);
    }

    @Override
    public Classifier.Predictions score(TimeSeries[] testSamples) {
        Double[] labels = this.predict(testSamples);
        return this.evalLabels(testSamples, labels);
    }

    @Override
    public Double[] predict(TimeSeries[] testSamples) {
        return this.predict((TimeSeries[])testSamples, (boolean)true).labels;
    }

    private OffsetPrediction predict(TimeSeries[] testSamples, boolean testing) {
        double avgOffset = 0.0;
        int correct = 0;
        int count = 0;
        Double[] predictedLabels = new Double[testSamples.length];
        int[] offsets = new int[testSamples.length];
        DoubleIntHashMap[] predictions = new DoubleIntHashMap[testSamples.length];
        for (int i = 0; i < testSamples.length; ++i) {
            predictions[i] = new DoubleIntHashMap();
        }
        DoubleDoubleHashMap perClassEarliness = new DoubleDoubleHashMap();
        DoubleIntHashMap perClassCount = new DoubleIntHashMap();
        block1: for (int s = 0; s < this.model.slaveModels.length; ++s) {
            if (this.model.masterModels[s] == null) continue;
            TimeSeries[] data = this.extractUntilOffset(testSamples, this.model.offsets[s], testing);
            this.slaveClassifier.setModel(this.model.slaveModels[s]);
            Classifier.Predictions result = this.slaveClassifier.predictProbabilities(data);
            for (int ind = 0; ind < data.length; ++ind) {
                if (predictedLabels[ind] == null) {
                    int counts;
                    double predictedLabel = result.labels[ind];
                    double[] probabilities = result.probabilities[ind];
                    double predictNow = svm.svm_predict(this.model.masterModels[s], TEASERClassifier.generateFeatures(probabilities, result.realLabels));
                    if (((double)s >= S || this.model.offsets[s] >= testSamples[ind].getLength() || predictNow == 1.0) && ((counts = this.getCount(predictions[ind], predictedLabel)) >= this.model.threshold || (double)s >= S || this.model.offsets[s] >= testSamples[ind].getLength())) {
                        predictedLabels[ind] = predictedLabel;
                        double earliness = Math.min(1.0, (double)this.model.offsets[s] / (double)testSamples[ind].getLength());
                        avgOffset += earliness;
                        offsets[ind] = this.model.offsets[s];
                        perClassEarliness.addTo(testSamples[ind].getLabel(), earliness);
                        perClassCount.addTo(testSamples[ind].getLabel(), 1);
                        if (this.compareLabels(testSamples[ind].getLabel(), predictedLabel)) {
                            ++correct;
                        }
                        ++count;
                    }
                }
                if (count == testSamples.length) break block1;
            }
        }
        if (testing) {
            for (DoubleDoubleCursor c : perClassEarliness) {
                System.out.println("Class\t" + c.key + "\t Earliness \t" + c.value / (double)perClassCount.get(c.key));
            }
        }
        if (testing && PRINT_EARLINESS) {
            for (int ind = 0; ind < offsets.length; ++ind) {
                int e = offsets[ind];
                System.out.print("[" + e + "," + (this.compareLabels(predictedLabels[ind], testSamples[ind].getLabel()) ? "True" : "False") + "],");
            }
            System.out.println("");
        }
        return new OffsetPrediction(avgOffset, predictedLabels, correct, testSamples.length);
    }

    public svm_parameter initSVMParameters(double gamma) {
        svm_parameter parameter2 = new svm_parameter();
        parameter2.eps = 1.0E-4;
        parameter2.nu = SVM_NU;
        parameter2.gamma = gamma;
        parameter2.kernel_type = SVM_KERNEL;
        parameter2.cache_size = 40.0;
        parameter2.svm_type = 2;
        return parameter2;
    }

    public static svm_problem initProblem(double[][] probabilities, int[][] labels, double[] correctPrediction) {
        svm.svm_set_print_string_function(new svm_print_interface(){

            @Override
            public void print(String s) {
            }
        });
        svm.rand.setSeed(1L);
        svm_problem problem = new svm_problem();
        svm_node[][] features = TEASERClassifier.initLibSVM(probabilities, labels);
        problem.y = correctPrediction;
        problem.l = features.length;
        problem.x = features;
        return problem;
    }

    public static svm_node[][] initLibSVM(double[][] probabilities, int[][] labels) {
        svm_node[][] featuresTrain = new svm_node[probabilities.length][];
        for (int a = 0; a < probabilities.length; ++a) {
            featuresTrain[a] = TEASERClassifier.generateFeatures(probabilities[a], labels[a]);
        }
        return featuresTrain;
    }

    protected static double getMinDiff(double[] probabilities) {
        int maxId = 0;
        double max = 0.0;
        for (int i = 0; i < probabilities.length; ++i) {
            if (!(probabilities[i] > max)) continue;
            max = probabilities[i];
            maxId = i;
        }
        double minDiff = 1.0;
        for (int i = 0; i < probabilities.length; ++i) {
            if (maxId == i) continue;
            minDiff = Math.min(minDiff, max - probabilities[i]);
        }
        return minDiff;
    }

    public static svm_node[] generateFeatures(double[] probabilities, int[] labels) {
        svm_node[] features = new svm_node[probabilities.length + 1];
        int maxLabel = 0;
        for (int i = 0; i < probabilities.length; ++i) {
            features[i] = new svm_node();
            features[i].index = 2 + labels[i];
            features[i].value = probabilities[i];
            maxLabel = Math.max(features[i].index, maxLabel);
        }
        features[features.length - 1] = new svm_node();
        features[features.length - 1].index = maxLabel + 4;
        features[features.length - 1].value = TEASERClassifier.getMinDiff(probabilities);
        Arrays.sort(features, new Comparator<svm_node>(){

            @Override
            public int compare(svm_node o1, svm_node o2) {
                return Integer.compare(o1.index, o2.index);
            }
        });
        return features;
    }

    class OffsetPrediction {
        double offset;
        Double[] labels;
        int correct;
        int N;

        public OffsetPrediction(double offset, Double[] labels, int correct, int N) {
            this.offset = offset;
            this.correct = correct;
            this.labels = labels;
            this.N = N;
        }

        public int getCorrect() {
            return this.correct;
        }

        public String toString() {
            return new StringBuffer("Avg. Offset\t" + this.offset + "\tacc: " + String.format("%.02f", (double)this.getCorrect() / (double)this.N) + "\tearliness: " + String.format("%.02f", this.offset / (double)this.N)).toString();
        }
    }

    public static class EarlyClassificationModel
    extends Classifier.Model {
        public svm_model[] masterModels;
        public WEASELClassifier.WEASELModel[] slaveModels;
        public int[] offsets = new int[(int)S + 1];
        public int threshold;

        public EarlyClassificationModel() {
            super("TEASER", 0, 1, 0, 1, false, -1);
            this.masterModels = new svm_model[(int)S + 1];
            this.slaveModels = new WEASELClassifier.WEASELModel[(int)S + 1];
            Arrays.fill(this.offsets, -1);
        }
    }
}

