import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")

# Import the data as DataFrames
data_mean = {
    "Mean": pd.read_csv(os.path.join("data", "20201002_lifelong_time-mean_prob_perm-mnist_mean.csv")),
    "Frozen": pd.read_csv(os.path.join("data", "20201002_lifelong_time-frozen_prob_perm-mnist_mean.csv")),
}
data_sem = {
    "Mean": pd.read_csv(os.path.join("data", "20201002_lifelong_time-mean_prob_perm-mnist_sem.csv")),
    "Frozen": pd.read_csv(os.path.join("data", "20201002_lifelong_time-frozen_prob_perm-mnist_sem.csv")),
}

# Stack all columns on top of each other
data_mean_stacked = {key: np.array(pd.melt(value, value_vars=value.columns, value_name='Task')['Task']) for key, value in data_mean.items()}
data_sem_stacked = {key: np.array(pd.melt(value, value_vars=value.columns, value_name='Task')['Task']) for key, value in data_sem.items()}

# HACK: Need to add first entry manually since it was not saved in the data
data_mean_stacked['Mean'] = np.append([0.25], data_mean_stacked['Mean'], axis=0)
data_mean_stacked['Frozen'] = np.append([0.0], data_mean_stacked['Frozen'], axis=0)
data_sem_stacked['Mean'] = np.append([0.0], data_sem_stacked['Mean'], axis=0)
data_sem_stacked['Frozen'] = np.append([0.0], data_sem_stacked['Frozen'], axis=0)

# Prepare the figure
fig, ax = plt.subplots(figsize=[2.0, 2.0])
labels = {"Mean": "Mean Release Probability", "Frozen": "Ratio of Frozen Release Probabilities"}

for key in data_mean_stacked:
    # Plot the energy over time as a line plot
    ax.plot(data_mean_stacked[key], label=labels[key])

    # Add confidence interval
    x = np.arange(0, len(data_mean_stacked[key]), 1)
    lowerbound = data_mean_stacked[key] - data_sem_stacked[key]
    upperbound = data_mean_stacked[key] + data_sem_stacked[key]
    ax.fill_between(x=x, y1=lowerbound, y2=upperbound, alpha=0.3)

# X/Y Labels and ticks
ax.set_xticks(np.arange(0, 100, 10))
ax.set_xticklabels(range(1, 11))
ax.set_xlim(0, 100)
ax.set_ylim(0.0, 1.0)

ax.set_xlabel("Task")

# Add the legend
ax.legend(ncol=1, bbox_to_anchor=(0.5, 1.0))

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("lifelong-mlp_line_time-probs_perm-mnist.pdf")
