import os

import matplotlib
import matplotlib.pyplot as plt
import torch
import numpy as np


# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")


# Helper functions
def permute_tensor(input):
    """
    Permute elements of tensor given a matrix of permutation indices
    """
    # Cache the original dimensions
    dimensions = input.size()

    # Apply the permutation to the flattened tensor
    permutation = torch.randperm(input.numel())
    output_flat = torch.index_select(input.view(-1), 0, permutation)

    # Restore original dimensions
    output = output_flat.view(dimensions)

    return output


# Import the data from the respective pt file
release_probs = torch.load(os.path.join("data", "20201022_energy_dyn-probs_all-release-probs.pt"))

# Set parameters of the figure
high_probability = 0.9
epoch = -1
run_id = 0
# layer = "layer1"

# Prepare the figure to use the full page
fig, ax = plt.subplots(figsize=[1.5, 1.75])

# Prepare lists that will contain info for different runs
many_imp_outputs = []
many_imp_outputs_shuffled = []
rest = []
rest_shuffled = []
few_imp_outputs = []
few_imp_outputs_shuffled = []

# Iterate over runs
for run_id in range(3):
    probs = release_probs['layer1'][epoch, :, :, run_id]
    probs_shuffled = permute_tensor(probs)

    # Finding the indeces of important weights
    n_neurons = probs.shape[0]
    output_weight_idxs = torch.nonzero(probs > high_probability)[:, 0]
    output_weight_idxs_shuffled = torch.nonzero(probs_shuffled > high_probability)[:, 0]

    # Counting the number of important weights per neuron
    output_weight_count = torch.bincount(output_weight_idxs, minlength=probs.shape[0])
    output_weight_count_shuffled = torch.bincount(output_weight_idxs_shuffled, minlength=probs_shuffled.shape[0])

    # Calculate mean and std of important weights per neuron for shuffled
    high_prob_input_mean = torch.mean(output_weight_count_shuffled.float())
    high_prob_input_std = torch.std(output_weight_count_shuffled.float())

    # Count weights with many/few (mean+-2*std) important outputs
    many = high_prob_input_mean + 2*high_prob_input_std
    few = high_prob_input_mean - 2*high_prob_input_std
    many_imp_outputs.append( torch.sum(output_weight_count > many).item())
    many_imp_outputs_shuffled.append( torch.sum(output_weight_count_shuffled > many).item())
    few_imp_outputs.append( torch.sum(output_weight_count < few).item())
    few_imp_outputs_shuffled.append( torch.sum(output_weight_count_shuffled < few).item())
    rest.append(n_neurons - many_imp_outputs[-1] - few_imp_outputs[-1])
    rest_shuffled.append(n_neurons - many_imp_outputs_shuffled[-1] - few_imp_outputs_shuffled[-1])

# Calculate mean and stds of measurements
many_few_means = [np.mean(many_imp_outputs), np.mean(rest), np.mean(few_imp_outputs)]
many_few_stds = [np.std(many_imp_outputs), np.std(rest), np.std(few_imp_outputs)]
many_few_shuffled_means = [np.mean(many_imp_outputs_shuffled), np.mean(rest_shuffled), np.mean(few_imp_outputs_shuffled)]
many_few_shuffled_stds = [np.std(many_imp_outputs_shuffled), np.std(rest_shuffled), np.std(few_imp_outputs_shuffled)]

# Create bar plots
h = 0.1  # bar width
a, b, c = 0.3, 0.6, 0.9  # bar locations
x_axis = np.asarray([a, b, c])

ax.bar(x_axis - h / 2, many_few_means[-1::-1], yerr=many_few_stds[-1::-1], width=h, label='True Connectivity', error_kw={"capsize": 1, "capthick": 0.4, "linewidth": 0.4})
ax.bar(x_axis + h / 2, many_few_shuffled_means[-1::-1], yerr=many_few_shuffled_stds[-1::-1], width=0.1, label='Shuffled Connectivity', error_kw={"capsize": 1, "capthick": 0.4, "linewidth": 0.4})

# Add the legend
#ax.legend(ncol=1, bbox_to_anchor=(0.5, 1.0))

# X/Y Labels and ticks
ax.set_xticks([a, b, c])
ax.set_xticklabels(['High\nImportance', 'Medium\nImportance', 'Low \n Importance'][-1::-1], rotation=45)

ax.set_yticks(np.linspace(0, n_neurons, 3))
ax.set_yticklabels(np.round(ax.get_yticks() / n_neurons, 2))

ax.set_ylabel('Fraction of Neurons')

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("energy-mlp_bar_neuron-activity.pdf")
