% This program simulates an extension trajectory based on a hidden markov
% model with a normal noise distribution, estimates the model parameters 
% from the extension trajectory using both HMM-DB and HMM-EM, 
% and finally compare the results from the two optimization approaches.
% Detailed balance is examined for results from EM.

% Reference:
% "Hidden Markov Model with Detailed Balance and Its Application to Single Protein Folding" 
% by Yongli Zhang, Junyi Jiao, and Aleksander A. Rebane

clear all
close all  
% global data Qnst priorg c 
t0=cputime;
% The input parameters
time=10;  % Time of simulation
T=5000*time;  % Number of data points
disp('Total number of data points:')
disp(T)
Qnst=7; % Number of states

c=1e6;  % Maximum molecular transition rate
dt=2e-4; % smapling time interval
g0=zeros(Qnst,Qnst);  % Input energy matrix
g0=[0 7 9 8  9  8 9;...
    0 0 9 10 7  9 9;...
    0 0 2 11 11 11 11;...
    0 0 0 0  8  9 10;...
    0 0 0 0  1  9 10;...
    0 0 0 0  0  -1 8;...
    0 0 0 0  0  0  0];

mu0=(0:Qnst-1).*10;
Sigma0=16.*ones(1,Qnst);

ld=diag(true(Qnst,1));  % True for diagonal element
ltriu=triu(true(Qnst),1); % True for upper diagonal elements, excluding diagonal elements
ltril=ltriu'; % True for lower diagonal elements
tmp=g0';
g0(ltril)=tmp(ltril);  % Make g0 symmetric

disp('g0')
disp(g0)

% Display energy barrier
gs=g0(ld);  % State energy column vector
gsm=repmat(gs,1,Qnst);
gb=g0-gsm;  % Activation energy matrix
disp('Energy barrier gb=')
disp(gb)

ng=0.5*Qnst*(Qnst+1)-1; % number of independent energy parameters
nx=ng+2*Qnst;  % Total number of fitting parameters in x array
nmu=ng+Qnst;   % state position end index in x.
x0=zeros(1,nx); % Input parameters
tmp=g0(ld);
x0(1:Qnst-1)=tmp(1:Qnst-1);
x0(Qnst:ng)=g0(ltril);
x0(ng+1:end)=[mu0 Sigma0];

gs=g0(ld);  % state energy vector
g1=repmat(gs,[1 Qnst]);
q=c.*exp(g1-g0);  %q is the transition rate matrix
q(ld)=0;
q(ld)=-sum(q');

disp('Rate constants:')
disp(q)

if(any(q.*dt)>1)
    disp('The transition is too fast!')
    disp('q.*dt')
    disp(q.*dt)
end

G=(g1-g1')./2;
G=exp(G);  % G(i,j)=exp((gs(i)-gs(j))/2)

% Calculate the symmetrix rate matrix Q
Q=G'.*q;

% Compute the eigenvalue decomposition of Q
[Ue,lambda]=eig(Q);
lambda=lambda(ld);  % lambda is now a column array of eigenvalues
U=cell(Qnst,1); % The symmetric U matrix for spectrum decomposition
for i=1:Qnst
    U{i}=Ue(:,i)*Ue(:,i)';
end
lambda_exp=exp(lambda.*dt);

% Compute the transition probability matrix transmat
transmat=zeros(Qnst,Qnst);
for i=1:Qnst
    transmat=transmat+lambda_exp(i).*U{i};
end
transmat=G.*transmat;

transmat0=transmat;
disp('transmat0')
disp(transmat0)

tmp=exp(-gs);
statepop0=tmp./sum(tmp);  % State population
lifetime0=-1000./q(ld);  % State lifetime in ms
disp('State population')
disp(statepop0')
disp('Lifetime')
disp(lifetime0')
 
% Simulate a hidden Markov process
initial_prob0 = zeros(Qnst,1);
initial_prob0(1)=1;
[data, hidden] = mhmm_sample_YZ(T, initial_prob0, transmat0, mu0, Sigma0);

% transition statistics (Table S2 in Supporting Info.)
aij=zeros(Qnst,Qnst);
for i=1:Qnst
    indexi=~logical(hidden-i);
    aij(i,i)=sum(indexi & circshift(indexi,[-1 0]));
    for j=i+1:Qnst
        indexj=~logical(hidden-j);
        indexj=circshift(indexj,[-1 0]);
        aij(i,j)=sum(indexi & indexj);
        aij(j,i)=aij(i,j);
    end
end
disp(['T= ' num2str(T)])
disp('Number of specific transitions:')
disp(aij)

priorg=initial_prob0;
% Estimate the parameters in the simulated Markov process
nx=length(x0);

% set the initial parameter for optimization
xA=zeros(1,nx);
xA=x0+2.*(rand(1,nx)-0.5);

% disp('xA')x
% disp(xA)
options = optimoptions('fminunc','Algorithm','quasi-newton','GradObj','on','TolFun',1e-5,'PlotFcns',@optimplotfval);

%options = optimoptions('fminunc','Algorithm','quasi-newton','GradObj','on','TolFun',1e-3,'Display','iter','PlotFcns',@optimplotfval);

tic
f=@(x)hmmfun_grad(x,data,Qnst,ng,priorg,c,dt);
[x,fmin] = fminunc(f,xA,options);
disp('Minimal BIC:')
disp(fmin)

t2=toc;

h=findobj(gca,'Type','line');
Lgrad=-get(h,'ydata');
xgrad=1:length(Lgrad);

% disp('The best-fitted parameters:')
% disp(x)
% disp(['CPU time for optimization: ' num2str(t2)])

% disp('Comparison of the fitting (starting point, input, output):')
% disp([xA;x0;x])

[transmat,mu,Sigma,en_all,en_s]=x2h_grad(x,Qnst,c,dt);

[statepop1,lifetime1,rate1]=state_info_grad(c,en_all);

% Calculate idealized state
sig=sqrt(Sigma);
obslik=zeros(Qnst,T);
for i=1:Qnst
 obslik(i,:)=normpdf(data,mu(i),sig(i));
end
path = viterbi_path(priorg, transmat, obslik);
numb_mismatch=sum(~(path==hidden'));
ratio_mismatch=numb_mismatch/T;
disp(['Ratio of mismatch: ' num2str(ratio_mismatch,4)])
%index=find(~(path==hidden'));
  
index=[];
ideal_path=zeros(size(path));
hidden_path=zeros(size(path));
% Calculate the relevant parameters.
for j=1:Qnst
   index=~logical(path-j);
   ideal_path(index)=mu(j);
   index=~logical(hidden-j);
   hidden_path(index)=mu0(j);
end   

% xt=(1:T).*dt;
% figure
% plot(xt,data,'-b',xt,hidden_path,'o-r',xt,ideal_path,'*k')
% xlabel('Time (s)')
% ylabel('Extension (nm)')


ddt=toc;
disp(['Total CPU time: ' num2str(ddt) ' seconds; optimization time= ' num2str(t2) ' seconds.'])

% Hidden-Markov modeling using EM algorithm.
mixmat2=ones(Qnst,1);
tic
% Parse fitting parameters in 1D x to transition probability matrix, state
% extensions and flucatuations
[transmatA,muA,SigmaA,en_allA,en_sA]=x2h_grad(xA,Qnst,c,dt);
SigmaA=reshape(SigmaA,[1,1,Qnst]); % State fluctuation

[LL, prior2, transmat2, mu2, Sigma2, mixmat2] = ...
    mhmm_em(data, initial_prob0, transmatA, muA, SigmaA, mixmat2,'thresh',1e-7, 'max_iter', 60,'verbose',0);

t3=toc;

loglik = mhmm_logprob(data, prior2, transmat2, mu2, Sigma2, mixmat2);

disp('Compare CPU time for gradient search and EM algrothm: ')
disp([t2 t3])

disp('Loglik (gradient, EM, difference):')
disp([-fmin loglik -fmin-loglik])

% Calculate state energy. 
% Note that the transition state energy energy_all_em may not be symmetric
[energy_state_em, energy_all_em,  statepop2, lifetime2, rate2]=TPM2en(transmat2,c,dt); 
% Average transition state energy between states i and j =
% (e(i,j)+e(j,i))/2
energy_all_em_av=(energy_all_em+energy_all_em')./2;  
% Difference of the energy barrier. Detailed balance means
% energy_all_em_diff=0
energy_all_em_diff=energy_all_em-energy_all_em';

disp('Energy barrier unbalance: ')
disp(energy_all_em_diff)

% Check if the HMM satisfies the detailed balance
P=diag(statepop2);
A=P*rate2;
% Degree of detial unbalance. Detailed balance means db=0;
db=2.*(A-A')./(A+A');
disp('Detailed unbalance:')
disp(db)

x_em=zeros(size(x));  % x from em algorithm
x_em(1:Qnst-1)=energy_state_em(1:Qnst-1);
x_em(Qnst:ng)=energy_all_em_av(ltril);
x_em(ng+1:end)=[mu2,Sigma2(:)'];

disp('Comparison of the best-fit parameters (starting point, input, gradient output, EM output):')
disp([xA;x0;x;x_em])

disp('State population (input, gradient output, EM output:')
disp([statepop0'; statepop1; statepop2'])

disp('State lifetime (input, gradient output, EM output:')
disp([lifetime0'; lifetime1; lifetime2'])

% State idealization from EM
obslik = mixgauss_prob(data, mu2, Sigma2);
path2 = viterbi_path(prior2, transmat2, obslik);

index=[];
ideal_path2=zeros(size(path2));
% Calculate the relevant parameters.
for j=1:Qnst
   index=~logical(path2-j);
   ideal_path2(index)=mu2(j);
end   

numb_mismatch2=sum(~(path2==hidden'));
ratio_mismatch2=numb_mismatch2/T;
disp('Ratio of mismatch (gradient, EM): ')
disp([ratio_mismatch, ratio_mismatch2])

xt=(1:T).*dt;
figure
plot(xt,data,'-b',xt,hidden_path,'or',xt,ideal_path,'+k',xt,ideal_path2,'sg')
xlabel('Time (s)')
ylabel('Extension (nm)')

savefig('Q7sim_current.fig')

disp('CPU time (gradient, EM):')
disp([t2 t3])

figure
xem=1:length(LL);
plot(xgrad,Lgrad,'-ok',xem,LL,'-*r')

% ------------------------------------------------------------------------
% calculate the standard deviations of the fitted parameters

np=11;
xu=x;
xb=x;

ix=1:nx;  % Index of the x to be tested.
ib=0.4; % Bound of the corresponding x.

xu(ix)=x(ix)+ib;
xb(ix)=x(ix)-ib;
% xu(4)=40;
% xb(4)=8;
plot_or_not=true;

[xfit,yfit,sigfit,xx,yy]=error_fun(f,x,xb,xu,np,plot_or_not);

% disp('Error:')
% disp(sigfit)

