% function [fitpars, max_log_lh, AIC, BIC] = fit_model(data, plotflag)
%
% This function fits the resource-rational model to a given data set.
% 
% INPUT
%  data     : data structure that should contain the following fields
%             data.p_vec = vector with probing probability of probed item for each trial
%             data.error_vec = vector with estimation error for each trial (in range [-pi, pi])
%  plotflag : 0=no plot; 1=create plot of fit


function [fitpars, max_log_lh, AIC] = fit_model(data, plotflag)

% set up a structure with "global" variables
gvar.kappa_map    = [linspace(0,5,250) linspace(5.001,5000,1000)];
gvar.J_map        = gvar.kappa_map.*besseli(1,gvar.kappa_map,1)./besseli(0,gvar.kappa_map,1); % mapping between J and kappa
gvar.n_gamma_bins = 50;   % number of bins to use in discretization of gamma distribution over J (for numerical integration)
gvar.n_VM_bins    = 720;  % number of bins to use in discretization of Von Mises distribution over epsilon (for numerical integration)
gvar.n_plot_bins  = 15;   % number of bins to use for plotting 

% Get initial parameters as starting point for optimizer; also get upper and lower bounds on the parameters
[initpars, lb, ub, plb, pub] = get_init_pars(data, gvar);  

% Run optimization to find maximum-likelihood parameters
fprintf('Searching for maximum-likelihood parameter vector...\n');
if exist('bads','file')
    % Use BADS (Acerbi & Ma 2017) for optimization, if it exists
    options = bads('defaults');
    options.Display = 'off';
    [fitpars, neg_max_log_lh] = bads(@(pars) -log_lh_function(pars, data, gvar), initpars, lb, ub, plb, pub, options);
else
    % Use fminsearch (not used in the original analysis -- cannot guarantee that results are the same as in paper!)
    [fitpars, neg_max_log_lh] = fminsearch(@(pars) -log_lh_function(pars, data, gvar), initpars);
end
max_log_lh = -neg_max_log_lh;
AIC = -2*max_log_lh + 2*numel(fitpars);
fprintf('  Done: [beta=%2.3f, lambda=%2.6f, tau=%2.2f], max log lh = %2.2f, AIC=%2.1f\n',fitpars(1),fitpars(2),fitpars(3),max_log_lh,AIC);

% Plot result
if plotflag
    figure
    set(gcf,'Position',get(gcf,'Position').*[0.1 0.1 1.5 0.8]);    
    X_emp = linspace(-pi,pi,gvar.n_plot_bins+1);
    X_emp = X_emp(2:end)-diff(X_emp(1:2))/2;
    X_fit = linspace(-pi,pi,250);
    uPi = unique(data.p_vec);
    for ii=1:numel(uPi)
        % Compute optimal resource, i.e., the value of Jbar that minimizes the expected total cost
        Jbar_opt = fminsearch(@(Jbar) cost_function(Jbar,fitpars,uPi(ii),gvar), 10);
        
        % Compute the predicted estimation error distribution under Jbar_opt
        J = discretize_gamma(Jbar_opt,fitpars(3),gvar.n_gamma_bins);  % discretize distribution over J in equally sized bins
        kappa = interp1(gvar.J_map,gvar.kappa_map,min(J,max(gvar.J_map)));  % compute kappa value corresponding to each bin center
        Y_fit = mean(bsxfun(@times,1./(2*pi*besseli0_fast(kappa)),exp(bsxfun(@times,kappa,cos(X_fit')))),2); % computer Von Mises under each kappa value and take average
        
        % plot
        subplotids(ii)=subplot(2,4,ii);
        idx = data.p_vec==uPi(ii); % find all trials with probing probability equal to uPi(ii)
        Y_emp = hist(data.error_vec(idx),X_emp);
        Y_emp = Y_emp/sum(Y_emp)/diff(X_emp(1:2));
        plot(X_emp,Y_emp,'ko','markerfacecolor','k','markersize',3);
        hold on
        plot(X_fit,Y_fit,'k-');
        xlim([-pi, pi])
        xlabel('Estimation error');
        if mod(ii,4)==1
            ylabel('Probability');
        end
        box off
        title(sprintf('p_i=%2.3f',uPi(ii)));
    end   
    linkaxes(subplotids,'xy')       
end

%---- HELPER FUNCTIONS ---%

% Likelihood function 
function log_lh = log_lh_function(parvec, data, gvar)
uPi = unique(data.p_vec);
p_resp = zeros(1,numel(data.error_vec));
for ii=1:numel(uPi)
    % Compute optimal resource, i.e., the value of Jbar that minimizes the expected total cost
    Jbar_opt = fminsearch(@(Jbar) cost_function(Jbar,parvec,uPi(ii),gvar), 10);
    
    % Compute the predicted estimation error distribution under Jbar_opt
    J = discretize_gamma(Jbar_opt,parvec(3),gvar.n_gamma_bins);  % discretize distribution over J in equally sized bins
    kappa = interp1(gvar.J_map,gvar.kappa_map,min(J,max(gvar.J_map)));  % compute kappa value corresponding to each bin center
    p_error = mean(bsxfun(@times,1./(2*pi*besseli0_fast(kappa)),exp(bsxfun(@times,kappa,cos(data.error_vec')))),2); % computer Von Mises under each kappa value and take average
    idx = find(data.p_vec==uPi(ii)); % find all trials with probing probability equal to uPi(ii)
    p_resp(idx) = p_error(idx); % compute probability of each response error in this set of trials
end
log_lh = sum(log(max(p_resp,1e-3)));

% Expected total cost function
function Cbar_total=cost_function(Jbar,parvec,p_i,gvar)
% unwrap parameters
beta = parvec(1);
lambda = parvec(2);
tau = parvec(3);
% discretize gamma distribution and VM distribution for numerical integration
J = discretize_gamma(Jbar,tau,gvar.n_gamma_bins);
kappa = interp1(gvar.J_map,gvar.kappa_map,min(J,max(gvar.J_map)));
[VM_x, VM_y] = discretize_VM(kappa,gvar);
% compute expected cost terms
C_behavioral = abs(VM_x).^beta;  % behavioral cost 
Cbar_behavioral = mean(sum(bsxfun(@times,VM_y,C_behavioral),2)); % expected behavioral cost (integrate over J and error)
Cbar_total = p_i*Cbar_behavioral + lambda*Jbar; % local expected total cost

% Discretize a gamma distribution into equal-probability bins
function bins = discretize_gamma(Jbar,tau,nbins)
X_emp = linspace(0,1,nbins+1);
X_emp = X_emp(2:end)-diff(X_emp(1:2))/2;
warning off
bins = gaminv(X_emp,Jbar/tau,tau);
warning on

% Discretize VM distributions into equally-spaced bins (1 distribution per element in kappa vector)
function [VM_x, VM_y] = discretize_VM(kappa,gvar)
VM_x = linspace(-pi,pi,gvar.n_VM_bins);
VM_x = VM_x(2:end)-diff(VM_x(1:2))/2;
VM_y = exp(bsxfun(@times,kappa',cos(VM_x)));
VM_y = bsxfun(@rdivide,VM_y,sum(VM_y,2));

function [initpars, lb, ub, plb, pub] = get_init_pars(data,gvar)
% set parameter ranges
lb  = [-10,    0, 1e-3]; % lower bounds on beta, lambda, tau (for BADS)
ub  = [ 10,    1, 1000]; % upper bounds on beta, lambda, tau (for BADS)
plb = [  0, 1e-4,  .01]; % plausible lower bounds on beta, lambda, tau (for BADS)
pub = [  3,   .1,  100]; % plausible upper bounds on beta, lambda, tau (for BADS)

% coarse grid search on parameter values 
nsteps = 8; % number of steps in each parameter dimension
beta_vec = linspace(plb(1),pub(1),nsteps);
lambda_vec = logspace(log10(plb(2)),log10(pub(2)),nsteps);
tau_vec = logspace(log10(plb(3)),log10(pub(3)),nsteps);
log_lh_mat = zeros(nsteps,nsteps,nsteps)-Inf;
fprintf('Doing coarse grid search to set starting point for optimizer...\n');
for ii=1:nsteps
    for jj=1:nsteps
        for kk=1:nsteps
            log_lh_mat(ii,jj,kk) = log_lh_function([beta_vec(ii), lambda_vec(jj), tau_vec(kk)],data,gvar);
        end
    end
    fprintf('  %2.1f%% done\n',ii*nsteps^2/nsteps^3*100);
end
[~, I] = max(log_lh_mat(:));
[idx1, idx2, idx3] = ind2sub(size(log_lh_mat),I);
initpars = [beta_vec(idx1) lambda_vec(idx2) tau_vec(idx3)];
