function [p, mat] = meg_stats(results,baseline,varargin)

% Function for calculating peak statistics results on time-series data (e.g.
% MEG/EEG) using a number of different methods. For many permutations this
% may take a while! All tests here are one-sided!
%
%
% INPUT:
%     results: results matrix (number of tests x number of subjects)
%     baseline: chance (e.g. 0.5 or 50) or 0 or whatever refers to the reference
% additional inputs (in pairs or as struct variable opt with field names being the first string):
%     'method': 't' OR 'wilkoxon' OR 'signperm' (default: t)
%     'q': cluster-inducing threshold or alpha for desired false discovery rate (default: 0.05)
%     'multcomp': 'none', 'fdr_indep', 'fdr_dep', 'cluster', 'clustermass', 'tfce' (default: none)
%     'n_perm' (only used when method is permutation related): 'full', or number (default: 10000)
%
%     PLEASE NOTE: We only compute unique random permutations if a subset
%       is chosen which makes it more accurate
%     PLEASE ALSO NOTE  since the permutation distribution is symmetric, we
%     actually only compute half of the desired permutations and then flip
%     them. While this increases the number, the increase in accuracy is
%     limited, because sampling errors are correlated. For that reason, it
%     is recommended to run 2x as many permutations as requested and just
%     report that 1x as many have been calculated (higher precision than
%     reported should not be an issue).

% TODO: incorporate cluster-correction

% by Martin Hebart, 2016


opt.method = 't';
opt.q = 0.05;
opt.multcomp = 'none';
opt.n_perm = 10000;

results = results-baseline; % remove chance

if nargin == 3
    if ~isstruct(varargin{1}) && ~isempty(varargin{1})
        error('For two input variables, the second must be a struct variable containing all relevant settings')
    end
    % overwrite fields
    fn = fieldnames(varargin{1});
    for i_fn = 1:length(fn)
        opt.(fn{i_fn}) = varargin{1}.(fn{i_fn});
    end
else
    for i = 1:nargin-3 % -2 for first two inputs and another one because they come in pairs
        opt.(varargin{i}) = varargin{i+1};
    end
end

% return settings to output
mat = opt;

[n_test,n_sub] = size(results);

switch lower(opt.method)
    case 't'
        M = mean(results,2);
        const1 = 1/sqrt(n_sub-1);
        const2 = 1/sqrt(n_sub);
        se = sqrt(sum(bsxfun(@minus,results,M).^2,2)) * const1 * const2;
        t = M./se; % t-value is the stat
        mat.mean = M+baseline; % add chance back
        mat.se = se;
        mat.t = t;
        mat.t_cutoff = tinv(1-opt.q,n_sub-1);
        p = tpdf(t,n_sub-1);
    case 'wilkoxon'
        for i_p = size(results,1):-1:1
            p(i_p) = signrank(results(i_p,:),0,'tail','right');
        end
        mat.median = median(results,2);
    case 'signperm'
        rng('shuffle')
        M = mean(results,2);
        const1 = 1/sqrt(n_sub-1);
        const2 = 1/sqrt(n_sub);
        se = sqrt(sum(bsxfun(@minus,results,M).^2,2)) * (const1 * const2);
        t = M./se;
        mat.mean = M+baseline; % add chance back
        mat.se = se;
        mat.t = t;
        
        % run the sign permutations
        [t_perm,n_perm] = signperm(results,opt);
        
        mat.t_perm = t_perm;
        mat.n_perm = n_perm;
        
        % get percentiles of t_perm (even though sorting is less memory
        % efficient, it is just a lot faster for many permutations)
        t_perm_sorted = sort(t_perm,2,'descend');
        t_cutoff = t_perm_sorted(:,floor(opt.q*n_perm));
        
        mat.t_cutoff = t_cutoff;
        
        % calculate p-values
        ind = sum(bsxfun(@minus,t_perm_sorted,t)>=0,2); % all the values where permutations reach or exceed the true value
        p = ind/n_perm;
        p = max(p,1/n_perm); % less conservative than dividing by (n_perm+1)
        % the line below would make this two sided
%         p(p>0.5) = 1-p(p>0.5);
        
    otherwise
        error('Method %s not implemented.',opt.method)
end


mat.p_orig = p; % save original p-values
switch lower(opt.multcomp)
    case 'none'
        % do nothing
    case 'fdr_indep'
        [~,crit_p,~,p] = fdr_bh(p,opt.q,'pdep');
        mat.crit_p = crit_p;
        p = min(p,1);
        mat.p = p;
    case 'fdr_dep'
        [~,crit_p,~,p] = fdr_bh(p,opt.q,'dep');
        mat.crit_p = crit_p;
        p = min(p,1);
        mat.p = p;
    case 'cluster'
        if ~isfield(mat,'t_perm')
            [t_perm,n_perm] = signperm(results,opt);
            mat.t_perm = t_perm;
            mat.n_perm = n_perm;
        end
        for i_perm = n_perm:-1:1
            h = mat.t_perm(:,i_perm) > mat.t_cutoff;
            maxclust(i_perm) = max(find_clusters(h));
        end
        
        maxclust_sorted = sort(maxclust,'descend');
        mat.clust_cutoff = maxclust_sorted(ceil(opt.q*n_perm));
        
        [clust,~,~,labels] = find_clusters(mat.t > mat.t_cutoff);
        p = ones(1,n_test);
        for i_label = 1:length(clust)
            p(labels==i_label) = sum(maxclust_sorted>=clust(i_label))/n_perm;
        end
        p = max(p,1/n_perm);
        mat.p = p;
        
    case 'clustermass' % like cluster, but using product of extent and height
        if ~isfield(mat,'t_perm')
            [t_perm,n_perm] = signperm(results,opt);
            mat.t_perm = t_perm;
            mat.n_perm = n_perm;
        end
        for i_perm = n_perm:-1:1
            h = mat.t_perm(:,i_perm) > mat.t_cutoff;
            [clust,~,~,labels] = find_clusters(h);
            mass = zeros(1,length(clust));
            for i_label = 1:length(clust)
                mass(i_label) = sum(mat.t_perm(labels==i_label,i_perm));
            end
            maxmass(i_perm) = max(mass);
        end
        
        maxmass_sorted = sort(maxmass,'descend');
        mat.clustmass_cutoff = maxmass_sorted(ceil(opt.q*n_perm));
        
        [clust,~,~,labels] = find_clusters(mat.t > mat.t_cutoff);
        p = ones(1,n_test);
        for i_label = 1:length(clust)
            currmass = sum(mat.t(labels==i_label));
            p(labels==i_label) = sum(maxmass_sorted>=currmass)/n_perm;
        end
        p = max(p,1/n_perm);
        mat.p = p;
        
    case 'tfce'
        
        % TODO: this is not entirely accurate, we would in fact need to
        % specify the cutoff using percentiles of the distribution to make
        % them more comparable (because the cutoff might be different for
        % each time point - it is correct if it is the same)
        
        if ~isfield(mat,'t_perm')
            [t_perm,n_perm] = signperm(results,opt);
            mat.t_perm = t_perm;
            mat.n_perm = n_perm;
        end
        
        % we use 20 thresholding steps of 0 to max(t)
        n_steps = 20;
        h_steps = linspace(0,max(mat.t),n_steps);
        % we weight the extent by 3/4 and the height by 2
        E = 3/4; H = 2;
        
        tfce = zeros(n_test,1); % initialize        
        for i_step = 1:n_steps
            % first get TFCE values for original statistic
            [clust,~,~,labels] = find_clusters(mat.t > h_steps(i_step));
            for i_label = 1:length(clust)
                tfce(labels==i_label) = tfce(labels==i_label) + clust(i_label)^E * h_steps(i_step);
            end
        end
        mat.tfce = tfce;

        tfce_perm = zeros(n_test,n_perm); % now repeat for each permutation
        for i_perm = 1:n_perm
            for i_step = 1:n_steps
                % first get TFCE values for original statistic
                [clust,~,~,labels] = find_clusters(mat.t_perm(:,i_perm) > h_steps(i_step));
                for i_label = 1:length(clust)
                    tfce_perm(labels==i_label,i_perm) = tfce_perm(labels==i_label,i_perm) + clust(i_label)^E * h_steps(i_step)^H;
                end
            end
        end
        
%         (BELOW WOULD BE THE UNCORRECTED P-VALUE)
%         % sort tfce_perm 
%         tfce_perm_sorted = sort(tfce_perm,'descend');
%         mat.tfce_cutoff = tfce_perm_sorted(:,ceil(opt.q*n_perm));
%         
%         ind = sum(bsxfun(@minus,tfce_perm_sorted,tfce)>=0,2); % all the values where permutations reach or exceed the true value
%         p = ind/n_perm;
%         p = max(p,1/n_perm); % less conservative than dividing by (n_perm+1)
%             
%         mat.p = p;

%       USING MAX STATISTIC
        tfce_perm_max = sort(max(tfce_perm),'descend');
        mat.tfce_cutoff = tfce_perm_max(ceil(opt.q*n_perm));
        ind = sum(bsxfun(@minus,tfce_perm_max,tfce)>=0,2); % all the values where permutations reach or exceed the true value
        p = ind/n_perm;
        p = max(p,1/n_perm);        
        
        mat.p = p;

        
    case 'maxt'
        error('Sorry, maxt not implemented yet.')
    otherwise
        error('unknown method %s for opt.multcomp')
end


function checkmem(n_elements,maxmem)

MBest = n_elements*8 / (1024*1024);

if  MBest > maxmem
    error('Estimated required memory for calculation of permutation matrix alone is %.0d MB and exceeds set memory capacity of 1024MB. Either increase memory limit in function or run fewer permutations at once.',MBest)
end


function [t_perm,n_perm] = signperm(results,opt)

try
    memlim = opt.memlim;
catch
memlim = 1024; % I would estimate the actual memory used to be 3-4x as high
end
maxmem = 0;

[n_test,n_sub] = size(results);
const1 = 1/sqrt(n_sub-1);
const2 = 1/sqrt(n_sub);

% calculate number of possible permutations
if ischar(opt.n_perm) && (strcmpi(opt.n_perm,'full') || strcmpi(opt.n_perm,'all'))
    n_perm = 2^(n_sub-1); % -1 because of symmetry
    checkmem(n_perm*n_sub,memlim)
    % the line below gives us all unique permutations (adjusted from dec2bin)
    % (but we are actually only sampling 2^(n_sub-1) rather than 2^n_sub and later calculating the second half as symmetric
    % the second -1 comes in because we count from 0
    permmat = -2*rem(floor(bsxfun(@times,(0:(2^(n_sub-1)-1)),pow2(1-n_sub:0)')),2)+1;
else
    n_perm = opt.n_perm/2; % only calculate half at first
    checkmem(n_perm*n_sub,memlim)
    maxperm = 2^(n_sub-1);
    if n_perm > maxperm
        warning(['Number of selected permutations exceeds number of possible permutations. Calculating full permutation instead (' num2str(n_perm) ' permutations).'])
        n_perm = maxperm;
        k = n_perm;
    else
        % randomly sample uniquely from the first half of all available permutations (this is more accurate than sampling just randomly and 2x as fast because we can flip symmetric permutations)
        k = randperm(2^(n_sub-1),n_perm);
    end
    permmat = -2*rem(floor(bsxfun(@times,k-1,pow2(1-n_sub:0)')),2)+1;
end
% we do sign permutations across subjects
h = whos('results');
est_MB = (h.bytes * n_perm)/(1024*1024);
n_iter = ceil(est_MB/memlim);
first_ind = 1;
t_perm = zeros(n_test,n_perm);
if n_iter == 1
    fprintf('Running %i permutations.\n',n_perm);
else
    fprintf('Running %i permutations in %i separate steps (because of memory constraints).\n',n_perm,n_iter)
end
for i_iter = 1:n_iter
    if i_iter == n_iter
        last_ind = n_perm;
        %                 n_step = last_ind - first_ind + 1;
    else
        last_ind = first_ind + floor(n_perm/n_iter) - 1;
        %                 n_step = floor(n_perm/n_iter);
    end
    
    permmat2 = shiftdim(repmat(permmat(:,first_ind:last_ind),1,1,n_test),2);
    
    % this is a short-cut for calculation of our permutation means
    Mperm = results*permmat(:,first_ind:last_ind) / n_sub;
    
    % for the std (and associated se) this is a little more complicated
    results_perm = bsxfun(@times,results,permmat2);
    % now to get the two together we need to shift some dimensions around
    se_perm = shiftdim(sqrt(sum((shiftdim(results_perm,2)-repmat(shiftdim(Mperm,1),1,1,n_sub)).^2,3)),1) * (const1*const2);
    t_perm(:,first_ind:last_ind) = Mperm./se_perm;
    
    first_ind = last_ind+1; % update
    
    tmp = whos;
    maxmem = max(maxmem,sum([tmp.bytes])/(1024*1024));
    
end

% replicate results as negative flip
t_perm(:,end+1:2*size(t_perm,2)) = -t_perm;
n_perm = 2*n_perm; % change back to original number

tmp = whos;
maxmem = max(maxmem,sum([tmp.bytes])/(1024*1024));
if maxmem > memlim
    disp(['Maximum memory used: ' num2str(maxmem) 'MB'])
end
