#!/usr/bin/env python
"""
findtau.py
Read in a table of peak heights from Sparky's rh command, with ncyc specified correctly within Sparky.
Calculate R2_eff for each peak at each ncyc, then use these values to fit tauE, R2 and Rex.
Report a reduced chi squared value for each fit.
"""
from pandas import Series, DataFrame
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.optimize import curve_fit
from math import pi,ceil
from scipy.stats import norm
import re

# User-editable values


# Sample Name for titles
sample_name = '20180703 AdcR L57M Apo BB CPMG 800'

# Input file information

noise = 1000

# Location of input file
peaktable = '/Users/…/20180703_AdcR_L57M_Apo.rh'

def fitFunc(nu,R2,Rex,t):
    return R2+Rex*(1-2*t*nu*np.tanh(1/(2*t*nu)))

# one_letter["SER"] will now return "S"
one_letter ={'VAL':'V', 'ILE':'I', 'LEU':'L', 'GLU':'E', 'GLN':'Q', \
'ASP':'D', 'ASN':'N', 'HIS':'H', 'TRP':'W', 'PHE':'F', 'TYR':'Y',    \
'ARG':'R', 'LYS':'K', 'SER':'S', 'THR':'T', 'MET':'M', 'ALA':'A',    \
'GLY':'G', 'PRO':'P', 'CYS':'C'}
 
# three_letter["S"] will now return "SER"
three_letter = dict([[v,k] for k,v in one_letter.items()])

def format_label(orig):
    res,atom = orig.split('-')
    resn = three_letter[res[0]]
    resi = res[1:]
    new = '%s\t%s\t%s'%(resi,resn,atom)
    return new

#def formatAssignment(ass):
#    #Condense the original sparky assignment string to remove the H
#    return re.sub(r'(\d+)C',r'\1-C',ass.split('-')[0])
#def greekFormatAssignment(ass):
#    return r'$%s$'%ass.replace('-CB',' \\beta').replace('-CE',' \epsilon').replace('-CG',' \gamma').replace('-CD',' \delta')


def name2ncyc(col_name):
    if '/' in col_name:
        filename,ncyc = col_name.split('/')
        return ncyc
    else:
        return col_name

def name2Hz(col_name):
    if '/' in col_name:
        filename,ncyc = col_name.split('/')
        return 25*int(ncyc)
    else:
        return col_name
    
def export_data(df,name):
    df.to_excel('%s_%s.xls'%(sample_name,name))
    filename = sample_name+'_pymol.txt'

def parsepeaktable(filepath):
    # Given the path to a file, return a pandas DataFrame object indexed by
    # the peak assignment, with columns for ncyc0 and other pulse frequencies.
    # Discard the columns for T-decay and SD since these are meaningless.
    df0 = pd.read_table(filepath, delim_whitespace=True,index_col='Assignment')
    column_names = list(df0.columns.values)
    filtered_column_names = []
    ncyc_list = []
    for col in column_names:
        if col not in ['T-decay','SD']:
            filename,ncyc = col.split('/')
            if ncyc not in ncyc_list:
                ncyc_list.append(ncyc)
                filtered_column_names.append(col)
    df1 = df0[filtered_column_names]
    col_names_in_Hz = [name2Hz(col) for col in filtered_column_names]
    df1.columns = col_names_in_Hz
    export_data(df1,'input')
    return df1

def calcR2eff(heightsDF):
    # extract ncyc0, generate DF without ncyc0,
    # calculate R2eff = -25*ln (Ix/I0)
    oldcols = list(heightsDF.columns.values)
    newcolumns = [col for col in oldcols if col != 0]
    newdf = heightsDF[newcolumns]
    ncyc0=DataFrame()
    for col in newcolumns:
        ncyc0[[col]] = heightsDF[[0]]
    R2effs = -25*np.log(newdf.div(ncyc0))
    export_data(R2effs,'R2eff')
    return R2effs

def fitCurves(R2effs):
    freqs = R2effs.columns.values
    assignments = R2effs[freqs[0]].keys()
    all_fits = pd.DataFrame(columns=['R2','R2_err','Rex','Rex_err','Tau','Tau_err','reduced_chi_squared'],index=assignments)
    for ass in assignments:
        R2efflist = R2effs.ix[ass].tolist()
        fitParams, fitCovariances = curve_fit(fitFunc, freqs, R2efflist,bounds=([0.,0.,0.0005],[100,100,0.005]))
        R2,Rex,Tau = fitParams
        R2_err,Rex_err,Tau_err = np.sqrt(np.diag(fitCovariances))
        chi_squared = np.sum((fitFunc(freqs, *fitParams)-R2efflist)**2)             
        reduced_chi_squared = chi_squared/(len(freqs)-len(fitParams))
        all_fits.ix[ass]=[R2,R2_err,Rex,Rex_err,Tau,Tau_err,reduced_chi_squared]   
    export_data(all_fits,'RexTauFits')
    return all_fits

def plotfakecurve(assignment,R2effs,ax):
    freqs = R2effs.columns.values
    R2efflist = R2effs.ix[assignment].tolist()
    ax.plot(freqs,R2efflist,'w')
    plt.setp(ax.get_xticklabels(),rotation='vertical')

def plotCurves(R2effs,all_fits):
    freqs = R2effs.columns.values
    assignments = R2effs[freqs[0]].keys()
    rows = 5
    cols = 4
    pages = ceil(float(len(assignments))/(rows*cols))
    f, axes = plt.subplots(rows,cols,sharex=True,sharey=True)
    f.set_size_inches(8,10.5)
    f.subplots_adjust(wspace=0.05,hspace=0.05)
    row=0
    col=0
    page=0
    for ass in assignments:
        print ass
        R2,R2_err,Rex,Rex_err,Tau,Tau_err,chisq = all_fits.ix[ass]
        R2efflist = R2effs.ix[ass].tolist()
        Tau_ms=Tau*1000
        #Tau_err_ms = Tau_err*1000
        exp1 = r'$R_{2} = %.1f$'%R2 + ' Hz'
        exp2 = r'$R_{ex} = %.1f$'%Rex + ' Hz'
        exp3 = r'$\tau = %.1f ms$'%Tau_ms
        exp4 = r'$\chi^{2} = %.2f$'%chisq
        ax = axes[row,col]
        ax.plot(freqs, R2efflist, '.')
        ax.plot(freqs, fitFunc(freqs, R2, Rex, Tau))
        ax.axis(ymax=50)
        ax.axis(ymin=0)
        plt.setp(ax.get_xticklabels(),rotation='vertical')
        ax.annotate(ass+'\n'+exp1+'\n'+exp2+'\n'+exp3+'\n'+exp4,xy=(10,110),xycoords='axes points',
                    horizontalalignment='left',verticalalignment='top')
        col=col+1
        if col>=cols:
            col=0
            row=row+1
        if row >=rows:
            big_ax=f.add_subplot(111)
            big_ax.set_axis_bgcolor('none')
            big_ax.tick_params(labelcolor='none',top='off',bottom='off',left='off',right='off')
            big_ax.spines['top'].set_color('none')
            big_ax.spines['bottom'].set_color('none')
            big_ax.spines['left'].set_color('none')
            big_ax.spines['right'].set_color('none')
            big_ax.set_title(sample_name)
            plt.ylabel('$R2_{eff}$'+' (Hz)')# '+r'$\frac{ncyc_n}{ncyc_0}$')
            plt.xlabel(r'$\nu_{cp}$' +'(Hz)',labelpad=20)
            plt.savefig('%s_curves_%d.pdf'%(sample_name,page))
            #plt.show()            
            row=0
            page=page+1
            f, axes = plt.subplots(rows,cols,sharex=True,sharey=True)
            f.set_size_inches(8,10.5)
            f.subplots_adjust(wspace=0.05,hspace=0.05)
    maxcharts = rows*cols*pages
    fakecharts = int(maxcharts-len(assignments))
    for i in range(fakecharts):
        plotfakecurve(ass,R2effs,axes[row,col])
        col=col+1
        if col>=cols:
            col=0
            row=row+1
    big_ax=f.add_subplot(111)
    big_ax.set_axis_bgcolor('none')
    big_ax.tick_params(labelcolor='none',top='off',bottom='off',left='off',right='off')
    big_ax.spines['top'].set_color('none')
    big_ax.spines['bottom'].set_color('none')
    big_ax.spines['left'].set_color('none')
    big_ax.spines['right'].set_color('none')
    big_ax.set_title(sample_name)
    plt.ylabel('$R2_{eff}$'+' (Hz)')# '+r'$\frac{ncyc_n}{ncyc_0}$')
    plt.xlabel(r'$\nu_{cp}$' +'(Hz)',labelpad=20)
    plt.savefig('%s_curves_%d.pdf'%(sample_name,page))

def Rexbarplot(all_fits):
    Rexvalues = all_fits['Rex'].values
    assignments = all_fits.index
    #assignments = [greekFormatAssignment(x) for x in all_fits.index]
    Rexerrors = all_fits['Rex_err'].values
    fix,ax = plt.subplots(figsize=(20,5))
    h = plt.bar(xrange(len(assignments)),
                  Rexvalues,
                  color='r',
                  label=assignments,
                  yerr=Rexerrors)
    plt.subplots_adjust(bottom=0.3)
    xticks_pos = [0.5*patch.get_width() + patch.get_xy()[0] for patch in h]
    plt.xticks(xticks_pos, assignments, ha='right', rotation=45)
    ax.set_ylabel('$R_{ex}$'+' (Hz)')
    ax.set_title(sample_name)
    plt.savefig(sample_name+'_bar.pdf')
    plt.show()    

def plotTaus(all_fits):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    culled_taus = 1000*all_fits[(all_fits['Rex'] > 5)][['Tau']]
    culled_rexs = all_fits[(all_fits['Rex'] > 5)][['Rex']]
    ax.plot(culled_rexs,culled_taus,'.')
    plt.ylabel('Tau (ms)')
    plt.xlabel('Rex (Hz)',labelpad=10)
    plt.title('tau vs Rex')
    fig.savefig('%s_tau_v_rex.pdf'%sample_name)
    return culled_taus.mean()    
    
def main():
    mainDataFrame = parsepeaktable(peaktable)
    R2effs = calcR2eff(mainDataFrame)
    all_fits = fitCurves(R2effs)
    plotCurves(R2effs,all_fits)
    Rexbarplot(all_fits)
    plotTaus(all_fits)
main()
