#!/usr/bin/python

#Developed by: Peter Freddolino, Tavazoie lab, Columbia University
# https://tavazoielab.c2b2.columbia.edu/lab/
#
#Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal with the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
#    * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimers.
#    * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimers in the documentation and/or other materials provided with the distribution.
#    * Neither the names of the Tavazoie lab, Columbia University, nor the names of its contributors may be used to endorse or promote products derived from this Software without specific prior written permission.
#
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE SOFTWARE.

import os
import pylab
import glob
import table
import numpy
import scipy.stats
import bootstrap

colors=["blue","red","green","purple","cyan"]
#ncol=0

# generate all plots needed to summarize some facs data

# directory containing initial condition facs data
# this should be a list, with one element for each biological replicate
initial_cond_run_reps = ["20150205_0835","20150206_0926"]

# all other run directories -- we figure out whether each timepoint is starving based on whether it has undergone enough generations
# this should be a list, with one sublist for each biological replicate
mindoub_starve = 4
other_cond_runs_reps =[["20150206_0926", "20150209_1226", "20150210_1116", "20150211_1235", "20150212_0835", "20150213_0941", "20150214_1636", "20150215_1415", "20150216_1933", "20150217_2119"], ["20150209_1226", "20150210_1116", "20150211_1235", "20150212_0835", "20150213_0941", "20150214_1636", "20150215_1415", "20150216_1933", "20150217_2119"]]



# number of days before each timepoint
# this should be a list with one element per biological replicate
cond_time_reps = [[0.0, 1.03541666667, 4.16041666667, 5.11180555556, 6.16666666667, 7.0, 8.04583333333, 9.33402777778, 10.2361111111, 11.4569444444, 12.5305555556],[0.0, 3.125, 4.07638888889, 5.13125, 5.96458333333, 7.01041666667, 8.29861111111, 9.20069444444, 10.4215277778, 11.4951388889]]

# cell types to consider
# this should be a list with one element per biological replicate
# note that the cell types in each replicate need to be in the same order -- they will be plotted together!!

celltype_reps = [['1034_rep1'],['1034_rep2']]

for types in celltype_reps:
  if len(types) != len(celltype_reps[0]):
    raise(ValueError("number of cell types in each replicate must match"))

maxt_arr = [0 for i in celltype_reps[0]]

def do_sims(count,nreps,nsim):
  # do draws from the posterior distribution for the count data
  # return the simulations
  # note that for this simple distribution we could obtain the quantiles analytically - simulations are used here to maintain compatibility
  #  with the possible future need for a more complex posterior distribution

  shape = count + 0.5
  return scipy.stats.gamma.rvs(shape,1.0/nreps,size=nsim)
  #return numpy.array(robjects.r('rgamma(%i,shape=%f,rate=%i)' % (nsim,shape,nreps)))

def get_ci(simdata):
  # return the bottom and top of a central 95% credible interval as a tuple

  ci_low = scipy.stats.scoreatpercentile(simdata,2.5)
  ci_hi = scipy.stats.scoreatpercentile(simdata,97.5) 
    
  return (ci_low,ci_hi)


def getcounts_from_file(infile,celltype):
  #read and return the counts of cells and beads for a given celltype from the given file
  instr=open(infile)
  for line in instr:
    linearr = line.split()
    tubename = linearr[1]
    tind = tubename.index('_',6)
    print tind
    print tubename[:tind]
    if tubename[:tind] == celltype:
      cellcount = int(linearr[3])
      beadcount = int(linearr[5])
      return (cellcount,beadcount)

  instr.close()
  print "Warning: could not find cell information for infile %s and cells %s" % (infile,celltype)
  return (0,0)

# first initialize all of the plots that we will eventually need
growth_figs = []
gfp_figs = []
mruby_figs = []

for celltype in celltype_reps[0]:
  growth_figs.append(pylab.figure(figsize=(8,3)))
  gfp_figs.append(pylab.figure(figsize=(8,3)))
  mruby_figs.append(pylab.figure(figsize=(8,3)))

# loop over biological replicates and do all needed plotting for each of them

for rep_i in range(len(initial_cond_run_reps)):

  celltypes = celltype_reps[rep_i]
  initial_cond_run = initial_cond_run_reps[rep_i]
  other_cond_runs = other_cond_runs_reps[rep_i]
  cond_times = cond_time_reps[rep_i]

  for cell_i,celltype in enumerate(celltypes):

    # first get and plot all of the counts, and figure out which timepoints are starving
    start_cells,start_beads = getcounts_from_file(os.path.join("data",initial_cond_run,"cellcounts.txt"),celltype)
    other_counts = [getcounts_from_file(os.path.join("data",i,"cellcounts.txt"),celltype) for i in other_cond_runs]
    goodflags = [ i>0 for i,j in other_counts ]

    good_times = [cond_times[0]]
    for i in range(len(other_counts)):
      if goodflags[i] <= 0:
        good_times.append(-1)
      else:
        good_times.append(cond_times[i+1])
        if cond_times[i+1] > maxt_arr[cell_i]:
          maxt_arr[cell_i] = cond_times[i+1]

    good_other_counts = [other_counts[i] for i in range(len(other_counts))]

    # constants in this section indicate dilution factors of the cells and beads in our experimental setup
    real_counts = [ float(start_cells)* ((1000.0/200) * (200.0/25000)* (9980/float(start_beads))) ]
    cell_sims = do_sims(start_cells,1,1000)
    bead_sims = do_sims(start_beads,1,1000)

    ci_los = []
    ci_his = []

    real_cilo, real_cihi = get_ci(cell_sims* ((1000.0/200) * (200.0/25000)* (9980/bead_sims)) )
    ci_los.append(real_cilo)
    ci_his.append(real_cihi)

    for cells,beads in good_other_counts:
      if cells == 0:
        real_counts.append(0)
        ci_los.append(0)
        ci_his.append(0)
        
      else:
        real_counts.append( float(cells) * (1000.0/490) * (9980/float(beads)))
        cell_sims = do_sims(cells,1,1000)
        bead_sims = do_sims(beads,1,1000)
        cell_ci_lo,cell_ci_hi = get_ci( cell_sims * (1000.0/490) * (9980.0/bead_sims))
        ci_los.append(cell_ci_lo)
        ci_his.append(cell_ci_hi)


    growthflags = [0] * len(real_counts)
    growthflags[0] = 2
    for i in range(1,len(growthflags)):
      if real_counts[i] > (4*real_counts[0]):
        growthflags[i] = 1

    pylab.figure(growth_figs[cell_i].number)
    ax = pylab.gca()
    ax.set_yscale('log')

    print good_times
    print real_counts
    good_times = numpy.array(good_times)
    real_counts = numpy.array(real_counts)
    plot_flags = ( real_counts > 0)
    ci_los = numpy.array(ci_los)
    ci_his = numpy.array(ci_his)
    pylab.plot(good_times[plot_flags],real_counts[plot_flags],marker='o')
    print ci_los
    print ci_his
    pylab.errorbar(good_times[plot_flags], real_counts[plot_flags], xerr=None, yerr = numpy.vstack((real_counts[plot_flags] - ci_los[plot_flags], ci_his[plot_flags]-real_counts[plot_flags])), ls='None',ecolor=colors[rep_i])



    pylab.xlim((-0.5, max(good_times) + 0.5))
    pylab.xlabel("Time (d)")
    pylab.ylabel("Cell count per mL")
    #pylab.savefig("%s_counts.pdf" % celltype)

    # now we correct and plot the fluorescence values
    gfpfig = pylab.figure(gfp_figs[cell_i].number)
    ax = pylab.gca()
    ax.set_yscale('log',basey=2)

    mrubyfig = pylab.figure(mruby_figs[cell_i].number)
    ax = pylab.gca()
    ax.set_yscale('log',basey=2)

    gfpvals_all = []
    mrubyvals_all = []
    times_all = []

    gfpvals_start = 0
    mrubyvals_start = 0

    for i,dirname in enumerate([initial_cond_run] + other_cond_runs):
      print dirname

      if (i > 0) and (not goodflags[i-1]):
        continue

      os.chdir(os.path.join('data',dirname))
      datfile = glob.glob("%s_*fcs_data.txt" % celltype)
      if len(datfile) != 1:
        print("Couldn't figure out the right data file")
        os.chdir('../../')
        continue

      mytab = table.read_table(datfile[0],sep=",",header=True,guesstypes=False)
      fscvals = numpy.array([float(f) for f in mytab.get_column('fsc')])
      sscvals = numpy.array([float(f) for f in mytab.get_column('ssc')])
      gfpvals = numpy.array([float(f) for f in mytab.get_column('gfp')])
      mrubyvals = numpy.array([float(f) for f in mytab.get_column('mruby')])

      print gfpvals
      
      # apply appropriate fluorescence corrections
      # these were obtained empirically for population of non-fluorescent cells in each growth state
      if growthflags[i] == 2:
        gfpvals = gfpvals - 12.7
        mrubyvals -= 4.0
      elif growthflags[i] == 1:
        gfpvals = gfpvals - 32.30624
        mrubyvals -= 7.37964
      elif growthflags[i] == 0:
        gfpvals -= (-5.6026163 + fscvals*0.0554416 + sscvals*0.0393806)
        mrubyvals -= (-1.8894818 + fscvals*0.0440239 + sscvals*0.0561541)
      else:
        raise("Invalid growthflags value")

      print "Dropfrac: " 
      print float(numpy.sum(gfpvals<0.1)) / len(gfpvals)
      gfpvals = gfpvals[gfpvals>0.1]
      mrubyvals = mrubyvals[mrubyvals>0.1]
      #gfpvals = numpy.fmax(gfpvals,0.1)
      #mrubyvals = numpy.fmax(mrubyvals,0.1)
      print gfpvals


      gfpvals_all.append(gfpvals)
      mrubyvals_all.append(mrubyvals)
      print numpy.median(gfpvals)
      print numpy.median(mrubyvals)
      times_all.append(cond_times[i])

      if i == 0:
        gfpvals_start = numpy.median(gfpvals)
        mrubyvals_start = numpy.median(mrubyvals)


      os.chdir('../../')

    print "***"
    print growthflags
    pylab.figure(num=gfpfig.number)
    #pylab.boxplot( gfpvals_all, notch=1,bootstrap=None,positions=times_all,hold=True,sym='')
    gfp_medians = []
    gfp_cis_lo = []
    gfp_cis_hi = []

    for datset in gfpvals_all:
      gfp_medians.append(numpy.median(datset))
      lo_ci,hi_ci = bootstrap.bootstrap_1v(datset, numpy.median,nsamp=1000)['ci95']
      gfp_cis_lo.append(lo_ci)
      gfp_cis_hi.append(hi_ci)

    gfp_medians = numpy.array(gfp_medians)
    gfp_cis_lo = numpy.array(gfp_cis_lo)
    gfp_cis_hi = numpy.array(gfp_cis_hi)

    pylab.plot(times_all,gfp_medians,color=colors[rep_i],mfc=colors[rep_i],marker='o',ls='None')
    pylab.errorbar( times_all, gfp_medians, xerr=None, yerr=numpy.vstack( (gfp_medians - gfp_cis_lo,gfp_cis_hi - gfp_medians) ),ls='None',ecolor=colors[rep_i],capsize=6)

    pylab.axhline(gfpvals_start,color=colors[rep_i],ls='--')
    pylab.xlabel("Time (d)")
    pylab.ylabel("Normalized DHFR-GFP fluorescence")
    pylab.xticks(numpy.arange(1+max(times_all)))
    #pylab.xlim((-0.5, max(times_all) + 0.5))
    #pylab.savefig("%s_gfp.pdf" % celltype)

    pylab.figure(num=mrubyfig.number)
    mruby_medians = []
    mruby_cis_lo = []
    mruby_cis_hi = []

    for datset in mrubyvals_all:
      mruby_medians.append(numpy.median(datset))
      lo_ci,hi_ci = bootstrap.bootstrap_1v(datset, numpy.median,nsamp=1000)['ci95']
      mruby_cis_lo.append(lo_ci)
      mruby_cis_hi.append(hi_ci)

    mruby_medians = numpy.array(mruby_medians)
    mruby_cis_lo = numpy.array(mruby_cis_lo)
    mruby_cis_hi = numpy.array(mruby_cis_hi)

    pylab.plot(times_all,mruby_medians,color=colors[rep_i],mfc=colors[rep_i],marker='o',ls='None')
    pylab.errorbar( times_all, mruby_medians, xerr=None, yerr=numpy.vstack( (mruby_medians - mruby_cis_lo,mruby_cis_hi - mruby_medians) ),ls='None',ecolor=colors[rep_i],capsize=6)
    #pylab.boxplot( mrubyvals_all, notch=1,bootstrap=None,positions=times_all,hold=True,sym='')
    pylab.axhline(mrubyvals_start,color=colors[rep_i],ls='--')
    pylab.xlabel("Time (d)")
    pylab.ylabel("Normalized URA3-mRuby fluorescence")
    pylab.xticks(numpy.arange(1+max(times_all)))
    #pylab.xlim((-0.5, max(times_all) + 0.5))
    #pylab.savefig("%s_mruby.pdf" % celltype)

# now save the plots

print "XXX"
print maxt_arr
  

for i,celltype in enumerate(celltype_reps[0]):
  xm = 0
  xa = maxt_arr[i]
  pylab.figure(num=growth_figs[i].number)
  pylab.xlim((xm-0.5,xa+0.5))
  pylab.xticks( numpy.arange(0,xa+1) )
  pylab.savefig("%s_growth.pdf" % celltype)
  pylab.figure(num=gfp_figs[i].number)
  pylab.xlim((xm-0.5,xa+0.5))
  pylab.xticks( numpy.arange(0,xa+1) )
  pylab.savefig("%s_gfp.pdf" % celltype)
  pylab.figure(num=mruby_figs[i].number)
  pylab.xlim((xm-0.5,xa+0.5))
  pylab.xticks( numpy.arange(0,xa+1) )
  pylab.savefig("%s_mruby.pdf" % celltype)



  


