import os

import matplotlib
import matplotlib.pyplot as plt
import torch
from scipy.stats import kstest

# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")


def permute_tensor(input):
    """
    Permute elements of tensor given a matrix of permutation indices
    """
    # Cache the original dimensions
    dimensions = input.size()

    # Apply the permutation to the flattened tensor
    permutation = torch.randperm(input.numel())
    output_flat = torch.index_select(input.view(-1), 0, permutation)

    # Restore original dimensions
    output = output_flat.view(dimensions)

    return output


# Import the data from the respective pt file
release_probs = torch.load(os.path.join("data", "20201022_energy_dyn-probs_all-release-probs.pt"))

# Set parameters of the figure
high_probability = 0.9
epoch = -1
run_id = 0

# Prepare the figure to use the full page
fig, axs = plt.subplots(nrows=len(release_probs), ncols=2, figsize=[5.626, 4])

for layer, (_, probs) in enumerate(release_probs.items()):
    # Extracting release probabilities from a specific epoch and run_id
    probs = probs[epoch, :, :, run_id]
    probs_shuffled = permute_tensor(probs)

    # Finding the indeces of important weights
    output_weight_idxs = torch.nonzero(probs > high_probability)[:, 1]
    input_weight_idxs = torch.nonzero(probs > high_probability)[:, 0]

    output_weight_idxs_shuffled = torch.nonzero(probs_shuffled > high_probability)[:, 1]
    input_weight_idxs_shuffled = torch.nonzero(probs_shuffled > high_probability)[:, 0]

    # Counting the number of important weights per neuron
    output_weight_count = torch.bincount(output_weight_idxs, minlength=probs.shape[1])
    input_weight_count = torch.bincount(input_weight_idxs, minlength=probs.shape[0])

    output_weight_count_shuffled = torch.bincount(output_weight_idxs_shuffled, minlength=probs_shuffled.shape[1])
    input_weight_count_shuffled = torch.bincount(input_weight_idxs_shuffled, minlength=probs_shuffled.shape[0])

    # Compute histograms over the number of connections per neuron
    output_weight_hist = torch.histc(
        output_weight_count.float(), bins=output_weight_count.max() + 1, min=0, max=output_weight_count.max()
    )
    input_weight_hist = torch.histc(
        input_weight_count.float(), bins=input_weight_count.max() + 1, min=0, max=input_weight_count.max()
    )

    output_weight_hist_shuffled = torch.histc(
        output_weight_count_shuffled.float(), bins=output_weight_count_shuffled.max() + 1, min=0, max=output_weight_count_shuffled.max()
    )
    input_weight_hist_shuffled = torch.histc(
        input_weight_count_shuffled.float(), bins=input_weight_count_shuffled.max() + 1, min=0, max=input_weight_count_shuffled.max()
    )

    # Perform a two-sampled Kolmogorov-Smirnoff test between shuffled and non-shuffled weight count
    print("KS-Test for input neurons in layer ", layer, ": ", kstest(output_weight_count_shuffled, output_weight_count))
    print("KS-Test for output neurons in layer ", layer, ": ", kstest(input_weight_count_shuffled, input_weight_count))

    # Create the bar plot
    axs[layer][1].bar(
        x=torch.arange(len(output_weight_hist)) / probs.shape[0],
        height=output_weight_hist / probs.shape[1],
        width=0.9 / probs.shape[0], alpha=0.8, label="True Connectivity"
    )
    axs[layer][1].bar(
        x=torch.arange(len(output_weight_hist_shuffled)) / probs.shape[0],
        height=output_weight_hist_shuffled / probs.shape[1],
        width=0.9 / probs.shape[0], alpha=0.8, label="Shuffled Connectivity"
    )
    axs[layer][0].bar(
        x=torch.arange(len(input_weight_hist)) / probs.shape[1],
        height=input_weight_hist / probs.shape[0],
        width=0.9 / probs.shape[1], alpha=0.8, label="True Connectivity"
    )
    axs[layer][0].bar(
        x=torch.arange(len(input_weight_hist_shuffled)) / probs.shape[1],
        height=input_weight_hist_shuffled / probs.shape[0],
        width=0.9 / probs.shape[1], alpha=0.8, label="Shuffled Connectivity"
    )
    # X/Y Labels and ticks
    axs[layer][0].set_ylabel("Fraction of Neurons")

axs[-1][0].set_xlabel("Fraction of Important Input Synapses")
axs[-1][1].set_xlabel("Fraction of Important Output Synapses")
fig.align_ylabels()

# Add a single legend for all subplots
handles, _ = axs[0][0].get_legend_handles_labels()
fig.legend(handles=handles, labels=["True Connectivity", "Shuffled Connectivity"], bbox_to_anchor=(0.5, 0.88), ncol=2)

# Save the figure
# fig.tight_layout(pad=0.3)
fig.savefig("energy-mlp_hist_neuron-connectivity_all.pdf")
