import matplotlib
matplotlib.use('TkAgg')
import seaborn as sns
from pylab import *
import numpy as np
import math, random
import multiprocessing as mp
from mpl_toolkits.mplot3d import Axes3D

sns.set(font_scale=1.0)

plt.style.use('ggplot')
sns.set_palette("husl")

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42



def calcShannonIndex(uniqueSet):
    n_unique = len(uniqueSet)
    pool_size = np.sum([c for c in uniqueSet.values()])

    H = 0.0

    for unique in uniqueSet.items():
        pi = float(int(unique[1])/pool_size)
        H += pi*np.log2(pi)

    print("Unique sequences: %d\tTotal sequences: %d\nShannon index: %f\n" % (n_unique, pool_size, -H))

    return -H


def buildUniqueSet(code=4, length=10, unique=10, total=100):
    '''Constructs an dict of random sequences based on the input parameters.'''

    counter = 0

    uniqueSet = {}

    ##Generate the required unique sequences.
    for i in range(int(unique)):
        uniqueSeq = []
        for j in range(length):
            bchoice = str(np.random.randint(0, code))
            uniqueSeq.append(bchoice)

        uniqueSeq = ''.join(uniqueSeq)

        if uniqueSeq not in uniqueSet:
            uniqueSet[uniqueSeq] = 1
            counter += 1


    ##Randomly amplify the unique set to give total sequences.
    tmp_list = list(uniqueSet)
    while counter < total:
        ## Pick a random unique sequence
        uchoice = np.random.randint(0, len(uniqueSet))
        uniqueSet[tmp_list[uchoice]] += 1
        counter += 1

    print("Built population. Code=%d Length=%d Unique=%e Total=%e" % (code, length, unique, total))

    return uniqueSet



def simulateRecombination(args):

    '''Simulates the hydrolysis and ligation reaction (recombination)'''

    uniqueSet, hydrolysis_rate, hydrolysis_mean_nt, ligation_rate, n_cycles = args

    hydrolysis_mean_nt = hydrolysis_mean_nt/2.0

    ## Build a list containing entire population distribution.
    rolloutSet = list(np.concatenate([[c[0]]*c[1] for c in uniqueSet.items()]))

    for i in range(n_cycles):

        random.shuffle(rolloutSet) ## Shuffle the set.

        ## Perform hydrolysis first using supplied parameters.
        n_to_hydrolyse = len(rolloutSet)*hydrolysis_rate

        hydrolysed = []
        # print("Hydrolysing population.")
        for i in range(int(n_to_hydrolyse)):
            seq_to_hydrolyse = rolloutSet.pop()

            ## Randomly hydrolyse the sequence via normal distribution around mean_nt.
            hyd_point = int(abs(np.random.normal(loc=hydrolysis_mean_nt, scale=len(seq_to_hydrolyse)*0.15)))

            if 0 < hyd_point < len(seq_to_hydrolyse):
                h1 = seq_to_hydrolyse[:hyd_point]
                h2 = seq_to_hydrolyse[hyd_point:]

                ## Add the fragments back to the start.
                # rolloutSet.insert(0, h1)
                # rolloutSet.insert(0, h2)
                hydrolysed.append(h1)
                hydrolysed.append(h2)


        ## Shuffle the population again.
        random.shuffle(rolloutSet)
        random.shuffle(hydrolysed)


        ## Perform ligation.
        n_to_ligate = (len(rolloutSet)+len(hydrolysed))*ligation_rate

        prob_hyd = len(hydrolysed)/(len(rolloutSet)+len(hydrolysed))
        prob_nhyd = len(rolloutSet)/(len(rolloutSet)+len(hydrolysed))

        ligate_set = []
        hydrolysed_l10 = []

        count = 0
        # print("Ligating population.")
        while count  <= int(n_to_ligate):

            if len(hydrolysed) < 1:
                break
            seq1 = hydrolysed.pop()
            if len(seq1) < 2:
                hydrolysed_l10.append(seq1)
                continue

            if len(hydrolysed) < 1:
                seq2 = rolloutSet.pop()
            else:
                pop_to_take = np.random.choice(2, p=[prob_hyd, prob_nhyd])
                if pop_to_take == 0:
                    seq2 = hydrolysed.pop()
                else:
                    seq2 = rolloutSet.pop()

            ligated = seq1+seq2
            ligate_set.append(ligated)
            # random.shuffle(rolloutSet)
            count += 1

        for j in ligate_set:
            rolloutSet.append(j)

        for k in hydrolysed:
            rolloutSet.append(k)

        for k in hydrolysed_l10:
            rolloutSet.append(k)


    newSet = {}

    for l in rolloutSet:
        if l not in newSet:
            newSet[l] = 1
        else:
            newSet[l] += 1

    print("Recombined initial set with: hydrolysis_rate=%f hydrolysis_mean_nt=%d ligation_rate=%f" % (hydrolysis_rate, hydrolysis_mean_nt, ligation_rate))

    print("Shannon index of initial set.")
    s_initial = calcShannonIndex(uniqueSet)
    print("Shannon index of recombined set")
    s_recombined = calcShannonIndex(newSet)
    s_delta = s_recombined-s_initial
    print("Delta Shannon index: %f\n" % (s_delta))

    return s_delta


######## Main ##########

n_cpus = 16

n_cycles = 1

set_length = 10

set_code_space = [4]
set_unique_space = np.around(np.geomspace(1e3, 1e6, num=4), decimals=2)
set_total_space = np.around(np.geomspace(1e3, 1e7, num=5), decimals=2)

#set_unique_space = [1e5]
#set_total_space = [1e7]

h_rate_space = np.around(np.arange(0.0, 1.0, 0.1), decimals=2)
l_rate_space = np.around(np.arange(0.0, 1.0, 0.1), decimals=2)




for set_code in set_code_space:
    for set_unique in set_unique_space:
        for set_total in set_total_space:

            if set_total < set_unique:
                continue

            initial_set = buildUniqueSet(code=set_code, length=set_length, unique=set_unique, total=set_total)
            task_queue = []

            ## Grid search
            for h_rate in h_rate_space:
                for l_rate in l_rate_space:

                    task_queue.append([initial_set, h_rate, set_length, l_rate, n_cycles])


            with mp.Pool(processes=n_cpus) as pool:
                results = [pool.apply_async(simulateRecombination, args=(x,)) for x in task_queue] ## Submit the task queue and block until complete.

                results = [p.get() for p in results]

            results_shaped = np.array(results).reshape(len(h_rate_space), len(l_rate_space))


            sns.heatmap(results_shaped, xticklabels=l_rate_space, yticklabels=h_rate_space, center=0, cmap=sns.diverging_palette(10, 250, s=99, l=30, sep=10, as_cmap=True),
            vmin=-5.0, vmax=5.0, square=True, linewidths=.5, cbar_kws={"shrink": .5}, annot=True, annot_kws={"size": 6})
            plt.title("Code=%d, length=%d, unique=%.1E, total=%.1E, cycles=%d" % (set_code, set_length, set_unique, set_total, n_cycles))
            plt.xlabel("Ligation rate")
            plt.ylabel("Hydrolysis rate")

            plt.savefig('complexity_c%d_l%d_u%0.1E_t%0.1E_cyc%d.pdf' % (set_code, set_length, set_unique, set_total, n_cycles), dpi=300, bbox_inches='tight')
            plt.clf()
            # plt.show(block=True)


