import os
os.chdir('C:\EAfiles\EAdetection')

import pickle
import matplotlib.pyplot as plt
import numpy as np
import core.helpers as hf
import matplotlib as mpl
from cycler import cycler
from scipy import signal

from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d as mpl3d
import matplotlib.cm as cmx
import matplotlib.colors as colors

import pandas as pd

#%% load dictionaries

#XL and L periods defined by Katharina Heinings code
filename = r'C:\EAfiles\MyDicts\srdict_stim'
infile = open(filename,'rb') #read binary
srdict = pickle.load(infile)
infile.close() 

# stimulation pulse times
filename = r'C:\EAfiles\MyDicts\Theospike_dict'
infile = open(filename,'rb') #read binary
Theospike_dict = pickle.load(infile)
infile.close()

#auc median for every electrode position
filename = r'C:\EAfiles\MyDicts\auc_dict'
infile = open(filename,'rb') #read binary
aucdict = pickle.load(infile)
infile.close()

#%%#############################################
#preparation 

# global variables
pre,post = 0.1,0.2 #seconds window before and after event
#--> snippets for later analysis 

#spikerate (sr)
sr =  10000.0

cheby_b, cheby_a = signal.cheby1(4, 1, 300, 'lowpass', fs=sr) 
# 4th order lowpass chebyshev filter (cut-off: 300Hz)

#function for extracting snippets
def get_spiketrace(spiket,datatrace,pre=pre,post=post,sr=10000.0,filt_on=False):
    cutout_pre = np.int(pre*sr) #time in datapoints
    cutout_post= np.int(post*sr)
    spikept = np.int(spiket*sr) #time of event in datatrace 
    spiketrace = datatrace[spikept-cutout_pre:spikept+cutout_post] #cutting datatrace into snippet with event
    if filt_on: #savitzky golay filter is tuned off
        return hf.savitzky_golay(spiketrace,101,2) #windowsize  21, order = 2
    else:
        return spiketrace

#collecting all spiketraces(snippets) in a matrix(snipmat), auc of every event and scalar Map for auc-color code of 3D PLot
def get_snipmat_auc_scalarMap(pre, post, datatrace, spiketimes, sr=10000.0):
    cutout_pre = np.int(pre*sr)
    cutout_post = np.int(post*sr)
    #cut out the waveforms from all spikes
    Nspikes = len(spiketimes)
    snippts = cutout_pre+cutout_post
    snipmat = np.zeros((Nspikes,snippts))
    auc = np.zeros((Nspikes))
    for tt,template_spike in enumerate(spiketimes):
        #spiketrace = get_spiketrace(template_spike,datatrace) # instead of datatrace you can use zTrace
        spiketrace = get_spiketrace(template_spike,datatrace) # instead of datatrace you can use zTrace
        snipmat[tt] = spiketrace
        auc[tt] = np.sum(np.abs(spiketrace)) 
          
    minval = np.min(auc)
    maxval = np.max(auc)
    minval = minval - (maxval-minval)*.15 # set minimum to a bit above white: small number small shift (9% 0.1)
    cNorm = colors.Normalize(vmin=minval, vmax=maxval)
    #scalarMap = cmx.ScalarMappable(norm=cNorm, cmap='Blues')
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap='Greys')
    
    return snipmat, auc, scalarMap  

#get timepoints where XL or L bursts occured using srdict
def get_burst_aucs(auc, srdict, recID, period): #period = 1. for 1Hz, 2 for 0.5Hz, 5 for 0.2Hz
    
    # ALL timepoints (x) for aucs (y)
    auc = auc
    timepoints = np.arange(0,len(auc)*period, period)
     
    # L BURSTS
    L_starts = srdict[recID]['L_burst_starts']
    L_stops = srdict[recID]['L_burst_stops']
    
    L_aucs = []
    L_timepoints = []

    for i in range(len(L_starts)):
        start = int(L_starts[i]/period)
        stop = int(L_stops[i]/period)
        
        y = auc[start:stop]
        L_aucs.append(y)
        
        x = timepoints[start:stop]
        L_timepoints.append(x)
        
    if len(L_aucs) > 1:     
        L_aucs_all = np.hstack(L_aucs)
        L_timepoints_all = np.hstack(L_timepoints)
    elif len(L_aucs) == 0.:
        L_aucs_all = np.array([])
        L_timepoints_all = np.array([])
    else:
        L_aucs_all = L_aucs[0]
        L_timepoints_all = L_timepoints[0]
    
    #XL BURSTS
    XL_starts = srdict[recID]['XL_burst_starts']
    XL_stops = srdict[recID]['XL_burst_stops']
    
    XL_aucs = []
    XL_timepoints = []

    for i in range(len(XL_starts)):
        start = int(XL_starts[i]/period)
        stop = int(XL_stops[i]/period)
        
        y = auc[start:stop]
        XL_aucs.append(y)
        
        x = timepoints[start:stop]
        XL_timepoints.append(x)
    
    if len(XL_aucs) > 1:
        XL_aucs_all = np.hstack(XL_aucs)
        XL_timepoints_all = np.hstack(XL_timepoints) 
    elif len(XL_aucs) == 0.:
        XL_aucs_all = np.array([])
        XL_timepoints_all = np.array([])
    else:
        XL_aucs_all = XL_aucs[0]
        XL_timepoints_all = XL_timepoints[0]
        
    interictal_aucs = []
    interictal_timepoints = []
    ictal_aucs = []
    ictal_timepoints = []
    
    for i,t in zip(auc,timepoints):

        ictal=False
        if len(L_aucs_all) != 0:
            if i in L_aucs_all:
                ictal_aucs.append(i)
                ictal_timepoints.append(t)
                ictal=True
        if len(XL_aucs_all) != 0:
            if i in XL_aucs_all:
                ictal_aucs.append(i)
                ictal_timepoints.append(t)
                ictal=True
        if not ictal:
            interictal_aucs.append(i)
            interictal_timepoints.append(t)
        
    interictal_aucs = np.array(interictal_aucs)
    interictal_timepoints = np.array(interictal_timepoints)
    ictal_aucs = np.array(ictal_aucs)
    ictal_timepoints = np.array(ictal_timepoints)
    
    return L_aucs_all, L_timepoints_all, XL_aucs_all, XL_timepoints_all,\
             interictal_aucs, ictal_aucs, interictal_timepoints, ictal_timepoints


#get 3D plot with aucs color coded
def plot_3D_AUC_burst(snipmat, auc, scalarMap, start, stop, step, L_timepoints_all, XL_timepoints_all): #0,60,1.0/60 for 1Hz
    plt.rcParams["figure.figsize"] = (18, 8) # (w, h)
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    nx = np.arange(start, stop, step)
    ny = np.arange(-(pre),post-0.00001,1./sr)  
    #lines
    fig01 = plt.figure()
    ax = fig01.gca(projection='3d')
    for ii in range(nx.shape[0]):
        colorVal = scalarMap.to_rgba(np.abs(auc[ii]))
        ax.plot(nx[ii]*np.ones((snipmat.shape[1])),ny,snipmat[ii,:], color=colorVal)
        for e in L_timepoints_all:
            if e==ii:
                ax.plot(nx[ii]*np.ones((snipmat.shape[1])),ny,snipmat[ii,:], color='orange')
        for ee in XL_timepoints_all:
            if ee==ii:    
                ax.plot(nx[ii]*np.ones((snipmat.shape[1])),ny,snipmat[ii,:], color='orange')
            
    #cbar = plt.colorbar(pic)
    #ax.set_xlabel('rec time [min]')
    #ax.set_ylabel('pulse time [s]')
    #ax.set_zlabel('pulse amp [mV]')
    ax.set_zlim3d(-2.0,0.5) #for EP32 and EP72
    ax.set_xlim3d(0,60) #for EP32 and EP72
    ax.set_ylim3d(-0.1,0.2) #for EP32 and EP72

    ax.tick_params(axis="x", labelsize=16)
    ax.tick_params(axis="y", labelsize=16)
    ax.tick_params(axis="z", labelsize=16)
    ax.xaxis.set_tick_params(width=2)
    ax.yaxis.set_tick_params(width=2)
    ax.zaxis.set_tick_params(width=2)
    
# get mean spiketrace of 4 quartes in recording
def plot_4Quaters_interictal_mean(period, snipmat, ictal_timepoints, Q1, Q2, Q3, Q4):
        for i in ictal_timepoints:     
            snipmat[int(i/period),:] = np.nan #set row of index "ictal timepoints" to nan, excluding events in interictal periods
 
        snipmat1=snipmat[Q1]
        snipmat2=snipmat[Q2]
        snipmat3=snipmat[Q3]
        snipmat4=snipmat[Q4]
        meanQ1 = np.nanmean(snipmat1,0) #axis = 0 taking the mean of each element of each row
        meanQ2 = np.nanmean(snipmat2,0)
        meanQ3 = np.nanmean(snipmat3,0)
        meanQ4 = np.nanmean(snipmat4,0)
        
        plt.figure()
        plt.rcParams["figure.figsize"] = (6, 5)
        timevec = np.arange(-(pre),post-0.00001,1./sr)
        plt.plot(timevec,meanQ1,'-', color='navy', linewidth = 2) 
        plt.plot(timevec,meanQ2,'-', color='mediumblue', linewidth = 2) 
        plt.plot(timevec,meanQ3,'-', color='royalblue', linewidth = 2) 
        plt.plot(timevec,meanQ4,'-', color='cornflowerblue', linewidth = 2) 
        plt.legend(['0-15 min', '15-30 min', '30-45 min', '45-60 min'])
        plt.xlabel('time [s]')
        plt.ylabel('amplitude [mV]')   
        plt.xlim(-0.025,0.2)
        plt.ylim(-2.1,0.3) #for EP32
        plt.xticks(fontsize = 16)
        plt.yticks(fontsize = 16)
        mpl.rcParams['xtick.major.width'] = 2
        mpl.rcParams['ytick.major.width'] = 2
        mpl.rcParams['axes.linewidth'] = 2
        #plt.title(recID) 
        
        
def plot_burst_auc_smoothn(interictal_timepoints,interictal_aucs,\
                L_timepoints_all, L_aucs_all, XL_timepoints_all, XL_aucs_all): 
    
    plt.figure()
    plt.rcParams["figure.figsize"] = (6, 5)
    plt.plot(interictal_timepoints/60., interictal_aucs, 'o', color='grey', linewidth=1, markersize=2)
    plt.plot(L_timepoints_all/60., L_aucs_all, 'o', color='orange', linewidth=1, markersize=2)
    plt.plot(XL_timepoints_all/60., XL_aucs_all, 'o', color='orange', linewidth=1, markersize=2)
    
    p = np.polyfit(interictal_timepoints,interictal_aucs,9) #p is your new/smooth x
    polyvals=np.polyval(p,interictal_timepoints)#smooted values
    plt.plot(interictal_timepoints/60.,polyvals, '-', color = 'blue')
      
    plt.xlabel('rec time [s]')
    plt.ylabel('auc') #area under the curve
    plt.ylim(200., 1200.)
    plt.xlim(0,60.)
    
    plt.xticks(fontsize = 16)
    plt.yticks(fontsize = 16)
    mpl.rcParams['xtick.major.width'] = 2
    mpl.rcParams['ytick.major.width'] = 2
    mpl.rcParams['axes.linewidth'] = 2         


#%% more preparation
    
#get arrays with pulsetimes for 1 Hz, 0.5 Hz and 0.2 Hz 
#(for the recordings that had a pulser) saved in txt files

pathtoPulsetimes_1Hz = 'C:\EAfiles\Pulsetimes\\Pulsetimes_1Hz.txt' 
Pulsetimes_1Hz = []

with open (pathtoPulsetimes_1Hz,'r') as f:
    pulse = f.read().splitlines()
    Pulsetimes_1Hz = Pulsetimes_1Hz + pulse
    
Pulsetimes_1Hz = map(float, Pulsetimes_1Hz) #converts stings from txt into floats
Pulsetimes_1Hz = np.asarray(Pulsetimes_1Hz) #converts list into array
    
pathtoPulsetimes_05Hz = 'C:\EAfiles\Pulsetimes\\Pulsetimes_05Hz.txt' 
Pulsetimes_05Hz = []

with open (pathtoPulsetimes_05Hz,'r') as f:
    pulse = f.read().splitlines()
    Pulsetimes_05Hz = Pulsetimes_05Hz + pulse
    
Pulsetimes_05Hz = map(float, Pulsetimes_05Hz)
Pulsetimes_05Hz = np.asarray(Pulsetimes_05Hz) 

pathtoPulsetimes_02Hz = 'C:\EAfiles\Pulsetimes\\Pulsetimes_02Hz.txt' 
Pulsetimes_02Hz = []

with open (pathtoPulsetimes_02Hz,'r') as f:
    pulse = f.read().splitlines()
    Pulsetimes_02Hz = Pulsetimes_02Hz + pulse

Pulsetimes_02Hz = map(float, Pulsetimes_02Hz)
Pulsetimes_02Hz = np.asarray(Pulsetimes_02Hz) 

#%% get recording IDs

pathtostim1HzIDspul = 'C:\EAfiles\ID_Dateien\\stim1HzIDs_pulser.txt'
stim1HzIDs_pulser = []     
with open (pathtostim1HzIDspul,'r') as f:
    badIDs = f.read().splitlines()
    stim1HzIDs_pulser = stim1HzIDs_pulser + badIDs #this works with strings

pathtostim1HzIDspris = 'C:\EAfiles\ID_Dateien\\stim1HzIDs_prismatix.txt'
stim1HzIDs_prismatix = []     
with open (pathtostim1HzIDspris,'r') as f:
    badIDs = f.read().splitlines()
    stim1HzIDs_prismatix = stim1HzIDs_prismatix + badIDs  
    
    
pathtostim05HzIDspul = 'C:\EAfiles\ID_Dateien\\stim05HzIDs_pulser.txt'
stim05HzIDs_pulser = []     
with open (pathtostim05HzIDspul,'r') as f:
    badIDs = f.read().splitlines()
    stim05HzIDs_pulser = stim05HzIDs_pulser + badIDs #this works with strings

pathtostim05HzIDspris = 'C:\EAfiles\ID_Dateien\\stim05HzIDs_prismatix.txt'
stim05HzIDs_prismatix = []     
with open (pathtostim05HzIDspris,'r') as f:
    badIDs = f.read().splitlines()
    stim05HzIDs_prismatix = stim05HzIDs_prismatix + badIDs  
    
    
pathtostim02HzIDspul = 'C:\EAfiles\ID_Dateien\\stim02HzIDs_pulser.txt'
stim02HzIDs_pulser = []     
with open (pathtostim02HzIDspul,'r') as f:
    badIDs = f.read().splitlines()
    stim02HzIDs_pulser = stim02HzIDs_pulser + badIDs #this works with strings

pathtostim02HzIDspris = 'C:\EAfiles\ID_Dateien\\stim02HzIDs_prismatix.txt'
stim02HzIDs_prismatix = []     
with open (pathtostim02HzIDspris,'r') as f:
    badIDs = f.read().splitlines()
    stim02HzIDs_prismatix = stim02HzIDs_prismatix + badIDs  

#%% ....more recording IDs
    
filenames_con1_all_1Hz =['EP10_1Hz1.smr','EP10_1Hz2.smr','EP11_1Hz3.smr',\
                 'EP28_1Hz1-EP30_pre1-1.smr','EP29_1Hz1-EP28_pre1-1.smr',\
                 'EP29_1Hz2-EP28_pre2-1.smr','EP30_1Hz2-EP28_p2-2.smr',\
                 'EP32_1Hz1-EP33_pre1-1.smr','EP32_1Hz2-EP33_pre2-1.smr',\
                 'EP3_1Hz1-EP4_pre1-1.smr','EP3_1Hz2-EP5_pre2-1.smr',\
                 'EP46_1Hz1-EP44_pre1-1.smr','EP46_1Hz2-EP44_pre2-1.smr',\
                 'EP4_1Hz2.smr','EP5_1Hz1.smr','EP62_1Hz1.smr','EP62_1Hz2.smr',\
                 'EP63_1Hz1-EP64_pre1-1.smr','EP63_1Hz2-EP64_pre2-1.smr',\
                 'EP69_1Hz2-EP70_20d000.smr','EP6_1Hz1.smr','EP70_1Hz2-EP72_p3-1.smr',\
                 'EP72_1Hz1-EP71_pre1-1.smr','EP72_1Hz2-EP71_pre2-1.smr',\
                 'EP7_1Hz2-EP8_p2-1.smr','EP9_1Hz1_EP11_pre1-1.smr','EP9_1Hz2-EP11_pre2-1.smr',\
                 'PJ225-226_21d_stim-225.smr','PJ225-226_22d_225-stim000.smr']

filenames_con2_all_1Hz =['EP9_p1-1_EP11_1Hz1.smr','EP9_p2-1-EP11_1Hz2.smr','EP29_p2-1-EP28_1Hz2.smr',\
                     'EP28_p1-1-EP30_1Hz1.smr','EP32_p1-1-EP33_1Hz1.smr','EP32_p2-1-EP33_1Hz2.smr',\
                     'EP46_p1-1-EP44_1Hz1.smr','EP46_p2-2-EP44_1Hz2.smr',\
                     'EP3_p1-1-EP4_1Hz1.smr','EP3_p2-1-EP5_1Hz2.smr','EP63_p1-1-EP64_1Hz1.smr',\
                     'EP63_p2-1-EP64_1Hz2.smr','EP70_pre1-1-EP69_1Hz1.smr','EP6_1Hz2.smr',\
                     'EP72_p1-1-EP71_1Hz1.smr','EP72_p2-1-EP71_1Hz2.smr','EP70_pre2-1-EP72_1Hz3.smr',\
                     'EP7_p1-1-EP8_1Hz1.smr','EP7_pre2-1-EP8_1Hz2.smr','PJ225-226_21d_stim-226.smr',\
                     'PJ225-226_22d_226-stim000.smr','PJ227-228_21d_228-stim000.smr',\
                     'PJ227-228_22d_228-stim000.smr', 'EP69_pre1-3-EP70_1Hz1.smr']

filenames_con1_pulser_1Hz = ['EP28_1Hz1-EP30_pre1-1.smr','EP29_1Hz1-EP28_pre1-1.smr',\
                 'EP29_1Hz2-EP28_pre2-1.smr','EP30_1Hz2-EP28_p2-2.smr',\
                 'EP32_1Hz1-EP33_pre1-1.smr','EP32_1Hz2-EP33_pre2-1.smr',\
                 'EP62_1Hz1.smr','EP62_1Hz2.smr','EP63_1Hz1-EP64_pre1-1.smr',\
                 'EP63_1Hz2-EP64_pre2-1.smr','EP69_1Hz2-EP70_20d000.smr',\
                 'EP70_1Hz2-EP72_p3-1.smr','EP72_1Hz1-EP71_pre1-1.smr',\
                 'EP72_1Hz2-EP71_pre2-1.smr', 'EP74_1Hz1.smr','EP74_1Hz2.smr']

filenames_con2_pulser_1Hz =['EP29_p2-1-EP28_1Hz2.smr','EP69_pre1-3-EP70_1Hz1.smr',\
                     'EP28_p1-1-EP30_1Hz1.smr','EP32_p1-1-EP33_1Hz1.smr','EP32_p2-1-EP33_1Hz2.smr',\
                     'EP46_p1-1-EP44_1Hz1.smr','EP46_p2-2-EP44_1Hz2.smr','EP63_p1-1-EP64_1Hz1.smr',\
                     'EP63_p2-1-EP64_1Hz2.smr','EP70_pre1-1-EP69_1Hz1.smr',\
                     'EP72_p1-1-EP71_1Hz1.smr','EP72_p2-1-EP71_1Hz2.smr','EP70_pre2-1-EP72_1Hz3.smr']

filenames_con1_pulser_05Hz = ['EP32_05Hz1.smr']

filenames_con1_pulser_02Hz = ['EP32_02Hz1.smr']
  

    
#%% LOAD !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
#....more recording IDs
    
filename = 'C:\EAfiles\MyDicts\stimRecIDs_1Hz_smr'
infile = open(filename,'rb') #read binary
stimRecIDs_1Hz_smr = pickle.load(infile)
infile.close() 

filename = 'C:\EAfiles\MyDicts\stimRecIDs_05Hz_smr'
infile = open(filename,'rb') #read binary
stimRecIDs_05Hz_smr = pickle.load(infile)
infile.close() 

filename = 'C:\EAfiles\MyDicts\stimRecIDs_02Hz_smr'
infile = open(filename,'rb') #read binary
stimRecIDs_02Hz_smr = pickle.load(infile)
infile.close() 

recIDs_1Hz = stimRecIDs_1Hz_smr.keys() 
recIDs_05Hz = stimRecIDs_05Hz_smr.keys() 
recIDs_02Hz = stimRecIDs_02Hz_smr.keys()

filenames_1Hz = stimRecIDs_1Hz_smr.values()
filenames_05Hz = stimRecIDs_05Hz_smr.values()
filenames_02Hz = stimRecIDs_02Hz_smr.values()        

filenames_1Hz.remove('PJ225-226_21d_stim-226.smr')
filenames_1Hz.remove('PJ225-226_21d_stim-225.smr')
filenames_1Hz.remove('PJ227-228_21d_228-stim000.smr')
filenames_1Hz.remove('PJ227-228_22d_228-stim000.smr')     
filenames_1Hz.remove('EP6_1Hz1.smr')
filenames_1Hz.remove('EP6_1Hz2.smr')  
filenames_1Hz.remove('EP10_1Hz1.smr')
filenames_1Hz.remove('EP10_1Hz2.smr')
filenames_1Hz.remove('EP70_pre1-1-EP69_1Hz1.smr')
filenames_1Hz.remove('EP69_1Hz2-EP70_20d000.smr')  
filenames_1Hz.remove('EP70_1Hz2-EP72_p3-1.smr') 
filenames_1Hz.remove('EP69_pre1-3-EP70_1Hz1.smr') 
filenames_1Hz.remove('EP32_p2-1-EP33_1Hz2.smr') 
filenames_1Hz.remove('EP32_p1-1-EP33_1Hz1.smr') 

filenames_05Hz.remove('EP63_05Hz1-EP64_pre1-2.smr')
    
 #%%   plot mean responses (only interictal)
 
folder = r'G:\Enya\EEG\recordings'

chanlist = [ 'HCi1']
#chanlist = [ 'HCi1', 'HCi2','HCc']
chanlist_2 = [ 'HCi1_2', 'HCi2_2','HCc_2']

#for filename in filenames:
for filename in ['EP74_1Hz1.smr','EP32_1Hz1-EP33_pre1-1.smr','EP32_05Hz1.smr','EP32_02Hz1.smr']:
    
    if filename in filenames_con1_pulser_1Hz:
        
        period = 1.
        
        recID = stimRecIDs_1Hz_smr.keys()[stimRecIDs_1Hz_smr.values().index(filename)] 
    
        ddict = hf.extract_smrViaNeo(os.path.join(folder, filename), chanlist=chanlist)
        
        for chan in chanlist:
        
            datatrace = ddict[chan]['trace']
            spiketimes = Pulsetimes_1Hz
            
            filtered = signal.filtfilt(cheby_b, cheby_a, datatrace, padlen = 0)
            
            snipmat, auc, scalarMap = get_snipmat_auc_scalarMap(pre, post, filtered, spiketimes)
                       
            L_aucs_all, L_timepoints_all, XL_aucs_all, XL_timepoints_all,\
             interictal_aucs, ictal_aucs, interictal_timepoints,\
             ictal_timepoints = get_burst_aucs(auc, srdict, recID, period)
                    
            plot_3D_AUC_burst(snipmat, auc, scalarMap, 0, 60, 1.0/60, L_timepoints_all, XL_timepoints_all)

            Q1 = np.arange(0,900)
            Q2 = np.arange(900,1800)
            Q3 = np.arange(1800,2700)
            Q4 = np.arange(2700,3600)
               
            plot_4Quaters_interictal_mean(1., snipmat, ictal_timepoints, Q1, Q2, Q3, Q4)
            
            plot_burst_auc_smoothn(interictal_timepoints,interictal_aucs,\
                    L_timepoints_all, L_aucs_all, XL_timepoints_all, XL_aucs_all)

    if filename in filenames_con2_pulser_1Hz:
        pass
        

    if filename in filenames_con1_pulser_05Hz:
        
        period = 2.
        
        recID = stimRecIDs_05Hz_smr.keys()[stimRecIDs_05Hz_smr.values().index(filename)]
    
        ddict = hf.extract_smrViaNeo(os.path.join(folder, filename), chanlist=chanlist)
        
        for chan in chanlist:
        
            datatrace = ddict[chan]['trace']
            spiketimes = Pulsetimes_05Hz
            
            filtered = signal.filtfilt(cheby_b, cheby_a, datatrace, padlen = 0)
            
            snipmat, auc, scalarMap = get_snipmat_auc_scalarMap(pre, post, filtered, spiketimes)
            
                      
            L_aucs_all, L_timepoints_all, XL_aucs_all, XL_timepoints_all,\
             interictal_aucs, ictal_aucs, interictal_timepoints,\
             ictal_timepoints = get_burst_aucs(auc, srdict, recID, period)
            
            plot_3D_AUC_burst(snipmat, auc, scalarMap, 0, 60, 2.0/60, L_timepoints_all, XL_timepoints_all)
        
            Q1 = np.arange(0,450)
            Q2 = np.arange(450,900)
            Q3 = np.arange(900,1350)
            Q4 = np.arange(1350,1800)  
            
            plot_4Quaters_interictal_mean(2., snipmat, ictal_timepoints, Q1, Q2, Q3, Q4)
            
            plot_burst_auc_smoothn(interictal_timepoints,interictal_aucs,\
                    L_timepoints_all, L_aucs_all, XL_timepoints_all, XL_aucs_all)


    if filename in filenames_con1_pulser_02Hz:
        
        period = 5.
        
        recID = stimRecIDs_02Hz_smr.keys()[stimRecIDs_02Hz_smr.values().index(filename)]
    
        ddict = hf.extract_smrViaNeo(os.path.join(folder, filename), chanlist=chanlist)
        
        for chan in chanlist:
        
            datatrace = ddict[chan]['trace']
            spiketimes = Pulsetimes_02Hz
            filtered = signal.filtfilt(cheby_b, cheby_a, datatrace, padlen = 0)
            
            snipmat, auc, scalarMap = get_snipmat_auc_scalarMap(pre, post, filtered, spiketimes)
                      
            L_aucs_all, L_timepoints_all, XL_aucs_all, XL_timepoints_all,\
             interictal_aucs, ictal_aucs, interictal_timepoints,\
             ictal_timepoints = get_burst_aucs(auc, srdict, recID, period)
                
            plot_3D_AUC_burst(snipmat, auc, scalarMap, 0, 60, 5.0/60, L_timepoints_all, XL_timepoints_all)
        
            Q1 = np.arange(0,180)
            Q2 = np.arange(180,360)
            Q3 = np.arange(360,540)
            Q4 = np.arange(540,720)       
            
            plot_4Quaters_interictal_mean(5., snipmat, ictal_timepoints, Q1, Q2, Q3, Q4)    
        
            plot_burst_auc_smoothn(interictal_timepoints,interictal_aucs,\
                L_timepoints_all, L_aucs_all, XL_timepoints_all, XL_aucs_all)

            
            
#%% get median auc for all chan and all recs
            
folder = r'G:\Enya\EEG\recordings'

aucdict = {}

for filename in filenames_1Hz: 
    
    period = 1.
    
    recID = stimRecIDs_1Hz_smr.keys()[stimRecIDs_1Hz_smr.values().index(filename)]
    print recID
    
    if recID in ['PJ225_1Hz1_HCi1','PJ226_1Hz1_HCi1_2','PJ228_1Hz1_HCi1_2',\
                    'PJ228_1Hz2_HCi1_2','']:
        pass
    else:
        
        aucdict[recID] = {} 
        
        if recID [-1:] == '2':
            chanlist = [ 'HCi1_2', 'HCi2_2', 'HCc_2' ]
        else:
            chanlist = [ 'HCi1', 'HCi2', 'HCc' ]
        
        ddict = hf.extract_smrViaNeo(os.path.join(folder, filename), chanlist=chanlist)
                
        for chan in chanlist:
            
            aucdict[recID][chan] = {}
            
            datatrace = ddict[chan]['trace']
            
            filtered = signal.filtfilt(cheby_b, cheby_a, datatrace, padlen = 0)
            
            zTrace = (filtered-np.mean(filtered))/np.std(filtered)

            filtered = zTrace
            
            if recID in stim1HzIDs_pulser:
                spiketimes = Pulsetimes_1Hz
    
            if recID in stim1HzIDs_prismatix:
                spiketimes = Theospike_dict[recID]

            snipmat, auc, scalarMap = get_snipmat_auc_scalarMap(pre, post, filtered, spiketimes)
            
            L_aucs_all, L_timepoints_all, XL_aucs_all, XL_timepoints_all,\
             interictal_aucs, ictal_aucs, interictal_timepoints,\
             ictal_timepoints = get_burst_aucs(auc, srdict, recID, period)
             
            for i in ictal_timepoints:     
                auc[int(i/period)] = np.nan #set row of index "ictal timepoints" to 0
            
            auc_median = np.nanmedian(auc)
            #print auc_median
             
            aucdict[recID][chan] = {'auc_median':auc_median}   
            

for filename in filenames_05Hz:
    
    period = 2.
    
    recID = stimRecIDs_05Hz_smr.keys()[stimRecIDs_05Hz_smr.values().index(filename)]
    print recID
    
    if recID in ['EP63_05Hz1_HCi1','EP9_05Hz2_HCi1']:
        pass
    else:
           
        aucdict[recID] = {} 
        
        if recID [-1:] == '2':
            chanlist = [ 'HCi1_2', 'HCi2_2', 'HCc_2' ]
        else:
            chanlist = [ 'HCi1', 'HCi2', 'HCc' ]
        
        ddict = hf.extract_smrViaNeo(os.path.join(folder, filename), chanlist=chanlist)
                
        for chan in chanlist:
            
            aucdict[recID][chan] = {}
            
            datatrace = ddict[chan]['trace']
            
            filtered = signal.filtfilt(cheby_b, cheby_a, datatrace, padlen = 0)
            
            zTrace = (filtered-np.mean(filtered))/np.std(filtered)

            filtered = zTrace
            
            if recID in stim05HzIDs_pulser:
                spiketimes = Pulsetimes_05Hz
    
            if recID in stim05HzIDs_prismatix:
                spiketimes = Theospike_dict[recID]

            snipmat, auc, scalarMap = get_snipmat_auc_scalarMap(pre, post, filtered, spiketimes)
            
            L_aucs_all, L_timepoints_all, XL_aucs_all, XL_timepoints_all,\
             interictal_aucs, ictal_aucs, interictal_timepoints,\
             ictal_timepoints = get_burst_aucs(auc, srdict, recID, period)
             
            for i in ictal_timepoints:     
                auc[int(i/period)] = np.nan #set row of index "ictal timepoints" to 0
            
            auc_median = np.nanmedian(auc)
            print auc_median
             
            aucdict[recID][chan] = {'auc_median':auc_median}
        
for filename in filenames_02Hz: 
    
    period = 5.
    
    recID = stimRecIDs_02Hz_smr.keys()[stimRecIDs_02Hz_smr.values().index(filename)]
    print recID
    
    aucdict[recID] = {} 
    
    if recID [-1:] == '2':
        chanlist = [ 'HCi1_2', 'HCi2_2', 'HCc_2' ]
    else:
        chanlist = [ 'HCi1', 'HCi2', 'HCc' ]
    
    ddict = hf.extract_smrViaNeo(os.path.join(folder, filename), chanlist=chanlist)
            
    for chan in chanlist:
                
            aucdict[recID][chan] = {}
            
            datatrace = ddict[chan]['trace']
            
            filtered = signal.filtfilt(cheby_b, cheby_a, datatrace, padlen = 0)
                
            if recID in stim02HzIDs_pulser:
                spiketimes = Pulsetimes_02Hz
                
            if recID in stim02HzIDs_prismatix:
                spiketimes = Theospike_dict[recID]

            snipmat, auc, scalarMap = get_snipmat_auc_scalarMap(pre, post, filtered, spiketimes)
            
            L_aucs_all, L_timepoints_all, XL_aucs_all, XL_timepoints_all,\
             interictal_aucs, ictal_aucs, interictal_timepoints,\
             ictal_timepoints = get_burst_aucs(auc, srdict, recID, period)
             
            for i in ictal_timepoints:     
                auc[int(i/period)] = np.nan #set row of index "ictal timepoints" to 0
            
            auc_median = np.nanmedian(auc)             
            aucdict[recID][chan] = {'auc_median':auc_median}

#save dictionary         
filename = r'C:\EAfiles\MyDicts\auc_dict'
outfile = open(filename, 'wb') #write binary: mit Erlaubnis zu (über)schreiben
pickle.dump(aucdict, outfile)
outfile.close()
# later just load the dict
