######### 2019-04-10
######### Contact: early@broadinstitute.org
#########
#########
#########
## Script for calculating pi (pairwise nucleotide diversity) per genomic region (eg, gene, transcript) from an allele frequency file.
## Takes output of pop_gen_stats_per_nt.pl as input
##


use warnings;
use strict;
use POSIX;
use Fcntl qw(:flock);

my $gene = ""; ## Gene/transcript name or ID
my $chrom= ""; ## Chromosome name (not used in calculations)
my $trans_site_number = ; ## Number sites of class $snp_eff_code in $gene

my $snp_eff_code = ""; ## missense_variant OR synonymous_variant; Originally from column 5 of INPUT_example.txt
my @populations = (qw//); # List of population names. MUST BE IN SAME ORDER AS IN INPUT FILE
my $min_coverage = ; # Minimum proportion of sites with at least 0.25 samples having a call. Default suggestion: 0.9
my @coverage = (); #For each population in @populations, the estimated proportion of sites covered by sequencing data. For low coverage genomes, this is used to correct estimates across regions for the absence of variant calls do to low coverage. MUST BE IN SAME ORDER AS POPS IN @populations ARRAY

my $input_file = ""; ## Input file in format produced by script pop_gen_stats_per_nt.pl. ORIGINAL HEADERS MUST BE INTACT
my $out_file = ""; ## Output file name


### Calculate "corrected length" for gene based on estimates of total coverage 
my %WL; 
foreach my $index (0..$#populations) {
	my $pop = $populations[$index];
	my $cover = $coverage[$index];
	my $corrected_site_number = $trans_site_number*$cover;
	$WL{$pop} = $corrected_site_number;
}

#### Make output file header
open OUT, ">$out_file" or die;
print OUT "GeneID\tChrom\tVarClass\tTotalSites";
foreach my $pop (@populations) {
	print OUT "\t$pop"."_S\t$pop"."_pi";
}
print OUT "\n";

print OUT "$gene\t$chrom\t$snp_eff_code\t$trans_site_number";


my @fst_parts_columns;
my @pi_columns;
my @ThetaW_columns;
my @Sample_Coverage_columns;
my @Alt_Count_columns;
	  
open SNPS, "$input_file" or (print "No variant file $input_file for gene $gene\n" and die);
  
my %fst_a_sum = my %fst_b_sum = my %fst_c_sum = my %snp_count = my %pi_sum = my %ThetaW_sum = my %a1_sum = my %a2_sum = my %b1_sum  = my %b2_sum  = my %c2_sum  = my %ThetaH_sum = my %ThetaL_sum = my %FayWuN_sum = my %bn1_sum = my %vd_sum = my %ud_sum = my %ne_sum = my %KST_sum = my %Snn_sum = my %Sample_Coverage = ();
my $effect_col;
my $filter_col;
my $total_snps = 0;

while (<SNPS>) {
	chomp;
    my @dl2 = split("\t", $_);
    if ($.==1) {
      foreach my $col (0..$#dl2) {
		if ($dl2[$col] =~ /_PARTS/) {
	  		@fst_parts_columns = (@fst_parts_columns, $col);
		}
		if ($dl2[$col] =~ /_pi/) {
	  		@pi_columns = (@pi_columns, $col);
		}
		if ($dl2[$col] =~ /CALLS_/) {
	  		@Sample_Coverage_columns = (@Sample_Coverage_columns, $col);
		}
		if ($dl2[$col] =~ /ALT_COUNT_/) {
	  		@Alt_Count_columns = (@Alt_Count_columns, $col);
		}
		if ($dl2[$col] =~ /FILTER/) {
			$filter_col = $col;
		}
		if ($dl2[$col] =~ /EFFECT/) {
			$effect_col = $col;
		}
      }
      next;
    }

    if (($dl2[$filter_col] ne "PASS") || ($dl2[$effect_col] ne $snp_eff_code)) {
      next;
    }
    
## Count number of segregating sites
    ++$total_snps;
    
## Make Hash of pi parts
    foreach my $pi (@pi_columns) {
      if (!exists $pi_sum{$pi}){
		$pi_sum{$pi} = 0;
		$snp_count{$pi} = 0;
      }
      if ($dl2[$pi] ne "NA") {
		$pi_sum{$pi} = ($pi_sum{$pi} + $dl2[$pi]);
		if ($dl2[$pi] > 0) {
			++$snp_count{$pi};
		}
      }
    }
    
## Count cumulative sample coverage across the sites
    foreach my $index (0..$#pi_columns) {
    	my $pop = $Sample_Coverage_columns[$index];
    	if (!exists $Sample_Coverage{$pop}) {
    		$Sample_Coverage{$pop} = 0;
    	}
    	if ($dl2[$pi_columns[$index]] ne "NA") {
    		if ($dl2[$pi_columns[$index]] > 0) {
    			$Sample_Coverage{$pop} = $Sample_Coverage{$pop} + $dl2[$pop];
    		}
    	}
    }
}
close SNPS;
   

foreach my $index (0..$#pi_columns) { ## cycle through individual population statistics

	if ($WL{$populations[$index]} == 0) {
		print OUT "\tNA\tNA\tNA\tNA";
		next;
	}

	my $avg_pi = 0;
	my $ThetaW = 0;
	my $TD = "NA";
## Calculate Average pi
    my $pi_col = $pi_columns[$index];
    my $Watterson_col = $ThetaW_columns[$index];
    my $coverage_col = $Sample_Coverage_columns[$index];
    if (!exists $pi_sum{$pi_col}) {
      $pi_sum{$pi_col} = 0;
    }
    $avg_pi = ($pi_sum{$pi_col}/$WL{$populations[$index]});
    

## Print individual population stats
	if ($coverage[$index] > $min_coverage) {
	    print OUT "\t$snp_count{$pi_col}\t$avg_pi";
	} else {
		print OUT "\tNA\tNA";
  	}
}

print OUT "\n";
close OUT;	
