function varargout = plotGenealogyTree_BranchList(BranchList, varargin)

%% Family Tree Inputs
%{
start with BranchList variable, which is a m x 7 array, m is the number of unique bacteria
col 1: Family       Unique number to identify the landing cell and its progeny
col 2: Branch       Unique number counting the total number of division events
col 3: Heir         1 or 2 for the 2 daughter cells after division
col 4: FromBranch   Parent Branch
col 5: FromHeir     Parent Heir
col 6: FirstFrame   1st BF frame bacterium is present in
col 7: LastFrame    last BF frame bacterium is present in

extra columns:
col 8: EndFate      0 = 'stay'/'stay and divide' (isleaf=0)
                    1 = 'detach' (isleaf=1)
                    2 = 'move off screen' (isleaf=1)
%}


if ~exist('BranchList','var') || isempty(BranchList)
    %% example:
    BranchList = [
        91	0	0	0	0	3170	5131	0
        91	1	1	0	0	5132	7359	0
        91	1	2	0	0	5132	7155	0
        91	2	1	1	2	7156	8933	0
        91	2	2	1	2	7156	9153	0
        91	3	1	1	1	7360	9043	0
        91	3	2	1	1	7360	8997	0
        91	4	1	2	1	8934	10872	0
        91	4	2	2	1	8934	11180	0
        91	5	1	3	2	8998	11242	0
        91	5	2	3	2	8998	12000	0
        91	6	1	3	1	9044	10928	0
        91	6	2	3	1	9044	10818	0
        91	7	1	2	2	9154	9398	1
        91	7	2	2	2	9154	11164	0
        91	8	1	6	2	10819	11096	1
        91	8	2	6	2	10819	12000	0
        91	9	1	4	1	10873	11493	1
        91	9	2	4	1	10873	12000	0
        91	10	1	6	1	10929	12000	0
        91	10	2	6	1	10929	12000	0
        91	11	1	7	2	11165	12000	0
        91	11	2	7	2	11165	12000	0
        91	12	1	4	2	11181	12000	0
        91	12	2	4	2	11181	12000	0
        91	13	1	5	1	11243	12000	0
        91	13	2	5	1	11243	12000	0
        ];
end

%% convert varargin into optional inputs either as a struct or a 'parameter','value' combination
if length(varargin) == 1 && isstruct(varargin{1})
    p = varargin{1};
else
    p = struct(varargin{:});
end

d.DisplayName = '';
v.DisplayName = {{'char'},{'vector'}};
d.VertLineWidth = 10;
v.VertLineWidth = {{'numeric'},{'scalar'}};
d.HorzLineWidth = 5;
v.HorzLineWidth = {{'numeric'},{'scalar'}};
d.AxisLineWidth = 2.5;
v.AxisLineWidth = {{'numeric'},{'scalar'}};
d.PlotFontSize = 48;
v.PlotFontSize = {{'numeric'},{'scalar'}};
d.tUnitConversion = 3/3600; % how much time each frame is; currently in h (3 s per frame), but can be any units
v.tUnitConversion = {{'numeric'},{'scalar'}};
d.showTimeGrid = true;
v.showTimeGrid = {{'logical','numeric'},{'scalar'}};
d.showTimeAxis = true;
v.showTimeAxis = {{'logical','numeric'},{'scalar'}};
d.TimeOffset = 0; % how much to offset the family time by, in units of tUnitConversion (NaN to just use the tracking data starting offset)
v.TimeOffset = {{'numeric'},{'scalar'}};
d.alwaysStartAtTimeZero = false;
v.alwaysStartAtTimeZero = {{'logical','numeric'},{'scalar'}};
d.PlotBoxAspectRatioX = 2;
v.PlotBoxAspectRatioX = {{'numeric'},{'scalar'}};
d.TotalTime = inf;
v.TotalTime = {{'numeric'},{'scalar'}};
d.showTextLabel = false;
v.showTextLabel = {{'logical','numeric'},{'scalar'}};
d.VertLineColor = [0,0,0]./255;
v.VertLineColor = {{'numeric'},{'numel',3}};
d.HorzLineColor = [125,0,0]./255;
v.HorzLineColor = {{'numeric'},{'numel',3}};
d.DetachMarker = 'v';
v.DetachMarker = {'o','+','*','.','x','s','d','^','v','<','>','p','h','none'};
d.MoveOffScreenMarker = 'none';
v.MoveOffScreenMarker = {'o','+','*','.','x','s','d','^','v','<','>','p','h','none'};
d.plotGenerations = false;
v.plotGenerations = {{'logical','numeric'},{'scalar'}};

p = checkOptionalInputs(p,d,v);
DisplayName = p.DisplayName;
VertLineWidth = p.VertLineWidth;
HorzLineWidth = p.HorzLineWidth;
AxisLineWidth = p.AxisLineWidth;
PlotFontSize = p.PlotFontSize;
tUnitConversion = p.tUnitConversion;
showTimeGrid = p.showTimeGrid ~= 0;
showTimeAxis = p.showTimeAxis ~= 0;
TimeOffset = p.TimeOffset;
alwaysStartAtTimeZero = p.alwaysStartAtTimeZero ~= 0;
PlotBoxAspectRatioX = p.PlotBoxAspectRatioX;
TotalTime = p.TotalTime;
showTextLabel = p.showTextLabel ~= 0;
VertLineColor = p.VertLineColor;
HorzLineColor = p.HorzLineColor;
DetachMarker = p.DetachMarker;
MoveOffScreenMarker = p.MoveOffScreenMarker;
plotGenerations = p.plotGenerations ~= 0;

%% Calculate tree using 3rd party functions

% cut off identities that exist for more than TotalTime
if isfinite(TotalTime)
    FirstFrame = min(BranchList(:,6));
    BranchList(BranchList(:,7) > (FirstFrame + TotalTime - 1), :) = [];
end

% initialize tree with empty node (so that we can plot the real founder cell at another frame besides 1)
TreeLabel = tree(' ');
TreeDuration = tree(BranchList(1,6));
TreeLabelValues = tree([-1,-1]);

% add founder cell node at first frame it appears in
TreeLabel = TreeLabel.addnode(1,sprintf('Family %d, %d,%d,%d:%d',BranchList(1,[1,2,3,6,7])));
TreeDuration = TreeDuration.addnode(1,BranchList(1,7)-BranchList(1,6)+1);
TreeLabelValues = TreeLabelValues.addnode(1,[0,0]);


for iElement = 2:size(BranchList,1)
    
    % find parent index
    FromIdx = find((BranchList(:,2) == BranchList(iElement,4)) & (BranchList(:,3) == BranchList(iElement,5)))+1;
    
    % add daughter cell node with parent info
    TreeLabel = TreeLabel.addnode(FromIdx,sprintf('%d,%d,%d:%d',BranchList(iElement,[2,3,6,7])));
    TreeDuration = TreeDuration.addnode(FromIdx,BranchList(iElement,7)-BranchList(iElement,6)+1);
    TreeLabelValues = TreeLabelValues.addnode(FromIdx,BranchList(iElement,[2,3]));
    
end

EndFateTree = tree(TreeLabelValues,'clear');
if size(BranchList,2) > 7
    EndFateTree = EndFateTree.set(1,'');
    for iElement = 2:EndFateTree.nnodes
        switch BranchList(iElement-1,8)
            case 0
                if EndFateTree.isleaf(iElement)
                    EndFate = 'stay';
                else
                    EndFate = 'stay and divide';
                end
            case 1
                EndFate = 'detach';
            case 2
                EndFate = 'move off screen';
        end
        EndFateTree = EndFateTree.set(iElement,EndFate);
    end
end
[BranchTreeAsym,SubTreeAsym] = getTreeAsymmetry(EndFateTree);

if alwaysStartAtTimeZero
    TreeDuration = TreeDuration.set(1,1);
end

%% Plot tree using 3rd party functions

% initialize handles
if isempty(DisplayName)
    FigureHandle = figure();
else
    FigureHandle = figure('NumberTitle','off','Name',DisplayName);
end
PlotHandle = axes('parent',FigureHandle,'Color','none','Box','off','YDir','reverse','XTickLabel','','XTick',[],'LineWidth',AxisLineWidth);

% plot tree
if plotGenerations
    [VertLineHandle,HorzLineHandle,TreeTextHandle] = TreeLabel.plot([],'parent',PlotHandle);
else
    [VertLineHandle,HorzLineHandle,TreeTextHandle] = TreeLabel.plot(TreeDuration,'parent',PlotHandle);
end
[VertLineHandle,HorzLineHandle,TreeTextHandle] = flipGenealogyTree(VertLineHandle,HorzLineHandle,TreeTextHandle);

% set empty initial node to invisible 
set(VertLineHandle.get(1),'Visible','off');
set(HorzLineHandle.get(1),'Visible','off');

% trick to get to plot founder cell correctly
set(VertLineHandle.get(2),'YData',get(VertLineHandle.get(2),'YData')-[0,1]);

% show ID labels or not
if ~showTextLabel
    set([TreeTextHandle.Node{:}],'Visible','off');
else
    set([TreeTextHandle.Node{:}],'Clipping','on');
end

% set horizontal lines (division events)
set([HorzLineHandle.Node{:}],'LineWidth',HorzLineWidth);
uistack([HorzLineHandle.Node{:}],'bottom');
set([HorzLineHandle.Node{:}],'Color',HorzLineColor);

% set YLim to automatically fit the tree
YData = get([VertLineHandle.Node{2:end}],'YData');
if iscell(YData)
    YData = [YData{:}];
end
set(PlotHandle,'YLim',[min(YData),max(YData)*1.02]);
if ~isfinite(TimeOffset)
    TimeOffset = min(YData) .* tUnitConversion;
end

% set XLim with enough space to show plot
XData = get([VertLineHandle.Node{:}],'XData');
XData = cat(1, XData{:});
if ~all(diff(XData)==0)
    set(PlotHandle,'Xlim',[min(XData(:)),max(XData(:))]+0.025.*diff([min(XData(:)),max(XData(:))]).*[-1,1]);
end

% go through every vertical line to color it (if needed, set a function instead of static color)
for n2 = 2:TreeLabel.nnodes
    set(VertLineHandle.get(n2),'Color',VertLineColor,'LineWidth',VertLineWidth,'DisplayName',sprintf('%d-%d',TreeLabelValues.get(n2)));
end

% add detach points if given
DetachPointsHandle = tree(VertLineHandle,'clear');
if size(BranchList,2) > 7
    for n2 = 1:DetachPointsHandle.nnodes
        markerX = mean(get(VertLineHandle.get(n2),'XData'));
        markerY = max(get(VertLineHandle.get(n2),'YData'));
        switch EndFateTree.get(n2)
            case 'detach'
                det_handle = line(markerX,markerY+VertLineWidth.*0.0015.*abs(diff(get(PlotHandle,'ylim'))),'parent',PlotHandle,...
                    'Color',VertLineColor,'LineStyle','none','DisplayName',sprintf('%d-%d',TreeLabelValues.get(n2)),...
                    'MarkerSize',VertLineWidth*2,'MarkerEdgeColor','none','Marker',DetachMarker,'MarkerFaceColor',VertLineColor);
                set(VertLineHandle.Node{n2},'LineWidth',VertLineWidth.*0.95);
            case 'move off screen'
                det_handle = line(markerX,markerY+VertLineWidth.*0.0015.*abs(diff(get(PlotHandle,'ylim'))),'parent',PlotHandle,...
                    'Color',VertLineColor,'LineStyle','none','DisplayName',sprintf('%d-%d',TreeLabelValues.get(n2)),...
                    'MarkerSize',VertLineWidth*2.5,'LineWidth',VertLineWidth/3.5,'MarkerEdgeColor','auto','Marker',MoveOffScreenMarker,'MarkerFaceColor','none');
                set(VertLineHandle.Node{n2},'LineWidth',VertLineWidth.*0.99);
            otherwise
                det_handle = [];
        end
        DetachPointsHandle = DetachPointsHandle.set(n2,det_handle);
    end
end

% add real time to y axes (also style it)
for n4 = 1:VertLineHandle.nnodes
    if ~isempty(VertLineHandle.get(n4)) && ishghandle(VertLineHandle.get(n4))
        set(VertLineHandle.get(n4),'YData',get(VertLineHandle.get(n4),'YData') .* tUnitConversion + TimeOffset);
    end
    if ~isempty(HorzLineHandle.get(n4)) && ishghandle(HorzLineHandle.get(n4))
        set(HorzLineHandle.get(n4),'YData',get(HorzLineHandle.get(n4),'YData') .* tUnitConversion + TimeOffset);
    end
    if ~isempty(DetachPointsHandle.get(n4)) && ishghandle(DetachPointsHandle.get(n4))
        set(DetachPointsHandle.get(n4),'YData',get(DetachPointsHandle.get(n4),'YData') .* tUnitConversion + TimeOffset);
    end
    if ~isempty(TreeTextHandle.get(n4)) && ishghandle(TreeTextHandle.get(n4))
        pos = get(TreeTextHandle.get(n4),'Position');
        pos(2) = pos(2) .* tUnitConversion + TimeOffset;
        set(TreeTextHandle.get(n4),'Position',pos);
    end
end
ylim(PlotHandle, ylim(PlotHandle) .* tUnitConversion + TimeOffset);
set(PlotHandle,'XTick',[],'YDir','reverse','TickDir','in','TickLength',0.015+[0,0],'YTickMode','auto','YTickLabelMode','auto','FontSize',PlotFontSize,'LineWidth',AxisLineWidth);
if showTimeAxis
    set(PlotHandle,'YColor','k');
else
    set(PlotHandle,'YColor','none');
end
if showTimeGrid
    set(PlotHandle,'YGrid','on');
else
    set(PlotHandle,'YGrid','off');
end
set(PlotHandle,'Color','none','XColor','none');
set(PlotHandle,'PlotBoxAspectRatio',[PlotBoxAspectRatioX,1,1]);

handles.FigureHandle = FigureHandle;
handles.PlotHandle = PlotHandle;
handles.TreeLabel = TreeLabel;
handles.TreeDuration = TreeDuration;
handles.TreeLabelValues = TreeLabelValues;
handles.VertLineHandle = VertLineHandle;
handles.HorzLineHandle = HorzLineHandle;
handles.TreeTextHandle = TreeTextHandle;
handles.DetachPointsHandle = DetachPointsHandle;
handles.EndFateTree = EndFateTree;
handles.BranchTreeAsym = BranchTreeAsym;
handles.SubTreeAsym = SubTreeAsym;

set(FigureHandle,'UserData',handles);

switch nargout
    case 1
        varargout{1} = handles;
end

end

%% Helper function to flip genealogy tree plot
function [VertLineHandle,HorzLineHandle,TreeTextHandle] = flipGenealogyTree(VertLineHandle,HorzLineHandle,TreeTextHandle)
%{
Vertically flip family tree such that the following branches are as left as possible: 
- Branches with non-detached immediate descendants
- Branches with higher Strahler numbers or have any descendants with higher Strahler numbers
- Branches with higher fraction of non-detached total descendants
- Branches that are on the surface longer if both descendants don't stay
%}

flipValues = @(x,x0) x+2.*(x0-x);

TreeIndex = VertLineHandle.breadthfirstiterator;

strahlerTree = getStrahlerNumber(VertLineHandle);

for n_tree = TreeIndex
    if VertLineHandle.isleaf(n_tree)
        continue;
    end
    v_subtree = VertLineHandle.subtree(n_tree);
    h_subtree = HorzLineHandle.subtree(n_tree);
    t_subtree = TreeTextHandle.subtree(n_tree);
    s_subtree = strahlerTree.subtree(n_tree);
    childIDs = v_subtree.getchildren(1);
    if numel(childIDs) ~= 2
        continue;
    end
    Xpos0 = get(t_subtree.get(1),'Position');
    Xpos0 = Xpos0(1);
    clear TimeOnSurface Xpos StrahlerNumber isLeafNode s_subsubtree FracDescendantsStay MaxDescendantStrahler
    for n_c = numel(childIDs):-1:1
        TimeOnSurface{n_c} = abs(diff(get(v_subtree.get(childIDs(n_c)),'YData')));
        Xpos{n_c} = get(t_subtree.get(childIDs(n_c)),'Position');
        Xpos{n_c} = Xpos{n_c}(1);
        StrahlerNumber{n_c} = s_subtree.get(childIDs(n_c));
        isLeafNode{n_c} = h_subtree.isleaf(childIDs(n_c));
        s_subsubtree{n_c} = s_subtree.subtree(childIDs(n_c));
        FracDescendantsStay{n_c} = (s_subsubtree{n_c}.nnodes - numel(s_subsubtree{n_c}.findleaves)) ./ s_subsubtree{n_c}.nnodes;
        MaxDescendantStrahler{n_c} = max([s_subsubtree{n_c}.Node{2:end}]);
    end
    
    if ~isLeafNode{1} && ~isLeafNode{2}
        doFlipTree = (Xpos{1} > Xpos{2} && StrahlerNumber{1} > StrahlerNumber{2}) || (Xpos{1} < Xpos{2} && StrahlerNumber{1} < StrahlerNumber{2}) ||...
            (Xpos{1} > Xpos{2} && MaxDescendantStrahler{1} > MaxDescendantStrahler{2}) || (Xpos{1} < Xpos{2} && MaxDescendantStrahler{1} < MaxDescendantStrahler{2}) ||...
            (Xpos{1} > Xpos{2} && FracDescendantsStay{1} > FracDescendantsStay{2}) || (Xpos{1} < Xpos{2} && FracDescendantsStay{1} < FracDescendantsStay{2});
    elseif ~isLeafNode{1} && isLeafNode{2}
        doFlipTree = Xpos{1} > Xpos{2};
    elseif isLeafNode{1} && ~isLeafNode{2}
        doFlipTree = Xpos{1} < Xpos{2};
    elseif isLeafNode{1} && isLeafNode{2}
        doFlipTree = (Xpos{1} > Xpos{2} && TimeOnSurface{1} > TimeOnSurface{2}) ||...
            (Xpos{1} < Xpos{2} && TimeOnSurface{1} < TimeOnSurface{2});
    end
    
    if doFlipTree
        SubtreeIndex = v_subtree.breadthfirstiterator;
        set(h_subtree.get(SubtreeIndex(1)),'XData',flipValues(get(h_subtree.get(SubtreeIndex(1)),'XData'),Xpos0));
        for n_c = SubtreeIndex(2:end)
            set(h_subtree.get(n_c),'XData',flipValues(get(h_subtree.get(n_c),'XData'),Xpos0));
            set(v_subtree.get(n_c),'XData',flipValues(get(v_subtree.get(n_c),'XData'),Xpos0));
            Pos = get(t_subtree.get(n_c),'Position');
            Pos(1) = flipValues(Pos(1),Xpos0);
            set(t_subtree.get(n_c),'Position',Pos);
        end
    end
end

end

%% Helper function to calculate Strahler number
function strahlerTree = getStrahlerNumber(InputTree)

strahlerTree = tree(InputTree, 1);
strahlerTree = strahlerTree.recursivecumfun(@calcStrahler);

    function y = calcStrahler(x)
        if diff(x) == 0
            y = 1+x(1);
        else
            y = max(x);
        end
    end

end

%% Helper function to calculate tree asymmetry
function [BranchTreeAsym,SubTreeAsym] = getTreeAsymmetry(EndFateTree)
%{
Calculate tree asymmetry. 
Original definition of terminal branches: All leaf nodes are terminal branches.
Modified definition of terminal branches: 'detach' leaf nodes are terminal branches. 'move off 
screen' and 'stay' leaf nodes are division branches that have 2 terminal branches ('detach').
%}

BranchTreeAsym = tree(EndFateTree, NaN(1,3));
SubTreeAsym = tree(EndFateTree, NaN);

TreeIndex = EndFateTree.breadthfirstiterator;

if all(EndFateTree.isemptynode) % no EndFate information; calculate tree asymmetry with regular definition of terminal branches
    % count L,R terminal branches for individual branch tree asymmetry values
    for n1 = TreeIndex(2:end)
        if EndFateTree.isleaf(n1)
            % terminal branch
            BranchTreeAsym = BranchTreeAsym.set(n1,[0,0,NaN]);
        else
            % non-terminal branch; count number of L,R terminal branches
            Descendants = EndFateTree.getchildren(n1);
            if numel(Descendants) ~= 2
                continue;
            end
            b_single = BranchTreeAsym.get(n1);
            for n2 = 1:numel(Descendants)
                e_subtree = EndFateTree.subtree(Descendants(n2));
                leavesIdx = e_subtree.findleaves;
                b_single(n2) = numel(leavesIdx);
            end
            BranchTreeAsym = BranchTreeAsym.set(n1,b_single);
        end
    end
    
else % EndFate information; calculate tree asymmetry with modified definition of terminal branches
    
    % count L,R terminal branches for individual branch tree asymmetry values
    for n1 = TreeIndex(2:end)
        if EndFateTree.isleaf(n1)
            % terminal branch; check if 'detach' or not
            switch EndFateTree.get(n1)
                case 'detach'
                    BranchTreeAsym = BranchTreeAsym.set(n1,[0,0,NaN]);
                case 'move off screen'
                    BranchTreeAsym = BranchTreeAsym.set(n1,[1,1,0]);
                case 'stay'
                    BranchTreeAsym = BranchTreeAsym.set(n1,[1,1,0]);
                otherwise
            end
        else
            % non-terminal branch; count number of L,R terminal branches
            Descendants = EndFateTree.getchildren(n1);
            if numel(Descendants) ~= 2
                continue;
            end
            b_single = BranchTreeAsym.get(n1);
            for n2 = 1:numel(Descendants)
                e_subtree = EndFateTree.subtree(Descendants(n2));
                leavesIdx = e_subtree.findleaves;
                [~,termBranchTypesIdx] = ismember(e_subtree.Node(leavesIdx),{'detach';'move off screen';'stay'});
                b_single(n2) = sum(termBranchTypesIdx==1) + 2.*sum(termBranchTypesIdx==2) + 2.*sum(termBranchTypesIdx==3);
            end
            BranchTreeAsym = BranchTreeAsym.set(n1,b_single);
        end
    end
    
end

% calculate individual branch tree asymmetry values from the L,R counts
A_p = cat(1,BranchTreeAsym.Node{2:end});
A_p(:,3) = abs(A_p(:,1) - A_p(:,2))./(A_p(:,1) + A_p(:,2) - 2); % abs(L-R)/(L+R-2)
A_p(all(A_p(:,1:2)==1,2),3) = 0; % [1,1] case should be 0
A_p(xor(A_p(:,1)==0,A_p(:,2)==0),3) = max(A_p(xor(A_p(:,1)==0,A_p(:,2)==0),1:2),[],2)./(max(A_p(xor(A_p(:,1)==0,A_p(:,2)==0),1:2),[],2)+1); % [0,x] and [x,0] cases should be x./(x+1)
A_p(~isfinite(A_p(:,3)),3) = 1; % other nonfinite cases marked as 1
A_p(A_p(:,3) < 0,3) = 1; % other negative cases marked as 1
A_p(all(A_p(:,1:2)==0,2),3) = NaN; % [0,0] case should be NaN
A_p = num2cell(A_p,2);
for n1 = 2:BranchTreeAsym.nnodes
    BranchTreeAsym = BranchTreeAsym.set(n1,A_p{n1-1});
end

% calculate subtree asymmetry values from the individual branch values
for n1 = TreeIndex(2:end)
    if EndFateTree.isleaf(n1)
        continue;
    end
    b_subtree = BranchTreeAsym.subtree(n1);
    SubTreeAsym = SubTreeAsym.set(n1,min(sum(cellfun(@(x) x(end), b_subtree.Node),'omitnan') ./ (sum(b_subtree.Node{1}(1:2))-1), 1));
end

end