#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This is a program that does Gillespie stochastic simulation for n(r), 
with simple model that has constant branching and terminating rates 
across the space.


@author: cherri
"""
import numpy as np
#import random as rand
import matplotlib.pyplot as plt
from neuronclass import *
datadir='data/'
graphdir='graphs/'

global n0, v0, kb, kt, diffk, sumk
n0=30  
v0=0.
key='wildtype'  # modify this key for different parameters as listed below
twostage=False
if key=='wildtype':
    kb=0.360 # Tm20 from Ting 2014
    kt=0.594 # wild type terminating rate Tm20 from Ting2014
elif key=='babo':
    kb=0.322 #Tm20 in babo mutant, Ting2014
    kt=0.431 # Tm20 in babo mutant terminating rate from Ting2014
elif key=='Tm2':
    kb=0.393 #Tm2  Ting2014
    kt=0.602 # Tm2 terminating rate from Ting2014
elif key=='Tm9':
    kb=0.371 #Tm9 , Ting2014
    kt=0.589 # Tm9  terminating rate from Ting2014
elif key=='big005-360':
    kb=0.360
    kt=kb+0.005
elif key=='big002-360':
    kb=0.360
    kt=kb+0.002
elif key=='big001-360':
    kb=0.360
    kt=kb+0.001
elif key=='big2stage':  #early stage increase branching 
    kb=[0.42,0.360]
    kt=[0.40,0.40]
    t_switch=[100]
    twostage=True
elif key=='big2stage_kfinc':  #early stage increase branching 
    kb=[0.42,0.360]
    kt=[0.40,0.42]
    t_switch=[100]
    twostage=True
filename=key    
#filename=key+str(n0)

if not twostage :
    diffk=kb-kt
    rmax=1/abs(diffk)*10
else:
    diffk=kb[-1]-kt[-1]
    rmax=1/abs(diffk)*10+t_switch[-1]

nsample=100
num_collect_data=200
r_collect=np.linspace(0,rmax,num_collect_data) # to collect the traj in even r
n_collect=np.zeros(num_collect_data)
# for doing mean and variances
n_collect_1sum = np.zeros(num_collect_data)
n_collect_2sum = np.zeros(num_collect_data)

R95=np.zeros(0)
Rg=np.zeros(0)
densityR95=np.zeros(0)
densityRg=np.zeros(0)
rterm=np.zeros(0)
#plotobj=plt.figure()
se=np.zeros((0,2))
se_b=np.zeros((0,2))
se_t=np.zeros((0,2))

for isample in range(nsample):
    
    if twostage:
        neuron=Neuron_two(n0)
        neuron.nrtraj(kb,kt,t_switch)     
    else:
        neuron=Neuron_one(n0)
        neuron.nrtraj(kb,kt)
    neuron.statistics()
    R95=np.append(R95,neuron.R95)
    Rg=np.append(Rg,neuron.Rg)
    densityR95=np.append(densityR95,neuron.densityR95)
    densityRg=np.append(densityRg,neuron.densityRg)
    rterm=np.append(rterm,neuron.r_terminating)
    itraj=0
    r_collect[0]=neuron.r_traj[0]
    for i in range(1,num_collect_data):
        while itraj<len(neuron.r_traj) and neuron.r_traj[itraj]<r_collect[i] :
            itraj+=1
        n_collect[i]=neuron.n_traj[itraj-1]
#    f=plt.step(neuron.r_traj, neuron.n_traj, '-')
#    plt.setp(f, 'color', '#cccccc', 'linewidth', 0.5)
    n_collect_1sum += n_collect
    n_collect_2sum += np.multiply(n_collect,n_collect)

    se=np.append(se,neuron.startend,axis=0)
    se_b=np.append(se_b,neuron.startend_b,axis=0)
    se_t=np.append(se_t,neuron.startend_t,axis=0)
#    if isample<10 :
#        np.save(datadir+"startend_"+filename+str(isample),neuron.startend)
    
    
        #end isample
np.save(datadir+"startend_all_"+filename,se)
np.save(datadir+"startend_branch_"+filename,se_b)
np.save(datadir+"startend_term_"+filename,se_t)
np.save(datadir+"R95-"+filename,R95)

datafile=open(datadir+filename+'.txt','w')
datafile.write('variable mean std\n')
datafile.write(f'rterm {np.mean(rterm):.3f} {np.std(rterm):.3f} \n')
print(f'rterm {np.mean(rterm):.3f} {np.std(rterm):.3f} \n')
datafile.write(f'R95 {np.mean(R95):.3f} {np.std(R95):.3f} \n')
print(f'R95 {np.mean(R95):.3f} {np.std(R95):.3f} \n')
datafile.write(f'Rg {np.mean(Rg):.3f} {np.std(Rg):.3f} \n')
print(f'Rg {np.mean(Rg):.3f} {np.std(Rg):.3f} \n')
datafile.write(f'densityR95 {np.mean(densityR95):.3f} {np.std(densityR95):.3f} \n')
print(f'densityR95 {np.mean(densityR95):.3f} {np.std(densityR95):.3f} \n')
datafile.write(f'densityRg {np.mean(densityRg):.3f} {np.std(densityRg):.3f} \n')
print(f'densityRg {np.mean(densityRg):.3f} {np.std(densityRg):.3f} \n')
datafile.close()


n_collect_mean=n_collect_1sum/nsample
n_collect_std=np.sqrt(n_collect_2sum/nsample-np.multiply(n_collect_mean, n_collect_mean))

#f1=plt.plot(r_collect,n_collect_mean,'-')
#plt.setp(f1, 'color', '#000000', 'linewidth', 2.0)
#f2=plt.plot(r_collect,n_collect_mean+n_collect_std, '-.',r_collect, n_collect_mean-n_collect_std, '-.')
#plt.setp(f2, 'color', '#000000', 'linewidth', 1.0)
#plt.xlabel("length of dendrite segments")
#plt.ylabel("number of dendrites")
#plt.show()
#plotobj.savefig(graphdir+filename+".pdf",bbox_inches='tight')

