# -*- coding: utf-8 -*-
"""
Created on Thu Feb 15 13:12:44 2018

@author: mscholz
"""

from __future__ import division

import numpy as np
from numpy import exp,sqrt
import lmfit
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import PIL.Image as Image
import matplotlib.image as mpimg
import glob
import glob2
import ntpath
import os.path
import pprint

from fretbursts import *
import phconvert as phc
import pycorrelate as pyc

import seaborn as sns
sns.set_style('white')
import scipy.optimize
from scipy.optimize import curve_fit

#%% _______CONVERT______
"""
-------------------------------------------------------------------------------
CONVERT FUNCTIONS                                                       
------------------------------------------------------------------------------- 
"""

class Sett_hdf5(object):
    """ the instance of the class contains parameters needed for conversion of sm to hdf5 """ 
    def __init__(self,**kwargs):
        self.author = 'author'
        self.author_affiliation = 'affiliation'
        self.description = 'description'
        self.sample_name = 'sample_name'
        self.dye_names = 'dye1,dye2'
        self.buffer_name = 'buffer'
        
        self.donor = 0
        self.acceptor = 1
        self.alex_period = 8000
        self.alex_offset = 10
        self.alex_period_donor = (300, 3300)
        self.alex_period_acceptor = (4400, 7300)
        
        for key in kwargs:
            if hasattr(self,key):
                setattr(self,key,kwargs[key])
            else:
                print('key '+key+' not recognized and was skipped')
                
def setd_PAX(d,**kwargs):
    """ Sets the Data() object for PAX experiment. If no kwargs specified, a default settings is applied.
        kwargs are passed to the d.add() function. """
    dset = dict()
    dset['ALEX'] = False
    dset['meas_type'] = 'PAX'
    dset['det_donor_accept'] = (1, 0)
    dset['D_ON'] = (10,1800)
    dset['A_ON'] = (1800,4090)
    dset['offset'] = 0
    dset['pax']=True #this was added 180228 and has to be tested properly
    for key in kwargs:
        dset[key] = kwargs[key]
    d.add(**dset)
    d.setup['excitation_alternated'] = np.array([0,1], dtype=np.uint8)
    d.setup['excitation_cw'] = np.array([1,0], dtype=np.uint8)
                
                
def get_sett(mode):
    """ Returns Sett_hdf5 objects with particular settings according to mode. """
    if mode=="NAR2014":
        return Sett_hdf5(donor=0, acceptor=1, alex_period=8000, alex_offset=10, alex_period_donor=(300,3300), alex_period_acceptor=(4400, 7300))
    elif mode=="whatever":
        return Sett_hdf5(donor=0, acceptor=1, alex_period=8000, alex_offset=10, alex_period_donor=(300,3300), alex_period_acceptor=(4400, 7300))
    else:
        return Sett_hdf5()
    
    
def convert_sm(filename,sett,flag_save=True, flag_plot_althist=False):
    """ Loads a single .sm file using settings in sett object.
        Returns a dictionary reflecting the structure of the hdf5 file.
        Saves data as hdf5 if flag_save==True.
        
        Example: convert_sm('file1.sm', get_sett("NAR2014")) 
    """
    try: 
        with open(filename, mode='r'): pass
        print('Data file found, OK')              
                  
        dic = phc.loader.usalex_sm(filename,
                                 donor = sett.donor,
                                 acceptor = sett.acceptor,
                                 alex_period = sett.alex_period,
                                 alex_offset = sett.alex_offset,
                                 alex_period_donor = sett.alex_period_donor,
                                 alex_period_acceptor = sett.alex_period_acceptor)
                                 
        if flag_plot_althist:
            phc.plotter.alternation_hist(dic)
                                                                     
        dic['description'] = sett.description        
        dic['sample'] = dict(
            sample_name=sett.sample_name,
            dye_names=sett.dye_names,
            buffer_name=sett.buffer_name,
            num_dyes = len(sett.dye_names.split(',')))        
        dic['identity'] = dict(
            author=sett.author,
            author_affiliation=sett.author_affiliation)

        if flag_save:
            h = phc.hdf5.save_photon_hdf5(dic, overwrite=True)
            fnhdf = get_path_woext(filename)+'.hdf5'
        else:
            fnhdf=''
        
        return fnhdf,dic
    
    except IOError:
        print('ATTENTION: Data file not found, please check the filename.\n'
              '           (current value "%s")' % filename)
              


def convert_sm_multi(sett, pattern="**/*.sm", avoid_list=['sm.bursts.sm'], flag_plot_althist=False):
    """ converts to hdf5 all files matching the pattern in the sense of glob2 pattern matching,
        if the filename does not contain any of the strings listed in avoid_list.
        sett is the Sett_Hdf5 object defining the parameters of the resulting hdf5 file.
    
        Example: convert_sm_multi(get_sett("NAR2014"), pattern="**/*.sm")
                 # it converts all .sm files in the current folder and its subfolders 
    """
        
    def test(fname,avoidlist):
        return [True for s in avoidlist if s in fname]==[]
    
    fns = [f for f in glob2.glob(pattern) if test(f,avoid_list)] 
    for fn in fns:
        print(fn)
        try:
            convert_sm(fn,sett,flag_plot_althist=flag_plot_althist)
        except:
            print('...unable to convert')
            
            
def merge_sm_to_hdf(fnlist,sett,outname='',flag_plot_althist=False):
    """ Files in fnlist are merged into one hdf file (with the name of the list file in the list '+ _merge')
        
        Example: merge_sm_to_hdf([file1.sm,file2.sm],get_sett("NAR2014"))
        Example: merge_sm_to_hdf(glob.glob('folder/*.sm'),get_sett("NAR2014"))
    """
    Fdet,Ftime = np.array([],dtype=np.int64),np.array([],dtype=np.int64) #np.empty([],dtype=np.int64),np.empty([],dtype=np.int64)
    dic = {}
    shift = 0
    duration = 0
    
    for fn in fnlist:        
        #try:
        print (fn)      
        _,dic = convert_sm(fn,sett,flag_save=False,flag_plot_althist=flag_plot_althist)
        Fdet = np.concatenate((Fdet, np.int64(dic['photon_data']['detectors' ])),axis=0)
        Ftime= np.concatenate((Ftime,np.int64(dic['photon_data']['timestamps'] + shift)),axis=0)
        duration += dic['acquisition_duration']
        alex_period = dic['photon_data']['measurement_specs']['alex_period']
        alex_offset = dic['photon_data']['measurement_specs']['alex_offset']
        shift = (Ftime[-1]//alex_period + 1)*alex_period - alex_offset
        #except:
        #    print("...unable to convert "+fn)
            
    if dic != {}:
        if outname=='':
            dic['_filename'] = dic['_filename'][0:-3]+'_merged'
        else:
            dic['_filename'] = get_fname_woext(outname)
        dic['photon_data']['detectors'] = Fdet
        dic['photon_data']['timestamps'] = Ftime
        dic['acquisition_duration'] = duration
        print ('saving merged data file')
        phc.hdf5.save_photon_hdf5(dic, overwrite=True)
    else:
        print('None of the input files was converted sucessfully and therefore there is nothing to be merged.')
    
    return dic['_filename']+'.hdf5'
    
    
def merge_sm_to_hdf_patt(sett, outname='',patt="*.sm", contains_list=[], avoid_list=['sm.bursts.sm'],flag_plot_althist=False):
    """ Searches files according to the glob2 pattern,
        and chooses those filenames that contain at least one substring from the contains_list
        and do not contain any substring from avoid_list.
        These multiple .sm files are then merged into one hdf file
        
        Example: merge_sm_to_hdf_patt(get_sett("NAR2014"), patt="**/*.sm", contains_list=['-00','-01','-02'], avoid_list=['sm.bursts.sm'])
    """
    fnlist = find_files(patt, contains_list=contains_list, avoid_list=avoid_list)
    fnhdf = merge_sm_to_hdf(fnlist,sett,outname=outname,flag_plot_althist=flag_plot_althist)
    return fnhdf
        


"""    
def merge_hdf_to_hdf(fnlist,sett):
    # This function is not complete, see the comments at the end.
    Fdet,Ftime = np.array([],dtype=np.bool),np.array([],dtype=np.int64)
    for fn in fnlist:        
        print (fn)
        d = loader.photon_hdf5(fn)
        Fdet = np.concatenate((Fdet, d.A_em[0]),axis=0)
        Ftime= np.concatenate((Ftime, d.ph_times_m[0]),axis=0)
        # Here has to come some function which either saves the Data() object as hdf5,
        # or a function which extracts the corresponding dictionary from Data() object which then can be saved by phc.hdf5.save_photon_hdf5(dic, overwrite=True).
        # I haven't found such functions yet. Creating the dictionary from the Data() attributes would be a pain.
"""


def spc2hdf(fn_spc, donor=1, acceptor=0, hdfoverwr = False):
    """ Converts the Becker&Hickl spc file into hdf5. Returns the output hdf5 filename. """
    fn_hdf = fn_spc[:-3]+'hdf5'
    if os.path.isfile(fn_hdf) and not hdfoverwr:
        print('hdf already exists, the generation of hdf file will be skipped.')
    else:
        datadic = phc.loader.nsalex_bh(fn_spc,
                              donor=donor,
                              acceptor=acceptor)
        datadic[0]['description']='description'
        phc.hdf5.save_photon_hdf5(datadic[0])
    return fn_hdf

def load_hdf(fn_hdf):
    """ Loads the hdf5 file and returns the fretbursts Data() object. """ 
    return loader.photon_hdf5(fn_hdf)    
    

#%%_______B&H SPC FUNCTIONS_________
"""
-------------------------------------------------------------------------------
B&H SPC FUNCTIONS                                                       
------------------------------------------------------------------------------- 
"""

def spc2hdf(fn_spc, donor=1, acceptor=0, hdfoverwr = False):
    fn_hdf = fn_spc[:-3]+'hdf5'
    if os.path.isfile(fn_hdf) and not hdfoverwr:
        print('hdf already exists, the generation of hdf will be skipped.')
    else:
        datadic = phc.loader.nsalex_bh(fn_spc,
                              donor=donor,
                              acceptor=acceptor)
        datadic[0]['description']='description'
        phc.hdf5.save_photon_hdf5(datadic[0])
    return fn_hdf

def setd(d,**kwargs):
    dset = dict()
    dset['ALEX'] = False
    dset['meas_type'] = 'PAX'
    dset['det_donor_accept'] = (1, 0)
    dset['D_ON'] = (10,1800)
    dset['A_ON'] = (1800,4090)
    dset['offset'] = 0
    for key in kwargs:
        dset[key] = kwargs[key]
    d.add(**dset)
    d.setup['excitation_alternated'] = np.array([0,1], dtype=np.uint8)
    d.setup['excitation_cw'] = np.array([1,0], dtype=np.uint8)

    
def load_hdf(fn_hdf):
    return loader.photon_hdf5(fn_hdf)




#%% ______PLOT FUNCTIONS______
"""
-------------------------------------------------------------------------------
PLOT FUNCTIONS                                                       
------------------------------------------------------------------------------- 
"""

def plot_timetrace(d,step_height=50,**kwargs):
    tt,y = burst_display_fun(d)
    dplot(d, timetrace, binwidth=1e-3, scroll=True, bursts=False, burst_picker=False, show_rate_th=True, F=True, legend=True, **kwargs)
    plt.plot(tt,y*step_height,tt,-y*step_height,'-',color='gray', alpha=0.3)
    plt.grid('off')


def plot_ratetrace(d, step_height=50):
    tt,y = burst_display_fun(d)
    dplot(d, ratetrace, scroll=True, bursts=False, burst_picker=False, show_rate_th=True, F=True, legend=True)
    plt.grid('off')
    plt.plot(tt,y*step_height,tt,-y*step_height,'-',color='gray', alpha=0.3)

    
def burst_display_fun(d):
    """ Generates a function which has value of 1 at times when there is a burst and 0 at times where there is not burst.
        This function is typically called by the timetrace function."""
    bursts = d.mburst[0]
    t = np.zeros(bursts.size*4)
    y = np.zeros(bursts.size*4)
    st = np.zeros(bursts.size)
    i = 0
    for b in bursts:
        t[i:i+4] = np.array([b.start-1,b.start,b.stop,b.stop+1])
        y[i:i+4] = np.array([0,1,1,0])
        i+=4
    return t*d.clk_p, y


def plot_EShist(d):
    plt.figure()
    sns.kdeplot(d.E[0],d.S[0],shade=True, cmap='hot')
    plt.scatter(d.E[0],d.S[0],s=3,color='white',alpha=0.5)
    plt.xlabel('E')
    plt.ylabel('S')
    plt.title('#bu:%d' % (d.mburst[0].size))

def plot_Ehist_gaussfit(ds, numgauss=2, model=None):
    """ Fits the FRET histogram with one, two, or three gaussians (numgauss parameter).
        If model parameter is defined, then numgauss is ignored and the passed model function is used instead (e.g. model = mfit.factory_two_gaussians(add_bridge=True)) 
        Returns the fit parameters """
    if model is None:    
        if numgauss==1:
            model = mfit.factory_gaussian()
        elif numgauss==2:
            model = mfit.factory_two_gaussians()
        elif numgauss==3:
            model = mfit.factory_three_gaussians()

    dplot(ds, hist_fret)
    plt.close()
    ds.E_fitter.fit_histogram(model, verbose=False)
    dplot(ds,hist_fret,show_model=True)
    ds.E_fitter.fit_res[0].params.pretty_print(columns=['value','stderr'])
    return ds.E_fitter.params.iloc[0]

def plot_hist_width(d,Ebins=None,tbins=None,**kwargs):
    """ Plots the histogram of burst-widths (i.e. burst time-durations). 
        If Ebins supplied, then it plots a separate burst-width histogram for the intervals of E_FRET defined by Ebins. 
        kwargs is passed to the plt.hist(...) function.
        This is an extension to 'dplot(ds, hist_width)'
    """
    
    if Ebins is not None:
        #creates mask for every E bin: the element of mask is True if the burst belongs to the specific E bin. 
        masks = [in_interval(d.E[0],(Ebins[k],Ebins[k+1])) for k in range(len(Ebins)-1)]
    elif Ebins is None:
        Ebins = [-np.inf,np.inf]
        masks = [np.ones(d.burst_widths[0].size, dtype=bool)]

    plt.figure()
    hist_list = []
    bin_edges_list = []
    #time bins in milliseconds:
    if tbins is None:
        #tbins = np.arange(0,10,d.clk_p*d.alex_period*1000)
        tbins = np.arange(0,10,0.1)
    for k,mask in enumerate(masks):
        str_label = "E in (%.2f - %.2f)" % (Ebins[k],Ebins[k+1])
        h,bin_edges,_ = plt.hist(d.burst_widths[0][mask]*1000,bins=tbins,histtype='step',label=str_label,lw=2,**kwargs)
        hist_list.append(h)
        bin_edges_list.append(bin_edges)
    plt.xlabel('time [ms]')
    plt.title('Histogram of burst time-widths (#bu={})'.format(d.burst_widths[0].size))
    plt.legend()
    return hist_list, bin_edges_list


def plot_scatter_fret_width(ds, flag_kde=False):
    """ plot scatter plot of burst-width vs. E
        This is an alternative to dplot(d,scatter_fret_width) """
    plt.figure()
    if flag_kde:
        sns.kdeplot(ds.E[0],ds.burst_widths[0]*1000, shade=True, cmap='rainbow', shade_lowest=False, aspect=2)
        plt.plot(ds.E[0],ds.burst_widths[0]*1000,'o',ms=0.2,alpha=0.25,color='black')
    else:
        plt.plot(ds.E[0],ds.burst_widths[0]*1000,'o',ms=1.,alpha=0.25,color='blue')
    plt.xlabel('E_FRET')
    plt.ylabel('burst width [ms]')


def plot_fret_avgwidth(ds,bin_edges_E=None):
    """ plots average burst-timewidth versus E_FRET """
    """ E_FRET intervals can be optionally passed on by the keyword argument."""
    if bin_edges_E is None:
        bin_edges_E = np.linspace(0,1,11)
    bu_width_avg = np.zeros(bin_edges_E.size-1)
    
    for ibin,be in enumerate(bin_edges_E[:-1]):
        mask = (ds.E[0] >= bin_edges_E[ibin]) * (ds.E[0] < bin_edges_E[ibin+1])
        bu_width_avg[ibin] = ds.burst_widths[0][mask].mean()
        
    x = (bin_edges_E[:-1]+bin_edges_E[1:])/2.
    plt.figure()
    plt.plot(x,bu_width_avg*1000,'o-')
    plt.xlabel('E')
    plt.ylabel('average bu width [ms]')
    plt.ylim(ymin=0)
    return x,bu_width_avg


def fit_Ehist(ds, fitfun, p0, bins=20,flag_plot=True):
    """ Fits the FRET histogram with one, two, or three gaussians.
        Returns the figure and fit parameters """
        
    histE,bin_edges = np.histogram(ds.E[0],bins=bins)
    x = (bin_edges[:-1] + bin_edges[1:])/2.
    fitpar,_,yfit = myfit(fitfun,x,histE,p0,flag_plot=False)
    if flag_plot:
        plt.figure()
        plt.hist(ds.E[0],bins=bin_edges,color='blue',alpha=0.3)
        plt.plot(x,yfit,'k-')
        for idx in np.arange(0,fitpar.size,3):
            plt.plot(x,Fitfuns.gauss(x,*fitpar[idx:idx+3]), 'k--')
        plt.xlabel('E_FRET')

    return fitpar,dict(x=x,y=yfit)


def plot_burst_timedistribution(d,tbin_s=100):
    """ Plots histogram of bursts according to their times using tbin_s as the bin width in seconds. """
    plt.figure()
    bins = np.arange(0,d.mburst[0].start[-1]*d.clk_p,tbin_s)
    h,bin_edges,_ = plt.hist(d.mburst[0].start*d.clk_p,bins=bins)
    plt.xlabel('time [s]')
    plt.ylabel('#bursts in time-interval')
    return h,bin_edges


def plot_intensity(d,tbin_s=2):
    """ Plots the averaged signal intensity as function of time. """
    plt.figure()
    h,bin_edges,_ = plt.hist(d.get_ph_times()*d.clk_p/tbin_s,bins=np.arange(0,d.time_max,tbin_s),histtype='step')
    plt.xlabel('time [s]')
    plt.ylabel('binned signal intensity')
    return h, bin_edges

#%% ______FIT FUNCTIONS______
"""
-------------------------------------------------------------------------------
FIT FUNCTIONS
-------------------------------------------------------------------------------
"""

class Fitfuns(object):
    @staticmethod
    def moexp(t,A,tau,b):
        return A*exp(-t/tau)+b
    
    @staticmethod
    def gauss(t,A,mu,sigma):
        return A*exp(-((t-mu)**2/(2*sigma**2)))
    
    @staticmethod
    def gauss_b(t,A,mu,sigma,b):
        return A*exp(-((t-mu)**2/(2*sigma**2)))+b

    @staticmethod                
    def gauss_two(t,A1,mu1,sigma1,A2,mu2,sigma2):
        return A1*exp(-((t-mu1)**2/(2*sigma1**2)))\
               + A2*exp(-((t-mu2)**2/(2*sigma2**2)))
    
    @staticmethod                        
    def gauss_three(t,A1,mu1,sigma1,A2,mu2,sigma2,A3,mu3,sigma3):
        return A1*exp(-((t-mu1)**2/(2*sigma1**2)))\
               + A2*exp(-((t-mu2)**2/(2*sigma2**2)))\
               + A3*exp(-((t-mu3)**2/(2*sigma3**2)))
    
    @staticmethod           
    def acf_simple(tau,A,tauD,r):
        return A*(1/(1+tau/tauD))*(1/np.sqrt(1+(r**2)*tau/tauD))
    
    @staticmethod    
    def psame_bursts_1(tau,n,tauD):
        """ see Hoffmann2011-RASP"""
        return 1 - 1/(1 + ((1+tau/tauD)**(-3/2))/n)
    
    @staticmethod
    def diffusion_2d(timelag, tau_diff, A0):
            return 1 + A0 * 1/(1 + timelag/tau_diff)

    @staticmethod        
    def diffusion_3d(timelag, tau_diff, A0, waist_z_ratio=0.1):
            return (1 + A0 * 1/(1 + timelag/tau_diff) *
                    1/np.sqrt(1 + waist_z_ratio**2 * timelag/tau_diff))
        

def myfit(fitfun,x,y,p0,fitrng=None,flag_plot=False):
    mask = np.isfinite(y)
    if fitrng is not None:
        mask = mask*np.logical_and(x>=fitrng[0],x<=fitrng[1])
    p, pcov = curve_fit(fitfun, x[mask], y[mask], p0)
    perr = np.sqrt(np.diag(pcov))
    yfit = fitfun(x,*p)
    
    if flag_plot:
        fig = plt.figure()
        plt.plot(x,y,'bo',ms=4)
        plt.plot(x,yfit,'r-')
        plt.plot(x[mask],yfit[mask],'-',color='lime')
        #plt.plot(x[mask],yfit[mask]-y[mask]-1)
        plt.xscale('log')
        #plt.xlabel('time [us]')
        plt.grid()
    
    return p,perr,yfit    


#%% ______ANALYSIS CLASS______
"""
-------------------------------------------------------------------------------
Alex Analysis                                                       
------------------------------------------------------------------------------- 
"""

class Analysis(object):
    def __init__(self, filename,
                 meas_type='ALEX', settPAX=None,
                 leakage=0.0, dir_ex=0.0, gamma=1.0,
                 bg_time_s=60, bg_tail_min_us='auto', F_bg=1.7,
                 busearch_m=10, busearch_L=10, busearch_F=5, busearch_computefret=True,
                 fuse_ms=0):
        
        self.filename = filename
        self.figs = dict()
        self.res = dict()
        
        # Load data
        d = loader.photon_hdf5(filename)
        if meas_type=='ALEX':
            print('ALEX')
        elif meas_type=='PAX':
            flag_PAX = True
            if settPAX is None:
                setd_PAX(d)
            else:
                setd_PAX(d,**settPAX)
                
#        if meas_type=='ALEX':
#            print('ALEX')
#        elif meas_type=='PAX':
#            if settPAX is None:
#                pass#setd_PAX(d,dict(D_ON=(10,1800)))
#            else:
#                setd_PAX(d,**settPAX)
        
        bpl.plot_alternation_hist(d)
        self.figs["alt"] = plt.gcf()
        loader.alex_apply_period(d)
        
        # calculate background
        d.calc_bg(bg.exp_fit, time_s=bg_time_s, tail_min_us=bg_tail_min_us, F_bg=F_bg)
        dplot(d, hist_bg, show_fit=True)
        s = calc_phrate_mean(d)['str']
        plt.text(0.5, 0.95, s, ha='center', va='top', transform = plt.gca().transAxes)
        self.figs["bg"] = plt.gcf()
        
        # corrections
        d.leakage = leakage
        d.dir_ex = dir_ex
        d.gamma = gamma
        
        # burst search
        d.burst_search(m=busearch_m, L=busearch_L, F=busearch_F, computefret=busearch_computefret, pax=(meas_type=='PAX'), ph_sel=Ph_sel(Dex='DAem', Aex='DAem'))
        
        #burst fuse
        if (fuse_ms>0):
            df = d.fuse_bursts(ms=fuse_ms, mute=True)
        else:
            df = d
        
        
        #basic burst selection
        ds = df.select_bursts(select_bursts.size, add_naa=True, th1=30)
        #print(self.ds.mburst[0])
        
        #timetrace
        plot_timetrace(ds, tmin=0, tmax=100)
        self.figs["timetrace"]=plt.gcf()
        
        self.d = d
        self.df = df
        self.ds = ds


    def proc(self,options=()):
        """ 
        It is more convenient to use a separate function for the processing purposes,
        since the data processing requirements always change over time,
        but it is included here as a method just to show an example how the data processing can be done.
        
        This offers a basic set of plot/analysis functions and generates figures.
        Ÿou can choose what type of analysis/plot will be run by listing them in the keywword argument options.
        For various other preimplemented plot/analysis functions see fretbursts\burst_plot.py.
        """
        d = self.d
        ds = self.ds
        
        if options == ():
            flag_all = True
        else:
            flag_all = False
        
        if ("bu_timedistr" in options) or flag_all:
            h, bin_edges = plot_burst_timedistribution(d,tbin_s=100)
            self.figs["bu_timedistr"]=plt.gcf()
            
        if ("intensity" in options) or flag_all:
            h, bin_edges = plot_intensity(d,tbin_s=2)
            self.figs["intensity"]=plt.gcf()    
            
        if ("ES" in options) or flag_all:
            alex_jointplot(ds, kind='hex')
            plt.suptitle(get_fname_woext(self.filename), fontsize=10)
            self.figs["ES"] = plt.gcf()
            
        if ("hist_fret" in options) or flag_all:
            dplot(ds, hist_fret, show_kde=True)
            plt.xlim(-0.2,1.2)
            self.figs["hist_fret"] = plt.gcf()
            
        if ("hist_size" in options) or flag_all:
            dplot(ds, hist_size)
            self.figs["hist_size"] = plt.gcf()
            
        if ("hist_width" in options) or flag_all:
            dplot(ds, hist_width)
            self.figs["hist_width"] = plt.gcf()
            plot_hist_width(ds,Ebins=(0,0.5,1.0))
            self.figs["hist_width_Ebins"] = plt.gcf()
        
        if ("avgwidth" in options) or flag_all:
            E,tw = plot_fret_avgwidth(ds)
            self.figs["avgwidth"] = plt.gcf()
            self.res['avgwidth'] = dict(E=E, twidth=tw)
            
        if ("scatter_fret_width" in options) or flag_all:
            plot_scatter_fret_width(ds)
            plt.xlim(-0.2,1.2)
            self.figs["scatter_fret_width"] = plt.gcf()
            
        if ("scatter_fret_size" in options) or flag_all:
            dplot(ds, scatter_fret_size)
            plt.xlim(-0.2,1.2)
            self.figs["scatter_fret_size"] = plt.gcf()
            
        if ("scatter_width_size" in options) or flag_all:
            dplot(ds, scatter_width_size)
            self.figs["scatter_width_size"] = plt.gcf()
        
        if ("ACF" in options) or flag_all:
            try:
                tau,g = calc_acf(ds, ph_sel=Ph_sel(Aex='Aem'), flag_fit_plot=True)
                self.figs["acf"] = plt.gcf()
                self.res['acf'] = dict(tau=tau,g=g)
            except:
                print('Unable to calculate the ACF.')
            
        if ("BVA" in options) or flag_all:
            self.bva = BVA(ds,mode=1,flag_E_orig=False)
            self.figs["bva"] = plt.gcf()
            
        if ("RASP" in options):
            self.rasp = RASP(ds)
            self.rasp.plot2d()
            self.figs["rasp"] = plt.gcf()
        
        
    def reportpng(self, outname=None, figkeys=[]):
        """ Combines figures from self.figs into one png and saves as outname.
        
            Arguments:
                outname (string): filename to save output image. If not provided, outname is generated automatically from the data filename by adding '_report.png'.
                figkeys (list of strings): only figures from self.figs whose keys are listed are combined to the png. If figkeys==[], all figures in self.figs are combined. 
        """       
        if figkeys==[]:
            figlist = list(self.figs.values())
        else:
            figlist = []
            for key in figkeys:
                if key in self.figs:
                    figlist.append(self.figs[key])
                else:
                    print('key '+key+' not recognized and was skipped')
                    
        if outname is None:
            outname = get_path_woext(self.filename)+'_report.png'
        self.figreport = figs_to_one_png(figlist,outname,tidy_up=True)
        plt.close('all')
    

#class Analysis_1(Analysis):
#    """ This is an example how another class can be derived from the ALEXAnalysis class. """
#    def __init__(self,**kwargs):
#        ALEXData.__init__(**kwargs)
#        self.var_1 = 1
#        
#    def proc(self):
#        pass
#        
#    def method_1(self):
#        print(self.filename)
#        print(self.var_1)

def bursts_info(d,idxs=None,nhead=5,ntail=3):
    bs = d.mburst[0]
    
    print('Number of bursts: {}'.format(bs.num_bursts))
    print('bursts info')
    
    idxs_list = []
    if idxs is None:
        idxs_list.append(np.arange(nhead))
        idxs_list.append(np.arange(bs.size-ntail,bs.size))
    else:
        idxs_list.append(idxs)
    #print(idxs)
    templ = '{:>8} {:>13} {:>13} {:>16} {:>16} {:>10} {:>8}' 
    print(templ.format('i','istart','istop','start','stop','width','counts'))
    
    templ = '{:>8} {:>13} {:>13} {:>16} {:>16} {:>10} {:>8}'
    for k,idxs in enumerate(idxs_list):
        for i,b in zip(idxs,bs[idxs]):
            print(templ.format(i,b.istart,b.istop,b.start,b.stop,b.width,b.counts))
        if k < len(idxs_list)-1:
            print('...  ...  ...')
    
    
def example_values(d):
    print('\nphoton times:')
    print(d.ph_times_m[0]) # or d.get_ph_times()
    print('\nclock period:')
    print(d.clk_p)
    print('\nphoton times:')
    print(d.get_ph_times(ich=0, ph_sel=Ph_sel(Dex='DAem', Aex='DAem'), compact=False))
    print('\nA_em mask:')
    print(d.A_em[0])
    print('\nD_ex mask:')
    print(d.D_ex[0])
    print('\nnumber of donor photons in individual bursts (corrected):')
    print(d.nd[0])
    print('\nnumber of donor photons in individual bursts (corrected):')
    print(d.na[0])
    print('\ntotal number of photons in individual bursts (corrected):')
    print(d.nt[0])
    print('\nFRET in individual bursts:')
    print(d.E[0])
    print('\nStoichiometry in individual bursts:')
    print(d.S[0])
    print('\nTime of the last timestamp in seconds:')
    print(d.time_max)
    print('\nNumber of bursts:')
    print(d.num_bursts[0])
    print('\nBurst sizes:')
    print(d.burst_sizes_ich())
    print('\nBurst widths in seconds:')
    print(d.burst_widths[0])
    
    print('\nBursts object:')
    print('Attributes of the Bursts object:')
    print(dir(d.mburst[0]))
    print('\nSome of the burst properties:')
    print('Bursts start times, widths (in clock periods), and counts:')
    print(d.mburst[0].start)
    print(d.mburst[0].width)
    print(d.mburst[0].counts)
    print('\nCompact bursts info:')
    bursts_info(d)
        
               
def get_test_Analysis(fname=None, **kwargs):
    import inspect
    import os
    if fname is None:
        #dirname = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
        fname = '..\data\s54.hdf5'
        print(fname)    
    a = Analysis(fname, **kwargs)
    return a

#%% ______BVA______
"""
-------------------------------------------------------------------------------
BVA
-------------------------------------------------------------------------------
"""

class BVA(object):
    """ Burst variance analysis.
        Calculation according to Torella et al.,2011: Identifying Molecular Dynamics in Single-Molecule FRET Experiments with Burst Variance Analysis
        
        Arguments:
            ds: fretbursts data object after burst search and burst selection.
            n: number of photons in a window (sub-burst) from which E_FRET is calculated
            wbin: the width of bin on the E_FRET-axis
            mode: switch for the function used to calculate BVA
            plot_kind: 'scatter','kde', or 'hist2d'
            flag_E_orig: if True, corrected FRET values ds.E[0] are used instead of the raw uncorrected values calculated during BVA
            long_percentile: bursts with widths longer than the width corresponding to the percentile are plotted in different color 
    """
    def __init__(self, ds, n=5, wbin=0.05, mode=1, plot_kind='scatter', flag_E_orig=False, long_percentile=80):
        self.ds = ds
        self.n = n
        self.wbin = wbin
        self.plot_kind = plot_kind
        self.flag_E_orig = flag_E_orig
        self.long_percentile = long_percentile
        
        self.Ebursts = None
        self.E_sub_std = None
        self.E_binned = None
        self.std_binned = None
        self.cfi = None
        self.res = None
        
        if mode==1:
            self.bva_my(ds,n,wbin)
        elif mode==2:
            self.bva(ds,n,wbin)
        elif mode==3:
            self.bva_tomas(ds,n)
            
    def _update_results(self, Ebursts, E_sub_std, E_binned=None, std_binned=None, cfi=None):
        self.Ebursts = Ebursts
        self.E_sub_std = E_sub_std
        self.E_binned = E_binned
        self.std_binned = std_binned
        self.cfi = cfi
        self.res = dict(E=self.Ebursts, Estd=self.E_sub_std, E_binned=self.E_binned, Estd_binned=self.std_binned, cfi=self.cfi)
        

    def bva_my(self, ds, n, wbin):        
        E_sub_std = [] #list of stddev(E) corresponding to individual bursts
        Ebursts = [] #list of average E of individual bursts
        nwindows = [] #list of sub-burst-counts of individual bursts
        
        Ebins = np.arange(0.,1.+wbin,wbin)
        nbins = Ebins.size-1
        eps = [[] for i in range(nbins)] #list of lists
        sumMi = np.zeros(nbins)
        
        """ Cycle through bursts and get the E-values of sub-bursts and standard deviations of E within each burst. Also sort into bins according to E."""
        for b in ds.mburst[0]:
            E_sub_bursts = [] #this list will store FRET efficiencies for the individual groups of n photons (sub-burst) within the burst
            #select indices between b.istart a b.istop, which correspond to photons with D excitation.
            iph = np.arange(b.istart,b.istop+1)[ds.D_ex[0][b.istart:b.istop+1]]
            #result e.g. array([5300,5301,5303,5305,...5371, 5372])
            for i in np.arange(0,len(iph)-n+1,n):
                idxs = iph[i:i+n]            
                Aem_count = ds.A_em[0][idxs].sum()
                E = Aem_count/float(n) #uncorrected FRET efficiency (proximity factor) calculated for the sub-burst as the number of Aem photons devided by the number of all photons in the sub-burst
                E_sub_bursts.append(E)
                
            E_burst_mean = np.mean(E_sub_bursts)
            #Find the index of the bin on the E-axis where the burst belongs to """
            idx = np.searchsorted(Ebins,E_burst_mean)-1 #finds the bin on the E-axis where the current burst belongs to
            #Append the list of E values of sub-bursts to the list eps[idx] and increase the value of sumMi[idx] by the number of sub-bursts """
            if idx>=0 and idx<nbins:
                eps[idx]+=E_sub_bursts
                sumMi[idx]+=len(E_sub_bursts)
                
            Ebursts.append(E_burst_mean)  #append mean E value of the burst to the list Ebursts
            E_sub_std.append(np.std(E_sub_bursts))  #append standard deviation of E inside the burst to the list
            nwindows.append(len(E_sub_bursts))  #append the number of sub-burst in the burst to the list 
        
        """ Cycle through all bins on the E-axis and compute standard deviations among E-values of sub-bursts within the bin. """
        std_binned = np.zeros(nbins)
        cfi = np.zeros(nbins)
        E_binned = np.zeros(nbins)
        for i in range(nbins):
            if sumMi[i]>0:
                epsarr = np.array(eps[i])  #convert E-values of all sub-bursts within the bin to numpy array
                mu = np.sum(epsarr)/sumMi[i]  #the mean E value calculated from all sub-bursts in the bin
                std_binned[i] = np.sqrt(np.sum((epsarr-mu)**2)/sumMi[i])  #standard deviation among E-values of sub-bursts within the bin
                E_binned[i] = mu
                cfi[i] = conf_int(E_binned[i],sumMi[i],n,reps=1000,qperc=95) #compute confidence interval for the bin
                
        #convert lists to numpy arrays            
        E_sub_std = np.array(E_sub_std)
        #E_sub_std[E_sub_std==0]=np.nan
        Ebursts = np.array(Ebursts)
        # or is it more correct to take ds.E[0] as FRET values instead?
        if self.flag_E_orig: Ebursts = ds.E[0]
        
        self.Ebursts = Ebursts
        self.E_sub_std = E_sub_std
        self.E_binned = E_binned
        self.std_binned = std_binned
        self.cfi = cfi
        self._update_results(Ebursts,E_sub_std,E_binned=E_binned, std_binned=std_binned,cfi=cfi)
        """ Now plot the results """
        self.plot()
              

    def bva(self, ds, n, wbin):
        """Burst variance analysis. """  
        """ Only Donor-excited photons are extracted. Indices of photons(timestamps) in bursts object are recomputed accordingly."""
        
        ph_d = ds.get_ph_times(ph_sel=Ph_sel(Dex='DAem'))
        bursts = ds.mburst[0]
        # The following line crashes if the burst search is done with Ph_sel('all') and not with Ph_sel(Dex='DAem'), this is mentioned in the documentation of recompute_index_reduce() function
        bursts_d = bursts.recompute_index_reduce(ph_d)
        
        Dex_mask = ds.get_ph_mask(ph_sel=Ph_sel(Dex='DAem'))   
        DexAem_mask = ds.get_ph_mask(ph_sel=Ph_sel(Dex='Aem'))
        DexAem_mask_d = DexAem_mask[Dex_mask]  #Now we have the mask with True values for D-excited photons which were detected at A-channel   
        
    
        """ calculation according to Torella et al.,2011: Identifying Molecular Dynamics in Single-Molecule FRET Experiments with Burst Variance Analysis """           
        
        E_sub_std = [] #list of stddev(E) corresponding to individual bursts
        Ebursts = [] #list of average E of individual bursts
        nwindows = [] #list of sub-burst-counts of individual bursts
        
        Ebins = np.arange(0.,1.+wbin,wbin)
        nbins = Ebins.size-1
        eps = [[] for i in range(nbins)]
        sumMi = np.zeros(nbins)
        
        """ Cycle through bursts and get the E-values of sub-bursts and standard deviations of E within each burst. Also sort into bins according to E."""
        for burst in bursts_d:
            E_sub_bursts = []  #this list will store FRET efficiencies for the individual groups of n photons (sub-burst) within the burst
            #The burst is divided into groups each containing n photons (sub-bursts)
            startlist = range(burst.istart, burst.istop + 2 - n, n)
            stoplist = [i + n for i in startlist]
            # The pair startlist[k],stoplist[k] defines the start and stop index of the k-th sub-burst
            # Now cycle through all the groups of n photons within the burst
            for start, stop in zip(startlist, stoplist):
                A_D = DexAem_mask_d[start:stop].sum() #this is the number of Aem photons in the sub-burst (all photons were D-excited)
                assert stop-start == n
                E = A_D / float(n) #uncorrected FRET efficiency (proximity factor) calculated for the sub-burst as the number of Aem photons devided by the number of all photons in the sub-burst  
                E_sub_bursts.append(E)
            
            E_burst_mean = np.mean(E_sub_bursts)
            #Find the index of the bin on the E-axis where the burst belongs to """
            idx = np.searchsorted(Ebins,E_burst_mean)-1 #finds the bin on the E-axis where the current burst belongs to
            #Append the list of E values of sub-bursts to the list eps[idx] and increase the value of sumMi[idx] by the number of sub-bursts """
            if idx>=0 and idx<nbins:
                eps[idx]+=E_sub_bursts
                sumMi[idx]+=len(E_sub_bursts)
                
            Ebursts.append(E_burst_mean)  #append mean E value of the burst to the list Ebursts
            E_sub_std.append(np.std(E_sub_bursts))  #append standard deviation of E inside the burst to the list
            nwindows.append(len(startlist))  #append the number of sub-burst in the burst to the list 
        
        """ Cycle through all bins on the E-axis and compute standard deviations among E-values of sub-bursts within the bin. """
        std_binned = np.zeros(nbins)
        cfi = np.zeros(nbins)
        E_binned = np.zeros(nbins)
        for i in range(nbins):
            if sumMi[i]>0:
                epsarr = np.array(eps[i])  #convert E-values of all sub-bursts within the bin to numpy array
                mu = np.sum(epsarr)/sumMi[i]  #the mean E value calculated from all sub-bursts in the bin
                std_binned[i] = np.sqrt(np.sum((epsarr-mu)**2)/sumMi[i])  #standard deviation among E-values of sub-bursts within the bin
                E_binned[i] = mu
                cfi[i] = conf_int(E_binned[i],sumMi[i],n,reps=1000,qperc=95) #compute confidence interval for the bin
                
        #convert lists to numpy arrays            
        E_sub_std = np.array(E_sub_std)
        #E_sub_std[E_sub_std==0]=np.nan
        Ebursts = np.array(Ebursts)
        # or is it more correct to take ds.E[0] as FRET values instead?
        if self.flag_E_orig: Ebursts = ds.E[0]
        
        self._update_results(Ebursts,E_sub_std,E_binned=E_binned, std_binned=std_binned,cfi=cfi)
        """ Now plot the results """
        self.plot()
    
    
    def bva_tomas(self, ds,n):
        #Next we prepare the data for BVA:    
        ph_d = ds.get_ph_times(ph_sel=Ph_sel(Dex='DAem'))
        bursts = ds.mburst[0]
        # The following line crashes if the burst search is done with Ph_sel('all') and not with Ph_sel(Dex='DAem'), this is mentioned in the documentation of recompute_index_reduce() function
        bursts_d = bursts.recompute_index_reduce(ph_d)
        
        Dex_mask = ds.get_ph_mask(ph_sel=Ph_sel(Dex='DAem'))   
        DexAem_mask = ds.get_ph_mask(ph_sel=Ph_sel(Dex='Aem')) 
        DexAem_mask_d = DexAem_mask[Dex_mask]
                
        E_sub_std = []
        for burst in bursts_d:
            E_sub_bursts = []
            startlist = range(burst.istart, burst.istop + 2 - n, n)
            stoplist = [i + n for i in startlist]
            for start, stop in zip(startlist, stoplist):
                A_D = DexAem_mask_d[start:stop].sum()
                assert stop-start == n
                E = A_D / float(n)
                E_sub_bursts.append(E)
            E_sub_std.append(np.std(E_sub_bursts))
        
        E_sub_std = np.array(E_sub_std)
        #E_sub_std[E_sub_std==0]=np.nan
        
        self._update_results(Ebursts,E_sub_std)
        self.plot()
     

    def plot(self):
        plt.figure(figsize=(5,5))
        
        #Plot the theoretical dependence of std on E in the case of static population
        x = np.arange(0,1.01,0.01)
        y = np.sqrt((x*(1-x))/self.n)
        plt.plot(x,y, lw=3, color='red')
        
        #Select a certain percentile of the longest bursts
        p = np.percentile(self.ds.burst_widths[0],self.long_percentile)
        mask_longbu = self.ds.burst_widths[0]>p
        mask_not_longbu = np.logical_not(mask_longbu)
        self.mask_longbu = mask_longbu
        self.mask_not_longbu = mask_not_longbu
        
        #scatter plot in the std(E)/E plane where each dot corresponds to a single burst.
        #Longer bursts are plotted separately because the dynamics can be more visible there
        if self.plot_kind=='scatter':
            plt.scatter(self.Ebursts[mask_not_longbu], self.E_sub_std[mask_not_longbu], s=12, alpha=0.15, color='blue', edgecolors='none')  
            plt.scatter(self.Ebursts[mask_longbu], self.E_sub_std[mask_longbu], s=12, alpha=0.2, color='cyan',edgecolors='none')
        elif self.plot_kind=='kde':
            im = sns.kdeplot(self.Ebursts, self.E_sub_std, shade=True, cmap='viridis', shade_lowest=False)
        elif self.plot_kind=='hist2d':
            im = plt.hist2d(self.Ebursts, self.E_sub_std,bins=40,cmap='viridis')
        
        
        plt.xlim(-0.2,1.2)
        plt.ylim(0,0.6)
        plt.xlabel('E', fontsize=14)
        plt.ylabel(r'$stdev(E)$', fontsize=14)
        
        #Plot the binned values
        if (self.E_binned is not None) and (self.std_binned is not None):
            plt.plot(self.E_binned, self.std_binned, 'rd', ms=10)
            if self.cfi is not None:
                plt.plot(self.E_binned, self.cfi, 'o-', color='lime')            
        


def conf_int(Emean,sumMi,n,reps=1000,qperc=99):
    """ Monte Carlo method to estimate the variability of E in the sub-bursts within the bin centered at Emean """
    sigma = []
    for r in range(reps):
        F = np.random.binomial(n,Emean,size=int(sumMi))/n #generates a series of E-values of sub-bursts, the total number of sub-bursts is sumMi 
        sigma.append(np.sqrt(np.sum((F-F.mean())**2)/F.size)) #std of the calculated E-values of the individual sub-bursts appended to the list
    return np.percentile(sigma,qperc) #returns the sigma value on the qperc percentile. i.e. it is improbable that a static population would randomly yield sigma larger than this value.


     
    
#%% ______RASP______
"""
-------------------------------------------------------------------------------
RASP
-------------------------------------------------------------------------------
"""

class RASP(object):
    def __init__(self, ds, dt_minmax_ms=(0,100), E0_minmax=(0,1), E1_minmax=(0,1)):
        """ RASP object unites several methods and variables related to RASP analysis.
        
            Arguments:
                dt_minmax_ms (tuple with two elements): minimal and maximal time-separation (in ms) of two burst so that they are included in the RASP burst-pairs
                E0_minmax (tuple with two elements): minimal and maximal E of the first burst so that it is considered for RASP analysis
                E1_minmax (tuple with two elements): minimal and maximal E of the second burst so that it is considered for RASP analysis
        """
                
        self.ds = ds  #Data object
        self.get_pairs(ds, dt_minmax_ms, E0_minmax, E1_minmax)
        
    def get_pairs(self, ds, dt_minmax_ms, E0_minmax, E1_minmax):
        bs = ds.mburst[0]  #table of bursts
        E = ds.E[0]  #E_FRET of individual bursts
        clkp = ds.clk_p  #clock period, typically 12.5ns (80 MHz)
        dt_minmax_clkp = tuple(p*1e-3/clkp for p in dt_minmax_ms) #convert time limits from ms to seconds
              
        """ Generates an array (table) of RASP burst-pairs, where each row has the form [E0, E1, time-separation] """
        # center-times of bursts
        tbs = (bs.stop + bs.start)/2.
        nbursts = tbs.size
        E0E1dt = []  #This will be a list of three-element lists. One row has three elements describing a RASP pair: E0, E1, and time period between the bursts. 
        
        """ Cycle through E0 bursts. For each burst find later bursts within the specified time period after the E0 burst.
            If the later burst is within the time-period, append the RASP pair E0,E1 to the list. """
        for i,b in enumerate(bs):
            #If the E_FRET of the burst falls within the limits
            if in_interval(E[i],E0_minmax):
                flag_dt_over = False
                k = 1  # i is the index of E0-burst, i+j will be the index of E1-burst
                while (not flag_dt_over) and (i+k<nbursts):  #while the time-separation between bursts is lower than limit and (i+j) is not out of range
                    dt = tbs[i+k]-tbs[i]  #time-separation of the two bursts
                    # If the time-separation between two bursts is within limits append the burst pair to the list
                    if in_interval(dt,dt_minmax_clkp):
                        if in_interval(E[i+k],E1_minmax):                     
                            E0E1dt.append([E[i],E[i+k],dt*clkp])
                    elif dt >= dt_minmax_clkp[1]:
                        # If time-separation between two bursts exceeds the upper limit, signilize by changing the flag
                        flag_dt_over = True
                    k+=1
        
        E0E1dt = np.array(E0E1dt)
        self._update_E0E1dt(E0E1dt)
        return E0E1dt
    
    def _update_E0E1dt(self,E0E1dt):
        self.E0E1dt = E0E1dt
        self.res = self.E0E1dt
        self.E0 = self.E0E1dt[:,0]
        self.E1 = self.E0E1dt[:,1]
        self.dt = self.E0E1dt[:,2]
        
    def plot2d_tseries(self,tseries=(5,10,20,40,60,100), cummul=True):
        """ plots an array of RASP 2D-diagram for a series of time intervals.
        
            Arguments:
                tseries (list or tuple): a series of upper-limits of time intervals
                cummul (bool): if True, the lower limit of time-interval is always zero, if False, the lower limit is the upper limit of the previous interval
        """
        
        figlist = []
        tmin = 0
        
        numrows = np.round(np.sqrt(len(tseries))).astype(np.int)
        numcols = np.ceil(len(tseries)/np.float(numrows)).astype(np.int)
        #fig, axarr = plt.subplots(numrows, numcols, sharex='col', sharey='row')
        fig, axarr = plt.subplots(numrows, numcols)
        iter_axarr = np.nditer(axarr,['refs_ok'])
        
        for k,tmax in enumerate(tseries):
            print(tmax)
            if cummul is False and k>0:
                tmin=tseries[k-1]
            
            #ax = axarr[k//numcols,k % numcols]
            ax = next(iter_axarr).item(0)
            self.plot2d(dt_limits_ms=(tmin,tmax),ax=ax)
        plt.tight_layout()
        
    
    def plot2d(self,dt_limits_ms=(0,100),E0_limits=(0,1),ax=None):
        """ Plots a 2D (E0-E1) RASP diagram using RASP burst-pairs within the specified limits """
        mask_dt = in_interval(self.dt*1000, dt_limits_ms)
        mask_E0 = in_interval(self.E0, E0_limits)
        npairs = np.sum(mask_dt*mask_E0)
        
        if ax is None:
            plt.figure()
            strtitle='RASP 2D plot,'+str(dt_limits_ms[0])+'-'+str(dt_limits_ms[1])+'ms, #pairs:'+str(npairs)
        else:
            plt.sca(ax) #Set current axis to the passed ax
            strtitle = str(dt_limits_ms[0])+'-'+str(dt_limits_ms[1])+'ms, #pairs:'+str(npairs)
            
        # Plot RASP burst-pairs into the E0E1 plot
        sns_plot = sns.kdeplot(self.E0[mask_dt],self.E1[mask_dt], shade=True, cmap='rainbow', shade_lowest=False, aspect=2)
        plt.scatter(self.E0[mask_dt], self.E1[mask_dt], s=1, marker='+',color='black',alpha=0.2)
        # plot diagonal line
        plt.plot([0,1],[0,1],'-',color='gray')
        plt.xlim(0,1)
        plt.ylim(0,1)
        plt.xticks()
        plt.gca().set_aspect('equal')
        plt.xlabel('E0',fontsize=10)
        plt.ylabel('E1',fontsize=10)
        plt.title(strtitle, fontsize=10)
        
        return self.E0[mask_dt], self.E1[mask_dt]
        
    
    def plothist(self,dt_limits_ms=(0,np.inf), E0_limits=(0,1), bins=20, normed=False):
        """ Plots the E1-histogram for E1-bursts which follow E0 falling to a specific range. """ 
        plt.figure()
        # Create mask for burst-pairs within the specified limits for E0 and time-separation
        mask_time = in_interval(self.dt*1000, dt_limits_ms)
        mask_E0 = in_interval(self.E0, E0_limits)
        mask = mask_time * mask_E0
        
        #plot histogram of E0 values
        plt.hist(self.E0[mask_time],bins=bins,normed=normed,color='blue',histtype='step')
        h,bin_edges,_ = plt.hist(self.E0[mask_time],bins=bins,normed=normed,color='blue',alpha=0.05,label="E0-bursts")
        #Plot histogram of E1 values which are in pair with a burst that has E0 within specified limits and their time-separation is within the limits
        h,bin_edges,_ = plt.hist(self.E1[mask],bins=bin_edges,normed=normed,alpha=0.3,color='red',label="E1-bursts")
        #Draw rectangle showing the E0 limits used for the burst-pair selection
        plt.gca().add_patch(patches.Rectangle((E0_limits[0],0),
                                                E0_limits[1]-E0_limits[0],
                                                plt.gca().get_ylim()[1],
                                                fill=False,edgecolor='blue',linewidth=1))
        
        plt.xlabel('E')
        plt.ylabel('#bursts')
        strtitle='RASP hist, E0 in ('+str(E0_limits[0])+'-'+str(E0_limits[1])+') ,'+str(dt_limits_ms[0])+'-'+str(dt_limits_ms[1])+'ms, #pairs:'+str(mask.sum())
        plt.title(strtitle)
        return h,bin_edges
    
    
    def plothist_model(self, dt_limits_ms=(0,200), E0_limits=(0,1), bins=20, normed=False):
        """ This function is in a testing/experimental mode, should be used as such """
        p_same = psame(self.ds,flag_plot=False)
        mask = (in_interval(self.dt*1000,dt_limits_ms))*(in_interval(self.E0, E0_limits))
    
        # Fit the E_FRET histogram with two gaussians
        gauss_params,Ehist_fit = fit_Ehist_gauss(self.ds, mode="two", bins=bins, flag_plot=False)
        print(gauss_params)
        print(Ehist_fit)
    
        w1 = 0
        w2 = 0
        
        # Modelling how E1 would be distributed if there is no dynamics in the system.
        # It is calculated what is the probability that the E0 burst belongs to each of the two FRET populations and
        # then it is considred that E1-burst can be the same particle with probability p_same,
        # or that it can be another particle (with probability 1-p_same) which can have an arbitrary E_FRET according to the FRET distribution.
        for E0,dt in zip(self.E0[mask], self.dt[mask]):
            p_same_t = p_same['fitfun'](dt,*p_same['fitpar'])
            y1 = Fitfuns.gauss(E0,*gauss_params[0:3])  #relative amplitude of the FRET-hist band 1 at E=E0
            y2 = Fitfuns.gauss(E0,*gauss_params[3:6])  #relative amplitude of the FRET-hist band 2 at E=E0
            A1 = gauss_params[0]
            A2 = gauss_params[3]
            p_band1 = y1/(y1+y2)  #probability that E0 belongs to the FRET band 1
            p_band2 = y2/(y1+y2)  #probability that E0 belongs to the FRET band 2
            w1 += p_band1*p_same_t + (1-p_same_t)*A1/(A1+A2)  #probability that the recurrent particle will belong to the population 1 if there is no dynamics
            w2 += p_band2*p_same_t + (1-p_same_t)*A2/(A1+A2)  #probability that the recurrent particle will belong to the population 2 if there is no dynamics
        
        histE1, bin_edges = self.plothist(dt_limits_ms=dt_limits_ms, E0_limits=E0_limits, bins=bins, normed=normed)  
        
        x = (bin_edges[:-1] + bin_edges[1:])/2.
        y_stat = w1*Fitfuns.gauss(x,*gauss_params[0:3]) + w2*Fitfuns.gauss(x,*gauss_params[3:6]) 
        y_stat = y_stat / y_stat.sum() * histE1.sum()
        plt.plot(x,y_stat,'mo-',label='static model')
        
        y_dyn = Ehist_fit['y'] / Ehist_fit['y'].sum() * histE1.sum()
        plt.plot(Ehist_fit['x'],y_dyn,'co-',label='dynamic model')
        strtitle='RASP hist, E0 in ('+str(E0_limits[0])+'-'+str(E0_limits[1])+') ,'+str(dt_limits_ms[0])+'-'+str(dt_limits_ms[1])+'ms, #pairs:'+str(mask.sum())
        plt.title(strtitle)
        plt.xlabel('E_FRET')
        plt.legend(loc='best')
    
    def plot_intertime(self,bins=20):
        """ plot histogram of all analysed interburst times between E0 and E1 """
   
        h,bin_edges = np.histogram(self.dt,bins=bins)
        x = (bin_edges[:-1]+bin_edges[1:])/2.
        plt.figure()
        plt.plot(x*1000,h,'o-')
        plt.xlabel('time-separation [ms]')
        plt.yscale('log')
        plt.title('histogram of time-separations between all investigated E0-E1 pairs')
        return h,bin_edges
    
"""
class RASP_Tomas(object):
    def __init__(self):
        pass
"""    
    

#%%______CORRELATE______
"""
-------------------------------------------------------------------------------
CORRELATE
-------------------------------------------------------------------------------
""" 

def calc_cf(t,u,unit,model=None,flag_fit_plot=True):
    """ calculate cross-correlation of t and u. Unit is the time unit (e.g. clock period).
        model is a lmfit model object with initial values of parameters set. If model=None, then diffusion_3d model is used with predefined parameters."""
     
    assert (np.diff(t) >= 0).all()
    assert (np.diff(u) >= 0).all()
        
    # compute lags in sec. then convert to timestamp units
    bins_per_dec = 20
    bins = pyc.make_loglags(-6, 1, bins_per_dec)[bins_per_dec // 2:] / unit
    print('Number of time-lag bins:', bins.size)
    
    Gn = pyc.pcorrelate(t, u, bins, normalize=True)
    tau = 0.5 * (bins[1:] + bins[:-1]) * unit
    
    fig, ax = plt.subplots(figsize=(10, 6))
    plt.semilogx(bins[1:]*unit, Gn, drawstyle='steps-pre')
    plt.xlabel('Time (s)')
    plt.grid(True); plt.grid(True, which='minor', lw=0.3);
    
    if model is None:        
        model = lmfit.Model(Fitfuns.diffusion_3d)
        params = model.make_params(A0=1, tau_diff=1e-3)
        params['A0'].set(min=0.01, value=1)
        params['tau_diff'].set(min=1e-6, value=1e-3)
        params['waist_z_ratio'].set(value=1/6, vary=False)  # 3D model only
        
    if flag_fit_plot:
        weights = np.ones_like(Gn)
        #weights = np.log(np.sqrt(G*np.diff(bins)))  # and example of using weights
        fitres = model.fit(Gn, timelag=tau, params=params, method='least_squares',
                           weights=weights)
        print('\nList of fitted parameters for %s: \n' % model.name)
        fitres.params.pretty_print(colwidth=10, columns=['value', 'min', 'max'])
        
        
        fig, ax = plt.subplots(2, 1, figsize=(10, 8), sharex=True,
                               gridspec_kw={'height_ratios': [3, 1]})
        plt.subplots_adjust(hspace=0)
        ax[0].semilogx(tau, Gn)
        for a in ax:
            a.grid(True); a.grid(True, which='minor', lw=0.3)
        ax[0].plot(tau, fitres.best_fit)
        ax[1].plot(tau, fitres.residual*weights, 'k')
        ym = np.abs(fitres.residual*weights).max()
        ax[1].set_ylim(-ym, ym)
        ax[1].set_xlim(bins[0]*unit, bins[-1]*unit);
        tau_diff_us = fitres.values['tau_diff'] * 1e6
        msg = ((r'$G(0)-1$ = {A0:.2f}'+'\n'+r'$\tau_D$ = {tau_diff_us:.0f} μs')
               .format(A0=fitres.values['A0'], tau_diff_us=tau_diff_us))
        ax[0].text(.75, .9, msg,
                   va='top', ha='left', transform=ax[0].transAxes, fontsize=18);
        ax[0].set_ylabel('G(τ)')
        ax[1].set_ylabel('residuals')
        ax[0].set_title('Correlation function')
        ax[1].set_xlabel('Time Lag, τ (s)');
        
    return tau,Gn

def calc_acf(d, ph_sel=Ph_sel(Aex='Aem'), flag_fit_plot=True):
    t = d.get_ph_times(ph_sel=ph_sel)
    tau,g = calc_cf(t, t, d.clk_p, flag_fit_plot=flag_fit_plot)
    return tau,g

    
def psame(d,tau_min=1e-3,fitfun=Fitfuns.psame_bursts_1,p0=(0.1,0.03),flag_plot=False):
    """ Calculates the p_same probability that the particle returns to the confocal volume after times specified by taus_ms.
        p_same(tau) is fitted by the fitfunction and the fit parameters are returned. """
    # get autocorrelation from center-times of bursts.
    times_clk = (d.mburst[0].start + d.mburst[0].stop)//2
    taus, G = calc_cf(times_clk, times_clk, d.clk_p, flag_fit_plot=False)
    mask = taus>tau_min
    taus = taus[mask]
    G = G[mask]
    p_same = 1 - 1/G
    
    # Fit the p_same(tau)
    fitpar,_,yfit = myfit(fitfun,taus,p_same,p0,flag_plot=False)
    
    if flag_plot:
        plt.figure()
        plt.plot(taus,p_same,'bo',ms=5)
        plt.plot(taus,yfit,'r-')
        plt.xscale('log')
        plt.xlabel('time [s]')
        plt.ylabel('ACF on bursts')
        plt.grid()
        plt.title('p_same calculated from ACF on bursts')
        
    return {'fitpar':fitpar,'fitfun':fitfun, 'taus':taus, 'p_same':p_same}
           

#%% ______UTILS______
"""
-------------------------------------------------------------------------------
UTILS
-------------------------------------------------------------------------------
"""

def find_files(glob2patt, contains_list=[], avoid_list=[]):
    """ searches files according to the glob2 pattern,
        and chooses those filenames that contain at least one substring from the contains_list
        and do not contain any substring from avoid_list.
        Returns a list of filenames. """
        
    def test_file(f,contains_list,avoid_list):
        list1 = [s in f for s in contains_list]
        flag_OK1 = not(False in list1)
        flag_OK2 = True
        for s in avoid_list:
            if s in f:
                flag_OK2=False 
                break
        flag_OK = flag_OK1 and flag_OK2
        return flag_OK
        
    fpaths = glob2.glob(glob2patt)
    return [f for f in fpaths if test_file(f,contains_list,avoid_list)]
    
    
def get_fname_woext(fname):
    """ returns filename without extension """
    s = ntpath.basename(fname)
    return  s.rsplit('.',1)[0]

    
def get_path_woext(fname):
    """ return whole path without file extension """
    return  fname.rsplit('.',1)[0]


def get_part(s,sep,istart,iend):
    """ Splits the string s according to the separator sep and returns the resulting parts with indices from istart to iend. """
    return sep.join(s.split(sep=sep)[istart:iend])


def in_interval(x,tp_minmax):
    """ Checks whether x is within limits defined by tuple tp_minmax = (lower_limit, upper_limit). Returns True or False. """
    return (x>=tp_minmax[0])*(x<=tp_minmax[1])


def fullprint(*args, **kwargs):
  from pprint import pprint
  import numpy
  opt = numpy.get_printoptions()
  numpy.set_printoptions(threshold='nan')
  pprint(*args, **kwargs)
  numpy.set_printoptions(**opt)
  

def binvec(y,binw):
    """ returns a binned vector y using binwidth of binw """
    imax = y.size - y.size%binw
    ybinned = y[:imax].reshape(y.size//binw,binw).sum(axis=1)
    return ybinned


def print_list(stringlist):
    """ prints the numbered stringlist """ 
    [print(i, s) for i,s in enumerate(stringlist)]
    print('\n\n')

    
def sublist(lis, idxs, printit = False):
    """ returns sublist: lis[idxs] """
    lis_sel = [lis[i] for i in idxs]
    if printit:
        [print(i, lis[i]) for i in idxs]
    return lis_sel

def savetxt_table(outname, tp_npvectors, **kwargs):
    """ Saves vectors as txt table with outname as filename.
        tp_vectors is a tuple of np.arrays that are vertically stacked and then transposed before saving as txt
        kwargs are passed to the np.savetxt function, see its documentation for the available options """
    A = np.transpose(np.vstack(tp_npvectors))
    np.savetxt(outname,A,**kwargs)
    
def num_rows_cols(n):
    nc = np.ceil(np.sqrt(n)).astype(np.int)
    nr = np.ceil(n/nc).astype(np.int)
    return nr, nc

def fnamelist(fnpatt, sortmode='time'):
    """ Prints numbered list and return list of files matching the pattern
        fnpatt: pattern for matching the filenames according to glob2.glob search
        sortmode: 'time' or 'abc', sorts the output list by time of creation or alphabetically."""
    fns = glob2.glob(fnpatt)
    if sortmode == 'time': 
        fns.sort(key=os.path.getmtime)
    elif sortmode == 'abc':
        fns.sort(key = str.lower)
    [print(i,f) for i,f in enumerate(fns)]
    return fns
    
        
    


#%% ______PNG EXPORT______
"""
-------------------------------------------------------------------------------
PNG EXPORT
-------------------------------------------------------------------------------
"""
  
def combine_pngs(fnlist, outname, gridsize = None, padding=0):
    """ Loads multiple .pngs, combines them into one png and saves.
    
        Arguments:
            fnlist (list): list of filenames of pngs.
            outname (string): filename for the output png.
            gridsize (tuple 2 elements): Number of rows and cols used to organize pngs into the grid. If not set, it is calculated automatically.
            padding (int): padding between individual pngs.
    """
    
    imlist = [Image.open(fn) for fn in fnlist]    
    n = len(imlist)
    if gridsize is None:
        nc = int(np.round(np.sqrt(n)))
        nr = int(np.ceil(n/float(nc)))
    else:
        nr, nc = gridsize[0], gridsize[1]
        
    w = int(np.max([im.size[0] for im in imlist]))
    h = int(np.max([im.size[1] for im in imlist]))
    
    im_new = Image.new("RGB", (nc*(w+padding), nr*(h+padding)),"white")
    for ir in range(nr):
        for ic in range(nc):
            if ir*nc+ic<n:
                x = ic*(w+padding)
                y = ir*(h+padding)
                im_new.paste(imlist[ir*nc+ic],(x,y))
    if os.path.exists(outname):
        os.remove(outname)
    im_new.save(outname)
    print('combined png saved as: '+outname)
    
    return im_new
    
    
def figs_to_one_png(figlist,outname,tidy_up=True,**kwargs):
    """ Generates one png image from a list of pyplot figures.
    
        Arguments:
            figlist (list): List of pyplot figures
            outname (string): filename of the output png file.
            tidy_up=True: deletes all the intermediate files generated by the function.
            kwargs are passed to the combine_pngs function.        
    """
    base = 'temporary_file'#+str(datetime.now())+'_'
    fnlist = [base+str(k)+'.png' for k in range(len(figlist))]
    for k,fig in enumerate(figlist):
        fig.savefig(fnlist[k])
    im_combined = combine_pngs(fnlist,outname,**kwargs)
    if tidy_up:    
        for fn in fnlist:
            os.remove(fn)
    
    plt.figure()        
    plt.imshow(im_combined)
    plt.axis('off')
    return plt.gcf()
    

#%% ______OTHER FUNCTIONS______
"""
-------------------------------------------------------------------------------
OTHER FUNCTIONS                                                     
------------------------------------------------------------------------------- 
"""

def calc_phrate_mean(d):
        """ calculates the overall mean photon rate and generates a string reporting it. """
        """ returns a dictionary with two keys: phrate and str. """
        
        phrate = d.ph_data_sizes[0]/(d.time_max - d.time_min)
        s = 'Mean photon rate: %.2f kHz' % (phrate/1000)
        return {"phrate":phrate,"str":s}
    

   
def load_d(filename, leakage=0.0, dir_ex=0.00, gamma=1.0, bg_time_s=50, bg_tail_min_us='auto', F_bg=1.7):
    d = loader.photon_hdf5(filename)
    bpl.plot_alternation_hist(d)
    
    d.leakage = leakage
    d.dir_ex = dir_ex
    d.gamma = gamma
    
    loader.alex_apply_period(d)
    
    d.calc_bg(bg.exp_fit, time_s=bg_time_s, tail_min_us=bg_tail_min_us, F_bg=F_bg)
    dplot(d, hist_bg, show_fit=True)
    s = calc_phrate_mean(d)['str']
    plt.text(0.5, 0.95, s, ha='center', va='top', transform = plt.gca().transAxes)

    return d


def get_test_data(fname=None,mode='basic'):
    if fname is None:
        fname = '../data/s54_test.hdf5'
    
    d = load_d(fname,bg_tail_min_us='auto')
    return d

#%%_______EXPERIMENTAL FUNCTIONS______
class experimental_myacf(object):
    
    @staticmethod
    def bin_timestamps(times_clk,binw_clk,y=None):
        """ Creates a vector (intensities in time) by binning the timestamps.
            y is the vector of weights of the timestamps, e.g. correponding to the number of photons in a burst if the timestamps are the time-centers of bursts. """ 
        ibins = times_clk // binw_clk
        if y is None:
            ybinned = np.bincount(ibins)
        else:    
            ybinned = np.bincount(ibins,weights=y)
        return ybinned

    @staticmethod
    def acf_timestamps_dynbin(times_clk, clk_p, taus_ms, y=None):
        """ Returns acf at times taus_ms.
            The bin-timewidth is increasing dynamically with the increasing tau delay.(The question is whether this is correct?) """
        n=10
        m=10
        a = np.zeros(taus_ms.size)
        shift = np.zeros(taus_ms.size,dtype=np.int64)
        taus_ms_acf = np.zeros(taus_ms.size)
        
        for k,tau_ms in enumerate(taus_ms):
            tau_clk = tau_ms*0.001/clk_p
            if k==0: #the first element of the delay vector taus_ms
                # the bin-timewidth (binw_clk) measured in clock ticks, computed as a fraction (1/n) of the first element of the delay vector (taus_ms) 
                binw_clk = int(np.round(tau_clk//n))
                # creates vector by binning timestamps
                ybinned = bin_timestamps(times_clk, binw_clk, y)
            else:
                # if the current delay is larger then m*n times the binwidth,
                # then the ybinned vector is binned again using binning window of size m.
                if tau_clk > m*n*binw_clk:  
                    imax = ybinned.size - ybinned.size%m
                    ybinned = ybinned[:imax].reshape(ybinned.size//m,m).sum(axis=1)
                    binw_clk*=m
            shift[k] = np.round(tau_clk/binw_clk).astype(int)
            a[k] = acf(ybinned,shift[k])
            taus_ms_acf[k] = shift[k]*binw_clk*clk_p*1000
    
        # the real delayes for which ACF is computed may be slightly different from the original taus_ms due to binning in the time-domain
        return a, taus_ms_acf
    
    @staticmethod
    def acf_timestamps(times_clk, binw_clk, clk_p, taus_ms, y=None):
        """ Performs binning on the timestamps and thus creates a vector for the later calculation of ACF.
            It also creates vector of shifts for ACF calculation based on the taus_ms shift-times.
            Returns autocorrelation function at times taus_ms. """
        ybinned = bin_timestamps(times_clk, binw_clk, y)
        shifts = np.round(taus_ms*0.001/clk_p/binw_clk).astype(int)
        a = acf(ybinned,shifts)#/(binw_clk**2)
        return a
    
    @staticmethod
    def acf(y,shifts):
        if not isinstance(shifts,np.ndarray):
            shifts=np.array([shifts])
        ymean = y.mean()
        dy = y#-ymean
        a = np.zeros(shifts.size)
        for j,sh in enumerate(shifts):
            if sh==0:
                a[j] = np.mean(dy*dy)
            else:
                a[j] = np.mean(dy[:-sh]*dy[sh:])
            a[j] /= ymean**2
            
        return a
    
    @staticmethod
    def acf_bu(d,taus_ms,binw_ms=None):
        """ Calculates the autocorellation function at times taus_ms based on burst center-times (it is not taking all photon timestamps, just the burst times).
            If binw_ms is not set (default), then the binning is computed dynamically. """
        times_clk = (d.mburst[0].start + d.mburst[0].stop)//2
        if binw_ms is not None:
            binw_clk = int(binw_ms*0.001/d.clk_p)
            a = acf_timestamps(times_clk, binw_clk, d.clk_p, taus_ms, y=None)
        else:
            a, taus_ms = acf_timestamps_dynbin(times_clk, d.clk_p, taus_ms, y=None)
        return a, taus_ms
    
    @staticmethod
    def acf_d(d,taus_ms):
        times_clk = d.ph_times_m[0] # or d.get_ph_times()
        a = acf_timestamps(times_clk, d.clk_p, taus_ms, y=None)
        return a
    
    @staticmethod
    def plot_acf(t,a):
        plt.figure()
        plt.plot(t,a,'o-')
        plt.xscale('log')
        
        
class CDE(object):
    @staticmethod
    def cde(a, tau_s=50e-6, kernel_type='laplace'):
        """
        Compute FRET-2CDE for each burst.
    
        FRET-2CDE is a quantity that tends to be around 10 for bursts which have no
        dynamics, while it has larger values (e.g. 30..100) for bursts with
        millisecond dynamics.
    
        References:
            Tomov et al. BJ (2012) doi:10.1016/j.bpj.2011.11.4025
    
        Arguments:
            a (Analysis object): Analysis object containing a.d, a.ds
            tau_s (scalar): time-constant of the exponential KDE in seconds
            kernel_type (string): either laplace or gauss
    
        Returns:
            dictionary of FRET values and FRET_2CDE quantities, one element
                per burst. This array contains NaN in correspondence of bursts
                containing to few photons to compute FRET-2CDE.
        """
        d = a.d
        ds = a.ds
        
        tau = int(tau_s/ds.clk_p)  # in raw timestamp units
        
        ph = d.get_ph_times(ph_sel=Ph_sel('all'))
        mask_d = d.get_ph_mask(ph_sel=Ph_sel(Dex='Dem'))
        mask_a = d.get_ph_mask(ph_sel=Ph_sel(Dex='Aem'))
    
        bursts = ds.mburst[0]
        if kernel_type=='laplace':
            fret_2cde = CDE.calc_fret_2cde(tau, ph, mask_d, mask_a, bursts)
        elif kernel_type=='gauss':
            fret_2cde = CDE.calc_fret_2cde_gauss(tau, ph, mask_d, mask_a, bursts)
        else:
            print('Error: kernel_type unrecognised, Laplace will be used')
            fret_2cde = CDE.calc_fret_2cde(tau, ph, mask_d, mask_a, bursts)
            
        plt.figure(figsize=(4.5, 4.5))
        hist_kws = dict(edgecolor='k', linewidth=0.2,
                    facecolor=sns.color_palette('Spectral_r', 100)[7])
    
        valid = np.isfinite(fret_2cde)
        sns.kdeplot(ds.E[0][valid], fret_2cde[valid],
                    cmap='Spectral_r', shade=True, shade_lowest=False, n_levels=20)
        plt.scatter(ds.E[0][valid], fret_2cde[valid], s=0.2, c='k', marker='o', alpha=0.25)
        plt.xlabel('E', fontsize=16)
        plt.ylabel('FRET-2CDE', fontsize=16);
        plt.xlim(-0.2,1.2)
        plt.ylim(-10, 50);
        plt.axhline(10, ls='--', lw=2, color='k')
        plt.text(0.05, 0.95, '2CDE', va='top', fontsize=22, transform=plt.gca().transAxes)
        plt.text(0.95, 0.95, '# Bursts: %d' % valid.sum(), 
                 va='top', ha='right', transform=plt.gca().transAxes)
        #plt.savefig('2cde.png', bbox_inches='tight', dpi=200, transparent=False)
        return dict(E=ds.E[0][valid], CDE=fret_2cde[valid])  
    
    @staticmethod    
    def calc_fret_2cde(tau, ph, mask_d, mask_a, bursts):
        """
        Compute FRET-2CDE for each burst.
    
        FRET-2CDE is a quantity that tends to be around 10 for bursts which have no
        dynamics, while it has larger values (e.g. 30..100) for bursts with
        millisecond dynamics.
    
        References:
            Tomov et al. BJ (2012) doi:10.1016/j.bpj.2011.11.4025
    
        Arguments:
            tau (scalar): time-constant of the exponential KDE
            ph (1D array): array of all-photons timestamps.
            mask_d (bool array): mask for DexDem photons
            mask_a (bool array): mask for DexAem photons
            bursts (Bursts object): object containing burst data
                (start-stop indexes are relative to `ph`).
    
        Returns:
            FRET_2CDE (1D array): array of FRET_2CDE quantities, one element
                per burst. This array contains NaN in correspondence of bursts
                containing to few photons to compute FRET-2CDE.
        """
        # Computing KDE burst-by-burst would cause inaccuracies at the burst edges.
        # Therefore, we first compute KDE on the full timestamps array and then
        # we take slices for each burst.
        # These KDEs are evaluated on all-photons array `ph` (hence the Ti suffix)
        # using D or A photons during D-excitation (argument ph[mask_d] or ph[mask_a]).
        KDE_DTi = phrates.kde_laplace(ph[mask_d], tau, time_axis=ph)
        KDE_ATi = phrates.kde_laplace(ph[mask_a], tau, time_axis=ph)
    
        FRET_2CDE = []
        for ib, burst in enumerate(bursts):
            burst_slice = slice(int(burst.istart), int(burst.istop) + 1)
            if ~mask_d[burst_slice].any() or ~mask_a[burst_slice].any():
                # Either D or A photon stream has no photons in current burst,
                # thus FRET_2CDE cannot be computed. Fill position with NaN.
                FRET_2CDE.append(np.nan)
                continue
    
            # Take slices of KDEs for current burst
            kde_adi = KDE_ATi[burst_slice][mask_d[burst_slice]]
            kde_ddi = KDE_DTi[burst_slice][mask_d[burst_slice]]
            kde_dai = KDE_DTi[burst_slice][mask_a[burst_slice]]
            kde_aai = KDE_ATi[burst_slice][mask_a[burst_slice]]
    
            # nbKDE does not include the "center" timestamp which contributes 1.
            # We thus subtract 1 from the precomputed KDEs.
            # The N_CHD (N_CHA) value in the correction factor is the number of
            # timestamps in DexDem (DexAem) stream falling within the current burst.
            N_CHD = mask_d[burst_slice].sum()
            N_CHA = mask_a[burst_slice].sum()
            nbkde_ddi = (1 + 2/N_CHD) * (kde_ddi - 1)
            nbkde_aai = (1 + 2/N_CHA) * (kde_aai - 1)
    
            # N_CHD (N_CHA) in eq. 6 (eq. 7) of (Tomov 2012) is the number of photons
            # in DexDem (DexAem) in current burst. Thus the sum is a mean.
            ED = np.mean(kde_adi / (kde_adi + nbkde_ddi))  # (E)_D
            EA = np.mean(kde_dai / (kde_dai + nbkde_aai))  # (1 - E)_A
    
            # Compute fret_2cde for current burst
            fret_2cde = 110 - 100 * (ED + EA)
            FRET_2CDE.append(fret_2cde)
        return np.array(FRET_2CDE)

    @staticmethod
    def calc_fret_2cde_gauss(tau, ph, mask_d, mask_a, bursts):
        """
        Compute a modification of FRET-2CDE using a Gaussian kernel.
        
        Reference: Tomov et al. BJ (2012) doi:10.1016/j.bpj.2011.11.4025
        
        Instead of using the exponential kernel (i.e. laplace distribution)
        of the original paper, here we use a Gaussian kernel.
        Photon density using Gaussian kernel provides a smooth estimate
        regardless of the evaluation time. On the contrary, the 
        laplace-distribution kernel has discontinuities in the derivative 
        (cuspids) on each time point corresponding to a timestamp.
        Using a Gaussian kernel removes the need of using the heuristic 
        correction (pre-factor) of nbKDE.
        
        Arguments:
            tau (scalar): time-constant of the exponential KDE
            ph (1D array): array of all-photons timestamps.
            mask_d (bool array): mask for DexDem photons
            mask_a (bool array): mask for DexAem photons
            bursts (Bursts object): object containing burst data
            
        Returns:
            FRET_2CDE (1D array): array of FRET_2CDE quantities, one element 
                per burst. This array contains NaN in correspondence of bursts 
                containing to few photons to compute FRET-2CDE.
        """
        # Computing KDE burst-by-burst would cause inaccuracies at the edges
        # So, we compute KDE for the full timestamps
        KDE_DTi = phrates.kde_gaussian(ph[mask_d], tau, time_axis=ph)
        KDE_ATi = phrates.kde_gaussian(ph[mask_a], tau, time_axis=ph)
    
        FRET_2CDE = []
        for ib, burst in enumerate(bursts):
            burst_slice = slice(int(burst.istart), int(burst.istop) + 1)
            if ~mask_d[burst_slice].any() or ~mask_a[burst_slice].any():
                # Either D or A photon stream has no photons in current burst, 
                # thus FRET_2CDE cannot be computed.
                FRET_2CDE.append(np.nan)
                continue
    
            kde_ddi = KDE_DTi[burst_slice][mask_d[burst_slice]]
            kde_adi = KDE_ATi[burst_slice][mask_d[burst_slice]]    
            kde_dai = KDE_DTi[burst_slice][mask_a[burst_slice]]
            kde_aai = KDE_ATi[burst_slice][mask_a[burst_slice]]
            
            ED = np.mean(kde_adi / (kde_adi + kde_ddi))  # (E)_D
            EA = np.mean(kde_dai / (kde_dai + kde_aai))  # (1 - E)_A
    
            fret_2cde = 110 - 100 * (ED + EA)
            FRET_2CDE.append(fret_2cde)
        return np.array(FRET_2CDE)

    

#%% ______TEST FUNCTIONS______

def test_kwargs(a,b=33,c=26,**kwargs):
    import inspect # 'inspect' is a local name
    x = 3          # 'x' is another local name
    args = inspect.getfullargspec(test_kwargs)
    args = inspect.getargspec(test_kwargs)
    #dic = {k:v for k,v in locals().items() if k in args}
    #print(dic)
    print(args)
    
def test_kwargs2(a,**kwargs):
    print(a)
    for key in kwargs:
        print(key, kwargs[key])

def _test():
    pass






#%% ______MAIN______
    
if __name__ == "__main__":
    
    from fretbursts.phtools import phrates
    #sns = init_notebook(apionly=True)
    #sns.__version__

    #a = Analysis('../data/ceb1-ctrl_PBSK-ctrl_01.hdf5', leakage=0.1, dir_ex=0.0, gamma=1.0, busearch_L=15, busearch_F=4, busearch_computefret=True, fuse_ms=2, meas_type='PAX', settPAX=dict(D_ON=(400,1800), A_ON=(1900, 3300)))
    #alex_jointplot(a.ds)
    
    cde = CDE.cde(a)

    


    #plt.close('all')
    #d = get_test_data()
    #a = get_test_Analysis(fuse_ms=5)
    #a.proc()
    #a.reportpng()
    
    #sett = Sett_hdf5(donor=0, acceptor=1, alex_period=8000, alex_offset=10, alex_period_donor=(300,3300), alex_period_acceptor=(4400, 7300))
    #fnhdf,_ = convert_sm('../data/data1.sm',sett)
    #convert_sm_multi(sett,pattern="../data/*.sm",flag_plot_althist=True)
    #merge_sm_to_hdf_patt(sett, patt="../data/*.sm")
    #fnlist = ['data1.sm','data2.sm']
    #fnlist = ['data/'+f for f in fnlist]
    #am.cvt.merge_sm_to_hdf(fnlist,sett)
    
    #a = Analysis('../data/s54.hdf5',leakage=0.1, dir_ex=0.0, gamma=1.1, busearch_F=5, fuse_ms=2)
    #pprint.pprint(vars(a.d))
    #a = Analysis('../data/ceb1-ctrl_PBSK-ctrl_01.hdf5',leakage=0.1, dir_ex=0.0, gamma=1.1, busearch_F=5, fuse_ms=2)
    
#    a.ds = a.df.select_bursts(select_bursts.time, time_s1=0, time_s2=500, computefret=False)
#    a.ds = a.ds.select_bursts(select_bursts.size, th1=50, computefret=False, add_naa=True)
#    a.ds = a.ds.select_bursts(select_bursts.S, S1=0.2, S2=0.8, computefret=False)
#    
    #pprint.pprint(vars(a.d))
    #a.proc(options=('ES'))
    #np.savetxt('E.txt', a.ds.E[0], fmt='%.3f', delimiter=' ', newline='\n', header='FRET values of bursts', footer='', comments='# ')
    #savetxt_table('ES.txt', (a.ds.E[0],a.ds.S[0]), fmt='%.3f', delimiter='; ', newline='\n', header='E and S values of bursts', footer='', comments='# ') 
    
    #bva = BVA(a.ds,mode=1,flag_E_orig=False)
    
    #r = RASP(a.ds, dt_minmax_ms=(0,200))
    #r.plot2d_tseries(tseries=[10,20,50,200],cummul=False)
    #r.plot2d_tseries(tseries=[10,50,200],cummul=True)
    


    
    
    #test_kwargs2(1,b=2)
    #test_kwargs2(1)
    
    
    #print(a.d.mburst[0])
    #print(type(a.d.mburst[0]))
    #print(type(a.ds.mburst[0]))
    
    #bursts_info(a.ds)
    #example_values(a.ds)
    
    #print(a.ds.mburst[0])
    #plt.close('all')
    #b = BVA(a.ds,mode=1,flag_E_orig=True,plot_kind='scatter')
    #b = BVA(a.ds,mode=3,long_percentile=100)
    
    #plot_Ehist_gaussfit(a.ds,numgauss=2)
    
    #model = mfit.factory_two_gaussians(add_bridge=True)
    #model.print_param_hints()
    #model.set_param_hint('p1_center', value=0.3, min=-0.1, max=0.6)
    
    #plot_Ehist_gaussfit(a.ds,model=model)
    
    
    #a.proc(options=())
    #a.reportpng()
    #a.example_values()
    
    #r = RASP(a.ds)
    #r.plot2d()
    #r.plothist(E0_limits=(0.1,0.3))
    
    #a = Analysis('../data/CEB1-2uMRAD51-NoATP-new_BPES-10K-2uMRAD_45PM_01.hdf5')
    #d = a.d
    
    #times_clk = (d.mburst[0].start + d.mburst[0].stop)//2
    #tau1, G1 = calc_cf(times_clk,times_clk,d.clk_p)
    #tau2,G2 = calc_cf(d.ph_times_m[0], d.ph_times_m[0], d.clk_p)
    #plt.figure()
    #plt.plot(tau1,(G1-1)/70,tau2,G2-1,'-')
    #plt.xscale('log')
    #plt.close('all')
    #ps = psame(d, fitfun=Fitfuns.psame_bursts_1, p0=[0.1,0.01], flag_plot=True)
    #plt.plot(ps['taus'],Fitfuns.psame_bursts_1(ps['taus'],0.05,0.02))

    #r = RASP(a.ds)
    #r.plot2d_tseries(tseries=(10,20,50))
    
    #a.proc1(nrowscols=(2,2))



    

