# -*- coding: utf-8 -*-
'''
Runs the model for all permutations of 1-year treatments with alendronate, romosozumab
and denosumab as described in the manuscript.

Author: David J. Joerg
'''
import sys, os
from itertools import permutations
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from tabulate import tabulate

sys.path.insert(0, os.pardir)
from model import OsteoporosisModel, ParameterSet
from tools import param_file, results_path, nft, plot_settings, colours, format_label

def medseq(a_dur, t_int, dose):
    '''
    Generates an array of constant dose administrations in fixed intervals.
    
    Args:
        a_dur (int): Duration of the medication episode (in years).
        t_int (int): Interval between consecutive administrations (in days).
        dose (float): Drug dose.
        
    Returns:
        Array containing dose administration data listed as (admin. time, dose).
    '''
    return np.array([[t, float(dose)] for t in np.arange(0, a_dur * 365, t_int)])

def main():
    # simulation time
    t_min, t_max = 66 * 365, 72 * 365
    t_ref = 67 * 365
    t_sim = 365. * 85
    
    # drug-individual administration schemes
    test_admins = {
        'alendronate': medseq(a_dur=1, t_int=7, dose=70),
        'romosozumab': medseq(a_dur=1, t_int=30, dose=140),
        'denosumab': medseq(a_dur=1, t_int=180, dose=60)
    }    
    ages = [67, 68, 69]

    # generate permutations
    treatments = {}
    seq_abbr = {}
    for drugs in permutations(test_admins.keys()):
        admins = deepcopy(test_admins)
        for a, d in zip(ages, drugs):
            admins[d][:,0] += a * 365
        key = ' -> '.join([d[0].upper() for d in drugs])
        treatments[key] = admins
        seq_abbr[key] = ''.join([d[0].upper() for d in drugs])
        
    # read parameters
    params = ParameterSet(param_file)

    # simulate
    t_all = {}; y_all = {}
    for d, admins in treatments.items():
        print('Simulating {}...'.format(d), end='')
        # equilibrate
        avatar = OsteoporosisModel(params, admins=admins, init_state='equilibrium')
        # max solver step 1 day within the treatment region
        max_step = OsteoporosisModel.piecewise_max_step(t_min, t_max)
        t_all[d], y_all[d] = avatar.propagate(t_sim, dt=1, max_step=max_step)
        print('done.')
    print('')

    # compute BMD changes
    tab = []
    for label, data in treatments.items():
        t, y = t_all[label], y_all[label]
        # normalise BMD to reference time point
        bmd_change = nft(t, y['bmd'], t_ref) - 1.
        # max. BMD change        
        max_bmd_change = np.max(bmd_change[t >= t_min])
        # BMD change 10 years after treatment end
        bmd_change_10yrs = bmd_change[t >= t_max + 10 * 365][0]
        # convert to percent and store
        tab.append([max_bmd_change, bmd_change_10yrs])
        
    print_tab = [[key] + val for key, val in zip(treatments.keys(), tab)]
    print(tabulate(print_tab, headers=['Treatment', 'Max. BMD change', 'BMD change 10 yrs\nafter treatment']))
 
    # plot
    plot_settings()
    obs = {
        'bmd': 'BMD total hip\n(rel. to $t_0$)',
        'ctx': 'CTX\n(rel. to baseline)',
        'p1np': 'P1NP\n(rel. to baseline)'
    }

    other = []
    # plot of global results
    fig, axs = plt.subplots(nrows=2, figsize=(3.5,6), sharex=False)
    
    plot_tab = sorted([[k] + v for k, v in zip(seq_abbr.keys(), tab)],
        key=lambda x: x[1])
    ranked = [v[0] for v in plot_tab]
    labels = [seq_abbr[v] for v in ranked]
    plot_tab = np.array([v[1:] for v in plot_tab])
    x_glob = range(len(plot_tab))
    
    axs[0].bar(x_glob, plot_tab[:,0], linewidth=1, edgecolor='black', color='#87deaa')
    lb, ub = 0.03, 0.055
    axs[0].set_ylim(lb, ub)
    yticks = np.arange(lb, ub, 0.01)
    axs[0].set_yticks(yticks)
    axs[0].set_yticklabels([str(int(np.round(100. * z))) + '%' for z in yticks])
    
    axs[1].bar(x_glob, plot_tab[:,1], linewidth=1, edgecolor='black', color='#87decd')
    lb, ub = -0.09, -0.049999
    axs[1].set_ylim(lb, ub)
    yticks = np.arange(lb, ub, 0.01)
    axs[1].set_yticks(yticks)
    axs[1].set_yticklabels(
        [str(int(np.round(100. * z))).replace('-', '$-$') + '%' for z in yticks]
    )

    axs[0].set_ylabel('Max. BMD\n(change from baseline)')
    axs[1].set_ylabel('BMD 10 yrs after treatment end\n(change from baseline)')
       
    for ax in axs:
        ax.set_xticks(x_glob)
        ax.set_xticklabels(labels, fontsize=10)
    axs[1].set_xlabel('Sequence')
    
    fig.align_ylabels(axs)
    fig.tight_layout()
    fig.savefig(os.path.join(results_path, 'treatment_scenario_globals.pdf'))
    plt.show()
    plt.close()
    
    # plot of the treatment region        
    nrows = len(obs) + len(other)
    ncols = len(treatments)
    fig, axs = plt.subplots(
        ncols = ncols,
        nrows = nrows,
        figsize = (ncols * 2, nrows * 2.),
        sharey = 'row',
        sharex = 'col'
    )
    
    range_age = np.arange(0, 100, 1)
    for i, c_sim in enumerate(obs):
        for j, d in enumerate(treatments.keys()):
            axs[i,j].set_xticks(range_age * 365)
            axs[i,j].set_xticklabels(range_age)

    max_val = np.full(nrows, -np.inf)            
    for j, d in enumerate(ranked):
        t = t_all[d]
        y = y_all[d]

        axs[0,j].set_xlim(t_min, t_max)
        
        for i, (c_sim, c_exp) in enumerate(obs.items()):
            
            axs[i,j].plot(t, nft(t, y[c_sim], t_ref), color=colours[c_sim], linewidth=2)
            
            axs[i,j].grid(color='#d0d0d0')
            axs[i,0].set_ylabel(c_exp)

        for k, cols in enumerate(other):
            i = k + len(obs)
            axs[i,j].set_ylabel('')
            
            for c_sim in cols:
                color = '#000000' if c_sim not in colours else colours[c_sim]
                
                plot = True
                if c_sim in y:
                    sim = y[c_sim]
                else:
                    plot = False
                
                # rescale medications to unit max
                if c_sim in test_admins.keys():
                    nrm = np.max(sim)
                    if nrm > 0:
                        sim /= nrm
                    axs[i,0].set_ylabel('serum levels\n(frac. of maximum)')

                # rescale turnover rates to %/year
                if c_sim in ['bone formation rate', 'bone resorption rate']:
                    sim *= 365. * 100.
                    axs[i,0].set_ylabel('% of peak BMD/year')
                
                # determine maximum value in the plotted range
                mask = (t_min <= t) & (t <= t_max)
                max_val[k] = max(max_val[k], 1.05 * np.max(sim[mask]))
                
                # plot
                if plot:
                    axs[i,j].plot(t, sim, color=color, linewidth=2,
                        label=format_label(c_sim))
                    axs[i,j].grid(color='#d0d0d0')
            
            axs[i,j].legend(fontsize=8)
        
    for i in range(1,3):
        axs[i,0].set_yticks(np.arange(0., 5., 0.5))
        axs[i,0].set_ylim(0, 2.)

    for j in range(ncols):
        axs[0,j].set_yticks(np.arange(0.9, 1.1, 0.02))
        axs[0,j].set_ylim(0.99, 1.06)
        
        axs[-1,j].set_xlabel('Age (years)')

    fig.align_ylabels(axs[:,0])
    fig.tight_layout()
    
    # write to file and display
    fig.savefig(os.path.join(results_path, 'treatment_scenario.pdf'))
    plt.show()
    plt.close()
            
if __name__ == '__main__':
    main()