# -*- coding: utf-8 -*-
"""
@author: Tim Balmer, Trussell Lab, OHSU
This script creates a .CSV file that can be used by pClamp to stimulate with a sinusoidally  modulated pattern.  
"""
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal

#choose the preset modulation frequency by changing mf variable to .3, 1 or 3
mf=1

if mf == .3:
    modulator_frequency = .3 #0.3 1 or 3
    carrier_frequency = 26 #26
    modulation_index = 6 #5/3*modulator_frequency #USE 6 for .3, 

if mf == 1:
    modulator_frequency = 1 #0.3 1 or 3
    carrier_frequency = 26 #26
    modulation_index = 26*2 #5/3*modulator_frequency #USE 40 for 1, 

if mf == 3:
    modulator_frequency = 3 #0.3 1 or 3
    carrier_frequency = 26 #26
    modulation_index = 140 #5/3*modulator_frequency #USE 120 for 3, 
    
   
sample_rt=20000 #sample rate in Hz
reps = 10 #number of sweeps
before = 5 #time before steady state stim and after modulation
ssdur = 10 #duration of steady state stim
moddur = 10 #duration of modulated stim
pulsedur = 100e-6 #pulse duration in seconds

sample_int=1/sample_rt #convert to sample interval in seconds

#make one cycle of the modulated stim
period=1/modulator_frequency
time=np.arange(0,period,sample_int) # 20kHz time series of one cycle
modulator =np.sin(2.0 * np.pi * modulator_frequency * time-(np.pi/2))*modulation_index #-(np.pi/2) shifts the wave so it goes up at first
carrier = np.sin(2.0 * np.pi * carrier_frequency * time)
product = np.zeros_like(modulator)
beta = modulation_index / modulator_frequency
mod = beta * np.cos(2. * np.pi * modulator_frequency * time + 1*np.pi)
product = np.cos(2. * np.pi * (carrier_frequency * time)  + mod)
mod_effective = beta * np.cos(2. * np.pi * modulator_frequency * time + 3/2*np.pi)
zero_times = (carrier_frequency+mod_effective) <= 0
zero_inds=np.nonzero(zero_times)[0]
if modulation_index>carrier_frequency:
    product[zero_inds]=0

plt.subplot(4, 1, 1)
plt.title('Frequency Modulation')
plt.plot(time,modulator)
plt.ylabel('Amplitude')
plt.xlabel('Modulator signal')
plt.subplot(4, 1, 2)
plt.plot(time,carrier)
plt.ylabel('Amplitude')
plt.xlabel('Carrier signal')
plt.subplot(4, 1, 3)
plt.plot(time,mod_effective)
plt.plot(time,zero_times)
plt.subplot(4, 1, 4)
plt.plot(time,product)
plt.plot(time,zero_times)
plt.ylabel('Amplitude')
plt.xlabel('Output signal')

#% find peaks of sine wave
product[product<.9]=0 #only look at the peaks that reach ~1
pks_ind = signal.argrelmax(product)[0] #find peak indices
pks_val = product[signal.argrelmax(product)] #find peak values
 # find indices where the peak values are above the threshold set above
plt.plot(pks_ind*sample_int,pks_val,'or')
plt.show()

#make a stimulation free beginning and ending
deadtime=np.zeros(sample_rt*before)
#make a train before modulation that is the same as the carrier frequency
prepulse=np.arange(0,ssdur,(1/carrier_frequency))
prepulse=prepulse*sample_rt
ppulse=np.zeros(sample_rt*ssdur)
pulsedur_dig = int(pulsedur/sample_int)
for j, i in enumerate(prepulse): #j is the iteration, i is the value
    i=int(i)
    ppulse[i:i+pulsedur_dig]=5

#make modulated pulse train
pulse=np.zeros(int(sample_rt*period))
for j, i in enumerate((pks_ind)): #j is the iteration, i is the value
    pulse[i:i+pulsedur_dig]=5
#pulse=np.trim_zeros(pulse,trim='f') #delete all the zeros before the first stim of pulse, this starts it at the right freqency from the steady state before it
leading_zeros=np.nonzero(pulse)[0][0] #find the first non-zero index
leading_zeros=280 #for proper offset
pulse=np.insert(pulse,-1,np.zeros(leading_zeros)) #add twice this many zeros to the end
pulse=np.tile(pulse,int(moddur/period)) #take this one cycle and repeat it 
pulse=np.trim_zeros(pulse,trim='f') #delete all the zeros before the first stim of pulse, this starts it at the right freqency from the steady state before it

ptrain=np.concatenate([deadtime,ppulse,pulse,deadtime]) #put the whole protocol together
x=np.arange(0,len(ptrain)/sample_rt,sample_int) #make an timeseries that matches

#plot the entire train
plt.figure()   
plt.plot(x,ptrain,'.-k')
plt.ylabel('Volts')
plt.xlabel('Seconds')

#% calculate frequency of pulses and plot with freq on y axis and time on x
intervals_ind = np.where(ptrain>0)[0] #find the indices of the pulses
intervals = np.diff(intervals_ind) #find the intervals between the pulses
rate=1/(intervals*sample_int) #convert intervals to rates
rate=np.insert(rate,0,[0])#add zero to beginning of rate so it is the correct length
#plot it
plt.figure()
plt.plot(intervals_ind/sample_rt,rate,'|k')
plt.ylabel('Rate (Hz)')
plt.xlabel('Seconds')
plt.ylim(0,200)
plt.xlim(0,30)

#%% export as txt file to be opened as stimulus file in clampex.  It's not obvious.  First open the .csv and save as .atf.  Then open the .atf as a stimulus file. Save the protocol.  See below.
#in clampex open this file: File/open data.  Set to V. save as atf file. close. 
#in clampex Acquire/New Protocol.  Set the sampling rate to match above (20kHz), set time to match above (20s), choose episodic stimulation.  
#On outputs tab choose OUT2 (which puts the analog output 2 channel into Voltage output.  
#In the waveforrm tab, click the lower 'Channel #2 tab'.  Check analog waveform box, chose stimulus file radio button.  Hit Stimulus File button and choose the ATF file you just made.
# Acquire/save protocol.  I'm saving them as 26xp3x10Hz, which means carrier freq X modulation freq X modulation index. p means point: p3 is 0.3
waveform=np.vstack((x*1000,np.tile(ptrain,(reps,1)))).T # multiplying by 1000 puts the timeseries into ms, which is what clampex wants
np.savetxt('26xp3x10Hz.csv',waveform,delimiter=",")