/*
 * Decompiled with CFR 0.152.
 */
package edu.ucr.test;

import com.carrotsearch.hppc.DoubleIntHashMap;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import libsvm.svm;
import sfa.classification.Classifier;
import sfa.classification.TEASERClassifier;
import sfa.classification.WEASELClassifier;
import sfa.timeseries.TimeSeries;
import sfa.timeseries.TimeSeriesLoader;

public class TEASERTester {
    private static FullPrediction[] predict(TEASERClassifier teaserClassifier, TimeSeries[] testSamples) {
        FullPrediction[] fullPredictions = new FullPrediction[testSamples.length];
        ArrayList earlyPredictions = new ArrayList();
        ArrayList<DoubleIntHashMap> predictionCounts = new ArrayList<DoubleIntHashMap>();
        TEASERClassifier.EarlyClassificationModel model = teaserClassifier.model;
        WEASELClassifier slaveClassifier = teaserClassifier.slaveClassifier;
        for (int i = 0; i < testSamples.length; ++i) {
            predictionCounts.add(new DoubleIntHashMap());
            earlyPredictions.add(new ArrayList());
        }
        for (int s = 0; s < model.slaveModels.length; ++s) {
            if (model.masterModels[s] == null) continue;
            TimeSeries[] snapshots = teaserClassifier.extractUntilOffset(testSamples, model.offsets[s], true);
            slaveClassifier.setModel(model.slaveModels[s]);
            Classifier.Predictions result = slaveClassifier.predictProbabilities(snapshots);
            for (int i = 0; i < snapshots.length; ++i) {
                int counts;
                Double predictedLabel = result.labels[i];
                double[] probabilities = result.probabilities[i];
                double maxProbability = Arrays.stream(probabilities).max().orElse(0.0);
                FullPrediction.EarlyPrediction earlyPrediction = new FullPrediction.EarlyPrediction(model.offsets[s], predictedLabel.intValue(), maxProbability);
                ((List)earlyPredictions.get(i)).add(earlyPrediction);
                double predictNow = svm.svm_predict(model.masterModels[s], TEASERClassifier.generateFeatures(probabilities, result.realLabels));
                if (!((double)s >= TEASERClassifier.S) && model.offsets[s] < testSamples[i].getLength() && predictNow != 1.0 || (counts = teaserClassifier.getCount((DoubleIntHashMap)predictionCounts.get(i), predictedLabel)) < model.threshold && !((double)s >= TEASERClassifier.S) && model.offsets[s] < testSamples[i].getLength() || fullPredictions[i] != null) continue;
                fullPredictions[i] = new FullPrediction(testSamples[i].getData(), model.offsets[s], testSamples[i].getLabel().intValue(), predictedLabel.intValue(), maxProbability, (List)earlyPredictions.get(i));
            }
        }
        return fullPredictions;
    }

    public static void main(String[] args) throws IOException {
        String datasetName = args[0];
        String trainPath = args[1];
        String testPath = args[2];
        String outputJsonPath = args[3];
        TimeSeries[] trainData = TimeSeriesLoader.loadDataset(trainPath);
        TimeSeries[] testData = TimeSeriesLoader.loadDataset(testPath);
        TEASERClassifier classifier = new TEASERClassifier();
        classifier.fit(trainData);
        FullPrediction[] fullPredictions = TEASERTester.predict(classifier, testData);
        Double[] earlyPredictions = classifier.predict(testData);
        boolean verified = true;
        for (int i = 0; i < fullPredictions.length; ++i) {
            if (fullPredictions[i].predictedLabel == earlyPredictions[i].intValue()) continue;
            verified = false;
            System.out.println("Not match: " + i);
        }
        System.out.println("Predictions verified: " + (verified ? "Yes" : "No"));
        try (PrintWriter outputWriter = new PrintWriter(outputJsonPath);){
            outputWriter.println(String.format("{ \"classifierName\": \"TEASER\", \"datasetName\": \"%s\", \"predictions\": [%s] }", datasetName, Arrays.stream(fullPredictions).map(FullPrediction::toString).collect(Collectors.joining(", "))));
        }
        System.exit(0);
    }

    static class FullPrediction {
        private double[] exemplar;
        private int earliness;
        private int truthLabel;
        private int predictedLabel;
        private double probability;
        private List<EarlyPrediction> earlyPredictions;

        FullPrediction(double[] exemplar, int earliness, int truthLabel, int predictedLabel, double probability, List<EarlyPrediction> earlyPredictions) {
            this.exemplar = exemplar;
            this.earliness = earliness;
            this.truthLabel = truthLabel;
            this.predictedLabel = predictedLabel;
            this.probability = probability;
            this.earlyPredictions = earlyPredictions;
        }

        public String toString() {
            return String.format("{ \"exemplar\": [%s], \"earliness\": %d, \"truthLabel\": %d, \"predictedLabel\": %d, \"probability\": %.4f, \"earlyPredictions\": [%s] }", Arrays.stream(this.exemplar).mapToObj(d -> String.format("%.7f", d)).collect(Collectors.joining(",")), this.earliness, this.truthLabel, this.predictedLabel, this.probability, this.earlyPredictions.stream().map(EarlyPrediction::toString).collect(Collectors.joining(", ")));
        }

        static class EarlyPrediction {
            private int snapshotLength;
            private int label;
            private double probability;

            EarlyPrediction(int snapshotLength, int label, double probability) {
                this.snapshotLength = snapshotLength;
                this.label = label;
                this.probability = probability;
            }

            public String toString() {
                return String.format("{ \"snapshotLength\": %d, \"label\": %d, \"probability\": %.4f }", this.snapshotLength, this.label, this.probability);
            }
        }
    }
}

