import os
os.chdir('C:\EAfiles\EAdetection')

import pickle
import numpy as np
import core.helpers as hf
from scipy import signal
import pandas as pd
import matplotlib.cm as cmx
import matplotlib.colors as colors

#%%#############################################
#PREPARE STUFF
#variables
pre,post = 0.02,0.06 #seconds window before and after spike 

sr =  10000.0

cheby_b, cheby_a = signal.cheby1(4, 1, 300, 'lowpass', fs=sr) 
#4th order lowpass chebyshev filter (cut-off: 300Hz)

#function for extracting snippets
def get_spiketrace(spiket,datatrace,pre=pre,post=post,sr=10000.0,filt_on=False):
    cutout_pre = np.int(pre*sr)
    cutout_post= np.int(post*sr)
    spikept = np.int(spiket*sr)
    spiketrace = datatrace[spikept-cutout_pre:spikept+cutout_post]
    if filt_on:
        return hf.savitzky_golay(spiketrace,21,2) #windowsize  21, order = 2
    else:
        return spiketrace
    
def get_snipmat_auc_scalarMap(pre, post, datatrace, spiketimes, colorlist, sr=10000.0):
    cutout_pre = np.int(pre*sr)
    cutout_post = np.int(post*sr)
    #cut out the waveforms from all spikes
    Nspikes = len(spiketimes)
    snippts = cutout_pre+cutout_post
    snipmat = np.zeros((Nspikes,snippts))
    auc = np.zeros((Nspikes))
    for tt,spiket in enumerate(spiketimes):
        spiketrace = get_spiketrace(spiket,filtered)
        snipmat[tt] = spiketrace
        auc[tt] = np.sum(np.abs(spiketrace))
          
    minval = np.min(auc)
    maxval = np.max(auc)
    minval = minval - (maxval-minval)*.15 #set minimum to a bit above white: small number small shift (9% 0.1)
    cNorm = colors.Normalize(vmin=minval, vmax=maxval)
    for color in colorlist:
        scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=color)
    return snipmat, auc, scalarMap       
    
#%% later LOAD 

filename = r'C:\EAfiles\MyDicts\Seizure_10Hz_IDs_smr'
infile = open(filename,'rb') #read binary
stimRecIDs_10Hz_smr = pickle.load(infile)
infile.close() 

recIDs_10Hz = stimRecIDs_10Hz_smr.keys() 
filenames_10Hz = stimRecIDs_10Hz_smr.values()

#%% create dictionary with stimulation start

Stimstart_10Hz_dict = {}

for recID in stimRecIDs_10Hz_smr.keys():
    
    Stimstart_10Hz_dict[recID]={}

    excelpath = r'C:\EAfiles\Excel\EEG-10Hz-Stimstart.xlsx'
    df = pd.read_excel(excelpath)
    df.set_index('recID', inplace=True)

    Stimstart_10Hz = float(df.Stimstart_10Hz[recID])
    Stimstart_10Hz_dict[recID]['Stimstart_10Hz'] = Stimstart_10Hz
    Stimlen_s = float(df.Stimlen[recID])
    Stimstart_10Hz_dict[recID]['Stimlen_s'] = Stimlen_s

#%% create dictionary with arrays for 10 Hz spiketimes
        
spike10Hz_dict = {}

for recID in stimRecIDs_10Hz_smr.keys():
    
        x = 0.1
        spikes10Hz = []
        
        pulsenumber = Stimstart_10Hz_dict[recID]['Stimlen_s']*10.
        Stimstart_10Hz = Stimstart_10Hz_dict[recID]['Stimstart_10Hz']
        
        for i in np.arange(0,pulsenumber,1):
            spikes10Hz.append(Stimstart_10Hz + x * i)
            spike10Hz_dict[recID] = spikes10Hz

 #%%    extract features
 
folder = r'G:\Enya\EEG\recordings_generalized_seizures'
chanlist = [ 'HCi1', 'HCi2', 'HCc' ]

featuredict_10Hz = {}

for filename in filenames_10Hz:
   
    recID = stimRecIDs_10Hz_smr.keys()[stimRecIDs_10Hz_smr.values().index(filename)]
    print recID
    
    featuredict_10Hz[recID] = {} 
    
    ddict = hf.extract_smrViaNeo(os.path.join(folder, filename), chanlist=chanlist)
            
    for chan in chanlist:
        
        featuredict_10Hz[recID][chan] = {}
        
        datatrace = ddict[chan]['trace']
        
        zTrace = (datatrace-np.mean(datatrace))/np.std(datatrace)
        datatrace = zTrace

        spiketimes = spike10Hz_dict[recID]
        
        pulsetimevec = np.arange(-(pre),post-0.00001,1./sr)
        
        for ii,spiket in enumerate(spiketimes):  
            
            spiketrace = get_spiketrace(spiket,datatrace)  
            
            filtered = signal.filtfilt(cheby_b, cheby_a, spiketrace, padlen = 0)
            filtered = np.abs(filtered)
            
            auc = np.sum(np.abs(filtered))              
            
            featuredict_10Hz[recID][chan][ii] = {'stime':spiket,'auc':auc}

filename = r'C:\EAfiles\MyDicts\latency_feature_dict_10Hz'
outfile = open(filename, 'wb')
pickle.dump(featuredict_10Hz, outfile)
outfile.close()

#%% calculate mean AUC

aucdict_10Hz = {}        

for recID in featuredict_10Hz.keys():    

    feat1 = 'auc'     
    stime = np.array([featuredict_10Hz[recID]['HCi1'][ii][feat1] \
                    for ii in np.arange(len(featuredict_10Hz[recID]['HCi1']))])

    mean_auc = np.mean(auc)
    aucdict_10Hz[recID] = {'mean_auc': mean_auc}
    
#%%

recIDs = np.array(aucdict_10Hz.keys())

flavour = 'mean_auc'
mean_auc = np.array([aucdict_10Hz[recID][flavour] for recID in aucdict_10Hz.keys()])    
    
#save in excel
os.chdir('C:\EAfiles\RESULTS')

df_sr=pd.DataFrame(columns=
    ['recID',
    'mean_auc'])
    
df_sr['recID'] = recIDs
df_sr['median_auc'] = mean_auc

df_sr.to_excel('Enya_AUC_results_10Hz.xlsx')  