#The following R packages are needed:
#tidyverse for general data processing
#RcppRoll for calculating rolling mean
#parallel for mcapply
#rhdf5 for handling H5 files


#tested with R version 3.5.1 (2018-07-02)
#tidyverse 1.2.1
#RcppRoll 0.3.0
#parallel 3.5.1
#rhdf5 1.2.1


require(tidyverse)
require(RcppRoll)
require(parallel)
require(rhdf5)

#The following constants are only needed if calculation from raw data is desired


#pixel size of CV7000 camera
pixel_size <- 6.5

microscope_objective <- 20

#binning factor used to achieve 45 FPS
binning_factor <- 3

pixel_factor <- pixel_size/microscope_objective*binning_factor

interval_time <-  0.022

#window size for moving average VAP
window_size <- 11




###Helper functions for calculating kinematic values
###Only needed if calculation from raw data is desired

#calculate VCL, total_distance, and total time
calc_vcl <- function(df) {
  
  vcl <- df %>% 
    select(experiment, well_number, particle, frame, x, y) %>% 
    group_by(experiment, well_number, particle) %>% 
    summarize(total_distance = sum(sqrt(diff(x)^2 + diff(y)^2), na.rm = TRUE) * pixel_factor,
              total_time = (max(frame)-min(frame)) * interval_time,
              vcl = total_distance / total_time)# %>% print()
}

#calculate VAP parts
calc_vap_ma_win <- function(df, win_size) {
  
  #calculate a moving average window. we need to change the size of the window at the begining and end
  #win_size the size of the window of values
  
  #beginning
  data_b <- df %>%
    group_by(experiment, well_number, particle) %>% 
    filter(row_number() <= win_size) %>% 
    mutate(ma_x = roll_mean(x, win_size, fill = NA, align = 'center'),
           ma_y = roll_mean(y, win_size, fill = NA, align = 'center')) %>% 
    filter(!is.na(ma_x))# %>% print()
  
  #end
  data_e <- df %>%
    group_by(experiment, well_number, particle) %>% 
    filter(row_number() > (n() - win_size)) %>% 
    mutate(ma_x = roll_mean(x, win_size, fill = NA, align = 'center'),
           ma_y = roll_mean(y, win_size, fill = NA, align = 'center')) %>% 
    filter(!is.na(ma_x))# %>% print()
  
  data_f <- bind_rows(data_b, data_e)# %>% print()
  
  
}
calc_vap_middle <- function(df, win_size) {
  
  data_vm <- df %>% 
    group_by(experiment, well_number, particle) %>% 
    mutate(ma_x = roll_mean(x, win_size, fill = NA, align = 'center'),
           ma_y = roll_mean(y, win_size, fill = NA, align = 'center')) %>% 
    filter(!is.na(ma_x))# %>% print()
  
}

#note the function below uses mclapply with 4 cores
calc_vap_temp <- function(df, window_size) {
  
  edge_windows <- seq(window_size-2, 3, -2)
  edge_df <- tibble()
  
  # for (i in edge_windows) {
  # 
  #   edge_df <- bind_rows(edge_df, calc_vap_ma_win(df, i))
  # 
  # }
  mclapply(edge_windows, function(x){
    edge_df <- bind_rows(edge_df, calc_vap_ma_win(df, x))
  }, mc.cores = 4)
  
  data_vt <- bind_rows(edge_df, calc_vap_middle(df, window_size))# %>% print()
  
}

#calculate VAP and ALH
calc_vap <- function(df, window_size = 11) {
  
  #this function is hardcoded for 11-point smoothing 
  #an example for adjusting the window size is given below
  
  data_vap <- df %>% 
    group_by(experiment, well_number, particle) %>% 
    #we need to filter the first and last row to keep those points
    filter(row_number() %in% c(1, n())) %>% 
    mutate(ma_x = x,
           ma_y = y) %>% 
    bind_rows(
      #calculate all windows at the beginning and end
      calc_vap_temp(df, window_size)
      
    ) %>% 
    arrange(experiment, well_number, particle, frame) %>%
    summarize(vap = (sum(sqrt(diff(ma_x)^2 + diff(ma_y)^2), na.rm = TRUE)*pixel_factor) / (((max(frame) - min(frame)))*interval_time),
              alh = 2 * max((sqrt((ma_x - x)^2 + (ma_y - y)^2)*pixel_factor), na.rm = TRUE))# %>% print()
  
}

#calculate VSL
calc_vsl <- function(df) {
  
  data_vsl <- df %>% 
    group_by(experiment, well_number, particle) %>% 
    filter(row_number() %in% c(1,n())) %>% arrange(particle) %>%
    mutate(ma_x = x,
           ma_y = y) %>%
    summarize(vsl = (sqrt(diff(ma_x)^2 + diff(ma_y)^2)*pixel_factor) / ((max(frame)-min(frame))*interval_time))# %>% print()
}

#calculate fractal dimension after Mortimer ST et al., 1996
calc_fractal_dimension <- function(df) {
  
  data_fd <- df %>%
    group_by(experiment, well_number, particle) %>% 
    summarize(planar_extent = max(sqrt((first(x) - x)^2 + (first(y) - y)^2)),
              track_intervals = length(frame)-1,
              total_distance = sum(sqrt(diff(x)^2 + diff(y)^2), na.rm = TRUE),
              D = log(track_intervals) / (log(track_intervals) + log(planar_extent/total_distance))) %>% 
    select(-total_distance, -track_intervals, -planar_extent)# %>% print()
  
}

#calculate and combine all kinematics and classify sperm
calc_sperm_kinetics <- function(df, win_size = window_size) {
  
  vcl <- calc_vcl(df)
  vap <- calc_vap(df, win_size)
  vsl <- calc_vsl(df)
  d <- calc_fractal_dimension(df)
  
  data_final <- inner_join(vcl, vap) %>%
    inner_join(vsl) %>%
    inner_join(d) %>% 
    mutate(lin = vsl / vcl * 100,
           str = vsl / vap * 100, 
           wob = vap / vcl * 100,
           status_mot = ifelse(vap > 25 & str > 80, 'PM',
                               ifelse((vap > 5 | vsl > 11), 'NPM', 'IM')),
           status_hyp = ifelse(vcl >= 150 & D >= 1.2, 'HA', 'not-HA'))
  
}

#wrapper for multi-position multiprocessing
calc_sperm_kinematics_mp <- function(df, nc = 4) {
  
  data_end <- df %>% 
    group_by(position) %>% 
    nest() %>%
    mutate(tidy = mclapply(data, calc_sperm_kinetics, mc.cores = nc)) %>%
    unnest(tidy) %>% print()
  
}





#The functions below have to be loaded into the environment

#wrapper to calculate well average of each kinematic
kinematics_well_avg <- function(df, median = TRUE) {
  
  if (!'status_hyp' %in% names(df)) {
    avg_df <- df %>%
      #summarize positions
      group_by(experiment, well_number) %>%
      summarize(sperm_count = length(particle),
                # HA_count = sum(status_hyp == 'HA'),
                # HA = HA_count/sperm_count * 100,
                IM_count = sum(status == 'IM'),
                NPM_count = sum(status == 'NPM'),
                PM_count = sum(status == 'PM'),
                PM = PM_count/sperm_count * 100,
                TM = (PM_count+NPM_count)/sperm_count * 100,
                VCL_median = median(vcl),
                VSL_median = median(vsl),
                VAP_median = median(vap), 
                # ALH_median = median(alh),
                STR_median = median(str),
                # WOB_median = median(wob),
                LIN_median = median(lin)) %>% print()
    # D_median = median(D)) %>% print()
    
  } else if (median) {
    avg_df <- df %>%
      #summarize positions
      group_by(experiment, well_number) %>%
      summarize(sperm_count = length(particle),
                HA_count = sum(status_hyp == 'HA'),
                HA = HA_count/sperm_count * 100,
                IM_count = sum(status_mot == 'IM'),
                NPM_count = sum(status_mot == 'NPM'),
                PM_count = sum(status_mot == 'PM'),
                PM = PM_count/sperm_count * 100,
                TM = (PM_count+NPM_count)/sperm_count * 100,
                VCL_median = median(vcl),
                VSL_median = median(vsl),
                VAP_median = median(vap), 
                ALH_median = median(alh),
                STR_median = median(str),
                WOB_median = median(wob),
                LIN_median = median(lin),
                D_median = median(D)) %>% print()
    
  } else if (!median) {
    
    avg_df <- df %>%
      #summarize positions
      group_by(experiment, well_number) %>%
      summarize(sperm_count = length(particle),
                HA_count = sum(status_hyp == 'HA'),
                HA = HA_count/sperm_count * 100,
                IM_count = sum(status_mot == 'IM'),
                NPM_count = sum(status_mot == 'NPM'),
                PM_count = sum(status_mot == 'PM'),
                PM = PM_count/sperm_count * 100,
                TM = (PM_count+NPM_count)/sperm_count * 100,
                VCL_mean = mean(vcl),
                VSL_mean = mean(vsl),
                VAP_mean = mean(vap), 
                ALH_mean = mean(alh),
                STR_mean = mean(str),
                WOB_mean = mean(wob),
                LIN_mean = mean(lin),
                D_median = mean(D)) %>% print()
    
  } else  {
    
    print('Error. Median has to be a Boolean value (TRUE/FALSE)')
    
  }
}

#split well number into character and number
split_well_id <- function(df) {
  
  output_df <- df %>% 
    mutate(x = str_extract(well_number, '[0-9]+'),
           y = str_extract(well_number, '[aA-zZ]+')) %>% print()
  
}

#classify control and compound wells
classify_controls <- function(df) {
  
  classified_column <- df %>% 
    mutate(cpd_type = ifelse(
      
      (x %in% c('23') & y %in% LETTERS[1:8]) | (x %in% c('24') & y %in% LETTERS[9:16]), 'Negative Control',
      
      ifelse(
        
        (x %in% c('23') & y %in% LETTERS[9:16]) | (x %in% c('24') & y %in% LETTERS[1:8]), 'Positive Control', 'Compound'
        
      )
      
    ))
  
  return(df$cpd_type <- classified_column)
  
}

#classify positions compares median value of both positions recorded in each well 
#assigns sticky label, if median values differ by more than 15
#assigns AF error, if only one position has been recorded, due to auto focus error
classify_positions <- function(df) {
  
  if(!any(names(df) == 'position')) {
    print('This dataframe does not contain a position column')
  } else {
    
    
    classified_column <- df %>% 
      group_by(name, position, well_number) %>%
      summarize(vcl = median(vcl)) %>%
      spread(position, vcl) %>%
      mutate(flag = ifelse(abs(`001`-`002`) > 15, 'sticky', 'ok')) %>%
      select(name, well_number, flag) %>%
      mutate(flag = ifelse(is.na(flag), 'AF error', flag)) %>% print()
    
  }
  
}


#Set path to h5 motility file
#the H5 file has /raw , /processed, /well_avg groups
#raw data can be loaded using '/motility/raw' in the h5read function below
#this is only recommended, if you have big enough memory

#Please note: no compound information available in this data. Unblinding was done by Calibr after dose response experiments have been performed


h5_file <- 'supplementary_file_1.H5'

h5_con <- H5Fopen(h5_file)

#load data from H5 file
processed_df <- h5read(h5_con, '/motility/processed') %>% 
  enframe() %>% 
  unnest() %>% 
  print()

#calculate average of each well
avg_df <- kinematics_well_avg(processed_df) %>% 
  gather(key = kinematic, value = value, -experiment, -well_number) %>% 
  split_well_id() %>% 
  print()

#calculate mean, median, sd, mad of DMSO controls
neg_con <- avg_df %>% 
  group_by(experiment, kinematic) %>% 
  filter((x == '23' & y %in% LETTERS[1:8]) | (x == '24' & y %in% LETTERS[9:16])) %>%
  summarize(neg_con_mean = mean(value),
            neg_con_sd = sd(value),
            neg_con_median = median(value),
            neg_con_mad = mad(value)) %>% print()

#calculate mean, median, sd, mad of pristimerin controls
pos_con <- avg_df %>% 
  group_by(experiment, kinematic) %>% 
  filter((x == '24' & y %in% LETTERS[1:8]) | (x == '23' & y %in% LETTERS[9:16])) %>%
  summarize(pos_con_mean = mean(value),
            pos_con_sd = sd(value),
            pos_con_median = median(value),
            pos_con_mad = mad(value)) %>% print()

#combine avg data frame with control data frames
control_df <- avg_df %>% 
  inner_join(inner_join(neg_con, pos_con)) %>% 
  print()

#calculate zprime and robust zprime
stats_df <- control_df %>% 
  group_by(experiment, kinematic) %>% 
  summarize(zprime = 1 - ((3*pos_con_sd[1] + 3*neg_con_sd[1])/abs(pos_con_mean[1] - neg_con_mean[1]) ),
            rzprime = 1 - ((3*pos_con_mad[1] + 3*neg_con_mad[1])/abs(pos_con_median[1] - neg_con_median[1]))) %>%
  print()


ggplot(stats_df %>% filter(kinematic == 'VCL_median'), aes(experiment, zprime)) +
  geom_point() +
  geom_hline(yintercept = c(0.4, 0.8), linetype = 'dashed', alpha = 0.5) +
  scale_y_continuous(breaks = seq(0,1,0.2), limits = c(0,1)) +
  labs(x = 'Plate ID', y = 'Z\'') +
  theme(
    axis.ticks.x = element_blank(),
    axis.text.x = element_blank(),
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank()
  )
ggsave('zprime_plates.png', height = 1.3, width = 4, scale = 1.3)


#normalize data with DMSO control wells poc = percentage of control
normalized_df <- control_df %>% 
  group_by(experiment, kinematic) %>% 
  mutate(poc = 100 * (value / neg_con_median)) %>%
  classify_controls() %>% 
  inner_join(classify_positions(processed_df), by = c('well_number'='well_number', 'experiment' = 'name')) %>% 
  print()

ggplot(normalized_df %>% filter(kinematic == 'VCL_median', flag == 'ok') %>% mutate(well_id = paste(experiment, well_number)),
       aes(factor(cpd_type, levels = c('Compound', 'Negative Control', 'Positive Control'),
                  labels = c('Compounds', 'Negative Controls', 'Positive Controls')), poc)) +
  geom_hline(yintercept = 100, alpha = 0.4) +
  geom_jitter(alpha = 0.5) +
  geom_hline(yintercept = c(85), linetype = 'dashed', alpha = 0.8) +
  scale_color_manual(name = '', values = viridis::viridis(3)) +
  labs(subtitle = 'ReFRAME library (11,903 cpds)', y = '% of control', x = '') + 
  theme(
    panel.grid.major = element_blank(),
    panel.grid.minor = element_blank(),
    legend.position = 'none'
  )

normalized_df %>% filter(kinematic == 'VCL_median', poc < 85, flag == 'ok', cpd_type == 'Compound') %>% 
  ungroup %>%  
  summarize(hits = length(well_number)) %>% print()

