import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")

# Import the data from the respective csv files
keys = ['model', 'weight_decay', 'energy_weight', 'mutual_inf', 'inf_per_energy_weight']
data_mean = pd.read_csv(os.path.join("data", "20210406_energy_summary_mean.csv"))[keys]
data_sem = pd.read_csv(os.path.join("data", "20210406_energy_summary_sem.csv"))[keys]

# Filter for weight_decay values with reasonable energy values
data_mean = data_mean[data_mean['weight_decay'] > 0.001]
data_sem = data_sem[data_sem['weight_decay'] > 0.001]

# Prepare the figure
fig, ax = plt.subplots(figsize=[2.6, 1.75])

colours = {"dyn": "#91056A", "mlp": "#1F407A", "prob-mlp": "#6F6F6F"}
labels = {"dyn": "Stochastic Release (Plastic)", "mlp": "Deterministic Release", "prob-mlp": "Stochastic Release (Fixed)"}

for name in data_mean['model'].unique():
    # Get data points of current model
    mean = data_mean[data_mean['model'] == name]
    sem = data_sem[data_sem['model'] == name]

    # Line plot with scatter like markers and errorbars
    ax.errorbar(mean['energy_weight'], mean['mutual_inf'], xerr=sem['energy_weight'], yerr=sem['mutual_inf'], label=labels[name],
                marker='o', linestyle='dashed', capsize=1, capthick=0.4, elinewidth=0.4)

# Add the legend
ax.legend(ncol=1, loc="lower right", frameon=True)

# Set Axis scaling
# ax.set_xlim(0.0, 5000.0)
ax.set_ylim(0.0, np.log2(10))  # log2(10) is an upper bound on the mutual information considered here

# X/Y Labels and ticks
ax.set_xlabel("Energy")
ax.set_ylabel("Mutual Information")

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("energy-mlp_scatter_information-energy-l2.pdf")