import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch


# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")

# Import the data from the respective pt files
release_probs = torch.load(os.path.join("data", "20210329_energy_random-init_release-probs_all.pt"))

# Create the normalised histogram with shape (num_epochs, num_layers, num_bins)
min_bin, max_bin, num_bins = 0.00, 1.0, 5
histogram = torch.zeros(2, num_bins)

for i, probs in enumerate(release_probs):
    histogram[i] = torch.histc(probs, min=min_bin, max=max_bin, bins=num_bins)
    histogram[i] = histogram[i] / histogram[i].sum()

# Prepare the figure
fig, ax = plt.subplots(figsize=[2.6, 1.5])
column_labels = ["Before Learning", "After Learning"]

# Plot the histogram
histogram_df = pd.DataFrame(histogram.t().numpy(), columns=column_labels)
histogram_df.plot.bar(ax=ax, rot=0)

# Add vertical line indicating high probability cut-off
ax.axvline(0.9 * (len(ax.get_xticks()) - 1), color='black', linestyle='--')

# # X/Y Labels and ticks
ax.set_xticklabels(np.round(np.linspace(min_bin, max_bin, num_bins), 3))
ax.set_ylim(0.0, 1.0)
ax.set_xlim(0.0, None)
ax.set_xlabel("Release Probability")
ax.set_ylabel("Normalised Frequency")

# Add the legend
ax.legend(ncol=1, loc="upper right", frameon=True)

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("energy-mlp_hist_release-probabilities_random-init.pdf")
