# -*- coding: utf-8 -*-
'''
Runs the model and creates plots comparing validation datasets and model results.
Script 1 of 2.

Author: David J. Joerg
'''
import sys, os

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tabulate import tabulate

sys.path.insert(0, os.pardir)
from model import OsteoporosisModel, ParameterSet, ClinicalTrialData
from tools import goodness_measures, nfc, param_file, data_path, results_path, \
    colours, format_label, plot_settings

def main():
    # simulation time
    t_sim = 365. * 85

    # clinical observables to plot
    obs = {
        'bmd': 'BMD total hip\n(rel. to $t_0$)',
        'ctx': 'CTX\n(rel. to baseline)',
        'p1np': 'P1NP\n(rel. to baseline)'
    }
    
    # list of datasets to simulate and compare with
    datasets = {
        'P (mcclung2018)':
            'looker1998_and_mcclung2018_placebo_to_placebo.csv',
        'P -> D60 Q6M (mcclung2018)':
            'looker1998_and_mcclung2018_placebo_to_denosumab.csv',
        'A70 Q1W -> R140 Q1M -> D60 Q6M (mcclung2018)':
            'looker1998_and_mcclung2018_alendronate_to_romosozumab_to_denosumab.csv',
        'B180 Q2W (recknor2015)':
            'looker1998_and_recknor2015_blosozumab_180_q2w.csv',
        'D60 Q6M -> T20 Q1D (leder2015)':
            'looker1998_and_leder2015_denosumab_to_teriparatide.csv',
        'T20 Q1D + D60 Q6M -> D60 Q6M (leder2015)':
            'looker1998_and_leder2015_combination_to_denosumab.csv',
    }
    for key, val in datasets.items():
        datasets[key] = os.path.join(data_path, val)  # prepend data path

    # read datasets and parameters
    fit_data = {key: ClinicalTrialData(val) for key, val in datasets.items()}
    params = ParameterSet(param_file)
        
    # simulate
    t_all = {}; y_all = {}
    for key, data in fit_data.items():
        print('Simulating {}...'.format(key.replace('\n', ' ')), end='')
        # equilibrate
        avatar = OsteoporosisModel(params, admins=data.admins, init_state='equilibrium')
        # max solver step 1 day within the treatment region
        max_step = OsteoporosisModel.piecewise_max_step(*data.treatment_period)
        t_all[key], y_all[key] = avatar.propagate(t_sim, dt=1, max_step=max_step)
        print('done.')
    print('')

    # analysis
    # print key indicators
    res_gm = []
    for key, data in fit_data.items():
        for i, (c_sim, label) in enumerate(obs.items()):   
            if c_sim in fit_data[key].data:
                t = t_all[key]
                t_data, data = np.transpose(fit_data[key].data[c_sim])
                sim = nfc(t, y_all[key][c_sim], t_data, data)
                mad, mape, r2, wmape = goodness_measures(t, sim, t_data, data)
                res_gm.append([key.replace('\n', ' '), c_sim, mad, mape, r2, wmape])
        
    columns = ['Dataset', 'Observable', 'MAD', 'MAPE', 'R2', 'WMAPE']
    print(tabulate(res_gm, headers=columns))
    pd.DataFrame(res_gm, columns=columns).to_csv(
        os.path.join(results_path, 'validation_gm_main.csv'),
        index = False
    )
 
    # plot
    plot_settings()
    btms = ['ctx', 'p1np', 'bsap']
    medications = ['blosozumab', 'romosozumab', 'alendronate', 'denosumab', 'teriparatide']
    
    t_min, t_max = np.array([58, 80]) * 365
    other = []
    markersize = 5
    
    nrows = len(obs) + len(other)
    ncols = len(datasets)
    fig, axs = plt.subplots(
        ncols = ncols,
        nrows = nrows,
        figsize = (ncols * 2.5, nrows * 2.),
        sharey = 'row',
        sharex = 'col'
    )
    
    range_age = np.arange(0, 100, 1)

    max_val = np.full(nrows, -np.inf)            
    for j, key in enumerate(datasets.keys()):
        t = t_all[key]
        y = y_all[key]
        a = y_all[key]
        
        tt0, tt1 = fit_data[key].treatment_period
        delta = tt1 - tt0
        t_min, t_max = tt0 - 0.2 * delta, tt1 + 1. * delta
        for i, c_sim in enumerate(obs):
            axs[i,j].set_xticks(range_age * 365)
            axs[i,j].set_xticklabels(range_age)
            
            if (t_max - t_min) / 365. > 10:
                axs[i,j].set_xticks(range_age[::2] * 365)
                axs[i,j].set_xticklabels(range_age[::2])
            else:
                axs[i,j].set_xticks(range_age * 365)
                axs[i,j].set_xticklabels(range_age)
        axs[0,j].set_xlim(t_min, t_max)
        
        for i, (c_sim, c_exp) in enumerate(obs.items()):
            if c_sim in fit_data[key].data:
                t_data, data = np.transpose(fit_data[key].data[c_sim])
                error_low, error_high = np.transpose(fit_data[key].errors[c_sim])
                yerr = [error_low, error_high]
                            
                axs[i,j].errorbar(t_data, data, yerr=yerr, marker='o',
                   linestyle='', color='#000000', markersize=markersize)
                   
                if c_sim in y:
                    sim = nfc(t, y[c_sim], t_data, data)
                    axs[i,j].plot(t, sim,
                       color=colours[c_sim], linewidth=2)
            
                axs[i,j].grid(color='#d0d0d0')
            else:
                if c_sim in y:
                    sim = nfc(t, y[c_sim], t_data, data)
                    axs[i,j].plot(t, sim,
                       color=colours[c_sim], linewidth=2)
                axs[i,j].grid(color='#d0d0d0')
           
            if c_sim in btms:
                if c_sim not in fit_data[key].data:
                    axs[i,j].text(t_max - (t_max - t_min) * 0.025, 3.8,
                       '(no data available)', 
                       va='top', ha='right', fontsize=9, color='#606060')
                axs[i,0].set_ylim(0, 4.)
                    
            axs[i,0].set_ylabel(c_exp)

        for k, cols in enumerate(other):
            i = k + len(obs)
            axs[i,j].set_ylabel('')
            
            for c_sim in cols:
                color = '#000000' if c_sim not in colours else colours[c_sim]
                
                plot = True
                if c_sim in y:
                    sim = y[c_sim]
                elif c_sim in a:
                    sim = a[c_sim]
                else:
                    plot = False
                
                # rescale medications to unit max
                if c_sim in medications:
                    nrm = np.max(sim)
                    if nrm > 0:
                        sim /= nrm
                    axs[i,0].set_ylabel('Systemic drug levels\n(frac. of maximum)')
                    axs[i,0].set_ylim(0, 1.2)

                # rescale turnover rates to %/year
                if c_sim in ['bone formation rate', 'bone resorption rate']:
                    sim *= 365. * 100.
                    axs[i,0].set_ylabel('% of peak BMD/year')
                    
                # determine maximum value in the plotted range
                mask = (t_min <= t) & (t <= t_max)
                max_val[k] = max(max_val[k], 1.05 * np.max(sim[mask]))
                
                # plot
                if plot:
                    axs[i,j].plot(t, sim, color=color, linewidth=2, 
                       label=format_label(c_sim))
                    axs[i,j].grid(color='#d0d0d0')
            
            axs[i,j].legend(fontsize=8)

    for j in range(ncols):
        axs[0,j].set_ylim(0.78, 0.94)
        axs[-1,j].set_xlabel('Age (years)')
       
    fig.align_ylabels(axs[:,0])
    fig.tight_layout()
    fig.savefig(os.path.join(results_path, 'validation_main.pdf'))
    
    plt.show()
    plt.close()
            
if __name__ == '__main__':
    main()