%Script to fit CHD2 methyl CEST data, jpb
%adapted from Bo Zhao's script (Qi Zhang lab)
%Used for global 2-state fit of concerted process residues in apo hTS
%params from this fit used in global fit of CPMG and CEST data (apo hTS)

clear all
global freq
global sllist
global sllist_allres
global datainput
global numpts
global restype
global chemshift
global fityn
global dataout
global SLdura
%global SLpow
global chemcenter
global numres

%read in workspaces from profile plotting script (40 & 25 Hz spin lock)
%also read in calibrated powers from another script
load('CEST25_full2_mat.mat')
load('CEST40_full3_mat.mat')
load('powcalib40.mat')
load('calib25pow.mat')

%index values for profiles to be fit (index of row in relInt array)
idxlist = [39 27 22 87 37 98 97 14 62 40];
numres = length(idxlist);
mc = 0; %how many monte carlo sims to run
freqlist_orig40 = freqs(2:length(freqs));
freqlist_orig25 = freqs25(2:length(freqs25));
freqlist_orig = [freqlist_orig40;freqlist_orig25];
freqlist = freqlist_orig;

%maxval = max(freqlist_orig40);

datamat = [];
for i = 1:numres
   datamat(:,i) = [relInts(idxlist(i),2:size(relInts,2)) relInts25(idxlist(i),2:size(relInts25,2))]; 
   cslist(:,i) = [0;0;cppm(idxlist(i))];
end

%cslist = [0 0 0 0;0 0 0 0;28.201 24.509 22.337 19.835]; %carbon chemical shift in ppm, from spectrometer
%cslist = [0;0;cppm(idxlist(i))];

clear inputhandle placeholder tempReg

datainput1 = [];
datainput2 = [];
initparam = [];
lb = [];
ub = [];
%fit2state contains initial dw guesses for each profile, set to 0 for fit
%without exchange
%fit2state = 0.6*ones(1,numres);
fit2state(1) = 1.4;
fit2state(2) = -0.9;
fit2state(3) = 0.4;
fit2state(4) = 0.66;
fit2state(5) = 0.86;
fit2state(6) = 0.25;
fit2state(7) = 0.68;
fit2state(8) = 1.1;
fit2state(9) = -0.75;
fit2state(10) = 0.42;
%fit2state(1) = 0.2;

%names for profiles included in fit (used for plotting)
residinfo = cell(1,numres);
NOparamlist = cell(1,numres);
residinfo{1}='L198';
residinfo{2}='L192';
residinfo{3}='I121';
residinfo{4}='L101';
residinfo{5}='L221';
residinfo{6}='V79';
residinfo{7}='L131';
residinfo{8}='I237';
residinfo{9}='L74';
residinfo{10}='L187';
% residinfo{11}='L74';
% residinfo{12}='V134';
% residinfo{13}='I37';
% residinfo{14}='I165';
% residinfo{15}='I249';

%SLdura = [0.1 0.5 0.2 0.7];
SLdura = 0.4;
SLpow = [powcalib40 calib25pow];
%chemcenter = [28.201 24.509 22.337 19.835];
chemcenter = 19.565;
for i = 1:numres
    peakoffset(i) = (cppm(idxlist(i)) - chemcenter)*150.902406439;
end

%get data and chemical shifts for profiles to be fit
    for i = 1:numres
        datainput{i} = datamat(:,i);
    end
    chemshift = cslist(3,:);

        if fit2state(1) == 0 %no exchange
            tmp_x0 = [1 15]; %R1 R2
            tmp_lb = [0 0];
            tmp_ub = [20 100];
            initparam = [];
            lb = [];
            ub = [];
            for i = 1:numres
               initparam = [initparam tmp_x0];
               lb = [lb tmp_lb];
               ub = [ub tmp_ub];
            end
            
        else %with exchange
            initparam = [0.013 240]; %pb kex
            lb = [0 0];
            ub = [0.1 6000];
            tmp_x0 = [1 15 fit2state(1)]; %R1 R2 dw
            tmp_lb =[0 0 -5];
            tmp_ub = [20 100 5];
            for i = 1:numres
                initparam = [initparam tmp_x0];
                lb = [lb tmp_lb];
                ub = [ub tmp_ub];
            end
        end
        
        outputparam = zeros(size(initparam,2), 2);
        
        %spy(jacpat); %to visualize the sparse jacobian generated
        %daspect([0.0091    0.5000    1.0000]);
        %jacpat = ones(size(datainput,1),size(initparam,2));
        freqlist = freqlist_orig; %reset to original freqlist with all points for each res
        sllist(1:length(freqlist_orig40)) = SLpow(1);
        sllist(length(freqlist_orig40)+1:length(freqlist_orig40)+length(freqlist_orig25)) = SLpow(2);
        for i = 1:numres
            sllist_allres{i} = sllist';
            freqlist_allres{i} = freqlist;
        end
        for i = 1:numres
            numdel = 0;
            for k = 1:size(datainput{i},1)
               %adjust this based on 2H decoupling power (remove sidebands)
               if abs(freqlist_allres{i}(k-numdel)-peakoffset(i)) > 400 && abs(freqlist_allres{i}(k-numdel)-peakoffset(i)) < 600 %remove 2H decoupling sidebands at ~500 Hz
                    %jacpat(k,:) = zeros(1,size(initparam,2)); %fminsearch
                    %doesn't use jacobian
                    datainput{i}(k-numdel) = [];
                    freqlist_allres{i}(k-numdel) = [];
                    sllist_allres{i}(k-numdel) = [];
                    numdel = numdel + 1;
               end
            end
        end
        for i = 1:numres
            %adjust based on number of points in 40 Hz dataset (maxval
            %holds the highest offset for the first dataset, and is used to
            %get the last index of the first dataset so you can separate
            %the two for plotting)
            maxval(i) = max(freqlist_allres{i}(1:71));
        end
        %jacpat(1:98,9) = zeros(98,1);
        %jacpat(1:98,2) = zeros(98,1);
        %jacpat(1:98,4) = zeros(98,1);
        %jacpat(99:196,8) = zeros(98,1);
        %jacpat(99:196,1) = zeros(98,1);
        %jacpat(99:196,3) = zeros(98,1);
         %opts = optimset('Algorithm',   'trust-region-reflective',...
         opts = optimset('MaxFunEvals',  1e5,...
             'MaxIter',      1e5,...
             'TolFun',       1e-12,...
             'TolX',         1e-12); %...
             %'scaleProblem','Jacobian',...
             %'JacobPattern', jacpat);
%         options = optimset('lsqnonlin');
%         options = optimset(options,'Algorithm','levenberg-marquardt');
%         options = optimset(options,'TolFun',1E-15);
%         options = optimset(options,'TolX',1E-15);
%         options = optimset(options,'FinDiffRelStep',1E-1);
%         %options = optimoptions(options,'OptimalityTolerance', 1E-15);
%         options = optimset(options,'MaxFunEvals', 2000);
%         %options = optimset(options,'PlotFcns', @optimplotresnorm);
        
        fityn = 1;
        restype = fit2state(1);
        freq = freqlist_allres;
        for i = 1:numres
            numpts(i) = size(freq{i},1);
        end
        
        %ms = MultiStart;
        %problem = createOptimProblem('fmincon','objective',@fitfunc_CEST_jpb_fminsearch2SL,...
        %'x0',initparam,'lb',lb,'ub',ub);
        %[xmin,fmin,flag,outpt,allmins] = run(ms,problem,100);
        
        %initparam = xmin;
        
        %fit the data
        tic
        fprintf('\n\n\nGlobally optimizing parameters to CEST data:\n');
        [output,resnorm2,exitflag2,output2,lambda,grad,hessian]...
            =fmincon(@fitfunc_CEST_jpb_fminsearch2SL_global, initparam, [], [], [], [], lb, ub, [], opts);
        
        
        %ci = nlparci(output,residual2,'jacobian',jacobian2,'alpha',0.317);
        outputparam(:,1) = output';
        %outputparam(:,2) = output'-ci(:,1);
        timer1 = toc;
        
        fprintf('\nOptimized for %6.2f seconds\n',timer1);
        
        %create 'delta' (error for mc sim, 1.5x base plane noise)
        fityn = 0;
        dataout0 = fitfunc_CEST_jpb_fminsearch2SL_global(outputparam(:,1));
        %delta = abs(datainput - dataout0);
        %delta=mean(delta)*(ones(length(delta),1));
        for i = 1:numres
            delta{i} = 1.5*relErr(idxlist(i),1)*ones(length(datainput{i}),1); %scale up base plane noise based on comparison with CPMG duplicate pts
        end
        datainput_orig = datainput;
        fityn = 1;
        
        %mc simulations
        for z = 1:mc
            for i = 1:numres
                datainput{i}=normrnd(datainput_orig{i},delta{i});
                %if points > 3 standard deviations away from mean are
                %picked, repick points
                while length(datainput{i}(abs(datainput{i} - datainput_orig{i}) > 3*delta{i})) > 0   
                        datainput{i}=normrnd(datainput_orig{i},delta{i});
                end
            end
            [output_mc,resnorm_mc,exitflag_mc,output2_mc]...
                =fmincon(@fitfunc_CEST_jpb_fminsearch2SL_global, initparam, [], [], [], [], lb, ub, [], opts);
            monte.pb(z)=output_mc(1);
            monte.kex(z)=output_mc(2);
            for i = 1:numres
                monte.R1(i,z)=output_mc((3*(i-1))+3);
                monte.R2(i,z)=output_mc((3*(i-1))+4);
                monte.dw(i,z)=output_mc((3*(i-1))+5);
            end
            z
        end
        
        if mc > 0
            %calculate avg & standard deviations of params from mc sims
            for i = 1:numres
                avgR1(i) = mean(monte.R1(i,:));
                stdR1(i) = std(monte.R1(i,:));
                avgR2(i) = mean(monte.R2(i,:));
                stdR2(i) = std(monte.R2(i,:));
                avgdw(i) = mean(monte.dw(i,:));
                stddw(i) = std(monte.dw(i,:));    
            end
            avgpb = mean(monte.pb);
            stdpb = std(monte.pb);
            avgkex = mean(monte.kex);
            stdkex = std(monte.kex);
        end
        
        %finely sampled fit curve for plotting
        datainput = datainput_orig;
        fityn = 0;
        freq40 = min(freqs(2:length(freqs))):10:max(freqs(2:length(freqs)));
        freq25 = min(freqs25(2:length(freqs25))):10:max(freqs25(2:length(freqs25)));
        freq = [];
        sllist = [];
        freqlist2 = [freq40 freq25];
        sllist(1:length(freq40))=SLpow(1);
        sllist(length(freq40)+1:length(freq40)+length(freq25))=SLpow(2);
        %sllist((size(freq,2)/2)+1,length(freq))=outputparam(9,1);
        for i = 1:numres
            sllist_allres{i} = sllist';
            freq{i} = freqlist2';
            numpts(i) = size(freq{i},1);
        end
        dataout1 = fitfunc_CEST_jpb_fminsearch2SL_global(outputparam(:,1));
        
        %plot profiles with fit curves
        for i = 1:numres
            [a,idx] = ismember(maxval(i),freqlist_allres{i}); %last index of 40 Hz

            %f = figure('visible','off');
            figure
            plot(freqlist_allres{i}(1:idx),datainput{i}(1:idx), 'bo', 'MarkerFaceColor', 'b', 'MarkerSize', 2)
            hold on
            plot(freq40,dataout1{i}(1:length(freq40)), 'Color', 'b', 'LineWidth', 1)
            hold off
            %cd fluor_nh_indiv
            axis manual
            hold off
            axis([-1500 1500 -0.1 max(datainput{i})+0.1]);
            set(gca,'fontsize',12)
            %title(strcat('^1^3C Methyl',{' '}, residinfo{i}),{' '},{'40 Hz'});
            title(strcat(residinfo{i},' 40 Hz'))
            xlabel('^1^3C Frequency Offset (Hz)', 'FontSize', 16, 'FontWeight', 'bold');
            ylabel('Normalized Intensity','FontSize', 16, 'FontWeight', 'bold');
            %tempresid = strrep(residinfo{i},{' '}, '_');
            %tempresid = strrep(tempresid, '_CEST', '');
            %resname =char(strcat('cest_methyl','_',tempresid));
            %outputname = strcat(resname, '.txt');
            %dlmwrite(outputname, outputparam);
            %saveas(f,resname,'png')
            %cd ../
            figure
            plot(freqlist_allres{i}(idx+1:length(freqlist_allres{i})),datainput{i}(idx+1:length(freqlist_allres{i})), 'bo', 'MarkerFaceColor', 'b', 'MarkerSize', 2)
            hold on
            plot(freq25,dataout1{i}(length(freq40)+1:length(freq40)+length(freq25)), 'Color', 'b', 'LineWidth', 1)
            hold off
            %cd fluor_nh_indiv
            axis manual
            hold off
            axis([-1500 1500 -0.1 max(datainput{i})+0.1]);
            set(gca,'fontsize',12)
            %title(strcat('^1^3C Methyl',{' '}, residinfo{i}),{' '},{'25 Hz'});
            title(strcat(residinfo{i},' 25 Hz'))
            xlabel('^1^3C Frequency Offset (Hz)', 'FontSize', 16, 'FontWeight', 'bold');
            ylabel('Normalized Intensity','FontSize', 16, 'FontWeight', 'bold');
        end
%end

%if mc > 0
%    mcparammat = [avgR1 stdR1;avgR2 stdR2;avgdw stddw;avgpb stdpb;avgkex stdkex];
%end