import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from scipy.stats import pearsonr


# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")

# Import the data from the respective pt file and convert to pandas dataframe
data = torch.load(os.path.join("data", "20201007_energy_release-probs_weights.pt"))
data["weights_magn"] = torch.abs(data["weights"])
data = pd.DataFrame(data)
data["important"] = data['release_probs'] > 0.9

# Prepare the figure to use the full page
fig, ax = plt.subplots(figsize=[2.6, 1.5])

# Compute the Pearson correlation
corr_coeff, p_value = pearsonr(data["release_probs"], data["weights_magn"])

# Create the violin plot
ax = sns.violinplot(x="important", y="weights_magn", data=data, bw=1.0, cut=0, scale='count', ax=ax, inner=None, linewidth=0.0)

# Add annotation with correlation coefficient
# ax.annotate("$r={:.4f}$".format(corr_coeff), (0.2, 0.9), xycoords='axes fraction', ha='center', va='center', bbox=dict(facecolor='white', edgecolor='none', boxstyle='square'))

# X/Y labels and ticks
ax.set_xticklabels(['Low\nRelease Probability', 'High\nRelease Probability'])
ax.set_ylim(-0.01, 0.3)
ax.set_xlabel(None)
ax.set_ylabel("Synaptic Energy Demand")

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("energy-mlp_violin_release-probs-weights.pdf")
