%copyright 2010 Alexandre Mathy

function [ mtn, rois] = analyze_movie( fn ,rois)


if nargin <2
    [filename, pathname] = uigetfile('*');
    fn=fullfile(pathname, filename);
    rois=getrois(fn,1);
end;


  prompt={'Start frame'};
  name='Number of the starting frame';
  numlines=1;
 defaultanswer={''};
 answer=inputdlg(prompt,name,numlines,defaultanswer);

 nstartframe=str2double(answer);


load bins;

m=mean(bins);
v=std(bins);

%mu=log(m^2/sqrt(v+m^2));
%sigma=sqrt(log(v/(m^2)+1));


%spikes=poissrnd(m,size(bins,1),size(bins,2));
%spikes=m*rand(size(bins,1),size(bins,2));
%spikes=lognrnd(mu,sigma,size(bins,1),size(bins,2)) ;
spikes=bins;

info.color='red'; 
info.title='Analyzing video'; 
p=progbar(info); 


readerobj=mmreader(fn);
 % Read in all video frames.
    vidFrames = read(readerobj,1);
    RGBmovie=[];
    
%    size(vidFrames)
       % Get the number of frames.
    %numFrames = get(readerobj, 'numberOfFrames');
      
    
    numFrames=length(spikes);%get(readerobj, 'numberOfFrames')

  
  
  %spikes=(rand(numFrames,1)>0.4)*1.0;
  
  
    frame_rate=readerobj.FrameRate;
      

       
       alph=0.3;
       bet=0.9;
            
      
       tot=[];
       
       
   
       
       se=strel('square',2);

       
      
       for j=1:length(rois)
           
          
           state(j).label=rois(j).name;
           state(j).roi=rois(j);
           state(j).currstate.k=1;
           state(j).tot=[];
           state(j).STAtot=[] ;
           
         
          x1=state(j).roi.pos(1);
           x2=state(j).roi.pos(3);
           y1=state(j).roi.pos(2);
           y2=state(j).roi.pos(4);
           hght=y2-y1;
           wdth=x2-x1;
           blank1=zeros(hght,wdth,3);
           blank2=zeros(hght,wdth);
           
           'dude'
           hght
           wdth
           
           
          % scrsz = get(0,'ScreenSize');
          % 'Position',[1 scrsz(4)/2 scrsz(3)/2 scrsz(4)/2]
           state(j).fig1=figure('Position', [(j-1) *250, 600,200,300]);
           set(state(j).fig1, 'Name', state(j).label);
           state(j).im1=image(blank1);
           state(j).fig2=figure('Position', [(j-1) *250, 100 ,200,300]);
           set(state(j).fig2, 'Name', state(j).label);
           state(j).im2=imagesc(blank2,[0 1]);
           
          
       end;
       
       vidFrames=[];
       f1=figure;
       xlim([0 numFrames]);
       for k = 1 :numFrames 
           
           
           try
           vidFrames = read(readerobj,k+nstartframe);
           catch Exception
               'reading error!'
               break;
           end;
       
           
         
           
           for j=1:size(state,2)
               currlabel=state(j).label;
               
               
               
               rngY=uint32(state(j).roi.pos(2)):uint32(state(j).roi.pos(4));
               rngX=uint32(state(j).roi.pos(1)):uint32(state(j).roi.pos(3));

               
               
               
               

               state(j).currstate.k=k;
               state(j).currstate.frame= vidFrames(rngY,rngX,:);
               [mtnidx,currstate]=getmotionidx(state(j).currstate,alph,bet,40,state(j).im1,state(j).im2,se);
               state(j).currstate=currstate;
               state(j).tot=[state(j).tot mtnidx];

               state(j).wdth=size(currstate.smoothed_diffim,1)
               state(j).hght=size(currstate.smoothed_diffim,2)
               
               
             
               if j==1
                 if k==1
                    RGBmovie=zeros(size(vidFrames(rngY,rngX,:),1),size(vidFrames(rngY,rngX,:),2),size(vidFrames(rngY,rngX,:),3),numFrames);
                 end
                 
               size(RGBmovie)
               size(vidFrames)
                 RGBmovie(:,:,:,k)=vidFrames(rngY,rngX,:);
               end
               
               if (k==1)
                        state(j).STAtot=zeros(numFrames,state(j).wdth*state(j).hght);
               end
               
               state(j).STAtot(k,:)= reshape(currstate.smoothed_diffim,1,state(j).wdth*state(j).hght)
               
               
%               size(currstate.smoothed_diffim)
%               pause;
              
           end
           
           if (mod(k,20)==1)          
               progbar(p,k/numFrames*100);
               disp(f1,state, numFrames)
           end;

       end
       
       
     for j=1:size(state,2)
          mtn(j,:)=state(j).tot;
     end;     
     
     close all;
     
%     figure;
          
    disp(f1,state, numFrames);
     
     tm=((1:size(mtn,2))-1)*1/frame_rate;
     
     mtn=mtn';
     tm=tm';
     
     
     mtn=[tm mtn];
     
     mtn;
     
     
     hold off;
     
     
    progbar(p,-1);
     


%spike triggered average

wdth=state(1).wdth;
hght=state(1).hght;
STAtot=state(1).STAtot;
mIdx=state(1).tot;
pre=1;
wsize=1;

if pre==1
    STAtot=[zeros(wsize,size(STAtot,2));STAtot];   
else
    STAtot=[STAtot;zeros(wsize,size(STAtot,2))];   
end;

ex=reshape(STAtot(1:(1+wsize-1),:),1,wsize*size(STAtot,2));

STAavg=zeros((size(STAtot,1)-wsize),size(ex,2));

for i=1:(size(STAtot,1)-wsize)

    curr=reshape(STAtot(i:(i+wsize-1),:),1,wsize*size(STAtot,2));
    
    STAavg(i,:)=curr;
end


%STAavg=STAavg-mean(STAavg);





%STAvec=[];
%for i=1:length(spikes)

%    if spikes(i)>0
%        STAvec=[STAvec;spikes(i)*STAavg(i,:)];
%    end;
%
%end;

save('STAstuff.mat','STAavg','wsize','wdth','hght','spikes','mIdx','rois','RGBmovie')
%norm
%mSTAvec=mean(STAvec);
%whitened
%mSTAvec=(STAavg*STAavg'+0.5*eye(size(STAavg,1),size(STAavg,1)))\mean(STAvec);


%size(STAvec)

%sum(spikes)

%mSTA=mean(STAtot);
%imSTA=reshape(mSTAvec,wsize,size(mSTAvec,2)/wsize);
%figure;

%for i=1:size(imSTA,1)
%    imagesc(reshape(imSTA(i,:),wdth,hght));
    
%    pause(1);
%end;
end


function [mtnidx,newstate]=getmotionidx(currstate, alph, bet, thresh, fig1,fig2,se)
 
           
         
           if currstate.k==1
             currstate.BG=currstate.frame;
           end;
           diffim=distfunct(currstate.frame,currstate.BG);
           
           if currstate.k==1
             newstate.smoothed_diffim=diffim;
           else  
             newstate.smoothed_diffim=(1-bet).*currstate.smoothed_diffim+bet.*diffim;
           end;
           
           %image(BG);
           %image(frame);
           
          % size(diffim)
           %figure(fig1)
           %image(currstate.frame)          
           
           set(fig1, 'CData', currstate.frame);
           
           binim=(newstate.smoothed_diffim>thresh);
           binim=imerode(binim,se);
           %figure(fig2)
           %imagesc(binim,[0 1]);
           set(fig2, 'CData', binim);
           
           drawnow
           
           mtnidx=sum(sum(binim));
           
           newstate.BG=(1-alph).*currstate.BG+alph.*currstate.frame;
           

           
end
function [C]=distfunct(A,B)
%C=abs(A(:,:,1)-B(:,:,1))+abs(A(:,:,2)-B(:,:,2))+abs(A(:,:,3)-B(:,:,3));
C=sum(abs(A-B),3);
end

function disp(f1, state,nf)

     figure(f1);
     
     cols=['r','g','b','k','y','m','c'];
     for j=1:size(state,2)
      
         plot(state(j).tot./max(state(j).tot),cols(j));
         xlim([0 nf]);
         hold on;
         leg{j}=  state(j).label;
        
     end;
     
     legend(leg);
end