# -*- coding: utf-8 -*-
'''
Helper functions for comparing model with data.
Tools and naming conventions for plotting and other outputs.

Author: David J. Joerg
'''

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

param_file = os.path.join(os.path.pardir, 'parameters', 'fit.csv')
data_path = os.path.join(os.path.pardir, 'data')
results_path = os.path.join(os.path.curdir, 'results')

# abbreviations
names = {
    'P': 'Placebo',
    'R': 'Romosozumab',
    'B': 'Blosozumab',
    'D': 'Denosumab',
    'A': 'Alendronate',
    'T': 'Teriparatide'
}

abbr_legend = {
    'romosozumab': 'R',
    'blosozumab': 'B',
    'denosumab': 'D',
    'alendronate': 'A',
    'teriparatide': 'T'
}

# standard units for drugs
std_unit = {
    'P': '',
    'R': 'mg',
    'B': 'mg',
    'D': 'mg',
    'A': 'mg',
    'T': 'mcg'
}

# standard colours for plots
colours = {
    'bone mineral fraction': '#00c080',
    'bone density': '#000000',
    'bmd': '#000000',
    'p1np': '#5c94bd',
    'ctx': '#ff6768',
    'bsap': '#266492',
    'bone formation rate': '#5c94bd',
    'bone resorption rate': '#ff6768', 
    'blosozumab': '#e080a0',
    'romosozumab': '#e080a0',
    'alendronate': '#5c94bd',
    'denosumab': '#80c0a0',
    'teriparatide': '#a09040',
    'pre-osteoclasts': '#ffa0a0',
    'osteoclasts': '#ff6768',
    'pre-osteoblasts': '#8cc4ed',
    'osteoblasts': '#5c94bd',
    'osteocytes': '#000000',
    'estrogen': '#5c94bd',
    'sclerostin': '#ff6768',
    'sclerostin*': '#404040',
    'runx2': '#e080a0',
    'creb/bad': '#5c94bd',
    'pth': '#000000',
    'resorption signal': '#000000'
}

def plot_settings():
    'Standard plot settings.'
    plt.rcParams['font.family'] = 'sans-serif'
    plt.rcParams['font.size'] = 13
    
def format_label(label):
    'Replaces terms in a label by abbreviations.'
    flabel = str(label)
    for key, val in abbr_legend.items():
        flabel = flabel.replace(key, val)
    return flabel

def format_title(title):
    'Formats a plot title in TeX format.'
    trim_elements = lambda string_list: [s.strip() for s in string_list]
    df_metadata = pd.read_csv(os.path.join(data_path, 'study_metadata.csv'), engine='python')
    cite_keys = dict(df_metadata[['Study ID', 'Citation key']].values)

    # extract treatment string and publication id    
    stitle = title.split(' (')
    if len(stitle) == 2:
        treatment, pub = stitle
        pub_key = pub[:-1]
    elif len(stitle) == 1:
        treatment = stitle[0]
        pub_key = None
    
    # split by sequence
    sequence = trim_elements(treatment.split('->'))
    # subsplit by parallel administration
    sequence = [trim_elements(s.split('+')) for s in sequence]

    # generate a TeX string with the medication sequence
    for i, s in enumerate(sequence):
        for j, x in enumerate(s):
            parts = x.split(' ')
            drug = parts[0][0]
            parts[0] = f'{names[drug]} {parts[0][1:]}{std_unit[drug]}'
            sequence[i][j] = ' '.join(parts)
        sequence[i] = ' + '.join(sequence[i])
    ftitle = r' $\to$ '.join(sequence)
        
    return ftitle, cite_keys[pub_key]
    
def common_points(t1, x1, t2, x2):
    '''
    Returns the values of two time series at coincident time points.
        
    Args:
        t1, x1 (ndarray): Time series #1.
        t2, x2 (ndarray): Time series #2.
            
    Returns:
        common_times (ndarray): Array of times present in both `t1` and `t2`.
        x1_out, x2_out (ndarray): Arrays of data points in `x1` or `x2`,
            respectively, at the time points `common_times`.
    '''
    # convert times to a convenient format
    times = [t1, t2]
    times = [t.tolist() if isinstance(t, np.ndarray) else t for t in times]
        
    # obtain common time points
    intersect = set(times[0]).intersection(times[1])
    common_times = np.array(sorted(list(intersect)))
    
    if len(common_times) == 0:
        raise RuntimeError('Data and simulation results have no common time points.')
    
    # obtain the indices of the common time points
    ind = [[t.index(s) for s in common_times] for t in times]

    # extract values
    vals = [x1, x2]
    x1_out, x2_out = [np.array([ vals[k][i] for i in ind[k] ]) for k in range(2)]
    
    return common_times, x1_out, x2_out

def goodness_measures(t_sim, sim, t_data, data):
    'Computes various goodness measures for reporting.'
    t_c, sim_c, data_c = common_points(t_sim, sim, t_data, data)

    # mad
    mad = np.mean(np.abs(data_c - sim_c))
    # standard mape
    mape = np.mean(np.abs(data_c - sim_c) / data_c)
    # R-squared
    r2 = 1. - np.sum(np.power(data_c - sim_c, 2.)) \
        / np.sum(np.power(data_c - np.mean(data_c), 2.))

    # minimal distance within typical time spacing    
    dt = 0.5 * np.median(np.diff(t_data))
    distances = []
    for t, x in zip(t_data, data):
        mask = (t_sim >= t - dt) & (t_sim <= t + dt)
        distances.append(np.min(np.abs(sim[mask] - x)))
    wmape = np.mean(np.array(distances) / data)

    return mad, mape, r2, wmape
    
def nfc(t1, x1, t2, x2):
    '''Normalises the data of a time series `t1`, `x1` to the first data point of
    another time series `t2`, `x2` (required overlap of time series).'''
    i0 = np.argmin(np.abs(t1 - t2[0]))
    return x1 / x1[i0] * x2[0]

def nft(t, x, t0):
    '''Normalises a time series to its value at a time point `t0` (requires `t0` to be
    covered by the time series).'''
    i0 = np.argmin(np.abs(t - t0))
    return x / x[i0]