package edu.ucr.test;

/*
 * Copyright 2021 Renjie Wu
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import com.carrotsearch.hppc.DoubleIntHashMap;
import libsvm.svm;
import sfa.classification.TEASERClassifier;
import sfa.timeseries.TimeSeries;
import sfa.timeseries.TimeSeriesLoader;

import java.io.IOException;
import java.io.PrintWriter;
import java.util.*;
import java.util.stream.Collectors;

public class TEASERTester {
    static class FullPrediction {
        static class EarlyPrediction {
            private int snapshotLength;
            private int label;
            private double probability;

            EarlyPrediction(int snapshotLength, int label, double probability) {
                this.snapshotLength = snapshotLength;
                this.label = label;
                this.probability = probability;
            }

            @Override
            public String toString() {
                return String.format("{ \"snapshotLength\": %d, \"label\": %d, \"probability\": %.4f }", snapshotLength, label, probability);
            }
        }

        private double[] exemplar;
        private int earliness;
        private int truthLabel;
        private int predictedLabel;
        private double probability;
        private List<EarlyPrediction> earlyPredictions;

        FullPrediction(double[] exemplar, int earliness, int truthLabel, int predictedLabel, double probability, List<EarlyPrediction> earlyPredictions) {
            this.exemplar = exemplar;
            this.earliness = earliness;
            this.truthLabel = truthLabel;
            this.predictedLabel = predictedLabel;
            this.probability = probability;
            this.earlyPredictions = earlyPredictions;
        }

        @Override
        public String toString() {
            return String.format("{ \"exemplar\": [%s], \"earliness\": %d, \"truthLabel\": %d, \"predictedLabel\": %d, \"probability\": %.4f, \"earlyPredictions\": [%s] }",
                    Arrays.stream(exemplar).mapToObj(d -> String.format("%.7f", d)).collect(Collectors.joining(",")),
                    earliness, truthLabel, predictedLabel,  probability,
                    earlyPredictions.stream().map(EarlyPrediction::toString).collect(Collectors.joining(", ")));
        }
    }

    private static FullPrediction[] predict(TEASERClassifier teaserClassifier, TimeSeries[] testSamples) {
        var fullPredictions = new FullPrediction[testSamples.length];
        var earlyPredictions = new ArrayList<List<FullPrediction.EarlyPrediction>>();
        var predictionCounts = new ArrayList<DoubleIntHashMap>();

        var model = teaserClassifier.model;
        var slaveClassifier = teaserClassifier.slaveClassifier;

        for (int i = 0; i < testSamples.length; i++) {
            predictionCounts.add(new DoubleIntHashMap());
            earlyPredictions.add(new ArrayList<>());
        }

        for (int s = 0; s < model.slaveModels.length; s++) {
            if (model.masterModels[s] != null) {
                var snapshots = teaserClassifier.extractUntilOffset(testSamples, model.offsets[s], true);

                slaveClassifier.setModel(model.slaveModels[s]);
                var result = slaveClassifier.predictProbabilities(snapshots);

                for (int i = 0; i < snapshots.length; i++) {
                    var predictedLabel = result.labels[i];
                    var probabilities = result.probabilities[i];
                    var maxProbability = Arrays.stream(probabilities).max().orElse(0.0);

                    var earlyPrediction = new FullPrediction.EarlyPrediction(model.offsets[s], predictedLabel.intValue(), maxProbability);
                    earlyPredictions.get(i).add(earlyPrediction);

                    var predictNow = svm.svm_predict(model.masterModels[s], TEASERClassifier.generateFeatures(probabilities, result.realLabels));
                    if (s >= TEASERClassifier.S || model.offsets[s] >= testSamples[i].getLength() || predictNow == 1) {
                        var counts = teaserClassifier.getCount(predictionCounts.get(i), predictedLabel);

                        if ((counts >= model.threshold || s >= TEASERClassifier.S || model.offsets[s] >= testSamples[i].getLength()) && fullPredictions[i] == null) {
                            fullPredictions[i] = new FullPrediction(testSamples[i].getData(), model.offsets[s], testSamples[i].getLabel().intValue(),
                                    predictedLabel.intValue(), maxProbability, earlyPredictions.get(i));
                        }
                    }
                }
            }
        }

        return fullPredictions;
    }

    public static void main(String[] args) throws IOException {
        var datasetName = args[0];
        var trainPath = args[1];
        var testPath = args[2];
        var outputJsonPath = args[3];

        var trainData = TimeSeriesLoader.loadDataset(trainPath);
        var testData = TimeSeriesLoader.loadDataset(testPath);

        var classifier = new TEASERClassifier();
        classifier.fit(trainData);

        var fullPredictions = predict(classifier, testData);
        var earlyPredictions = classifier.predict(testData);

        boolean verified = true;
        for (int i = 0; i < fullPredictions.length; i++) {
            if (fullPredictions[i].predictedLabel != earlyPredictions[i].intValue()) {
                verified = false;
                System.out.println("Not match: " + i);
            }
        }

        System.out.println("Predictions verified: " + (verified ? "Yes" : "No"));

        try (var outputWriter = new PrintWriter(outputJsonPath)) {
            outputWriter.println(String.format("{ \"classifierName\": \"TEASER\", \"datasetName\": \"%s\", \"predictions\": [%s] }",
                    datasetName, Arrays.stream(fullPredictions).map(FullPrediction::toString).collect(Collectors.joining(", "))));
        }

        // Somehow JVM is waiting for lock
        System.exit(0);
    }
}
