"""
x-axis: Fraction of important _output_ weights per neuron (i.e. 150/200)
y-axis: Fraction of neurons for which this holds (i.e. 300/784)
"""

import os

import matplotlib
import matplotlib.pyplot as plt
import torch


# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")


# Helper functions
def permute_tensor(input):
    """
    Permute elements of tensor given a matrix of permutation indices
    """
    # Cache the original dimensions
    dimensions = input.size()

    # Apply the permutation to the flattened tensor
    permutation = torch.randperm(input.numel())
    output_flat = torch.index_select(input.view(-1), 0, permutation)

    # Restore original dimensions
    output = output_flat.view(dimensions)

    return output


# Import the data from the respective pt file
release_probs = torch.load(os.path.join("data", "20201022_energy_dyn-probs_all-release-probs.pt"))

# Set parameters of the figure
high_probability = 0.9
epoch = -1
run_id = 0
layer = "layer1"

probs = release_probs[layer][epoch, :, :, run_id]
probs_shuffled = permute_tensor(probs)

# Finding the indeces of important weights
input_weight_idxs = torch.nonzero(probs > high_probability)[:, 0]
input_weight_idxs_shuffled = torch.nonzero(probs_shuffled > high_probability)[:, 0]

# Counting the number of important weights per neuron
input_weight_count = torch.bincount(input_weight_idxs, minlength=probs.shape[0])
input_weight_count_shuffled = torch.bincount(input_weight_idxs_shuffled, minlength=probs_shuffled.shape[0])

# Compute histograms over the number of connections per neuron
input_weight_hist = torch.histc(
    input_weight_count.float(), bins=input_weight_count.max() + 1, min=0, max=input_weight_count.max()
)
input_weight_hist_shuffled = torch.histc(
    input_weight_count_shuffled.float(), bins=input_weight_count_shuffled.max() + 1, min=0, max=input_weight_count_shuffled.max()
)

# Prepare the figure to use the full page
fig, ax = plt.subplots(figsize=[1.8, 1.75])

# Create the histogram plot
ax.bar(
    x=torch.arange(len(input_weight_hist)) / probs.shape[1],
    height=input_weight_hist / probs.shape[0],
    width=0.9 / probs.shape[1], alpha=0.8, label="True Connectivity"
)
ax.bar(
    x=torch.arange(len(input_weight_hist_shuffled)) / probs.shape[1],
    height=input_weight_hist_shuffled / probs.shape[0],
    width=0.9 / probs.shape[1], alpha=0.8, label="Shuffled Connectivity"
)

# X/Y Labels and ticks
ax.set_ylabel("Fraction of Neurons")
ax.set_xlabel("Fraction of Important\nInput Synapses")

# Add the legend
ax.legend(loc='upper right', ncol=1, frameon=False)

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("energy-mlp_hist_neuron-connectivity.pdf")
