classdef KKwikCutter < MClust.Cutter
    %
    % Select clusters from those generated by KKwik
    %
    % Version control: ADR-2013-12-12
    
    properties
        figHandle_AvgWV = [];
        figHandle_ISI = [];
        whoHasFocus = [];
        whoHasComparison = [];
        
        nextClusterButton = [];
        prevClusterButton = []
    end
    
    methods
        % -------------------------------
        % Constructor/Destructor/Import/Export
        % -------------------------------
        function self = KKwikCutter()
            OK = self.LoadClusters();
            if OK
                MCS = MClust.GetSettings();
                self.figHandle_AvgWV = figure('Name', 'SelectKKwik: AvgWV');
                MCS.PlaceWindow(self.figHandle_AvgWV); % ADR 2013-12-12
                self.figHandle_ISI = figure('Name', 'SelectKKwik: ISI');
                MCS.PlaceWindow(self.figHandle_ISI); % ADR 2013-12-12
                set(self.redrawAxesButton, 'value', true);
                self.whoHasFocus = self.Clusters{1};
                self.whoHasComparison = self.whoHasFocus;
                self.ReGo();
            else
                self.close();
            end
        end
        
        function exportClusters(self)
            % find clusters to keep
            clu2keep = false(size(self.Clusters));
            for iC = 1:length(self.Clusters)
                clu2keep(iC) = self.Clusters{iC}.keep;
            end
            A = self.Clusters(clu2keep);
            nClu = length(A);
            
            % find sets to merge
            M = cell(1,nClu);
            for iC = 1:nClu
                M{iC} = A{iC}.getMergeSet();
            end
            U = unique(M);
            
            B = cell(1,length(U));
            for iC = 1:length(U)
                f = find(strncmp(U{iC}, M, length(U{iC})));
                B{iC} = MClust.ClusterTypes.SpikelistCluster();
                if length(f)==1
                    B{iC}.name = U{iC};
                else
                    B{iC}.name = A{f(1)}.mergeSet;
                end
                B{iC}.color = A{f(1)}.color;
                for jC = 1:length(f)
                    B{iC}.AddSpikes(A{f(jC)}.GetSpikes());
                end
            end
            self.exportClusters@MClust.Cutter(B);
        end
        
        function close(self)
            MCS = MClust.GetSettings();
            if ~isvalid(self), return; end
            if ~isempty(self.figHandle_AvgWV) && ishandle(self.figHandle_AvgWV)
                MCS.StoreWindowPlace(self.figHandle_AvgWV); % ADR 2013-12-12
                delete(self.figHandle_AvgWV);
            end
            if ~isempty(self.figHandle_ISI) && ishandle(self.figHandle_ISI)
                MCS.StoreWindowPlace(self.figHandle_ISI); % ADR 2013-12-12
                delete(self.figHandle_ISI);
            end
            self.close@MClust.Cutter();
        end
        
        % ---------------------------------
        % Display and callbacks
        % ---------------------------------
        function SetFocus(self, C)
            for iC = 1:length(self.Clusters)
                if ~isequal(self.Clusters{iC}, C)
                    self.Clusters{iC}.LoseFocus(C);
                end
            end
            self.whoHasFocus = C;
        end
        
        function SetComparison(self, C)
            for iC = 1:length(self.Clusters)
                if ~isequal(self.Clusters{iC}, C)
                    self.Clusters{iC}.LoseComparison(C);
                end
            end
            self.whoHasComparison = C;
        end
        
        function iFocus = FindFocus(self)
            iFocus = [];
            for iC = 1:length(self.Clusters)
                if isequal(self.whoHasFocus, self.Clusters{iC})
                    iFocus = iC;
                end
            end
        end
        
        function RedisplayAvgWV(self)
            MCS = MClust.GetSettings();
            figure(self.figHandle_AvgWV);
            errorbar(self.whoHasFocus.xrange, self.whoHasFocus.mWV, self.whoHasFocus.sWV, 'b');
            set(gca, 'YLim',MCS.AverageWaveform_ylim);
            if ~isequal(self.whoHasComparison, self.whoHasFocus)
                hold on
                plot(self.whoHasComparison.xrange, self.whoHasComparison.mWV, 'r');
                hold off
            end
        end
        
        function RedisplayISI(self)
            MCD = MClust.GetData();
            figure(self.figHandle_ISI);
            cla
            T = MCD.FeatureTimestamps(self.whoHasFocus.GetSpikes());
            if length(T)>2 % check for 1-spike clusters ADR 2013-12-12
                MClust.HistISI(T, 'axesHandle', gca, 'myColor', 'b');
            end
            if ~isequal(self.whoHasComparison, self.whoHasFocus)
                hold on
                T = MCD.FeatureTimestamps(self.whoHasComparison.GetSpikes());
                if length(T)>2 % check for 1-spike clusters ADR 2013-12-12
                    MClust.HistISI(T, 'axesHandle', gca, 'myColor', 'r');
                end
                hold off
            end
            
        end
        
        % ---------------------------------
        % Next/Prev
        % ---------------------------------
        function NextCluster(self)
            nC = length(self.Clusters);
            iFocus = self.FindFocus+1;
            if iFocus > nC, beep; iFocus = nC; end
            self.Clusters{iFocus}.TakeFocus;
        end
        
        function PrevCluster(self)
            nC = length(self.Clusters);
            iFocus = self.FindFocus - 1;
            if iFocus < 1, beep; iFocus = 1; end
            self.Clusters{iFocus}.TakeFocus;
        end
        
        
        % ---------------------------------
        % Load/Save
        % ---------------------------------
        function OK = LoadClusters(self)
            MCD = MClust.GetData();
            
            if isempty(FindFiles('*.KKmat.*', 'StartingDirectory', MCD.TTdn))
                fn = fullfile(MCD.TTdn, [MCD.TTfn '.clu.*']);
            else
                fn = fullfile(MCD.TTdn, [MCD.TTfn '.KKmat.*']);
            end
            [fn,fd] = uigetfile(fn, 'Load KlustaKwik parameters');
            if isequal(fn,0) % canceled ADR 2013-12-12
                OK = false;
                return;
            else
                [~,fntype,~] = fileparts(fn);
                [~,~,fntype] = fileparts(fntype);
                switch (fntype)
                    case '.KKmat'
                        load(fullfile(fd,fn), 'fnClu', 'spikesToInclude', '-mat');
                        if ~exist(fnClu,'file')
                            [~,fnClu,xt] = fileparts(fnClu);
                            fnClu = fullfile(fd,[fnClu xt]);
                        end
                    case '.clu'
                        fnClu = fullfile(fd,fn);
                        spikesToInclude = 1:MCD.nSpikes();
                    otherwise
                        warning('MCLUST::KKwikCutter','Unknown fntype: %s',fntype);
                        OK=false;
                        return
                end
                if ~exist(fnClu, 'file')
                    OK = false;
                    return
                end
                
                A = dlmread(fnClu, '%d');
                A = A(2:end);  % first line is number of classes found
                U = unique(A);
                nClu = length(U);
                
                %colors = copper(nClu);
                colors = repmat([0 0 0], nClu, 1);
                try
                    WV = MCD.LoadNeuralWaveforms();
                    WVT = WV.range(); WVD = WV.data();
                catch ME
                    WV = [];
                end
                
                self.Clusters = cell(1,nClu);
                for iC = 1:nClu                    
                    S = spikesToInclude(A == U(iC));
                    
                    n = sprintf('KK%02d', iC);
                    if isempty(WV)
                        self.Clusters{iC} = MClust.ClusterTypes.KKCluster(n, colors(iC,:), S, []);
                    else
                        WV = tsd(WVT(S), WVD(S, :, :));
                        self.Clusters{iC} = MClust.ClusterTypes.KKCluster(n, colors(iC,:), S, WV);
                    end
                    self.Clusters{iC}.setAssociatedCutter(@self.GetCutter);
                end
                
                self.ReGo();
                OK = true;
            end
        end %LoadClusters
        
    end %methods
    
end % class

