import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.stats import pearsonr


# Global configuration of matplotlib to use custom style sheet
matplotlib.use('pgf')
matplotlib.style.use("style.mplstyle")


# Helper functions
def permute_tensor(input):
    """
    Permute elements of tensor given a matrix of permutation indices
    """
    # Cache the original dimensions
    dimensions = input.size()

    # Apply the permutation to the flattened tensor
    permutation = torch.randperm(input.numel())
    output_flat = torch.index_select(input.view(-1), 0, permutation)

    # Restore original dimensions
    output = output_flat.view(dimensions)

    return output


# Import the data from the respective pt file
release_probs = torch.load(os.path.join("data", "20201022_energy_dyn-probs_all-release-probs.pt"))

# Set parameters of the figure
high_probability = 0.9
epoch = -1
run_id = 0

# Prepare the figure to use the full page
fig, ax = plt.subplots(figsize=[1.8, 1.75])

probs_input = release_probs['layer0'][epoch, :, :, run_id]
probs_output = release_probs['layer1'][epoch, :, :, run_id]

# Counting the number of important input and output weights per neuron (normalised)
input_weight_count = torch.sum(probs_input > high_probability, dim=1) / probs_input.shape[1]
output_weight_count = torch.sum(probs_output > high_probability, dim=0) / probs_output.shape[1]

# Compute the Pearson correlation
corr_coeff, p_value = pearsonr(input_weight_count, output_weight_count)

# Create the scatter plot
ax.scatter(input_weight_count, output_weight_count, alpha=0.7)

# Add regression line
regression_line = np.poly1d(np.polyfit(input_weight_count, output_weight_count, 1))
x_lim = ax.get_xlim()
ax.plot(x_lim, regression_line(x_lim), alpha=0.7)

# Add annotation with correlation coefficient
ax.annotate("$r={:.4f}$".format(corr_coeff), (0.3, 0.9), xycoords='axes fraction', ha='center', va='center', bbox=dict(facecolor='white', edgecolor='none', boxstyle='square'))

# X/Y Labels
ax.set_xlabel("Fraction of Important\nInput Synapses")
ax.set_ylabel("Fraction of Important\nOutput Synapses")

# Save the figure
fig.tight_layout(pad=0.0)
fig.savefig("energy-mlp_scatter_neuron-connectivity.pdf")
