function model = LearnGRF (tf,trg,logic,rep,p_t1,p_t2)
% LearnGRF Fits a gene regulation function to the activity pattern of a
% target gene by simulated annealing
% ARGUMENTS:
% TF: index or indices of regulating transcription factor; multiple TFs
% must be in a row vector: [tf1,tf2]
% TRG: index of the target gene
% LOGIC: number of logic to use in the fit; assign 0 if logic is unknown and to be
% inferred with the GRF
% REP: number of replicate dataset to be used in the fit
% P_t1, P_t2, P_t3: list of "protein model" trajectories for different
% values of lambda; assign empty [] if there are less than 3 TFs or if the
% "protein grid" is to be computed inside the function
% PRINT: boolean flag that, if TRUE, plots some characteristics of the
% annealing algorithm

global expr1;
global expr2;
global syn1;
global syn2;
global Ngrid;

if (nargin==4)
    p_t1 = [];
    p_t2 = [];
    Ngrid = 10000;
elseif (nargin == 5)
    if numel(tf) == 2
        p_t2 = [];
    end
    [r c] = size(p_t1);
    if r ~= 0
        Ngrid = r;
    else
        Ngrid = 10000;
    end
elseif (nargin == 6)
   [r c] = size(p_t1);
   if r ~= 0
        Ngrid = r;
   else
        [r c] = size(p_t2);
        if r ~= 0
            Ngrid = r;
        else
            Ngrid = 10000;
        end
    end
else
    error ('expected 4,5 or 6 arguments.'); 
end
lambda = 0.0005;
tau = 100;
theta0 = 0.1;

print = false; % omit output of the simulated annealing algorithm

if (numel(tf)==1)
    
    % time
    t = 0:5:205;

    %data
    if (rep==1)
        tf_expr = expr1(tf,:);
        trg_syn = syn1(trg,:);
    elseif(rep==2)
        tf_expr = expr2(tf,:);
        trg_syn = syn2(trg,:);
    elseif (rep==0)
        tf_expr = 0.5*(expr1(tf,:)+expr2(tf,:));
        trg_syn = 0.5*(syn1(trg,:)+syn2(trg,:));
    else
        error('argument rep must be 0,1, or 2.');
    end

    % fill missing data
    nans = isnan(trg_syn);
    if (sum(nans)~=0)
       if (sum(nans) >= length(nans)-3)
           error('at least three values of target synthesis rate must not be NaN!');
       end
       tt = zeros(1,numel(nans(nans==0)));
       xx = tt;
       j=1;
       for i = 1:length(nans)
          if (nans(i)==0)
              tt(j) = t(i);
              xx(j) = trg_syn(i);
              j=j+1;
          end
       end
       trg_syn = interp1(tt,xx,t,'spline');
    end

    nans = isnan(tf_expr);
    if (sum(nans)~=0)
       if (sum(nans) >= length(nans)-3)
           error('at least three values of target synthesis rate must not be NaN!');
       end
       tt = zeros(1,numel(nans(nans==0)));
       xx = tt;
       j=1;
       for i = 1:length(nans)
          if (nans(i)==0)
              tt(j) = t(i);
              xx(j) = tf_expr(i);
              j=j+1;
          end
       end
       tf_expr = interp1(tt,xx,t,'spline');
    end
    
    % define boundaries
    LB = zeros(1,5);
    UB = LB;
    LB(1) = log(2)/70; UB(1) = log(2)/5;                       % lp
    LB(2) = 0; UB(2) = 1000*max(trg_syn);                      % b
    LB(3) = 0.01*max(trg_syn); UB(3) = 2*(max(trg_syn)-min(trg_syn));      % a_max
    LB(4) = 0; UB(4) = 100*max(tf_expr);                       % K
    LB(5) = 0; UB(5) = 30;                                     % n
    
    % pre-calculate p(t)
    if (isempty(p_t1))
        p_t1 = zeros(Ngrid,42);
        lp = linspace(LB(1),UB(1),Ngrid);
        for i=1:Ngrid
            p0 = set_p0(tf_expr,lp(i));
            p_t1(i,:) = protein_traj(tf_expr,t,p0,lp(i),lp(i));
        end
    end
    
    lam = zeros(1,1);

    [r c] = size(p_t1);
    if (r == 1 && c == 1)
        lam(1) = p_t1;
        p_t1 = protein_traj(tf_expr,t,[],p_t1);
    end
    
    % simulated annealing
    if (logic ~= 0)
        if (print)
            [bp bsc resh th arh] = SimAnnFit(logic,p_t1,[],[],trg_syn,LB,UB,lambda,tau,theta0);
            
            if (lam(1) ~= 0)
                bp(1) = lam(1);
            end

            figure; plot(resh);
            figure; plot(th);
            figure; plot(arh);
        else
            [bp bsc] = SimAnnFit(logic,p_t1,[],[],trg_syn,LB,UB,lambda,tau,theta0);
            
            if (lam(1) ~= 0)
                bp(1) = lam(1);
            end
        end
        
        model.n_tfs = 1;
        model.tf_id = tf;
        model.n_trgs = 1;
        model.rep = rep;
        model.logic = logic;
        model.trg_id = trg;

        model.p0 = set_p0(tf_expr,bp(1));
        model.lambdaP = bp(1);

        model.b = bp(2);
        model.K = bp(4);
        model.n = bp(5);
        model.alpha_max = bp(3);
        model.best_score = bsc;
        model.norm_score = bsc/(42*var(trg_syn));
        
    else
       for l = 1:2
           if (print)
                [bp bsc resh th arh] = SimAnnFit(l,p_t1,[],[],trg_syn,LB,UB,lambda,tau,theta0);

                if (lam(1) ~= 0)
                    bp(1) = lam(1);
                end
                
                figure; plot(resh);
                figure; plot(th);
                figure; plot(arh);
            else
                [bp bsc] = SimAnnFit(l,p_t1,[],[],trg_syn,LB,UB,lambda,tau,theta0);
                
                if (lam(1) ~= 0)
                    bp(1) = lam(1);
                end
            end
            mm(l).n_tfs = 1;
            mm(l).tf_id = tf;
            mm(l).n_trgs = 1;
            mm(l).rep = rep;
            mm(l).logic = l;
            mm(l).trg_id = trg;

            mm(l).p0 = set_p0(tf_expr,bp(1));
            mm(l).lambdaP = bp(1);

            mm(l).b = bp(2);
            mm(l).K = bp(4);
            mm(l).n = bp(5);
            mm(l).alpha_max = bp(3);
            mm(l).best_score = bsc;
            mm(l).norm_score = bsc/(42*var(trg_syn));
       end
       
       if (mm(1).best_score < mm(2).best_score)
            model = mm(1);
        else
            model = mm(2);
        end
       
    end
    
    
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
elseif (numel(tf) == 2)
    
    % time
    t = 0:5:205;

    %data
    if (rep==1)
        tf_expr1 = expr1(tf(1),:);
        tf_expr2 = expr1(tf(2),:);
        trg_syn = syn1(trg,:);
    elseif(rep==2)
        tf_expr1 = expr2(tf(1),:);
        tf_expr2 = expr2(tf(2),:);
        trg_syn = syn2(trg,:);
    elseif (rep==0)
        tf_expr1 = 0.5*(expr1(tf(1),:)+expr2(tf(1),:));
        tf_expr2 = 0.5*(expr1(tf(2),:)+expr2(tf(2),:));
        trg_syn = 0.5*(syn1(trg,:)+syn2(trg,:));
    else
        error('argument rep must be 0,1, or 2.');
    end

    % fill missing data
    nans = isnan(trg_syn);
    if (sum(nans)~=0)
       if (sum(nans) >= length(nans)-3)
           error('at least three values of target synthesis rate must not be NaN!');
       end
       tt = zeros(1,numel(nans(nans==0)));
       xx = tt;
       j=1;
       for i = 1:length(nans)
          if (nans(i)==0)
              tt(j) = t(i);
              xx(j) = trg_syn(i);
              j=j+1;
          end
       end
       trg_syn = interp1(tt,xx,t,'spline');
    end

    nans = isnan(tf_expr1);
    if (sum(nans)~=0)
       if (sum(nans) >= length(nans)-3)
           error('at least three values of target synthesis rate must not be NaN!');
       end
       tt = zeros(1,numel(nans(nans==0)));
       xx = tt;
       j=1;
       for i = 1:length(nans)
          if (nans(i)==0)
              tt(j) = t(i);
              xx(j) = tf_expr1(i);
              j=j+1;
          end
       end
       tf_expr1 = interp1(tt,xx,t,'spline');
    end
    
    nans = isnan(tf_expr2);
    if (sum(nans)~=0)
       if (sum(nans) >= length(nans)-3)
           error('at least three values of target synthesis rate must not be NaN!');
       end
       tt = zeros(1,numel(nans(nans==0)));
       xx = tt;
       j=1;
       for i = 1:length(nans)
          if (nans(i)==0)
              tt(j) = t(i);
              xx(j) = tf_expr2(i);
              j=j+1;
          end
       end
       tf_expr2 = interp1(tt,xx,t,'spline');
    end
    
    % define boundaries
    LB = zeros(1,8);
    UB = LB;
    LB(1) = log(2)/70; UB(1) = log(2)/5;                       % lp1
    LB(2) = log(2)/70; UB(2) = log(2)/5;                       % lp2
    LB(3) = 0; UB(3) = 1000*max(trg_syn);                      % b
    LB(4) = 0.01*max(trg_syn); UB(4) = 2*(max(trg_syn)-min(trg_syn));      % a_max
    LB(5) = 0; UB(5) = 100*max(tf_expr1);                      % K1
    LB(6) = 0; UB(6) = 100*max(tf_expr2);                      % K2
    LB(7) = 0; UB(7) = 30;                                     % n1
    LB(8) = 0; UB(8) = 30;                                     % n2
    
    % pre-calculate p(t)
    if (isempty(p_t1))
        p_t1 = zeros(Ngrid,42);
        lp = linspace(LB(1),UB(1),Ngrid);
        for i=1:Ngrid
            p0 = set_p0(tf_expr1,lp(i));
            p_t1(i,:) = protein_traj(tf_expr1,t,p0,lp(i),lp(i));
        end
    end
    if (isempty(p_t2))
        p_t2 = zeros(Ngrid,42);
        lp = linspace(LB(2),UB(2),Ngrid);
        for i=1:Ngrid
            p0 = set_p0(tf_expr2,lp(i));
            p_t2(i,:) = protein_traj(tf_expr2,t,p0,lp(i),lp(i));
        end
    end
    
    lam = zeros(1,2);

    [r c] = size(p_t1);
    if (r == 1 && c == 1)
        lam(1) = p_t1;
        p_t1 = protein_traj(tf_expr1,t,[],p_t1);
    end

    [r c] = size(p_t2);
    if (r == 1 && c == 1)
        lam(2) = p_t2;
        p_t2 = protein_traj(tf_expr2,t,[],p_t2);
    end
    
    % simulated annealing
    if (logic ~= 0)
        if (print)
            [bp bsc resh th arh] = SimAnnFit(logic,p_t1,p_t2,[],trg_syn,LB,UB,lambda,tau,theta0);

            for i = 1:2
                if (lam(i) ~= 0)
                    bp(i) = lam(i);
                end
            end
            
            figure; plot(resh);
            figure; plot(th);
            figure; plot(arh);
        else
            [bp bsc] = SimAnnFit(logic,p_t1,p_t2,[],trg_syn,LB,UB,lambda,tau,theta0);
            
            for i = 1:2
                if (lam(i) ~= 0)
                    bp(i) = lam(i);
                end
            end
        end
        
        model.n_tfs = 2;
        model.tf_id = tf;
        model.n_trgs = 1;
        model.rep = rep;
        model.logic = logic;
        model.trg_id = trg;

        model.p0(1) = set_p0(tf_expr1,bp(1));
        model.lambdaP(1) = bp(1);

        model.p0(2) = set_p0(tf_expr2,bp(2));
        model.lambdaP(2) = bp(2);

        model.b = bp(3);
        model.K = [bp(5) bp(6)];
        model.n = [bp(7) bp(8)];
        model.alpha_max = bp(4);
        model.best_score = bsc;
        model.norm_score = bsc/(42*var(trg_syn));
        
    else
        % determine direction of regulation individually
        mi1 = LearnGRF(tf(1),trg,0,rep,p_t1);
        mi2 = LearnGRF(tf(2),trg,0,rep,p_t2);
        if (mi1.logic == 1 && mi2.logic == 1)
           ll = [1 4 5]; ln = 3;
        elseif (mi1.logic == 1 && mi2.logic == 2)
           ll = [2 7]; ln = 2;
        elseif (mi1.logic == 2 && mi2.logic == 1)
           ll = [3 8]; ln = 2;
        elseif (mi1.logic == 2 && mi2.logic == 2)
           ll = [6 9 10]; ln = 3;
        end
        
        best = 1e300;
        for l=1:ln
            if (print)
                [bp bsc resh th arh] = SimAnnFit(ll(l),p_t1,p_t2,[],trg_syn,LB,UB,lambda,tau,theta0);

                for i = 1:2
                    if (lam(i) ~= 0)
                        bp(i) = lam(i);
                    end
                end
                
                figure; plot(resh);
                figure; plot(th);
                figure; plot(arh);
            else
                [bp bsc] = SimAnnFit(ll(l),p_t1,p_t2,[],trg_syn,LB,UB,lambda,tau,theta0);
                
                for i = 1:2
                    if (lam(i) ~= 0)
                        bp(i) = lam(i);
                    end
                end
            end
            
            mm(l).n_tfs = 2;
            mm(l).tf_id = tf;
            mm(l).n_trgs = 1;
            mm(l).rep = rep;
            mm(l).logic = ll(l);
            mm(l).trg_id = trg;

            mm(l).p0(1) = set_p0(tf_expr1,bp(1));
            mm(l).lambdaP(1) = bp(1);

            mm(l).p0(2) = set_p0(tf_expr2,bp(2));
            mm(l).lambdaP(2) = bp(2);

            mm(l).b = bp(3);
            mm(l).K = [bp(5) bp(6)];
            mm(l).n = [bp(7) bp(8)];
            mm(l).alpha_max = bp(4);
            mm(l).best_score = bsc;
            mm(l).norm_score = bsc/(42*var(trg_syn));
            
            if (mm(l).best_score < best)
                best_l = l;
                best = mm(l).best_score;
            end
            
        end
        
        model = mm(best_l);
        
    end
    
    
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

else
    error ('LearnGRF supports only 2 input factors.'); 
end