import os
from textwrap import wrap

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd

from utils import plot_bracket


# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")

# Import the data from the respective csv files
data_mean = pd.concat((
    pd.read_csv(os.path.join("data", f))
    for f in ["20201214_lifelong-ablation_split-mnist_summary_mean.csv", "20201214_lifelong-ablation_perm-mnist_summary_mean.csv"]
))
data_sem = pd.concat((
    pd.read_csv(os.path.join("data", f))
    for f in ["20201214_lifelong-ablation_split-mnist_summary_sem.csv", "20201214_lifelong-ablation_perm-mnist_summary_sem.csv"]
))

# Select only a subset of ablation conditions
data_mean = data_mean[~data_mean["key"].isin(['No Freezing', 'No LR Modulation', 'No Homeostasis'])]
data_sem = data_sem[~data_sem["key"].isin(['No Freezing', 'No LR Modulation', 'No Homeostasis'])]

# Prepare the figure
fig, ax = plt.subplots(figsize=[2.0, 2.0])

# Plot the bar chart
ax.bar(x := [0, 1, 2, 3.3, 4.3, 5.3], height=data_mean["mean_acc"], width=0.7, yerr=data_sem['mean_acc'], error_kw={"capsize": 1, "capthick": 0.4, "linewidth": 0.4}, color=["#91056A", "#91056A", "#91056A", "#1F407A", "#1F407A", "#1F407A"])

# Indicate significance for Split MNIST
plot_bracket(ax, x[0], x[1], 0.88, "*")
plot_bracket(ax, x[0], x[2], 0.92, "**")

# Indicate significance for Perm MNIST
plot_bracket(ax, x[3], x[4], 0.88, "**")
plot_bracket(ax, x[3], x[5], 0.92, "**")
# plot_bracket(ax, 0, 3, 0.96, "**")

# X/Y Labels and ticks & title
ax.set_xticks(x)
ax.set_xticklabels(('\n'.join(wrap(label, 14)) for label in data_mean["key"]), rotation=90)
ax.set_ylabel('Average Test Accuracy')
ax.set_ylim(0.5, 1.0)

ax.text(0.25, 1.05, "Split MNIST", ha="center", transform=ax.transAxes)
ax.text(0.75, 1.05, "Perm MNIST", ha="center", transform=ax.transAxes)

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("lifelong-mlp_bar_ablation.pdf")
