######### 2019-04-10
######### Contact: early@broadinstitute.org
#########
#########
## Script for calculating pi (pairwise nucleotide diversity) *per nt* from an allele frequency file.
## 
## Output serves as input for gene-level pop gen calculations (pop_gen_stats_per_region.pl)
#########
#########
## Multiallelic sites are included
## Variants are not polarized
## No minimum sample size limit
## Minimum call rate currently hard-coded
#########

use warnings;
use strict;

use List::Util qw(sum);


############# 
############# Fill in variables below
############# 
my $gene = "";
my $min_call_rate = 0.5; # Populations with a call rate below this value will be marked NA in the output
my @populations = (qw/ /); # List of populations. Must match names in input file
my $input = ""; #file giving information on allele counts across populations of interest for the genomic region of interest (eg, gene). For example see: INPUT_example.txt
my $output = ""; # output file name


############# 
############# 
############# 
  
open SNPS, $input or (print "Variant file ($input) for gene $gene does not exist\n" and next);
open OUT, ">$output";
  my %POS;
  my %TOTAL;
  my %ALT_COL;
  my %REF_COL;
  my %CALLED_COL;
  my %PERC_CALLED_COL;
  my %MISSING_COL;
  while (<SNPS>) {
    chomp;
    my (@dl) = split("\t",$_);
    if ($. == 1) {
		print OUT "$_";
		foreach my $pop (@populations) {
		print OUT "\t$pop"."_pi";
	}
	print OUT "\n";
	foreach my $col (0..$#dl) {
		if ($dl[$col] =~ /ALT_COUNT/) {
	  		my ($other, $pop) = split("_COUNT_",$dl[$col]);
	  		$ALT_COL{$pop} = $col;
		}
		if ($dl[$col] =~ /REF_COUNT/) {
	  		my ($other, $pop) = split("_COUNT_",$dl[$col]);
	  		$REF_COL{$pop} = $col;
		}
		if ($dl[$col] =~ /CALLS/) {
	  		my ($other, $pop) = split("_",$dl[$col]);
	  		$CALLED_COL{$pop} = $col;
		}
		if ($dl[$col] =~ /MISSING/) {
	  		my ($other, $pop) = split("_",$dl[$col]);
	  		$MISSING_COL{$pop} = $col;
		}
		if ($dl[$col] =~ /PERCENT_CALLED/) {
	  		my ($other, $pop) = split("_CALLED_",$dl[$col]);
	  		$PERC_CALLED_COL{$pop} = $col;
		}
      }
      next;
    }
    
    my %COUNTS;
    my $site = $dl[1];
    my $Ref_allele = $dl[1];
    my $Alt_allele = $dl[3];
    my @alleles = split(",",$Alt_allele);
    my $allele_count = ($#alleles+2);
    my @total_count = ();
    foreach (0..($allele_count-1)) {
      $total_count[$_] = 0;
    }
    print OUT "$_";
  
## Calculate unpolarized diversity statistics
    my %CALLED_SITES;
    foreach my $pop (@populations) {
      if ($dl[$PERC_CALLED_COL{$pop}] < $min_call_rate) {
		print OUT "\tNA";
		$COUNTS{$pop} = "NA";
      } else {
		my @alt_alleles;
		my $alt_count = 0;
		if ($allele_count > 2) {
	  		@alt_alleles = split(",",$dl[$ALT_COL{$pop}]);
	  		foreach my $allele (@alt_alleles) {
	    		$alt_count += $allele;
	 		}
		} else {
	  		@alt_alleles = ($dl[$ALT_COL{$pop}]);
	  		$alt_count = $dl[$ALT_COL{$pop}];
		}
		my $ref_allele = $dl[$REF_COL{$pop}];
		my $called_sites = $ref_allele+$alt_count;
		$CALLED_SITES{$pop} = $called_sites;
		my @all_alleles = ($ref_allele,@alt_alleles);
		foreach my $index (0..$#all_alleles) {
	  		$total_count[$index] = $total_count[$index] + $all_alleles[$index];
		}
      
## pi
		my $pi;
		$COUNTS{$pop} = "$ref_allele,$dl[$ALT_COL{$pop}]"; # comma separated list of allele frequencies within the population
		if (($alt_count == 0) || (($ref_allele==0) && ($allele_count==2))) {
	  		$pi = "NA";
		} else {
	  		my $hetero_sum = 1;
	  	foreach my $allele (@all_alleles) {
	    	my $freq = $allele/$called_sites;
	    	$hetero_sum = $hetero_sum - ($freq*$freq);
	  	}
	  	$pi = ($called_sites/($called_sites-1))*($hetero_sum);
	  	}
	  	print OUT "\t$pi";
	  }
	}
	print OUT "\n";
}
close OUT;
close SNPS;

