#!/usr/bin/python
# Simple class for reading and managing a table of data

#Developed by: Peter Freddolino, Tavazoie lab, Columbia University
# https://tavazoielab.c2b2.columbia.edu/lab/
#
#Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal with the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#
#    * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimers.
#    * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimers in the documentation and/or other materials provided with the distribution.
#    * Neither the names of the Tavazoie lab, Columbia University, nor the names of its contributors may be used to endorse or promote products derived from this Software without specific prior written permission.
#
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE SOFTWARE.

import csv
import sys

def read_table(filename, sep=",", comment="#", header=True, colnames=None, guesstypes=False):
  """
  Read a table from a file, following R-like syntax

  We return a Table object containing the data from the file

  sep and comment are the field separator and comment characters, respectively
  header=True indicates that the first non-comment line has column headings
  otherwise, colnames must be provided

  If guesstypes is true, we try to assign each field as an int, float, or string
  """

  mytab = Table()

  if (header and colnames):
    raise ValueError("header and colnames are mutually exclusive")

  mystr = open(filename)
  mycolnames = []
  mycoltypes = []

  if (colnames):
    mytab.init_keys(colnames)
    mycolnames = colnames


  for line in mystr:
    if (line[0] == comment):
      continue

    line = line.rstrip('\r\n')
    if (mycolnames == []):
      #print line
      #print line.split(sep)
      mycolnames = line.split(sep)

      # deal with duplicate column names
      for i in range(len(mycolnames)):
        for j in range(i):
           while mycolnames[j] == mycolnames[i]:
            sys.stderr.write("WARNING reading table: Duplicate column names (resolved automatically)\n")
            mycolnames[i] += "_"

      mytab.init_keys(mycolnames)
      mytab.ordered_keys = mycolnames
      continue

    #print "Adding:"
    #print line
    linearr = line.split(sep)
    row = {}

    # identify the type of each column
    # Note that these must stay the same throughout the file
    if (mycoltypes == []):
      if (not guesstypes):
        mycoltypes = ["string" for field in linearr]
      else:
        mycoltypes = [guess_type(field) for field in linearr]

    for key,val,dattype in zip(mycolnames, linearr, mycoltypes):
      row[key] = conv_string(val, dattype)

    mytab.add_row(row)

  mystr.close()
  return mytab

def write_table(filename, mytable, sep=","):
  """
  Write a table following an R-friendly format
  """

  myf = open(filename, 'w')

  csv.register_dialect('pythontable', delimiter=sep, quotechar="`")

  mystr = csv.writer(myf, dialect='pythontable')

  keys = mytable.ordered_keys

  mystr.writerow(keys)
  for row in mytable.get_rows():
    #print keys
    #print row
    rowlist = [row[key] for key in keys]
    mystr.writerow(rowlist)

  myf.close()

def merge_tables(table1, table2, outfile=None, sep1=",", sep2=",", sepout=","):
  """
  Merge a pair of tables

  table1 and table2 can either be table objects or file names

  If outfile is None, the merged table is returned instead
  """

  def write_merged_table(merged_tab, outfile, sepout):
    if outfile is not None:
      write_table(outfile, newtab, sepout)
      return
    else:
      return merged_tab

  if (isinstance(table1, Table)):
    tab1 = table1
  else:
    tab1 = read_table(table1, sep=sep1)

  if (isinstance(table2, Table)):
    tab2 = table2
  else:
    tab2 = read_table(table2, sep=sep2)

  # Properly handle the case where one of the tables is empty
  if (tab1.length == 0):
    return write_merged_table(tab2, outfile, sepout)

  if (tab2.length == 0):
    return write_merged_table(tab1, outfile, sepout)

  newtab = Table()

  if (tab1.ordered_keys != tab2.ordered_keys):
    raise ValueError("Tables do not have matching keys!")

  for key in tab1.ordered_keys:
    col1 = tab1.get_column(key)
    col2 = tab2.get_column(key)

    newtab.add_column((col1 + col2), key)

  return write_merged_table(newtab, outfile, sepout)

def guess_type(instr):
  """
  Try to guess the type (int, float, string) of some input data

  To do this we just try the relevant conversions
  """

  try:
    myval = int(instr)
    return "int"
  except:
    pass

  try:
    myval = float(instr)
    return "float"
  except:
    pass

  return "string"

def conv_string(val, typestring):
  """
  Convert a string to an int, float, or string
  """

  if (typestring == "int"):
    return int(val)

  if (typestring == "float"):
    return float(val)

  return val


class Table:
  """
  Class for storing data in tabular form

  The data is stored as a dictionary of lists, with one column of data per header, all of the same length
  """

  def __init__(self):

    self.length = 0
    self.data = {}
    self.ordered_keys = []
    self.iterval = 0

  def __iter__(self):
    self.iterval=0
    return self 

  def __getattr__(self, name):
    return self.data[name]

  def __len__(self):
    # return the number of *rows*
    return self.length

  def __str__(self):
    # make a string representation of the table
    retstr = ""
    for header in self.ordered_keys:
      retstr += "%s\t" % header
    retstr.rstrip()
    retstr += "\n"

    for rowid in range(self.length):
      for header in self.ordered_keys:
        retstr += "%s\t" % self.data[header][rowid]
      retstr.rstrip()
      retstr += "\n"

    return retstr

  def next(self):

    if self.iterval >= self.length:
      self.iterval = 0
      raise StopIteration

    else:
      self.iterval += 1
      return self.get_row(self.iterval - 1)


  def __repr__(self):

    myrep = ""

    for key in self.ordered_keys:
      myrep += "%s\t" % key
    
    myrep += "\n"

    for row in self.get_rows():
      for key in self.ordered_keys:
        myrep += "%s\t" % row[key]

      myrep += "\n"

    return myrep


  def get_row_range(self,start,end):
    """
    Return a new table containing only the specified set of rows
    """

    myrows = self.get_rows()

    newtab = Table()
    newtab.init_keys(self.ordered_keys)

    newrows = myrows[start:end]

    for row in newrows:
      newtab.add_row(row)

    return newtab


  def get_rows(self):
    """
    Return a list of all rows in the table
    """

    rowlist = []
    for i in range(self.length):
      rowlist.append(self.get_row(i))

    return rowlist

  def init_keys(self, keys):
    """
    add empty lists with column headers in keys to a table

    this function should only be called for an empty table
    """

    if (self.length > 0):
      raise ValueError("init_keys should only be called on an empty table")

    self.ordered_keys = keys

    for key in keys:
      self.data[key] = []

  def add_column(self, column, header):
    """
    Add a column to a table

    Don't allow the addition if the column is the wrong length, or if the column title is already present
    
    Also make the column name a property so that it can easily be accessed
    """

    if (self.data.has_key(header)):
      raise ValueError("Table already contains a column called %s" % header)

    if (self.data == {}):
      self.length = len(column)
    elif (self.length != len(column)):
      raise ValueError("New column is of incorrect length for table")

    self.data[header] = column
    self.ordered_keys.append(header)


  def remove_column(self, header):
    """
    Delete the column called header
    """

    self.data.pop(header)
    self.ordered_keys.pop(self.ordered_keys.index(header))

  def remove_row(self,rownum):
    """
    Remove the rownum-th row (0 indexed) from the table
    """

    for header in self.data.keys():
      thisdat = self.data[header]
      
      col_length = len(thisdat)

      while rownum < 0:
        rownum = col_length + rownum

      if rownum >= col_length:
        raise(KeyError("Invalid index %i" % rownum))
      elif rownum == col_length:
        newcol = thisdat[:-1]
        self.data[header]=newcol

      elif rownum == 0:
        newcol = thisdat[1:]
        self.data[header]=newcol

      else:
        newcol = thisdat[:rownum] + thisdat[(rownum+1):]
        self.data[header]=newcol

    self.length -= 1


  def get_headers(self):
    """
    Returns a list of the names of the columns
    """

    return self.data.keys()

  def add_row(self, row):
    """
    Add a row to a table

    The row must be a dictionary containing all of the keys currently existing as column headers
    """

    rowkeys = row.keys()
    rowkeys.sort()
    tablekeys = self.data.keys()
    tablekeys.sort()

    if (tablekeys != rowkeys):
      print tablekeys
      print rowkeys
      print row
      print set(rowkeys) - set(tablekeys)
      print set(tablekeys) - set(rowkeys)
      raise ValueError("New row headers do not match those in the table")

    for key in tablekeys:
      self.data[key].append(row[key])

    self.length += 1

  def get_column(self, header):
    """
    Return the column with the specified title
    """

    return self.data[header]

  def get_columns(self, colnames):
    """
    Return a new table with only the columns corresponding to colnames
    """

    newtab = Table()
    newtab.length = self.length

    for col in colnames:
      if not self.data.has_key(col):
        raise ValueError("Couldn't find column named %s in table" % col)
      
      newtab.add_column(self.data[col], col)

    return newtab


  def get_row(self, rownum):
    """
    Return the rownum-th row of the table
    """

    retdict = {}

    for key in self.data.keys():
      retdict[key] = self.data[key][rownum]

    return retdict

  def set_row(self,rownum,newrow):
    """
    Set the rownum-th row of the table to a new set of values
    """

    keys_tab = set(self.data.keys())
    keys_row = set(newrow.keys())

    if keys_tab != keys_row:
      raise("Keys for new row do not match those of table")

    for key in newrow.keys():
      self.data[key][rownum] = newrow[key]

  def filter(self, testexpr, global_dict=globals(), extravars = None):
    """
    Return a new table with only rows where textexpr is true

    textexpr can be any python expression, and will be evaluated in an environment for each row
      where variables exist corresponding to each column in the table, named by their columns
    """

    rettab = Table()
    rettab.init_keys(self.ordered_keys)
    colnames = self.ordered_keys

    if extravars is not None:
      global_dict.update(extravars)

    for i in range(self.length):
      curr_row = self.get_row(i)
      #for col in colnames:
      #  exec("%s = curr_row[\'%s\']" % (col,col))

      #print curr_row
      #print testexpr
      if eval(testexpr, global_dict, curr_row):
        rettab.add_row(curr_row)

    return rettab
    

  def filter_table(self, testcol, testfunc):
    """
    Return a new table with only rows for which testfunc(testcol[row]) is true

    This function should not be used in the future because filter is better
    """

    rettab = Table()
    rettab.init_keys(self.data.keys())

    for i in range(self.length):
      curr_row = self.get_row(i)
      if testfunc(curr_row[testcol]):
        rettab.add_row(curr_row)

    return rettab

  def sort_table(self, sortcol, cmpfunc):
    """
    Returns a sorted version of the table

    sorting is done based on the values in sortcol, with cmpfunc a function
      taking the two values and returning -1/0/1 like cmp()
    """

    newtab = Table()
    newtab.init_keys(self.ordered_keys)

    rowlist = self.get_rows()
    rowlist.sort(cmp=cmpfunc, key=(lambda a: a[sortcol]))


    for row in rowlist:
      newtab.add_row(row)

    return newtab

