/*
 * Copyright 2016 higherfrequencytrading.com
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.openhft.chronicle.core.jlbh;

import net.openhft.affinity.Affinity;
import net.openhft.affinity.AffinityLock;
import net.openhft.chronicle.core.Jvm;
import net.openhft.chronicle.core.util.Histogram;
import net.openhft.chronicle.core.util.NanoSampler;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.io.PrintStream;
import java.util.*;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.stream.Collectors;

/**
 * Java Latency Benchmark Harness
 * The harness is intended to be used for benchmarks where co-ordinated omission is an issue.
 * Typically these would be of the producer/consumer nature where the start time for the benchmark
 * may be on a different thread than the end time.
 * <p></p>
 * This tool was inspired by JMH.
 */
public class JLBH implements NanoSampler {
    private static final Double[] NO_DOUBLES = {};
    private final SortedMap<String, Histogram> additionHistograms = new ConcurrentSkipListMap<>();
    // wait time between invocations in nanoseconds
    private final long latencyBetweenTasks;
    @NotNull
    private final JLBHOptions jlbhOptions;
    @NotNull
    private final PrintStream printStream;
    private final Consumer<JLBHResult> resultConsumer;
    @NotNull
    private Histogram endToEndHistogram = new Histogram();
    @NotNull
    private Histogram osJitterHistogram = new Histogram();
    private long noResultsReturned;
    @NotNull
    private AtomicBoolean warmUpComplete = new AtomicBoolean(false);
    //Use non-atomic when so thread synchronisation is necessary
    private boolean warmedUp;

    /**
     * @param jlbhOptions Options to run the benchmark
     */
    public JLBH(@NotNull JLBHOptions jlbhOptions) {
        this(jlbhOptions, System.out, null);
    }

    /**
     * Use this constructor if you want to test the latencies in more automated fashion.
     * The result is passed to the result consumer after the JLBH::start method returns.
     * You can create you own consumer, or use provided JLBHResultConsumer::newThreadSafeInstance()
     * that allows you to retrieve the result even if the JLBH has been executed in a different thread.
     *
     * @param jlbhOptions Options to run the benchmark
     * @param printStream Used to print text output. Use System.out to show the result on you standard out (e.g. screen)
     * @param resultConsumer If provided, accepts the result data to be retrieved after the latencies have been measured
     */
    public JLBH(@NotNull JLBHOptions jlbhOptions, @NotNull PrintStream printStream, Consumer<JLBHResult> resultConsumer) {
        this.jlbhOptions = jlbhOptions;
        this.printStream = printStream;
        this.resultConsumer = resultConsumer;
        if (jlbhOptions.jlbhTask == null) throw new IllegalStateException("jlbhTask must be set");
        latencyBetweenTasks = jlbhOptions.throughputTimeUnit.toNanos(1) / jlbhOptions.throughput;
    }

    /**
     * Add a probe to measure a section of the benchmark.
     *
     * @param name Name of probe
     * @return NanoSampler
     */
    public NanoSampler addProbe(String name) {
        return additionHistograms.computeIfAbsent(name, n -> new Histogram());
    }

    /**
     * Start benchmark
     */
    public void start() {
        jlbhOptions.jlbhTask.init(this);
        @NotNull OSJitterMonitor osJitterMonitor = new OSJitterMonitor();
        @NotNull List<double[]> percentileRuns = new ArrayList<>();
        @NotNull Map<String, List<double[]>> additionalPercentileRuns = new TreeMap<>();

        if (jlbhOptions.recordOSJitter) {
            osJitterMonitor.setDaemon(true);
            osJitterMonitor.start();
        }

        long warmupStart = System.currentTimeMillis();
        for (int i = 0; i < jlbhOptions.warmUpIterations; i++) {
            jlbhOptions.jlbhTask.run(System.nanoTime());
        }

        AffinityLock lock = Affinity.acquireLock();
        try {
            for (int run = 0; run < jlbhOptions.runs; run++) {
                long runStart = System.currentTimeMillis();
                long startTimeNs = System.nanoTime();
                for (int i = 0; i < jlbhOptions.iterations; i++) {

                    if (i == 0 && run == 0) {
                        while (!warmUpComplete.get()) {
                            Jvm.pause(2000);
                            printStream.println("Complete: " + noResultsReturned);
                        }
                        printStream.println("Warm up complete (" + jlbhOptions.warmUpIterations + " iterations took " +
                                ((System.currentTimeMillis() - warmupStart) / 1000.0) + "s)");
                        if (jlbhOptions.pauseAfterWarmupMS != 0) {
                            printStream.println("Pausing after warmup for " + jlbhOptions.pauseAfterWarmupMS + "ms");
                            Jvm.pause(jlbhOptions.pauseAfterWarmupMS);
                        }
                        jlbhOptions.jlbhTask.warmedUp();
                        runStart = System.currentTimeMillis();
                        startTimeNs = System.nanoTime();
                    } else if (jlbhOptions.accountForCoordinatedOmission) {
                        startTimeNs += latencyBetweenTasks;
                        long millis = (startTimeNs - System.nanoTime()) / 1000000 - 2;
                        if (millis > 0) {
                            Jvm.pause(millis);
                        }
                        Jvm.busyWaitUntil(startTimeNs);

                    } else {
                        if (latencyBetweenTasks > 2e6) {
                            long end = System.nanoTime() + latencyBetweenTasks;
                            Jvm.pause(latencyBetweenTasks / 1_000_000 - 1);
                            // account for jitter in Thread.sleep() and wait until a fixed point in time
                            Jvm.busyWaitUntil(end);
                        } else {
                            Jvm.busyWaitMicros(latencyBetweenTasks / 1000);
                        }
                        startTimeNs = System.nanoTime();
                    }

                    jlbhOptions.jlbhTask.run(startTimeNs);
                }

                while (endToEndHistogram.totalCount() < jlbhOptions.iterations) {
                    Thread.yield();
                }
                long totalRunTime = System.currentTimeMillis() - runStart;

                percentileRuns.add(endToEndHistogram.getPercentiles());

                printStream.println("-------------------------------- BENCHMARK RESULTS (RUN " + (run + 1) + ") --------------------------------------------------------");
                printStream.println("Run time: " + totalRunTime / 1000.0 + "s");
                printStream.println("Correcting for co-ordinated:" + jlbhOptions.accountForCoordinatedOmission);
                printStream.println("Target throughput:" + jlbhOptions.throughput + "/" + timeUnitToString(jlbhOptions.throughputTimeUnit) + " = 1 message every " + (latencyBetweenTasks / 1000) + "us");
                printStream.printf("%-48s", String.format("End to End: (%,d)", endToEndHistogram.totalCount()));
                printStream.println(endToEndHistogram.toMicrosFormat());

                if (additionHistograms.size() > 0) {
                    additionHistograms.entrySet().forEach(e -> {
                        List<double[]> ds = additionalPercentileRuns.computeIfAbsent(e.getKey(),
                                i -> new ArrayList<>());
                        ds.add(e.getValue().getPercentiles());
                        printStream.printf("%-48s", String.format("%s (%,d) ", e.getKey(), e.getValue().totalCount()));
                        printStream.println(e.getValue().toMicrosFormat());
                    });
                }
                if (jlbhOptions.recordOSJitter) {
                    printStream.printf("%-48s", String.format("OS Jitter (%,d)", osJitterHistogram.totalCount()));
                    printStream.println(osJitterHistogram.toMicrosFormat());
                }
                printStream.println("-------------------------------------------------------------------------------------------------------------------");

                noResultsReturned = 0;
                endToEndHistogram.reset();
                additionHistograms.values().forEach(Histogram::reset);
                osJitterMonitor.reset();
            }
        } finally {
            Jvm.pause(5);
            lock.release();
            Jvm.pause(5);
        }

        printPercentilesSummary("end to end", percentileRuns);
        if (additionalPercentileRuns.size() > 0) {
            additionalPercentileRuns.entrySet().forEach(e -> printPercentilesSummary(e.getKey(), e.getValue()));
        }

        consumeResults(percentileRuns, additionalPercentileRuns);

        jlbhOptions.jlbhTask.complete();
    }

    private void consumeResults(List<double[]> percentileRuns, Map<String, List<double[]>> additionalPercentileRuns) {
        if (resultConsumer != null) {
            final JLBHResult.ProbeResult endToEndProbeResult = new ImmutableProbeResult(percentileRuns);
            final Map<String, ImmutableProbeResult> additionalProbeResults = additionalPercentileRuns.entrySet()
                    .stream()
                    .collect(Collectors.toMap(
                            Map.Entry::getKey,
                            probe -> new ImmutableProbeResult(probe.getValue())));
            resultConsumer.accept(new ImmutableJLBHResult(endToEndProbeResult, additionalProbeResults));
        }
    }

    private void printPercentilesSummary(String label, @NotNull List<double[]> percentileRuns) {
        printStream.println("-------------------------------- SUMMARY (" + label + ")------------------------------------------------------------");
        @NotNull List<Double> consistencies = new ArrayList<>();
        double maxValue = Double.MIN_VALUE;
        double minValue = Double.MAX_VALUE;
        int length = percentileRuns.get(0).length;
        for (int i = 0; i < length; i++) {
            boolean skipFirst = length > 3;
            if (jlbhOptions.skipFirstRun == JLBHOptions.SKIP_FIRST_RUN.SKIP) {
                skipFirst = true;
            } else if (jlbhOptions.skipFirstRun == JLBHOptions.SKIP_FIRST_RUN.NO_SKIP) {
                skipFirst = false;
            }
            for (double[] percentileRun : percentileRuns) {
                if (skipFirst) {
                    skipFirst = false;
                    continue;
                }
                double v = percentileRun[i];
                if (v > maxValue)
                    maxValue = v;
                if (v < minValue)
                    minValue = v;
            }
            consistencies.add(100 * (maxValue - minValue) / (maxValue + minValue / 2));

            maxValue = Double.MIN_VALUE;
            minValue = Double.MAX_VALUE;
        }

        @NotNull List<Double> summary = new ArrayList<>();
        for (int i = 0; i < length; i++) {
            for (double[] percentileRun : percentileRuns) {
                summary.add(percentileRun[i] / 1e3);
            }
            summary.add(consistencies.get(i));
        }

        @NotNull StringBuilder sb = new StringBuilder();
        addHeaderToPrint(sb, jlbhOptions.runs);
        printStream.println(sb.toString());

        sb = new StringBuilder();
        addPrToPrint(sb, "50:     ", jlbhOptions.runs);
        addPrToPrint(sb, "90:     ", jlbhOptions.runs);
        addPrToPrint(sb, "99:     ", jlbhOptions.runs);
        addPrToPrint(sb, "99.9:   ", jlbhOptions.runs);
        addPrToPrint(sb, "99.99:  ", jlbhOptions.runs);
        if (jlbhOptions.iterations > 1_000_000)
            addPrToPrint(sb, "99.999: ", jlbhOptions.runs);
        if (jlbhOptions.iterations > 10_000_000)
            addPrToPrint(sb, "99.9999:", jlbhOptions.runs);
        addPrToPrint(sb, "worst:  ", jlbhOptions.runs);

        printStream.printf(sb.toString(), summary.toArray(NO_DOUBLES));
        printStream.println("-------------------------------------------------------------------------------------------------------------------");
    }

    private void addPrToPrint(@NotNull StringBuilder sb, String pr, int runs) {
        sb.append(pr);
        for (int i = 0; i < runs; i++) {
            sb.append("%12.2f ");
        }
        sb.append("%12.2f");
        sb.append("%n");
    }

    private void addHeaderToPrint(@NotNull StringBuilder sb, int runs) {
        sb.append("Percentile");
        for (int i = 1; i < runs + 1; i++) {
            if (i == 1)
                sb.append("   run").append(i);
            else
                sb.append("         run").append(i);
        }
        sb.append("      % Variation");
    }

    private String timeUnitToString(@NotNull TimeUnit timeUnit) {
        switch (timeUnit) {
            case NANOSECONDS:
                return "ns";
            case MICROSECONDS:
                return "us";
            case MILLISECONDS:
                return "ms";
            case SECONDS:
                return "s";
            case MINUTES:
                return "min";
            case HOURS:
                return "h";
            case DAYS:
                return "day";
            default:
                throw new IllegalArgumentException("Unrecognized time unit value '" + timeUnit + "'");
        }
    }

    @Override
    public void sampleNanos(long nanos) {
        sample(nanos);
    }

    public void sample(long nanoTime) {
        noResultsReturned++;
        if (noResultsReturned < jlbhOptions.warmUpIterations && !warmedUp) {
            endToEndHistogram.sample(nanoTime);
            return;
        }
        if (noResultsReturned == jlbhOptions.warmUpIterations && !warmedUp) {
            warmedUp = true;
            endToEndHistogram.reset();
            if (additionHistograms.size() > 0) {
                additionHistograms.values().forEach(Histogram::reset);
            }
            warmUpComplete.set(true);
            return;
        }
        endToEndHistogram.sample(nanoTime);
    }

    private class OSJitterMonitor extends Thread {
        final AtomicBoolean reset = new AtomicBoolean(false);

        @Override
        public void run() {
            // make sure this thread is not bound by its parent.
            Affinity.setAffinity(AffinityLock.BASE_AFFINITY);
            @Nullable AffinityLock affinityLock = null;
            if (jlbhOptions.jitterAffinity) {
                printStream.println("Jitter thread running with affinity.");
                affinityLock = AffinityLock.acquireLock();
            }

            try {
                long lastTime = System.nanoTime(), start = lastTime;
                while (true) {
                    if (reset.compareAndSet(true, false)) {
                        osJitterHistogram.reset();
                        lastTime = System.nanoTime();
                    }
                    for (int i = 0; i < 1000; i++) {
                        long time = System.nanoTime();
                        if (time - lastTime > jlbhOptions.recordJitterGreaterThanNs) {
                            osJitterHistogram.sample(time - lastTime);
                        }
                        lastTime = time;
                    }
                    if (lastTime > start + 60e9)
                        Jvm.pause(1);
                }
            } finally {
                if (affinityLock != null)
                    affinityLock.release();
            }
        }

        void reset() {
            reset.set(true);
        }
    }
}
