import os
os.chdir('C:\EAfiles\EAdetection')

import pickle
import matplotlib.pyplot as plt
import numpy as np
import core.helpers as hf
import matplotlib as mpl
from cycler import cycler
from scipy import signal
import pandas as pd
import matplotlib.cm as cmx
import matplotlib.colors as colors
import seaborn as sns


#%% load dictionaries

# burst information
filename = 'C:\EAfiles_oLFS\MyDicts\srdict_stim'
infile = open(filename,'rb') #read binary
srdict = pickle.load(infile)
infile.close() 

# spiketimes
filename = 'C:\EAfiles_oLFS\MyDicts\Theospike_dict'
infile = open(filename,'rb') #read binary
Theospike_dict = pickle.load(infile)
infile.close()

filename = 'C:\EAfiles_oLFS\MyDicts\minfeature_dict'
infile = open(filename,'rb') #read binary
minfeaturedict = pickle.load(infile)
infile.close()

filename = 'C:\EAfiles_oLFS\MyDicts\latency_feature_dict'
infile = open(filename,'rb') #read binary
latencyfeaturedict = pickle.load(infile)
infile.close()

#%%#############################################
#PREPARE STUFF
#variables
pre,post = 0.1,0.2 #seconds window before and after spike 
sr =  10000.0 #sampling rate

cheby_b, cheby_a = signal.cheby1(4, 1, 300, 'lowpass', fs=sr) 
# 4th order lowpass chebyshev filter (cut-off: 300Hz)

#function for extracting snippets
def get_spiketrace(spiket,datatrace,pre=pre,post=post,sr=10000.0,filt_on=False):
    cutout_pre = np.int(pre*sr)
    cutout_post= np.int(post*sr)
    spikept = np.int(spiket*sr)
    spiketrace = datatrace[spikept-cutout_pre:spikept+cutout_post]
    if filt_on:
        return hf.savitzky_golay(spiketrace,101,2) #windowsize  21, order = 2
    else:
        return spiketrace

def get_snipmat_auc_scalarMap(pre, post, datatrace, spiketimes, sr=10000.0):
    cutout_pre = np.int(pre*sr)
    cutout_post = np.int(post*sr)
    #cut out the waveforms from all spikes
    Nspikes = len(spiketimes)
    snippts = cutout_pre+cutout_post
    snipmat = np.zeros((Nspikes,snippts))
    auc = np.zeros((Nspikes))
    for tt,template_spike in enumerate(spiketimes):
        #spiketrace = get_spiketrace(template_spike,datatrace) # instead of datatrace you can use zTrace
        spiketrace = get_spiketrace(template_spike,filtered) # instead of datatrace you can use zTrace
        snipmat[tt] = spiketrace
        auc[tt] = np.sum(np.abs(spiketrace)) 
          
    minval = np.min(auc)
    maxval = np.max(auc)
    minval = minval - (maxval-minval)*.15 # set minimum to a bit above white: small number small shift (9% 0.1)
    cNorm = colors.Normalize(vmin=minval, vmax=maxval)
    #scalarMap = cmx.ScalarMappable(norm=cNorm, cmap='Blues')
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap='Greys')
    
    return snipmat, auc, scalarMap  

def get_burst_aucs(auc, srdict, recID, period): #period = 1. for 1Hz, 2 for 0.5Hz, 5 for 0.2Hz
    
    # ALL timepoints (x) for aucs (y)
    auc = auc
    timepoints = np.arange(0,len(auc)*period, period)
     
    # L BURSTS
    L_starts = srdict[recID]['L_burst_starts']
    L_stops = srdict[recID]['L_burst_stops']
    
    L_aucs = []
    L_timepoints = []

    for i in range(len(L_starts)):
        start = int(L_starts[i]/period)
        stop = int(L_stops[i]/period)
        
        y = auc[start:stop]
        L_aucs.append(y)
        
        x = timepoints[start:stop]
        L_timepoints.append(x)
        
    if len(L_aucs) > 1:     
        L_aucs_all = np.hstack(L_aucs)
        L_timepoints_all = np.hstack(L_timepoints)
    elif len(L_aucs) == 0.:
        L_aucs_all = np.array([])
        L_timepoints_all = np.array([])
    else:
        L_aucs_all = L_aucs[0]
        L_timepoints_all = L_timepoints[0]
    
    #XL BURSTS
    XL_starts = srdict[recID]['XL_burst_starts']
    XL_stops = srdict[recID]['XL_burst_stops']
    
    XL_aucs = []
    XL_timepoints = []

    for i in range(len(XL_starts)):
        start = int(XL_starts[i]/period)
        stop = int(XL_stops[i]/period)
        
        y = auc[start:stop]
        XL_aucs.append(y)
        
        x = timepoints[start:stop]
        XL_timepoints.append(x)
    
    if len(XL_aucs) > 1:
        XL_aucs_all = np.hstack(XL_aucs)
        XL_timepoints_all = np.hstack(XL_timepoints) 
    elif len(XL_aucs) == 0.:
        XL_aucs_all = np.array([])
        XL_timepoints_all = np.array([])
    else:
        XL_aucs_all = XL_aucs[0]
        XL_timepoints_all = XL_timepoints[0]
        
    interictal_aucs = []
    interictal_timepoints = []
    ictal_aucs = []
    ictal_timepoints = []
    
    for i,t in zip(auc,timepoints):

        ictal=False
        if len(L_aucs_all) != 0:
            if i in L_aucs_all:
                ictal_aucs.append(i)
                ictal_timepoints.append(t)
                ictal=True
        if len(XL_aucs_all) != 0:
            if i in XL_aucs_all:
                ictal_aucs.append(i)
                ictal_timepoints.append(t)
                ictal=True
        if not ictal:
            interictal_aucs.append(i)
            interictal_timepoints.append(t)
        
    #interictal_minvals = np.array(interictal_minvals)
    #interictal_timepoints = np.array(interictal_timepoints)
    #ictal_minvals = np.array(ictal_minvals)
    #ictal_timepoints = np.array(ictal_timepoints)
    
    return L_aucs_all, L_timepoints_all, XL_aucs_all, XL_timepoints_all,\
             interictal_aucs, ictal_aucs, interictal_timepoints, ictal_timepoints

def get_ictal_timepoints(srdict, recID):
    
    timepoints = np.arange(0,3600,1)
    # L BURSTS
    L_starts = srdict[recID]['L_burst_starts']
    L_stops = srdict[recID]['L_burst_stops']
    
    L_timepoints = []

    for i in range(len(L_starts)):
        start = int(L_starts[i]/period)
        stop = int(L_stops[i]/period)
        
        x = timepoints[start:stop]
        L_timepoints.append(x)
        
    if len(L_timepoints) > 1:     
        L_timepoints_all = np.hstack(L_timepoints)
    elif len(L_timepoints) == 0.:
        L_timepoints_all = np.array([])
    else:
        L_timepoints_all = L_timepoints[0]
    
    #XL BURSTS
    XL_starts = srdict[recID]['XL_burst_starts']
    XL_stops = srdict[recID]['XL_burst_stops']
    
    XL_timepoints = []

    for i in range(len(XL_starts)):
        start = int(XL_starts[i]/period)
        stop = int(XL_stops[i]/period)
        
        x = timepoints[start:stop]
        XL_timepoints.append(x)
        
    if len(XL_timepoints) > 1:     
        XL_timepoints_all = np.hstack(XL_timepoints)
    elif len(XL_timepoints) == 0.:
        XL_timepoints_all = np.array([])
    else:
        XL_timepoints_all = XL_timepoints[0]
        
    return L_timepoints_all, XL_timepoints_all  


def plot_latency_scatter_burst(stime, latency_HCi1_HCi2, latency_HCi1_HCc,L_timepoints_all, XL_timepoints_all):
    plt.plot(stime, latency_HCi1_HCi2*1000, 'o', color = 'b',  linewidth=1, markersize=2)
    plt.plot(stime, latency_HCi1_HCc*1000, 'o', color = 'grey',  linewidth=1, markersize=2)
    
    plt.plot(L_timepoints_all, np.zeros(len(L_timepoints_all)), 'o', color='orange', linewidth=1, markersize=2)
    plt.plot(XL_timepoints_all, np.zeros(len(XL_timepoints_all)), 'o', color='red', linewidth=1, markersize=2)
    
    #plt.plot(stime, latency_HCi2_HCc, '.k',  linewidth=1, markersize=2)
    plt.legend(['HCi1 to HCi2','HCi1 to HCc'])
    plt.xlabel('stimulation pulses')
    plt.ylabel('latency [ms]')
    plt.xlim(0,3600)
    plt.ylim(-5,20)
    plt.title(recID[:9])
    
    plt.xticks(fontsize = 16)
    plt.yticks(fontsize = 16)
    mpl.rcParams['xtick.major.width'] = 2
    mpl.rcParams['ytick.major.width'] = 2
    mpl.rcParams['axes.linewidth'] = 2      
        
        
def plot_latency_linreg(stime, latency_HCi1_HCi2, latency_HCi1_HCc,latency_HCi2_HCc):
    sns.set()
    sns.axes_style="whitegrid"
    sns.set_style("ticks", {"xtick.major.size": 2, "ytick.major.size": 2})
    sns.regplot(stime, latency_HCi1_HCi2*1000, color='b', scatter_kws={'s':2}, robust = True)
    sns.regplot(stime, latency_HCi1_HCc*1000, color='grey', scatter_kws={'s':2}, robust = True)
    #sns.regplot(stime, latency_HCi2_HCc*1000, color='darkgreen', scatter_kws={'s':2}, robust = True)
    
    plt.legend(['HCi1 to HCi2','HCi2 to HCc'])
    plt.xlabel('stimulation pulses')
    plt.ylabel('latency [ms]')
    plt.xlim(0,3600)
    plt.ylim(0,15)
    plt.title(recID[:9])
    
    plt.xticks(fontsize = 16)
    plt.yticks(fontsize = 16)
    mpl.rcParams['xtick.major.width'] = 2
    mpl.rcParams['ytick.major.width'] = 2
    mpl.rcParams['axes.linewidth'] = 22

        
def plot_latency_scatter(stime, latency_HCi1_HCi2, latency_HCi1_HCc,latency_HCi2_HCc):    
    plt.figure()
    plt.rcParams["figure.figsize"] = (6, 5)
    plt.plot(stime, latency_HCi1_HCi2*1000, '+', color='b', linewidth=1, markersize=1.5)
    plt.plot(stime, latency_HCi1_HCc*1000, '+', color='grey', linewidth=1, markersize=1.5)
    
    plt.legend(['HCi1 to HCi2','HCi2 to HCc'])
    plt.xlabel('stimulation pulses')
    plt.ylabel('latency [ms]')
    plt.xlim(0,3600)
    plt.ylim(0,15)
    plt.title(recID[:9])
    
    plt.xticks(fontsize = 16)
    plt.yticks(fontsize = 16)
    mpl.rcParams['xtick.major.width'] = 2
    mpl.rcParams['ytick.major.width'] = 2
    mpl.rcParams['axes.linewidth'] = 2
        

#%%
#get arrays with spiketimes for 1 Hz, 0.5 Hz and 0.2 Hz 

pathtoPulsetimes_1Hz = 'C:\EAfiles_oLFS\Pulsetimes\\Pulsetimes_1Hz.txt' 
Pulsetimes_1Hz = []

with open (pathtoPulsetimes_1Hz,'r') as f:
    pulse = f.read().splitlines()
    Pulsetimes_1Hz = Pulsetimes_1Hz + pulse
    
Pulsetimes_1Hz = map(float, Pulsetimes_1Hz) #converts stings from txt into floats
Pulsetimes_1Hz = np.asarray(Pulsetimes_1Hz) #converts list into array
    
pathtoPulsetimes_05Hz = 'C:\EAfiles_oLFS\Pulsetimes\\Pulsetimes_05Hz.txt' 
Pulsetimes_05Hz = []

with open (pathtoPulsetimes_05Hz,'r') as f:
    pulse = f.read().splitlines()
    Pulsetimes_05Hz = Pulsetimes_05Hz + pulse
    
Pulsetimes_05Hz = map(float, Pulsetimes_05Hz)
Pulsetimes_05Hz = np.asarray(Pulsetimes_05Hz) 

pathtoPulsetimes_02Hz = 'C:\EAfiles_oLFS\Pulsetimes\\Pulsetimes_02Hz.txt' 
Pulsetimes_02Hz = []

with open (pathtoPulsetimes_02Hz,'r') as f:
    pulse = f.read().splitlines()
    Pulsetimes_02Hz = Pulsetimes_02Hz + pulse

Pulsetimes_02Hz = map(float, Pulsetimes_02Hz)
Pulsetimes_02Hz = np.asarray(Pulsetimes_02Hz) 

#%%

pathtostim1HzIDspul = 'C:\EAfiles_oLFS\ID_Dateien\\stim1HzIDs_pulser.txt'
stim1HzIDs_pulser = []     
with open (pathtostim1HzIDspul,'r') as f:
    badIDs = f.read().splitlines()
    stim1HzIDs_pulser = stim1HzIDs_pulser + badIDs #this works with strings

pathtostim1HzIDspris = 'C:\EAfiles_oLFS\ID_Dateien\\stim1HzIDs_prismatix.txt'
stim1HzIDs_prismatix = []     
with open (pathtostim1HzIDspris,'r') as f:
    badIDs = f.read().splitlines()
    stim1HzIDs_prismatix = stim1HzIDs_prismatix + badIDs  

#%% LOAD !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    
filename = 'C:\EAfiles_oLFS\MyDicts\stimRecIDs_1Hz_smr'
infile = open(filename,'rb') #read binary
stimRecIDs_1Hz_smr = pickle.load(infile)
infile.close() 

filename = 'C:\EAfiles_oLFS\MyDicts\stimRecIDs_05Hz_smr'
infile = open(filename,'rb') #read binary
stimRecIDs_05Hz_smr = pickle.load(infile)
infile.close() 

filename = 'C:\EAfiles_oLFS\MyDicts\stimRecIDs_02Hz_smr'
infile = open(filename,'rb') #read binary
stimRecIDs_02Hz_smr = pickle.load(infile)
infile.close() 

recIDs_1Hz = stimRecIDs_1Hz_smr.keys() 
recIDs_05Hz = stimRecIDs_05Hz_smr.keys() 
recIDs_02Hz = stimRecIDs_02Hz_smr.keys()

filenames_1Hz = stimRecIDs_1Hz_smr.values()
filenames_05Hz = stimRecIDs_05Hz_smr.values()
filenames_02Hz = stimRecIDs_02Hz_smr.values()        

#remove no-virus control animals and broken recordings 
filenames_1Hz.remove('PJ225-226_21d_stim-226.smr')
filenames_1Hz.remove('PJ225-226_21d_stim-225.smr')
filenames_1Hz.remove('PJ227-228_21d_228-stim000.smr')
filenames_1Hz.remove('PJ227-228_22d_228-stim000.smr')     
filenames_1Hz.remove('EP6_1Hz1.smr')
filenames_1Hz.remove('EP6_1Hz2.smr')  
filenames_1Hz.remove('EP10_1Hz1.smr')
filenames_1Hz.remove('EP10_1Hz2.smr')
filenames_1Hz.remove('EP70_pre1-1-EP69_1Hz1.smr')
filenames_1Hz.remove('EP69_1Hz2-EP70_20d000.smr')  
filenames_1Hz.remove('EP70_1Hz2-EP72_p3-1.smr') 
filenames_1Hz.remove('EP69_pre1-3-EP70_1Hz1.smr') 
filenames_1Hz.remove('EP32_p2-1-EP33_1Hz2.smr') 
filenames_1Hz.remove('EP32_p1-1-EP33_1Hz1.smr') 

filenames_05Hz.remove('EP63_05Hz1-EP64_pre1-2.smr')

#%% plot one spiketrace
    
#default_cycler = (cycler(color=['b','r','grey','darkred'])) 
default_cycler = (cycler(color=['b','grey'])) 

folder = r'G:\Enya\EEG\recordings'

#for filename in filenames_1Hz:
for filename in ['EP29_1Hz1-EP28_pre1-1.smr']:
#for filename in ['EP7_pre2-1-EP8_1Hz2.smr']:
#for filename in ['EP4_1Hz2.smr']:
#for filename in ['EP29_1Hz1-EP28_pre1-1.smr']:    
    
    recID = stimRecIDs_1Hz_smr.keys()[stimRecIDs_1Hz_smr.values().index(filename)]
    #print recID
    period = 1.

    if recID in stim1HzIDs_pulser:
        spiketimes = Pulsetimes_1Hz
    if recID in stim1HzIDs_prismatix:
        spiketimes = Theospike_dict[recID]
        
    if recID [-1:] == '2':
        chanlist = [ 'HCi1_2', 'HCc_2' ]
    else:
        chanlist = [ 'HCi1', 'HCc' ]
        
    ddict = hf.extract_smrViaNeo(os.path.join(folder, filename), chanlist=chanlist)
    
    #plt.figure()
    
    for chan in chanlist:
    
        datatrace = ddict[chan]['trace']
        datatimevec = np.arange(len(datatrace))/sr
        pulsetimevec = np.arange(-(pre),post-0.00001,1./sr)
    
        zTrace = (datatrace-np.mean(datatrace))/np.std(datatrace)
        datatrace = zTrace
        
        filtered = signal.filtfilt(cheby_b, cheby_a, datatrace, padlen = 0)
        
        #filtered = np.abs(filtered)
        
        snipmat,auc,scalarMap = get_snipmat_auc_scalarMap(pre, post, filtered, spiketimes)
              
        #onespike = snipmat[800]

        onespike = np.mean(snipmat,0)
        
        start = np.int((pre-0.005)*sr)
        threshold = 1.2
        idx = start + np.argwhere(np.abs(onespike[start:]) > threshold)[0]
        
        plt.rc('axes', prop_cycle=default_cycler)
        plt.plot(pulsetimevec,onespike,'-', linewidth = 2)
        
        #plt.plot(pulsetimevec[idx],onespike[idx],'o')
      
        plt.title(recID[:9])
        plt.legend(['HCi1','HCi1 thres cross','HCc','HCc thres cross'])
        plt.xlabel('time [s]') 
        plt.ylabel('amplitude [mV]')   
        #plt.xlim(-0.1,0.2)
        plt.ylim(-5.,3.)
        plt.xticks(fontsize = 16)
        plt.yticks(fontsize = 16)
        mpl.rcParams['xtick.major.width'] = 2
        mpl.rcParams['ytick.major.width'] = 2
        mpl.rcParams['axes.linewidth'] = 2
        

 #%%    extract features
 
folder = r'G:\Enya\EEG\recordings'

featuredict = {}

for filename in filenames_1Hz:

#for filename in ['EP29_1Hz1-EP28_pre1-1.smr']:    
    
    recID = stimRecIDs_1Hz_smr.keys()[stimRecIDs_1Hz_smr.values().index(filename)]
    print recID
    
    featuredict[recID] = {} 
    
    if recID [-1:] == '2':
        chanlist = [ 'HCi1_2', 'HCi2_2', 'HCc_2' ]
    else:
        chanlist = [ 'HCi1', 'HCi2', 'HCc' ]
    
    ddict = hf.extract_smrViaNeo(os.path.join(folder, filename), chanlist=chanlist)
            
    for chan in chanlist:
        
        featuredict[recID][chan] = {}
        
        datatrace = ddict[chan]['trace']
        
        zTrace = (datatrace-np.mean(datatrace))/np.std(datatrace)
        datatrace = zTrace
        
        if recID in stim1HzIDs_pulser:
            spiketimes = Pulsetimes_1Hz
        if recID in stim1HzIDs_prismatix:
            spiketimes = Theospike_dict[recID]
        
        pulsetimevec = np.arange(-(pre),post-0.00001,1./sr)
        
        for ii,spiket in enumerate(spiketimes):  
            
            spiketrace = get_spiketrace(spiket,zTrace)  
            
            filtered = signal.filtfilt(cheby_b, cheby_a, spiketrace, padlen = 0)
            filtered = np.abs(filtered)
            
            auc = np.sum(np.abs(filtered))

            start = np.int((pre-0.005)*sr) 
            threshold = 1.2
            try:
                idx = start + np.argwhere(np.abs(filtered[start:]) > threshold)[0]   
               
                threscrosstime = pulsetimevec[idx][0]
                threscrossvalfilt = filtered[idx][0]
                threscrossvalraw = spiketrace[idx][0]
            except:
                threscrosstime = np.nan
                threscrossval = np.nan
                
            
            featuredict[recID][chan][ii] = {'stime':spiket,'threscrosstime':threscrosstime,\
                       'threscrossvalfilt':threscrossvalfilt,'threscrossvalraw':threscrossvalraw,\
                       'auc':auc}

         
filename = r'C:\EAfiles\MyDicts\latency_feature_dict'
outfile = open(filename, 'wb')
pickle.dump(featuredict, outfile)
outfile.close()


#%% plot scatter threscross value and time

default_cycler = (cycler(color=['b','k','grey'])) 
chanlist_all = [ 'HCi1', 'HCc' ,'HCi1_2', 'HCc_2' ] #HCi1 = idHC, #HCc = cdHC

featuredict = latencyfeaturedict


 
for recID in featuredict.keys():   
#for recID in ['EP32_1Hz1_HCi1']:       
    print recID

    
    for chan in chanlist_all:    
        
        if chan in featuredict[recID].keys():
            
            feat1 = 'threscrosstime'
            threscrosstime = np.array([featuredict[recID][chan][ii][feat1] \
                            for ii in np.arange(len(featuredict[recID][chan]))])
            
            feat2 = 'threscrossvalfilt'
            threscrossvalfilt = np.array([featuredict[recID][chan][ii][feat2] \
                            for ii in np.arange(len(featuredict[recID][chan]))])
    
            if recID in ['EP4_1Hz1_HCi1','EP9_1Hz1_HCi1','EP9_1Hz2_HCi1','EP11_1Hz1_HCi1_2',\
                         'EP11_1Hz2_HCi1_2','EP11_1Hz3_HCi1','PJ225_1Hz2_HCi1']:
                pass
            
            else:
                for i in np.arange(0,len(threscrosstime),1):
                    if threscrosstime[i] > 0.025:
                        threscrosstime[i] = np.nan
                        threscrossvalfilt[i] = np.nan
                    elif threscrosstime[i] < 0.0: #only for pulser IDs
                        threscrosstime[i] = np.nan
                        threscrossvalfilt[i] = np.nan
            
            
  
            plt.rc('axes', prop_cycle=default_cycler)
            plt.plot(threscrosstime, threscrossvalfilt, 'o',  linewidth=1, markersize=2)
            plt.legend(['HCi1','HCi2','HCc'])
            plt.xlabel('thres cross time[s]')
            plt.ylabel('thres cross value [z]')
            plt.xlim(0.0,0.025)
            plt.ylim(1.2,1.7)
            plt.title(recID[:9])
                
            plt.xticks(fontsize = 16)
            plt.yticks(fontsize = 16)
            mpl.rcParams['xtick.major.width'] = 2
            mpl.rcParams['ytick.major.width'] = 2
            mpl.rcParams['axes.linewidth'] = 2
        


    plt.savefig('C:\EAfiles\PULSEANALYSIS\LATENCY_smr/'+recID[:9]+'_Threscross_1,2.png')
    plt.close()
    
    
#%%

featuredict = latencyfeaturedict  
    
default_cycler = (cycler(color=['b','k','grey'])) 
chanlist_all = [ 'HCi1', 'HCi2', 'HCc' ,'HCi1_2', 'HCi2_2', 'HCc_2' ]
period = 1. #for 1 Hz

latencydict = {}
 
for recID in featuredict.keys():   
    latencydict[recID] = {}
    
    try:         
        stime = np.array([featuredict[recID]['HCi1'][ii]['stime'] \
                for ii in np.arange(len(featuredict[recID]['HCi1']))]) 
    
        feat1 = 'threscrosstime'
        feat2 = 'threscrossvalfilt'
        
        
        threscrosstime_HCi1 = np.array([featuredict[recID]['HCi1'][ii][feat1] \
                        for ii in np.arange(len(featuredict[recID]['HCi1']))])

        threscrossvalfilt_HCi1 = np.array([featuredict[recID]['HCi1'][ii][feat2] \
                        for ii in np.arange(len(featuredict[recID]['HCi1']))])
        
        if recID in ['EP4_1Hz1_HCi1','EP9_1Hz1_HCi1','EP9_1Hz2_HCi1','EP11_1Hz1_HCi1_2',\
             'EP11_1Hz2_HCi1_2','EP11_1Hz3_HCi1','PJ225_1Hz2_HCi1']:
            pass
        
        else:
            for i in np.arange(0,3600,1):
                if threscrosstime_HCi1[i] > 0.025:
                    threscrosstime_HCi1[i] = np.nan
                    threscrossvalfilt_HCi1[i] = np.nan
        
        threscrosstime_HCi2 = np.array([featuredict[recID]['HCi2'][ii][feat1] \
                        for ii in np.arange(len(featuredict[recID]['HCi2']))])
              
        threscrossvalfilt_HCi2 = np.array([featuredict[recID]['HCi2'][ii][feat2] \
                        for ii in np.arange(len(featuredict[recID]['HCi2']))])
            
        if recID in ['EP4_1Hz1_HCi1','EP9_1Hz1_HCi1','EP9_1Hz2_HCi1','EP11_1Hz1_HCi1_2',\
                     'EP11_1Hz2_HCi1_2','EP11_1Hz3_HCi1','PJ225_1Hz2_HCi1']:
            pass
        
        else:
            for i in np.arange(0,3600,1):
                if threscrosstime_HCi2[i] > 0.025:
                    threscrosstime_HCi2[i] = np.nan
                    threscrossvalfilt_HCi2[i] = np.nan        
        
        threscrosstime_HCc = np.array([featuredict[recID]['HCc'][ii][feat1] \
                        for ii in np.arange(len(featuredict[recID]['HCc']))])
        
          
        threscrossvalfilt_HCc = np.array([featuredict[recID]['HCc'][ii][feat2] \
                        for ii in np.arange(len(featuredict[recID]['HCc']))])
            
        if recID in ['EP4_1Hz1_HCi1','EP9_1Hz1_HCi1','EP9_1Hz2_HCi1','EP11_1Hz1_HCi1_2',\
                     'EP11_1Hz2_HCi1_2','EP11_1Hz3_HCi1','PJ225_1Hz2_HCi1']:
            pass
        
        else:
            for i in np.arange(0,3600,1):
                if threscrosstime_HCc[i] > 0.025:
                    threscrosstime_HCc[i] = np.nan
                    threscrossvalfilt_HCc[i] = np.nan  


    except:
        stime = np.array([featuredict[recID]['HCi1_2'][ii]['stime'] for ii in np.arange(len(featuredict[recID]['HCi1_2']))]) 
        
        feat1 = 'threscrosstime'
        threscrosstime_HCi1 = np.array([featuredict[recID]['HCi1_2'][ii][feat1] \
                        for ii in np.arange(len(featuredict[recID]['HCi1_2']))])
        
        feat2 = 'threscrossvalfilt'
        threscrossvalfilt_HCi1 = np.array([featuredict[recID]['HCi1_2'][ii][feat2] \
                        for ii in np.arange(len(featuredict[recID]['HCi1_2']))])
    
        
        if recID in ['EP4_1Hz1_HCi1','EP9_1Hz1_HCi1','EP9_1Hz2_HCi1','EP11_1Hz1_HCi1_2',\
             'EP11_1Hz2_HCi1_2','EP11_1Hz3_HCi1','PJ225_1Hz2_HCi1']:
            pass
        
        else:
            for i in np.arange(0,3600,1):
                if threscrosstime_HCi1[i] > 0.025:
                    threscrosstime_HCi1[i] = np.nan
                    threscrossvalfilt_HCi1[i] = np.nan
    
    
        threscrosstime_HCi2 = np.array([featuredict[recID]['HCi2_2'][ii][feat1] \
                        for ii in np.arange(len(featuredict[recID]['HCi2_2']))])
    
    
        threscrossvalfilt_HCi2 = np.array([featuredict[recID]['HCi2_2'][ii][feat2] \
                        for ii in np.arange(len(featuredict[recID]['HCi2_2']))])
    
        
        if recID in ['EP4_1Hz1_HCi1','EP9_1Hz1_HCi1','EP9_1Hz2_HCi1','EP11_1Hz1_HCi1_2',\
                     'EP11_1Hz2_HCi1_2','EP11_1Hz3_HCi1','PJ225_1Hz2_HCi1']:
            pass
        
        else:
            for i in np.arange(0,3600,1):
                if threscrosstime_HCi2[i] > 0.025:
                    threscrosstime_HCi2[i] = np.nan
                    threscrossvalfilt_HCi2[i] = np.nan      
    
    
        threscrosstime_HCc = np.array([featuredict[recID]['HCc_2'][ii][feat1] \
                        for ii in np.arange(len(featuredict[recID]['HCc_2']))])
        
      
        threscrossvalfilt_HCc = np.array([featuredict[recID]['HCc_2'][ii][feat2] \
                        for ii in np.arange(len(featuredict[recID]['HCc_2']))])
    
        if recID in ['EP4_1Hz1_HCi1','EP9_1Hz1_HCi1','EP9_1Hz2_HCi1','EP11_1Hz1_HCi1_2',\
                     'EP11_1Hz2_HCi1_2','EP11_1Hz3_HCi1','PJ225_1Hz2_HCi1']:
            pass
        
        else:
            for i in np.arange(0,3600,1):
                if threscrosstime_HCc[i] > 0.025:
                    threscrosstime_HCc[i] = np.nan
                    threscrossvalfilt_HCc[i] = np.nan 

    latency_HCi1_HCi2 = threscrosstime_HCi2 - threscrosstime_HCi1
    
    #define time window after visual inspection
    for i in np.arange(0,3600,1):
        if latency_HCi1_HCi2[i] < 0.:
            latency_HCi1_HCi2[i] = np.nan   
        elif latency_HCi1_HCi2[i] > 0.0075:
             latency_HCi1_HCi2[i] = np.nan                   
    
    latency_HCi1_HCc = threscrosstime_HCc - threscrosstime_HCi1
    
    for i in np.arange(0,3600,1):
        if latency_HCi1_HCc[i] < 0.005:
            latency_HCi1_HCc[i] = np.nan   
        elif latency_HCi1_HCc[i] > 0.015:
             latency_HCi1_HCc[i] = np.nan
    
    latency_HCi2_HCc = threscrosstime_HCc - threscrosstime_HCi2    
    
    for i in np.arange(0,3600,1):
        if latency_HCi2_HCc[i] < 0.0:
            latency_HCi2_HCc[i] = np.nan   
        elif latency_HCi2_HCc[i] > 0.01:
             latency_HCi2_HCc[i] = np.nan
    
    print recID
    
    median_latency_HCi1_HCi2 = np.nanmedian(latency_HCi1_HCi2*1000)
    median_latency_HCi1_HCc = np.nanmedian(latency_HCi1_HCc*1000)
    median_latency_HCi2_HCc = np.nanmedian(latency_HCi2_HCc*1000)
    
    std_latency_HCi1_HCi2 = np.nanstd(latency_HCi1_HCi2*1000)
    std_latency_HCi1_HCc = np.nanstd(latency_HCi1_HCc*1000)
    std_latency_HCi2_HCc = np.nanstd(latency_HCi2_HCc*1000)

    latencydict[recID] = {'median_latency_HCi1_HCi2':median_latency_HCi1_HCi2,\
                          'median_latency_HCi1_HCc':median_latency_HCi1_HCc,\
                          'median_latency_HCi2_HCc':median_latency_HCi2_HCc,\
                          'std_latency_HCi1_HCi2':std_latency_HCi1_HCi2,\
                          'std_latency_HCi1_HCc':std_latency_HCi1_HCc,\
                          'std_latency_HCi2_HCc':std_latency_HCi2_HCc}
    
  
    if recID in ['EP8_1Hz2_HCi1_2','EP5_1Hz1_HCi1','EP5_1Hz2_HCi1_2','EP9_1Hz1_HCi1',\
                 'EP9_1Hz2_HCi1','EP11_1Hz1_HCi1_2','EP11_1Hz3_HCi1','EP29_1Hz1_HCi1',\
                 'EP29_1Hz2_HCi1', 'EP62_1Hz1_HCi1','EP64_1Hz1_HCi1_2','EP64_1Hz2_HCi1_2',\
                 'EP32_1Hz2_HCi1']:
        sns.set()
        sns.axes_style="whitegrid"
        sns.set_style("ticks", {"xtick.major.size": 2, "ytick.major.size": 2})
        sns.regplot(stime, latency_HCi1_HCi2*1000, marker = 'o', color='k', scatter=False, robust = True)
        #sns.regplot(stime, latency_HCi1_HCc*1000, marker='o', color='grey', scatter=False)
        
        
    if recID in ['EP8_1Hz2_HCi1_2','EP9_1Hz1_HCi1','EP9_1Hz2_HCi1',\
                 'EP11_1Hz1_HCi1_2','EP29_1Hz1_HCi1',\
                 'EP62_1Hz1_HCi1','EP64_1Hz1_HCi1_2','EP32_1Hz1_HCi1']:
        sns.set()
        sns.axes_style="whitegrid"
        sns.set_style("ticks", {"xtick.major.size": 2, "ytick.major.size": 2})
        #sns.regplot(stime, latency_HCi1_HCi2*1000, marker='o', color='b', scatter=False)
        sns.regplot(stime, latency_HCi1_HCc*1000, marker = 'o', color='grey', scatter=False, robust = True)
        
        plt.legend(['HCi1 to HCi2','HCi1 to HCc'])
        plt.xlabel('stimulation pulses')
        plt.ylabel('latency [ms]')
        plt.xlim(0,3600)
        plt.ylim(0,15)
        plt.title(recID[:9])
        
        plt.xticks(fontsize = 16)
        plt.yticks(fontsize = 16)
        mpl.rcParams['xtick.major.width'] = 2
        mpl.rcParams['ytick.major.width'] = 2
        mpl.rcParams['axes.linewidth'] = 2
        
#%%  save
        
filename = r'C:\EAfiles\MyDicts\latency_median_dict'
outfile = open(filename, 'wb')
pickle.dump(latencydict, outfile)
outfile.close()        

#%% extract latencies 
    
recIDs = np.array(latencydict.keys())

flavour = 'median_latency_HCi1_HCc'
median_latency_HCi1_HCc = np.array([latencydict[recID][flavour] for recID in latencydict.keys()])

flavour = 'median_latency_HCi1_HCi2'
median_latency_HCi1_HCi2 = np.array([latencydict[recID][flavour] for recID in latencydict.keys()])

flavour = 'median_latency_HCi2_HCc'
median_latency_HCi2_HCc = np.array([latencydict[recID][flavour] for recID in latencydict.keys()])

flavour = 'std_latency_HCi1_HCc'
std_latency_HCi1_HCc = np.array([latencydict[recID][flavour] for recID in latencydict.keys()])

flavour = 'std_latency_HCi1_HCi2'
std_latency_HCi1_HCi2 = np.array([latencydict[recID][flavour] for recID in latencydict.keys()])

flavour = 'std_latency_HCi2_HCc'
std_latency_HCi2_HCc = np.array([latencydict[recID][flavour] for recID in latencydict.keys()])

#save in excel

os.chdir('C:\EAfiles\RESULTS')

df_sr=pd.DataFrame(columns=
    ['recID',
    'median_latency_HCi1_HCc [ms]', 
    'median_latency_HCi1_HCi2 [ms]',
    'median_latency_HCi2_HCc [ms]',
    'std_latency_HCi1_HCc [ms]',
    'std_latency_HCi1_HCi2 [ms]',
    'std_latency_HCi2_HCc [ms]'])
    
df_sr['recID'] = recIDs
df_sr['median_latency_HCi1_HCc [ms]'] = median_latency_HCi1_HCc
df_sr['median_latency_HCi1_HCi2 [ms]'] = median_latency_HCi1_HCi2
df_sr['median_latency_HCi2_HCc [ms]'] = median_latency_HCi2_HCc
df_sr['std_latency_HCi1_HCc [ms]'] = std_latency_HCi1_HCc
df_sr['std_latency_HCi1_HCi2 [ms]'] = std_latency_HCi1_HCi2
df_sr['std_latency_HCi2_HCc [ms]'] = std_latency_HCi2_HCc

df_sr.to_excel('Enya_latency_results.xlsx')