import os
import textwrap

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd

from utils import independent_ttest, plot_bracket

# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")

# Import the data from the respective csv files
data_mean = pd.read_csv(os.path.join("data", "20201002_lifelong_final-acc-split-mnist_mean.csv")).drop(columns="Model").T[5]
data_sem = pd.read_csv(os.path.join("data", "20201002_lifelong_final-acc-split-mnist_sem.csv")).drop(columns="Model").T[5]

# Prepare the figure
fig, ax = plt.subplots(figsize=[2.5, 1.5])

# Plot the bar chart
colors = ['#91056A', '#1F407A', '#6F6F6F', '#D39BC3']
data_mean.drop(labels='Joint Training', axis=0).plot.bar(ax=ax, rot=0, yerr=data_sem, color=colors, error_kw={"capsize": 1, "capthick": 0.4, "elinewidth": 0.4})
ax.axhline(data_mean['Joint Training'], color='black', linestyle='--')

# Indicate significance
plot_bracket(ax, 0, 1, 0.86, "*")
plot_bracket(ax, 0, 2, 0.90, "*")
plot_bracket(ax, 0, 3, 0.94, "**")

# X/Y Labels and ticks
ax.set_xticklabels([textwrap.fill(l, 14) for l in data_mean.index if l != "Joint Training"])
ax.set_ylim(0.5, 1.0)
ax.set_ylabel("Average Test Accuracy")

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("lifelong-mlp_bar_final-acc.pdf")

# Run T-tests
for key1 in data_mean.drop(labels='Joint Training', axis=0).keys():
    for key2 in data_mean.drop(labels='Joint Training', axis=0).keys():
        tval, pval = independent_ttest(data_mean[key1], data_mean[key2], data_sem[key1], data_sem[key2], 3, 3)

        print("T-test between", key1, " and ", key2, "t=", tval, "(p=", pval, ")")