import os

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd

# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")

# Import the data from the respective csv files
data_mean = pd.read_csv(os.path.join("data", "20210406_energy_summary_mean.csv"))[['model', 'weight_decay', 'inf_per_energy_weight']]
data_sem = pd.read_csv(os.path.join("data", "20210406_energy_summary_sem.csv"))[['model', 'weight_decay', 'inf_per_energy_weight']]

# Filter for weight_decay values with reasonable energy values
data_mean = data_mean[data_mean['weight_decay'] == 0.05]['inf_per_energy_weight']
data_sem = data_sem[data_sem['weight_decay'] == 0.05]['inf_per_energy_weight']

# Reshape the dataframes such to the format required by pandas' plot.bar()
# data_mean_reshaped = data_mean.pivot(index='weight_decay', columns='model', values='inf_per_energy_weight')
# data_sem_reshaped = data_sem.pivot(index='weight_decay', columns='model', values='inf_per_energy_weight')

# Prepare the figure
fig, ax = plt.subplots(figsize=[2.6, 1.75])
data_mean.plot.bar(ax=ax, rot=0, yerr=data_sem, error_kw={"capsize": 1, "capthick": 0.4, "linewidth": 0.4}, color=['#91056A', '#1F407A', '#6F6F6F'])

# Add the legend
# ax.legend(["Learned Probabilities", "Fixed Probabilities ($p_i=1.0$)", "Fixed Probabilities ($p_i=0.4$)"],
#           loc='lower center', ncol=1, bbox_to_anchor=(0.5, 1.0), borderaxespad=0.0, columnspacing=0.9, handletextpad=0.5)

# X/Y Labels and ticks
ax.set_xticklabels(["Stochastic Release\n(Plastic)", "Deterministic\nRelease", "Stochastic Release\n(Fixed)"])
ax.set_ylabel('Mutual Information per Energy')

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("energy-mlp_bar_information-energy.pdf")
