import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.stats import pearsonr


# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")


# Helper functions
def permute_tensor(input):
    """
    Permute elements of tensor given a matrix of permutation indices
    """
    # Cache the original dimensions
    dimensions = input.size()

    # Apply the permutation to the flattened tensor
    permutation = torch.randperm(input.numel())
    output_flat = torch.index_select(input.view(-1), 0, permutation)

    # Restore original dimensions
    output = output_flat.view(dimensions)

    return output


# Import the data from the respective pt file
release_probs = torch.load(os.path.join("data", "20201022_energy_dyn-probs_all-release-probs.pt"))

# Set parameters of the figure
high_probability = 0.9
epoch = -1
run_id = 0

# Prepare the figure to use the full page
fig, axs = plt.subplots(ncols=len(release_probs) - 1, nrows=1, figsize=[5.623, 1.75])
labels = ["True Connectivity", "Shuffled Connectivity"]

for shuffle in [False, True]:
    for l, (layer_current, layer_next) in enumerate(zip(release_probs, list(release_probs)[1:])):
        probs_current = release_probs[layer_current][epoch, :, :, run_id]
        probs_next    = release_probs[layer_next][epoch, :, :, run_id]

        if shuffle:
            probs_current = permute_tensor(probs_current)
            probs_next    = permute_tensor(probs_next)

        # Counting the number of important input and output weights per neuron
        input_weight_count = torch.sum(probs_current > high_probability, dim=1) / probs_current.shape[1]
        output_weight_count = torch.sum(probs_next > high_probability, dim=0) / probs_next.shape[0]

        # Compute the Pearson correlation
        corr_coeff, p_value = pearsonr(input_weight_count, output_weight_count)
        print("Hidden layer {} has a correlation between the number of important input and output weights per neuron of r={:4f} (p={:4f})".format(l, corr_coeff, p_value))

        # Create the scatter plot
        axs[l].scatter(input_weight_count, output_weight_count, alpha=0.7, label=labels[l])

        if not shuffle:
            # Add regression line
            regression_line = np.poly1d(np.polyfit(input_weight_count, output_weight_count, 1))
            x_lim = axs[l].get_xlim()
            axs[l].plot(x_lim, regression_line(x_lim), alpha=0.7)

            # Add annotation with correlation coefficient
            axs[l].annotate("$r={:.4f}$".format(corr_coeff), (0.85, 0.1), xycoords='axes fraction', ha='center', va='center', bbox=dict(facecolor='white', edgecolor='none', boxstyle='square'))

        # X/Y Labels and ticks
        axs[l].set_xlabel("Fraction of Important Input Synapses")

# axs[1].set_xlabel("Fraction of Important Input Synapses")
axs[0].set_ylabel("Fraction of Important\nOutput Synapses")
axs[1].set_ylim([-0.1, 1.1])

# fig.tight_layout(pad=1.5, w_pad=0.4)
# Add a single legend for all subplots
handles, _ = axs[0].get_legend_handles_labels()
fig.legend(handles=handles, labels=["True Connectivity", "Shuffled Connectivity"], bbox_to_anchor=(0.5, 0.89), ncol=2)

# Save the figure
fig.subplots_adjust(bottom=0.2)
fig.savefig("energy-mlp_scatter_neuron-connectivity_all.pdf")
