% script M-file
% Written by Stephen Yang in January 2014
% Modified by A. Lee 2/11/14 to fix residuals calc in grid search
% and exponent syntax.
% modified by jpb. 3-state fitting of CPMG & CEST data for apo hTS 
% Model: A (major)-> B and A (major)-> C (linear) (BAC)
% for residues in concerted process, slow process params fixed from CEST
% only fit, fast process params fit locally (SINGLE PROFILE FITS HERE)
% fits forward & reverse rates instead of pops & kex

clear;
clear global field;
clear global taucp;
clear global data;
clear global n180array;
clear global delta;

global field
global protonfield
global taucp
global data
global n180array
global delta
global ptsmax
global kab
global dwab
global kba
global isMQ
global numres
global freq
global sllist
global sllist_allres
global numpts
global chemshift
global fityn
global SLdura
global chemcenter
global relErr
global idxlist

options = optimset('lsqnonlin');
options = optimset(options,'algorithm','trust-region-reflective');
%options = optimset(options,'algorithm','levenberg-marquardt');
options = optimset(options,'Display','off');
options = optimset(options,'TolFun',1E-12);
options = optimset(options,'TolX',1E-12);
options = optimset(options,'MaxFunEvals', 3000);
% ################ read in data from file ################

%load CPMG mats (from profile plotting script)
load('R2eff_MQ850.mat')
load('eR2eff_MQ850.mat')
load('R2eff_MQ600.mat')
load('eR2eff_MQ600.mat')
load('R2eff_SQ850.mat')
load('eR2eff_SQ850.mat')
load('taucpmg_MQ850.mat')
load('taucpmg_MQ600.mat')
load('taucpmg_SQ850.mat')

%load CEST mats (from profile plotting/spin lock calibration scripts)
load('CEST25_full2_mat.mat')
load('CEST40_full3_mat.mat')
load('powcalib40.mat')
load('calib25pow.mat')
clear datainput
clear dataout
global datainput
global dataout

residue = 'L74'; % name of single profile to be fit, used in plotting
mc = 0; %number of mc sims

isMQ = [1 1 0]; %1 = MQ, 0 = SQ

ylolim = 0;     % lower limit for y-axis in plot

xlb = [0.1 20 0 0 0 5 5 5 0 1]; % lower bounds for 
xub = [1000 100000 5 0 0 80 80 80 5 50]; % upper bounds for fitted params
%xlb = [0 0 0 0 0 5 5 5 0 1]; % lower bounds for 
%xub = [0 0 0 0 0 80 80 80 5 50]; % upper bounds for fitted params
%kac kca dw_ac dw_ab1H dw_ac1H r2_fld1,2,...

%r2abc = [14 12 36]; % r2 for fields 1, 2, 3, 4, 5

flds = 3;

%13C field strength (MHz)
field = [213.802506 150.902406 213.802506];

ptsmax = [length(taucpmg_MQ850) length(taucpmg_MQ600) length(taucpmg_SQ850)];

protonfield = field*(850.28/213.802506); %assumes heteronuc is 13C

taucp{1} = taucpmg_MQ850';
taucp{2} = taucpmg_MQ600';
taucp{3} = taucpmg_SQ850';

%index values for probe to be fit in CPMG & CEST datasets
%for example, 21(MQ), 23(SQ), 39(CEST) for L198
idxMQ = 44;
idxSQ = 47;
idxlist = 62; %CEST

data{1} = [R2eff_MQ850(idxMQ,:)' eR2eff_MQ850(idxMQ,:)'];
data{2} = [R2eff_MQ600(idxMQ,:)' eR2eff_MQ600(idxMQ,:)'];
data{3} = [R2eff_SQ850(idxSQ,:)' eR2eff_SQ850(idxSQ,:)'];

%intial guess for R2 (CPMG) at each field
r2abc = [min(data{1}(:,1)) min(data{2}(:,1)) min(data{3}(:,1))];
%r2abc = [min(data{1}(:,1))+5 min(data{2}(:,1))+5 min(data{3}(:,1))+5];
%Use values from previous fit for certain residues having dw1H > 0, where
%r2abc is noticeably different than the value of last point for MQ CPMG
%r2abc = [24.8 18.9 42.3]; %L198
%r2abc = [19.0 14.2 26.8]; %L192
%r2abc = [26.6 20.7 48.3]; %V79
%r2abc = [30.0 24.9 48.0]; %L101
%r2abc = [11.4 10.8 36.4]; %L221
%r2abc = [15.3 14.1 39.6]; %L121
%r2abc = [37.5 30.0 54.8]; %L131

for i = 1:flds
    r2eff{i} = data{1,i}(:,1);
    err{i} = data{1,i}(:,2);
end

%relaxation time
ctimeMQ = 0.04;
ctimeSQ = 0.02;
%number of 180s in each half of relaxation period
n180array{1} = ctimeMQ./(2*taucp{1,1}(:));
n180array{2} = ctimeMQ./(2*taucp{1,2}(:));
n180array{3} = ctimeSQ./(2*taucp{1,3}(:));

for i = 1:flds
    delta{i} = taucp{1,i}(:)./2;
end

%set up params for grid search, slow process params fixed from CEST fit
kab = 3.23; %3.23
kba = 237.77; %237.77
kacgd = [1 5 10 50 100 500];
kcagd = [1000 2000 3000 5000 7000 9000];
%kacgd = [0];
%kcagd = [0];
dwab = -0.7475; % in units of ppm (1.4059 L198)
dwacgd = [0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0]; % in units of ppm
%dwacgd = [0]; % in units of ppm
dwab1Hgd = [0];
%dwab1Hgd = [0 0.1 0.2 0.3];
%dwac1Hgd = [0 0.1 0.2 0.3];
dwac1Hgd = [0];

ss = 0;
count = 0;

fl = 1:flds;

% GRID SEARCH (uses CPMG data only)
% Calculate all of the r2calc values for each set of parameters
% Take the r2calc values for a set of parameters and subtract the actual
% r2eff values
% Square the residuals; take their sum
% The set of parameters that obtains the lowest sum will be used for the
% initial guess

%ratio = zeros(ptsmax(1),1);
%r2calc = zeros(ptsmax(1),flds);

gsnum = 1;

for n = kacgd
    for j = kcagd
        kac = n;
        kca = j;
        kexac = n+j;
        kexab = kab+kba;
        %calculate pops from rates
        if kexac ~= 0
            x = (kba/kexab)/(1-(kba/kexab));
            y = (j/kexac)/(1-(j/kexac));
            pb = 1/(1+x+(x/y));
            pa = x*pb;
            pc = (x/y)*pb;
        else
            pb = kab/kexab;
            pa = kba/kexab;
            pc = 0;
        end
        M0 = [pa pb pc]';
        for k = dwab
            dwabgd = k*field(fl)*2*pi;
            for m = dwacgd
                dwac = m*field(fl)*2*pi;
                for p = kab
                    for q = kba
                        for h = dwab1Hgd
                            dwab1H = h*protonfield(fl)*2*pi;
                            for v = dwac1Hgd
                                dwac1H = v*protonfield(fl)*2*pi;

                                count = count+1;

                                for z = 1:flds
                                    if isMQ(z) == 0
                                        A = [(-1*r2abc(z)-kab-kac) kba kca; ...
                                            kab (-1*r2abc(z)-kba+1i*dwabgd(z)) 0; ...
                                            kac 0 (-1*r2abc(z)-kca+1i*dwac(z))];
                                        Aconj = conj(A);

                                        for u = 1:ptsmax(z)
                                            B1 = expm(delta{1,z}(u)*A);
                                            B2 = expm(delta{1,z}(u)*Aconj);

                                            B3 = real(((B1*B2*B2*B1)^(n180array{1,z}(u)))*M0); % take real part of M
                                            ratio = (B3(1))/M0(1);

                                            r2calc{z}(u) = log(ratio)*...
                                                ((-1)/(4*n180array{1,z}(u)*delta{1,z}(u)));
                                            clear B1;
                                            clear B2;
                                            clear B3;
                                            clear ratio;
                                        end
                                    elseif isMQ(z) == 1
                                           m1 = ...
                                               [(-1*r2abc(z)-kab-kac) kba kca; ...
                                                kab (-1*r2abc(z)-kba-1i*(dwab1H(z)+dwabgd(z))) 0; ...
                                                kac 0 (-1*r2abc(z)-kca-1i*(dwac1H(z)+dwac(z)))];
                                           m2 = ...
                                               [(-1*r2abc(z)-kab-kac) kba kca; ...
                                                kab (-1*r2abc(z)-kba-1i*(dwab1H(z)-dwabgd(z))) 0; ...
                                                kac 0 (-1*r2abc(z)-kca-1i*(dwac1H(z)-dwac(z)))];
                                           m1conj = conj(m1);
                                           m2conj = conj(m2);

                                           for u = 1:ptsmax(z)
                                              M1 = expm(delta{z}(u)*m1);
                                              M2 = expm(delta{z}(u)*m2);
                                              M1conj = expm(delta{z}(u)*m1conj);
                                              M2conj = expm(delta{z}(u)*m2conj);
                                              W=(M1*M2*M2*M1)^(n180array{z}(u)/2);
                                              S=(M2conj*M1conj*M1conj*M2conj)^(n180array{z}(u)/2);
                                              Y=(M2*M1*M1*M2)^(n180array{z}(u)/2);
                                              Z=(M1conj*M2conj*M2conj*M1conj)^(n180array{z}(u)/2);
                                              r2calc{z}(u) = ((-1)/(4*n180array{z}(u)*delta{z}(u)))*log(real((0.5/(M0(1)))*[1 0 0]*(W*S+Y*Z)*M0));
                                           end
                                    end
                                    resid{z} = ((r2calc{z}(:)-r2eff{z}(:))./err{z}(:)).^2;   
                                end
                                resid_col = [];
                                for z = 1:flds
                                    resid_col = [resid_col;resid{z}];
                                end
                                %fldresid = sum(resid)';
                                ss = sum(resid_col); % chi-square for each grid point
                                gsresid(gsnum) = ss;
                                kacgdlist(gsnum) = n;
                                kcagdlist(gsnum) = j;
                                dwacgdlist(gsnum) = m;
                                dwab1Hgdlist(gsnum) = h;
                                dwac1Hgdlist(gsnum) = v;
                                gsnum = gsnum + 1;
                                if count == 1
                                    ssbest = ss;
                                    xbest = [n j m h v];
                                    residbest = resid;
                                elseif ss < ssbest
                                    ssbest = ss;
                                    xbest = [n j m h v];
                                    residbest = resid;
                                end
                            end
                        end
                    end
                end
            end
        end
    end
end
drew = 'grid search done'

%set up CEST stuff
numres = 1;
freqlist_orig40 = freqs(2:length(freqs));
freqlist_orig25 = freqs25(2:length(freqs25));
freqlist_orig = [freqlist_orig40;freqlist_orig25];
freqlist = freqlist_orig;
datamat = [];
for i = 1:numres
   datamat(:,i) = [relInts(idxlist(i),2:size(relInts,2)) relInts25(idxlist(i),2:size(relInts25,2))]; 
   cslist(:,i) = [0;0;cppm(idxlist(i))];
end
SLdura = 0.4;
SLpow = [powcalib40 calib25pow];
chemcenter = 19.565;
for i = 1:numres
    peakoffset(i) = (cppm(idxlist(i)) - chemcenter)*150.902406439;
end
for i = 1:numres
    datainput{i} = datamat(:,i);
    %datainput2{i} = datamat(:,i);
end
chemshift = cslist(3,:);
freqlist = freqlist_orig; %reset to original freqlist with all points for each res
sllist(1:length(freqlist_orig40)) = SLpow(1);
sllist(length(freqlist_orig40)+1:length(freqlist_orig40)+length(freqlist_orig25)) = SLpow(2);
for i = 1:numres
    sllist_allres{i} = sllist';
    freqlist_allres{i} = freqlist;
end
for i = 1:numres
    numdel = 0;
    for k = 1:size(datainput{i},1)
       % adjust based on 2H decoupling power (removal of sidebands)
       if abs(freqlist_allres{i}(k-numdel)-peakoffset(i)) > 400 && abs(freqlist_allres{i}(k-numdel)-peakoffset(i)) < 600 %remove 2H decoupling sidebands at ~500 Hz
            datainput{i}(k-numdel) = [];
            freqlist_allres{i}(k-numdel) = [];
            sllist_allres{i}(k-numdel) = [];
            numdel = numdel + 1;
       end
    end
end
for i = 1:numres
    %adjust based on number of points in 40 Hz dataset (maxval is maximum
    %offset in 40 Hz dataset, used to get index of last point in 40 Hz
    %dataset for plotting)
    maxval(i) = max(freqlist_allres{i}(1:71));
end
fityn = 1;
freq = freqlist_allres;
for i = 1:numres
    numpts(i) = size(freq{i},1);
end

% ############ set up some plotting stuff ############
for i = 1:flds
    itaucp{i} = 1./taucp{i};
end

dat(1).color = 'r'; % 1st field
dat(2).color = 'b';
dat(3).color = 'g';
dat(4).color = 'y';
dat(5).color = 'k';

% #################### fit data ####################
% initial guesses
x0 = xbest;
x0 = [x0 r2abc 1 15];

tic
fprintf('\n\n\nGlobally optimizing parameters to CPMG & CEST data:\n');

[x,resnorm,residual,exitflag,output,lambda,jacobian] = ...
    lsqnonlin(@fitfunction_3stateSQMQ_linear2_CEST_kfit,x0,xlb,xub,options);

timer1 = toc;        
fprintf('\nOptimized for %6.2f seconds\n',timer1);

exitflag_orig = exitflag;
xbest = x;
resnormbest = resnorm;
residualbest = residual;

ci = nlparci(x,residual,'Jacobian',jacobian);

output_orig = output;
chisquared = resnorm;

% Monte Carlo Simulations

fityn = 1;
%mc = 200;
data_orig = data;
CESTdata_orig = datainput;
rareval = 0;
monte.kac = zeros(1,mc);
monte.kca = zeros(1,mc);
monte.dwcac = zeros(1,mc);
monte.dwhab = zeros(1,mc);
monte.dwhac = zeros(1,mc);
monte.R2MQ850 = zeros(1,mc);
monte.R2MQ600 = zeros(1,mc);
monte.R2SQ850 = zeros(1,mc);
monte.R1CEST = zeros(1,mc);
monte.R2CEST = zeros(1,mc);
for i=1:mc
   for j = 1:flds
       data{1,j}(:,1) = normrnd(data_orig{1,j}(:,1),data_orig{1,j}(:,2));
       %if any points > 3 standard deviations from mean are picked, repick
       %points
       while length(data{1,j}(abs(data{1,j}(:,1) - data_orig{1,j}(:,1)) > 3*mean(data_orig{1,j}(:,2)),1)) > 0   
                data{1,j}(:,1) = normrnd(data_orig{1,j}(:,1),data_orig{1,j}(:,2));
                rareval = rareval + 1;
       end
   end
   datainput{1} = normrnd(CESTdata_orig{1},1.5*relErr(idxlist,1)*ones(size(CESTdata_orig{1},1),size(CESTdata_orig{1},2)));
   %if any points > 3 standard deviations from mean are picked, repick
   %points
       while length(datainput{1}(abs(datainput{1} - CESTdata_orig{1}) > 3*1.5*relErr(idxlist,1))) > 0   
                datainput{1} = normrnd(CESTdata_orig{1},1.5*relErr(idxlist,1)*ones(size(CESTdata_orig{1},1),size(CESTdata_orig{1},2)));
                rareval = rareval + 1;
       end
   [x,resnorm,residual,exitflag,output,lambda,jacobian] = ...
    lsqnonlin(@fitfunction_3stateSQMQ_linear2_CEST_kfit,x0,xlb,xub,options);
    monte.kac(i) = x(1);
    monte.kca(i) = x(2);
    monte.dwcac(i) = x(3);
    monte.dwhab(i) = x(4);
    monte.dwhac(i) = x(5);
    monte.R2MQ850(i) = x(6);
    monte.R2MQ600(i) = x(7);
    monte.R2SQ850(i) = x(8);
    monte.R1CEST(i) = x(9);
    monte.R2CEST(i) = x(10);
    i
end

%get mean, median, standard dev. of params from mc sims
kac_mc = [mean(monte.kac) median(monte.kac) std(monte.kac)];
kca_mc = [mean(monte.kca) median(monte.kca) std(monte.kca)];
dwcac_mc = [mean(monte.dwcac) median(monte.dwcac) std(monte.dwcac)];
dwhab_mc = [mean(monte.dwhab) median(monte.dwhab) std(monte.dwhab)];
dwhac_mc = [mean(monte.dwhac) median(monte.dwhac) std(monte.dwhac)];
R2MQ850_mc = [mean(monte.R2MQ850) median(monte.R2MQ850) std(monte.R2MQ850)];
R2MQ600_mc = [mean(monte.R2MQ600) median(monte.R2MQ600) std(monte.R2MQ600)];
R2SQ850_mc = [mean(monte.R2SQ850) median(monte.R2SQ850) std(monte.R2SQ850)];
R1CEST_mc = [mean(monte.R1CEST) median(monte.R1CEST) std(monte.R1CEST)];
R2CEST_mc = [mean(monte.R2CEST) median(monte.R2CEST) std(monte.R2CEST)];

monte.kexac = monte.kac+monte.kca;
kexab = kab+kba;
%calculate pops from rates
if monte.kexac(1) ~= 0
    x = (kba/kexab)/(1-(kba/kexab));
    y = (monte.kca./monte.kexac)./(1-(monte.kca./monte.kexac));
    monte.pb = 1./(1+x+(x./y));
    monte.pa = x.*monte.pb;
    monte.pc = (x./y).*monte.pb;
else
    monte.pb = kab/kexab;
    monte.pa = kba/kexab;
    monte.pc = 0;
end

pa_mc = [mean(monte.pa) median(monte.pa) std(monte.pa)];
pb_mc = [mean(monte.pb) median(monte.pb) std(monte.pb)];
pc_mc = [mean(monte.pc) median(monte.pc) std(monte.pc)];
kexac_mc = [mean(monte.kexac) median(monte.kexac) std(monte.kexac)];

data = data_orig;
datainput = CESTdata_orig;

% ####### plot fitted line #######
clear r2calc;
clear A;
clear Aconj;
clear B1;
clear B2;
clear B3;
clear pa;
clear pc;
clear M0;

fityn = 0;
%finer sampling of CEST profile
freq40 = min(freqs(2:length(freqs))):10:max(freqs(2:length(freqs)));
freq25 = min(freqs25(2:length(freqs25))):10:max(freqs25(2:length(freqs25)));
freq = [];
sllist = [];
freqlist2 = [freq40 freq25];
sllist(1:length(freq40))=SLpow(1);
sllist(length(freq40)+1:length(freq40)+length(freq25))=SLpow(2);
%sllist((size(freq,2)/2)+1,length(freq))=outputparam(9,1);
for i = 1:numres
    sllist_allres{i} = sllist';
    freq{i} = freqlist2';
    numpts(i) = size(freq{i},1);
end
r2calc = fitfunction_3stateSQMQ_linear2_CEST_kfit(xbest); %CPMG data calc from best fit
dataout1 = dataout; %CEST data calc from best fit

plcolor = ['k' 'k' 'k' 'k'];
for j = 1:numres
    %CPMG plots
    for i = 1:flds
        itau{i} = 1./taucp{i};
        fit = [itau{i} r2calc{j,i}'];
        fit = sortrows(fit,1);
        figure
        if isMQ(i) == 1
            errorbar(itaucp{i},r2eff{j,i},err{j,i},err{j,i},'ok','MarkerFaceColor',dat(i).color,'MarkerSize',6);
        else
            errorbar(itaucp{i},r2eff{j,i},err{j,i},err{j,i},'ok','MarkerFaceColor',dat(i).color,'MarkerSize',6);
        end
        hold on
        plot(fit(:,1),fit(:,2),plcolor(i));
        title(strcat(residue,' BAC global'));
        xlabel('1/\tau_c_p (Hz)');
        ylabel('R_{2,eff} (rad/sec)');
        ylim([min(r2eff{j,i})-2 min(r2eff{j,i})+8])
    end
    %CEST plots
    [a,idx] = ismember(maxval(j),freqlist_allres{j}); %last index of 40 Hz
    figure
    plot(freqlist_allres{j}(1:idx),datainput{j}(1:idx), 'bo', 'MarkerFaceColor', 'b', 'MarkerSize', 2)
    hold on
    plot(freq40,dataout1{j}(1:length(freq40)), 'Color', 'b', 'LineWidth', 1)
    hold off
    axis manual
    hold off
    axis([-1500 1500 -0.1 max(datainput{j})+0.1]);
    set(gca,'fontsize',12)
    %title(strcat('^1^3C Methyl',{' '}, residinfo{i}),{' '},{'40 Hz'});
    title(strcat(residue,' 40 Hz'))
    xlabel('^1^3C Frequency Offset (Hz)', 'FontSize', 16, 'FontWeight', 'bold');
    ylabel('Normalized Intensity','FontSize', 16, 'FontWeight', 'bold');
    figure
    plot(freqlist_allres{j}(idx+1:length(freqlist_allres{j})),datainput{j}(idx+1:length(freqlist_allres{j})), 'bo', 'MarkerFaceColor', 'b', 'MarkerSize', 2)
    hold on
    plot(freq25,dataout1{j}(length(freq40)+1:length(freq40)+length(freq25)), 'Color', 'b', 'LineWidth', 1)
    hold off
    axis manual
    hold off
    axis([-1500 1500 -0.1 max(datainput{j})+0.1]);
    set(gca,'fontsize',12)
    %title(strcat('^1^3C Methyl',{' '}, residinfo{i}),{' '},{'25 Hz'});
    title(strcat(residue,' 25 Hz'))
    xlabel('^1^3C Frequency Offset (Hz)', 'FontSize', 16, 'FontWeight', 'bold');
    ylabel('Normalized Intensity','FontSize', 16, 'FontWeight', 'bold');
end