import os
from textwrap import wrap

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")

# Import the data from the respective csv files
data_mean = pd.read_csv(os.path.join("data", "20201214_lifelong-ablation_split-mnist_summary_mean.csv"))
data_sem = pd.read_csv(os.path.join("data", "20201214_lifelong-ablation_split-mnist_summary_sem.csv"))

# Select only a subset of ablation conditions
data_mean = data_mean[~data_mean["key"].isin(['No Freezing', 'No LR Modulation', 'No Homeostasis'])]
data_sem = data_sem[~data_sem["key"].isin(['No Freezing', 'No LR Modulation', 'No Homeostasis'])]

# Prepare the figure
fig, ax = plt.subplots(figsize=[2.6, 1.75])

# Plot the bar chart
data_mean['mean_acc'].plot.bar(ax=ax, rot=-90, yerr=data_sem['mean_acc'], error_kw={"capsize": 1, "capthick": 0.4, "linewidth": 0.4})
print("WARNING: No consolidation baseline is hardcoded")
ax.axhline(0.7768, color='black', linestyle='--')

# Indicate significance
def plot_bracket(ax, x1, x2, y, label):
    height = 0.01
    ax.plot([x1, x1, x2, x2], [y, y + height, y + height, y], linewidth=0.5, color="black")
    ax.text((x1 + x2) * 0.5, y + height, label, ha='center', va='bottom', color="black", fontsize="x-small")

plot_bracket(ax, 0, 1, 0.88, "*")
plot_bracket(ax, 0, 2, 0.92, "**")
plot_bracket(ax, 0, 3, 0.96, "**")

# X/Y Labels and ticks & title
ax.set_xticks(np.arange(len(data_mean["key"])))
ax.set_xticklabels(['\n'.join(wrap(label, 14)) for label in data_mean["key"]])
ax.set_ylabel('Average Test Accuracy')
ax.set_ylim(0.5, 1.0)

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("lifelong-mlp_bar_ablation_split-mnist.pdf")

##########################################################
from utils import independent_ttest

mean1 = data_mean[data_mean["key"] == 'Full Model']["mean_acc"].item()
mean2 = data_mean[data_mean["key"] == 'No Presynaptic Plasticity']["mean_acc"].item()

sem1 = data_sem[data_sem["key"] == 'Full Model']["mean_acc"].item()
sem2 = data_sem[data_sem["key"] == 'No Presynaptic Plasticity']["mean_acc"].item()

tval, pval = independent_ttest(mean1, mean2, sem1, sem2, 3, 3)
print("Difference between ablations Full Model and No Presynaptic Plasticity is significant with t=", tval, "(p=", pval, ")")
