%%Apply Maximum Likelihood Estimation filtering to MED OTU counts

% This data table T, is ASV_Counts_Matrix in the the Supp_Info_and_Data
% attached file
T = readtable('../Data/mat-counts.txt'); %% Complete data table
nSamp = 28; %% Number samples analyzing

%Sort table in ascending order of desired sample
T = sortrows(T, find(strcmpi(T.Properties.VariableNames,'RAW_STOOL')));

raw = T.RAW_STOOL; %%Raw stool sample counts
rawP = raw/sum(raw); %% Raw percentage

counts = table2array(T(1:height(T),2:nSamp+1)); %%Count data
[szR, szC] = size(counts);

p = zeros([szR szC]); %Percentage counts

for i=1:szC
    p(:,i) = counts(:,i)/sum(counts(:,i));
end
    
%The MLE 95% Confidence interval is p +/- 1.96/sqrt(-l'')
%The MLE 90% CI is p +/- 1.645/sqrt(-l'')
% Where l = const + klog(p) + (n-k) log (1-p)
%And for low reads reparametrize with logit transform 
% phi = log (p/(1-p)) =~ log(p)
% CI is then phi +/- 1.645/sqrt(n*p*(1-p))
%and transforming back to p gives e^phi/(1+e^phi)

se = 1.645; %standard error

phiLow = zeros([szR szC]);
phiHigh = zeros([szR szC]);
countsLow = zeros([szR szC]);
countsHigh =  zeros([szR szC]);

for i = 1:szC
    phiLow(:,i) = log(p(:,i)./(1-p(:,i))) ...
        - se./sqrt(sum(counts(:,i))*p(:,i).*(1-p(:,i)));
    
    phiHigh(:,i) = log(p(:,i)./(1-p(:,i))) ...
        + se./sqrt(sum(counts(:,i))*p(:,i).*(1-p(:,i)));
end

plow = exp(phiLow)./(1+exp(phiLow));    
phigh = exp(phiHigh)./(1+exp(phiHigh)); phigh(isnan(phigh)) = 0;

%Express in counts as well
for i = 1:szC
    countsLow(:,i) = plow(:,i)*sum(counts(:,i));
    countsHigh(:,i) = phigh(:,i)*sum(counts(:,i));
end

%Apply to raw stool
rawPhiLow = log(rawP./(1-rawP)) - se./sqrt(sum(raw)*rawP.*(1-rawP));
rawPhiHigh = log(rawP./(1-rawP)) + se./sqrt(sum(raw)*rawP.*(1-rawP));

rawLowP = exp(rawPhiLow)./(1+exp(rawPhiLow)); 
rawHighP = exp(rawPhiHigh)./(1+exp(rawPhiHigh));

rawCountsLow = rawLowP*sum(raw);
rawCountsHigh = rawHighP*sum(raw);

%%Fit a fractional amount of RAW STOOL to OTUs in sample that have a high
%%rank in the raw stool
fitData = zeros([szR szC]);
fitThresh = 10; %threshold to include data for fitting
r = zeros([1 szC]); %ratio of best fit for large OTUs
err = 0;
errTmp = 100000;

for j = 1:szC
    
    %The fitData is the portion of counts to fit. High OTUs at large 
    % rank of Raw stool do not need to be fit while low OTUs do
    fitData = counts(:,j);
    
    for i = 1:szR
        if counts(i,j) > raw(i)/fitThresh
            fitData(i) = NaN;
        end
    end
    fitData(fitData == 0) = NaN;
    
    % Calculate the fitting error
    for k = 10:1000
        for i = 1:szR
            if ~isnan(fitData(i))
                err = err + (fitData(i) - raw(i)/k)^2;
            end
        end
        
        % Update smallest error
        if err < errTmp
            r(j) = k;
        end
        
        errTmp = err;
        err = 0;
    end
    
    
end

%Filter counts
countsFiltered = counts;
numExcluded = zeros([szC 1]); % number OTUs excluded


for i = 1:szR
    for j = 1:szC
        if countsLow(i,j) < raw(i)/r(j) %Divide by ratio of best fit
            countsFiltered(i,j) = 0;
            if countsLow(i,j) ~= 0
                numExcluded(j) = numExcluded(j) + 1;
            end
        end
    end
end

richness = sum(countsFiltered > 0); %number non zero reads

%Create Filtered count table and filtered percent table
Tout = T;
Tout{:,2:nSamp+1} = countsFiltered;
writetable(Tout,'../Data/FilteredCountsTable.txt','Delimiter','\t');

pctFiltered = zeros([szR szC]);
pctFiltered = countsFiltered./sum(countsFiltered);
Pout = T;
Pout{:,2:nSamp+1} = pctFiltered;
writetable(Pout,'../Data/pctFilteredTable.txt','Delimiter','\t');


%% Plots 
figure(1)

examp = 4; %%Example for image
errorbar([1:length(raw)],raw,...
    raw-rawCountsLow, rawCountsHigh-raw,...
    'd','MarkerSize',5, 'Color',[0.95 0 0], ...
    'MarkerEdgeColor',[1 0 0],'LineWidth',1.2)


hold on

errorbar([1:length(counts)],counts(:,examp),...
    counts(:,examp)-countsLow(:,examp), countsHigh(:,examp)-counts(:,examp),...
    '.','MarkerSize',15, 'Color',[0.2 0.2 0.2], ...
    'MarkerEdgeColor',[0 0 0],'LineWidth',1.2)

ylim([0 2000])
xlim([0 167])
set(gca,'LineWidth',1.2)
set(gca,'FontSize',14)
xlabel('Rank','FontSize',18)
ylabel('Counts','FontSize',18)
hold on
plot(raw/r(examp),'-b','LineWidth',2)
legend('RAW STOOL','BHIS PLATE','Threshold','Location','northwest')
pbaspect([1.2 1 1])

hold off




