%jpb CEST/CPMG SQ,MQ global fitting
%Used for global fit of two geminals of same residue, where one geminal
%is 3-state with the concerted process, and the other is 2-state
%fast process params shared by the two geminals
%Model: A (major)<-> B and A (major)<-> C (linear) (BAC) for one, 2-state
%for other
%used for L101, L131 (apo hTS)

clear;
clear global field;
clear global taucp;
clear global data;
clear global n180array;
clear global delta;
clear global datainput;
clear global dataout;

global field
global protonfield
global taucp
global data
global n180array
global delta
global ptsmax
global isMQ
global numres
global freq
global sllist
global sllist_allres
global datainput
global numpts
global chemshift
global fityn
global dataout
global SLdura
global chemcenter
global relErr
global idxlist
global is3state

options = optimset('lsqnonlin');
options = optimset(options,'algorithm','trust-region-reflective');
%options = optimset(options,'algorithm','levenberg-marquardt');
options = optimset(options,'Display','off');
options = optimset(options,'TolFun',1E-12);
options = optimset(options,'TolX',1E-12);
options = optimset(options,'MaxFunEvals', 10000);
%options = optimset(options,'FinDiffRelStep',1e-3);

% load CPMG mats (from profile plotting script)
load('R2eff_MQ850.mat')
load('eR2eff_MQ850.mat')
load('R2eff_MQ600.mat')
load('eR2eff_MQ600.mat')
load('R2eff_SQ850.mat')
load('eR2eff_SQ850.mat')
load('taucpmg_MQ850.mat')
load('taucpmg_MQ600.mat')
load('taucpmg_SQ850.mat')

%load CEST mats (from profile plotting/spin lock calibration scripts)
load('CEST25_full2_mat.mat')
load('CEST40_full3_mat.mat')
load('powcalib40.mat')
load('calib25pow.mat')

%residue = {'L192','L198','L187','L121','L101 1','L101 2','L131 1','L131 2','L212','L252','L73 1','L73 2','L74','V79','L221','L259','L269','V164'};
residue = {'L131 1','L131 2'}; %profile names for the two geminal methyls for plotting
dw0 = [0.6757 0]; %slow process 13C dw (from CEST only fit)
dwh0 = [0.2 0]; %slow process 1H dw guess
mc = 0; %num mc sims

numres = length(residue);

isMQ = [1 1 0]; %1 = MQ, 0 = SQ

is3state = [1 0]; %1 = 3 state, 0 = 2 state

ylolim = 0;     % lower limit for y-axis in plot

% set up initial param array, bounds
xlb = [3.23 237.77 0.1 20]; %fix slow process rates from CEST only fit
xub = [3.23 237.77 1000 100000];
x0 = [3.23 237.77 5 6000];
%kab,kba,kac,kca
for i = 1:numres
    if is3state(i) == 1
        xlb_res = [dw0(i) 0 0 0 5 5 5 0 0]; %fix 13C dwab from CEST only fit
        xub_res = [dw0(i) 5 1 0 80 80 80 5 50];
        %dwcab,dwcac,dwhab,dwhac,r2abc1,r2abc2,r2abc3,CEST R1,CEST R2
    else
        xlb_res = [0 0 5 5 5 0 0];
        xub_res = [5 0 80 80 80 5 50];
        %dwcab,dwcac,dwhab,dwhac,r2abc1,r2abc2,r2abc3,CEST R1,CEST R2
    end
    xlb = [xlb xlb_res];
    xub = [xub xub_res];
end

%set up CPMG stuff
flds = 3;

%13C field strength (MHz)
field = [213.802506 150.902406 213.802506];

ptsmax = [length(taucpmg_MQ850) length(taucpmg_MQ600) length(taucpmg_SQ850)];

protonfield = field*(850.28/213.802506); %assumes heteronuc is 13C

taucp{1} = taucpmg_MQ850';
taucp{2} = taucpmg_MQ600';
taucp{3} = taucpmg_SQ850';

%index values for profiles to be fit
%for example, 21(MQ), 23(SQ), 39(CEST) for L198
idxMQ = [69 74];
idxSQ = [73 78];
idxlist = [87 92]; %CEST

for i = 1:numres
    data{i,1} = [R2eff_MQ850(idxMQ(i),:)' eR2eff_MQ850(idxMQ(i),:)'];
    data{i,2} = [R2eff_MQ600(idxMQ(i),:)' eR2eff_MQ600(idxMQ(i),:)'];
    data{i,3} = [R2eff_SQ850(idxSQ(i),:)' eR2eff_SQ850(idxSQ(i),:)'];

    %initial guess of R2 (CPMG) for each field
    if i == 1
        %Use values from previous fit for certain residues having dw1H > 0, where
        %r2abc is noticeably different than the value of last point for MQ CPMG
        %r2abc(i,:) = [30.0 24.9 48.0]; %L101
        r2abc = [37.5 30.0 54.8]; %L131
    else
        r2abc(i,:) = [min(data{1}(:,1)) min(data{2}(:,1)) min(data{3}(:,1))];
    end

    for j = 1:flds
        r2eff{i,j} = data{i,j}(:,1);
        err{i,j} = data{i,j}(:,2);
    end
end

%relaxation time
ctimeMQ = 0.04;
ctimeSQ = 0.02;
%number of 180s in each half of relaxation period
n180array{1} = ctimeMQ./(2*taucp{1,1}(:));
n180array{2} = ctimeMQ./(2*taucp{1,2}(:));
n180array{3} = ctimeSQ./(2*taucp{1,3}(:));

for i = 1:flds
    delta{i} = taucp{1,i}(:)./2;
end

%set initial param values
for i=1:numres
    if is3state(i) == 1
        x0_res = [dw0(i) 1.5 dwh0(i) 0 r2abc(i,:) 1 15]; %dwcab,dwcac,dwhab,dwhac,r2 cpmg,CEST R1,CEST R2
    else
        x0_res = [0.5 0 r2abc(i,:) 1 15]; 
    end
   x0 = [x0 x0_res];
end

% ############ set up some plotting stuff ############
for i = 1:flds
    itaucp{i} = 1./taucp{i};
end

dat(1).color = 'r'; % 1st field
dat(2).color = 'b';
dat(3).color = 'g';
dat(4).color = 'y';
dat(5).color = 'k';

%set up CEST stuff
freqlist_orig40 = freqs(2:length(freqs));
freqlist_orig25 = freqs25(2:length(freqs25));
freqlist_orig = [freqlist_orig40;freqlist_orig25];
freqlist = freqlist_orig;
datamat = [];
for i = 1:numres
   datamat(:,i) = [relInts(idxlist(i),2:size(relInts,2)) relInts25(idxlist(i),2:size(relInts25,2))]; 
   cslist(:,i) = [0;0;cppm(idxlist(i))];
end
SLdura = 0.4;
SLpow = [powcalib40 calib25pow];
chemcenter = 19.565;
for i = 1:numres
    peakoffset(i) = (cppm(idxlist(i)) - chemcenter)*150.902406439;
end
for i = 1:numres
    datainput{i} = datamat(:,i);
    %datainput2{i} = datamat(:,i);
end
chemshift = cslist(3,:);
freqlist = freqlist_orig; %reset to original freqlist with all points for each res
sllist(1:length(freqlist_orig40)) = SLpow(1);
sllist(length(freqlist_orig40)+1:length(freqlist_orig40)+length(freqlist_orig25)) = SLpow(2);
for i = 1:numres
    sllist_allres{i} = sllist';
    freqlist_allres{i} = freqlist;
end
for i = 1:numres
    numdel = 0;
    for k = 1:size(datainput{i},1)
       %adjust based on 2H decoupling power (removal of sidebands)
       if abs(freqlist_allres{i}(k-numdel)-peakoffset(i)) > 400 && abs(freqlist_allres{i}(k-numdel)-peakoffset(i)) < 600 %remove 2H decoupling sidebands at ~500 Hz
            datainput{i}(k-numdel) = [];
            freqlist_allres{i}(k-numdel) = [];
            sllist_allres{i}(k-numdel) = [];
            numdel = numdel + 1;
       end
    end
end
for i = 1:numres
    %adjust based on number of points in 40 Hz dataset (max offset in 40 Hz dataset)
    maxval(i) = max(freqlist_allres{i}(1:71));
end
fityn = 1;
freq = freqlist_allres;
for i = 1:numres
    numpts(i) = size(freq{i},1);
end

% #################### fit data ####################
tic
fprintf('\n\n\nGlobally optimizing parameters to CPMG & CEST data:\n');

[x,resnorm,residual,exitflag,output,lambda,jacobian] = ...
    lsqnonlin(@fitfunction_cpmgCEST_2_3stateSQMQ_full_linear_global_kfit,x0,xlb,xub,options);
exitflag_orig = exitflag;

timer1 = toc;        
fprintf('\nOptimized for %6.2f seconds\n',timer1);

xbest = x;
ci = nlparci(x,residual,'Jacobian',jacobian);

output_orig = output;
chisquared = resnorm;

% Monte Carlo Simulations
if mc > 0
    fityn = 1;
    %mc = 200;
    data_orig = data;
    CESTdata_orig = datainput;
    rareval = 0;
    monte.pa = zeros(1,mc);
    monte.kex = zeros(1,mc);
    monte.dwc = zeros(1,mc);
    monte.dwh = zeros(1,mc);
    monte.R2MQ850 = zeros(1,mc);
    monte.R2MQ600 = zeros(1,mc);
    monte.R2SQ850 = zeros(1,mc);
    monte.R1CEST = zeros(1,mc);
    monte.R2CEST = zeros(1,mc);
    for i=1:mc
       for k = 1:numres
           for j = 1:flds
               data{k,j}(:,1) = normrnd(data_orig{k,j}(:,1),data_orig{k,j}(:,2));
               %if point with value >3 std from mean is picked, repick
               %points
               while length(data{k,j}(abs(data{k,j}(:,1) - data_orig{k,j}(:,1)) > 3*mean(data_orig{k,j}(:,2)),1)) > 0   
                        data{k,j}(:,1) = normrnd(data_orig{k,j}(:,1),data_orig{k,j}(:,2));
                        rareval = rareval + 1;
               end
           end
           datainput{k} = normrnd(CESTdata_orig{k},1.5*relErr(idxlist(k),1)*ones(size(CESTdata_orig{k},1),size(CESTdata_orig{k},2)));
           %if point with value >3 std from mean is picked, repick
           %points
               while length(datainput{k}(abs(datainput{k} - CESTdata_orig{k}) > 3*1.5*relErr(idxlist(k),1))) > 0   
                        datainput{k} = normrnd(CESTdata_orig{k},1.5*relErr(idxlist(k),1)*ones(size(CESTdata_orig{k},1),size(CESTdata_orig{k},2)));
                        rareval = rareval + 1;
               end
        end
       [x,resnorm,residual,exitflag,output,lambda,jacobian] = ...
        lsqnonlin(@fitfunction_cpmgCEST_2_3stateSQMQ_full_linear_global_kfit,x0,xlb,xub,options);
        monte.kab(i) = x(1);
        monte.kba(i) = x(2);
        monte.kac(i) = x(3);
        monte.kca(i) = x(4);
        monte.dwcab1(i) = x(5);
        monte.dwcac1(i) = x(6);
        monte.dwhab1(i) = x(7);
        monte.dwhac1(i) = x(8);
        monte.R2MQ850_1(i) = x(9);
        monte.R2MQ600_1(i) = x(10);
        monte.R2SQ850_1(i) = x(11);
        monte.R1CEST_1(i) = x(12);
        monte.R2CEST_1(i) = x(13);
        monte.dwcac2(i) = x(14);
        monte.dwhac2(i) = x(15);
        monte.R2MQ850_2(i) = x(16);
        monte.R2MQ600_2(i) = x(17);
        monte.R2SQ850_2(i) = x(18);
        monte.R1CEST_2(i) = x(19);
        monte.R2CEST_2(i) = x(20);
        i
    end

    % pa_mc = [mean(monte.pa) median(monte.pa) std(monte.pa)];
    % kex_mc = [mean(monte.kex) median(monte.kex) std(monte.kex)];
    % dwc_mc = [mean(monte.dwc) median(monte.dwc) std(monte.dwc)];
    % dwh_mc = [mean(monte.dwh) median(monte.dwh) std(monte.dwh)];
    % R2MQ850_mc = [mean(monte.R2MQ850) median(monte.R2MQ850) std(monte.R2MQ850)];
    % R2MQ600_mc = [mean(monte.R2MQ600) median(monte.R2MQ600) std(monte.R2MQ600)];
    % R2SQ850_mc = [mean(monte.R2SQ850) median(monte.R2SQ850) std(monte.R2SQ850)];
    % R1CEST_mc = [mean(monte.R1CEST) median(monte.R1CEST) std(monte.R1CEST)];
    % R2CEST_mc = [mean(monte.R2CEST) median(monte.R2CEST) std(monte.R2CEST)];

    %calculate pops from rates
    monte.kexac = monte.kac+monte.kca;
    monte.kexab = monte.kab+monte.kba;
    x = (monte.kba./monte.kexab)./(1-(monte.kba./monte.kexab));
    y = (monte.kca./monte.kexac)./(1-(monte.kca./monte.kexac));
    monte.pb_1 = 1./(1+x+(x./y));
    monte.pa_1 = x.*monte.pb_1;
    monte.pc_1 = (x./y).*monte.pb_1;
    monte.pa_2 = monte.kca./monte.kexac;
    monte.pc_2 = monte.kac./monte.kexac;

    %get mean, median, std of params from mc sims
    pa_1_mc = [mean(monte.pa_1) median(monte.pa_1) std(monte.pa_1)];
    pb_1_mc = [mean(monte.pb_1) median(monte.pb_1) std(monte.pb_1)];
    pc_1_mc = [mean(monte.pc_1) median(monte.pc_1) std(monte.pc_1)];
    kexac_mc = [mean(monte.kexac) median(monte.kexac) std(monte.kexac)];
    pa_2_mc = [mean(monte.pa_2) median(monte.pa_2) std(monte.pa_2)];
    pc_2_mc = [mean(monte.pc_2) median(monte.pc_2) std(monte.pc_2)];

    data = data_orig;
    datainput = CESTdata_orig;
end

% ####### plot fitted line #######
clear r2calc;
clear A;
clear Aconj;
clear B1;
clear B2;
clear B3;
clear pa;
clear pc;
clear M0;

fityn = 0;
%finer sampling of CEST profile
freq40 = min(freqs(2:length(freqs))):10:max(freqs(2:length(freqs)));
freq25 = min(freqs25(2:length(freqs25))):10:max(freqs25(2:length(freqs25)));
freq = [];
sllist = [];
freqlist2 = [freq40 freq25];
sllist(1:length(freq40))=SLpow(1);
sllist(length(freq40)+1:length(freq40)+length(freq25))=SLpow(2);
%sllist((size(freq,2)/2)+1,length(freq))=outputparam(9,1);
for i = 1:numres
    sllist_allres{i} = sllist';
    freq{i} = freqlist2';
    numpts(i) = size(freq{i},1);
end
r2calc = fitfunction_cpmgCEST_2_3stateSQMQ_full_linear_global_kfit(xbest); %calc CPMG data from best fit
dataout1 = dataout; %calc CEST data from best fit

plcolor = ['k' 'k' 'k' 'k'];
for j = 1:numres
    %CPMG plots
    for i = 1:flds
        itau{i} = 1./taucp{i};
        fit = [itau{i} r2calc{j,i}'];
        fit = sortrows(fit,1);
        figure
        if isMQ(i) == 1
            errorbar(itaucp{i},r2eff{j,i},err{j,i},err{j,i},'ok','MarkerFaceColor',dat(i).color,'MarkerSize',6);
        else
            errorbar(itaucp{i},r2eff{j,i},err{j,i},err{j,i},'ok','MarkerFaceColor',dat(i).color,'MarkerSize',6);
        end
        hold on
        plot(fit(:,1),fit(:,2),plcolor(i));
        if is3state(j) == 1
            title(strcat(residue{j},' BAC global'));
        else
            title(strcat(residue{j},' 2 state global'));
        end
        xlabel('1/\tau_c_p (Hz)');
        ylabel('R_{2,eff} (rad/sec)');
        ylim([min(r2eff{j,i})-2 min(r2eff{j,i})+8])
    end
    %CEST plots
    [a,idx] = ismember(maxval(j),freqlist_allres{j}); %last index of 40 Hz
    figure
    plot(freqlist_allres{j}(1:idx),datainput{j}(1:idx), 'bo', 'MarkerFaceColor', 'b', 'MarkerSize', 2)
    hold on
    plot(freq40,dataout1{j}(1:length(freq40)), 'Color', 'b', 'LineWidth', 1)
    hold off
    axis manual
    hold off
    axis([-1500 1500 -0.1 max(datainput{j})+0.1]);
    set(gca,'fontsize',12)
    %title(strcat('^1^3C Methyl',{' '}, residinfo{i}),{' '},{'40 Hz'});
    title(strcat(residue{j},' 40 Hz'))
    xlabel('^1^3C Frequency Offset (Hz)', 'FontSize', 16, 'FontWeight', 'bold');
    ylabel('Normalized Intensity','FontSize', 16, 'FontWeight', 'bold');
    figure
    plot(freqlist_allres{j}(idx+1:length(freqlist_allres{j})),datainput{j}(idx+1:length(freqlist_allres{j})), 'bo', 'MarkerFaceColor', 'b', 'MarkerSize', 2)
    hold on
    plot(freq25,dataout1{j}(length(freq40)+1:length(freq40)+length(freq25)), 'Color', 'b', 'LineWidth', 1)
    hold off
    axis manual
    hold off
    axis([-1500 1500 -0.1 max(datainput{j})+0.1]);
    set(gca,'fontsize',12)
    %title(strcat('^1^3C Methyl',{' '}, residinfo{i}),{' '},{'25 Hz'});
    title(strcat(residue{j},' 25 Hz'))
    xlabel('^1^3C Frequency Offset (Hz)', 'FontSize', 16, 'FontWeight', 'bold');
    ylabel('Normalized Intensity','FontSize', 16, 'FontWeight', 'bold');
end