## Retrospective Cohort Analysis Script.
set.seed(123)

# The user of this script supplies cohort-specific values to a series of 
# variables.  Make sure you address the specific components of your data
# set as they apply to all user inputs before the "DON'T EDIT BELOW HERE"
# warning.

#********************************************
#------------Cohort Specification------------
#********************************************

# Establish a folder for output files.
#********************************************
# It is recommended that you create an empty directory for output files.
# This script is going to write/overwrite many additional files to that directory.
# Create and then choose this directory now:
# setwd("~/retro_analysis/output_files")
#********************************************

# Now, where is the rectangular data frame for your cohort?
#********************************************
in_dir <- "" # Replace
out_dir <- "" # Replace
cohort_filename <- paste0(in_dir," .csv") # Replace

# What cohort does this data frame represent?
#********************************************
# Choose either "ARD" or "Pneumonia" of "COVID" or "COVID_pooled".  
# Comment out all but one of these choices.
cohort_name = "ARD"
#---OR---
# cohort_name = "Pneumonia"
#---OR---
# cohort_name = "COVID"
#********************************************
# Choose a data source (only affects created filenames, can be anything).
cohort_data_source = "MS"
#---OR---
# cohort_data_source = "Optum"
#---OR---
# cohort_data_source = "Swedish"
#********************************************

# Specify various lists of regression covariates.
#********************************************
#------------Required covariates-------------
#********************************************
# Outcomes (cannot be empty)
outcomes <- c("outcome_vent", "outcome_death_and_vent")
# Link your outcome names to brief plot descriptions of the outcomes.
outcome.altnames=c("outcome_vent"="Progression to ventilation",
                   "outcome_death_and_vent"="Progression to ventilation and death")
#********************************************
# Treatments (e.g., alpha blockers, Tamsulosin, Doxazosin)
# Each drug_treatment_n column should reflect use of a particular drug or class of drugs; 
# this column identifies the treatment group in every target cohort.  Each drug_treatment_n_exc_cols 
# identifies columns for which the presence of a 1 in a row represents an observation 
# meeting an exclusion criteria.
drug_treatment_1 <- "treatment_alpha_min180days"
drug_treatment_1_incl_cols <- c() # If left empty, there are no treatment-based inclusions.
drug_treatment_1_excl_cols <- c("treatment_alpha_180days_less") 
# If left empty, there are no treatment-based exclusions.
run_target_cohort_1 <- TRUE
#********************************************
drug_treatment_2 <- "treatment_tamsulosin_min180days"
drug_treatment_2_incl_cols <- c()
drug_treatment_2_excl_cols <- c("treatment_terazosin_pastyear", 
                                "treatment_prazosin_pastyear", 
                                "treatment_silodosin_pastyear", 
                                "treatment_doxazosin_pastyear", 
                                "treatment_alfuzosin_pastyear", 
                                "treatment_tamsulosin_180days_less")
run_target_cohort_2 <- TRUE
#********************************************
drug_treatment_3 <- "treatment_Doxazosin_min180days"
drug_treatment_3_incl_cols <- c()
drug_treatment_3_excl_cols <- c("treatment_terazosin_pastyear", 
                                "treatment_prazosin_pastyear", 
                                "treatment_silodosin_pastyear", 
                                "treatment_Tamsulosin_pastyear", 
                                "treatment_alfuzosin_pastyear", 
                                "treatment_doxazosin_180days_less")
run_target_cohort_3 <- TRUE
#********************************************
# Comorbidities
comorbidities <- c("comorbidity_diabetes_mellitus", 
                   "comorbidity_hypertension", 
                   "comorbidity_heart_failure", 
                   "comorbidity_ischemic_heart_disease", 
                   "comorbidity_acute_myocardial_infarction",
                   "comorbidity_chronic_obstructive_pulmonary_disease", 
                   "comorbidity_cancer")
# BPH is left out of the comorbidities group for convenience as we don't include it in 
# our models.  It will be listed separately here:
bph <- "comorbidity_benign_prostatic_hyperplasia"
#********************************************
# Age Demographics and Restriction
age <- "patient_age"
min_age_to_consider <- 45 # inclusive; this restriction happens for 'target cohorts'
max_age_to_consider <- 65 # inclusive; this restriction happens for 'target cohorts'
#********************************************
# Sex Demographics
sex <- "patient_sex"
sex_to_consider <- c("M") # this restriction happens for 'target cohorts'; can replace with c("M", "F"), etc.
male_char <- "M" # choose for however males are flagged
#********************************************

#********************************************
#-----------Optional covariates--------------
#********************************************
# Temporal Features and Restriction
time_value <- "year_factor"
# Time period filter: Ensure you provide a number in the same units as time_value 
# (integer weeks, numeric or date YMD, etc.).
min_time_period <- 2007 # inclusive
max_time_period <- 2015 # inclusive
#********************************************
# Treatments 
negative_exp_drug <- "" # else leave as ""
# for analysis and regressions:
other_covariates <- c("inpat_stay_weeks_past12mo_log",
                      "outpat_stay_weeks_past12mo_log",
                      "inpat_duration_past12mo_log",
                      "inpat_stay_weeks_past2mo_log") 
# for regressions only:
other_regression_covariates <- c() 
#********************************************

# Set flags to exclude certain covariates from the 
# analysis and bypass associated tables and figures.
#********************************************
#-----------Causal Analysis Flags------------
#********************************************
# trimming threshold for propensity score
propensity_threshold <- 0.01 
# remove low-propensity observations
filter_propensity <- TRUE 
# whether to perform matching
match_flag <- TRUE 
# whether to use scaled covariates in matching
scale_flag <- TRUE 
# number of control matches in  N:1 matching; must input positive integer if match_flag is TRUE
match.ratio <- 5 
# keep causal_flag false for runtime reasons; runs causal analyses additional to propensity weighting and matching
causal_flag <- FALSE 
# Regression with Interaction Terms
interact_flag <- FALSE 
# use comorbidities in models
consider_comorbidities <- TRUE
# use the time variable in models
consider_time <- TRUE
# use the negative control drug as a covariate in models (if specified)
neg_exp_flag <- FALSE
# use Twang (if TRUE) to estimate propensities; otherwise use a causal forest
twang_propensities <- FALSE
#********************************************

#********************************************
#**********No User Edits Past Here***********
#********************************************

#********************************************
#-----------------Libraries------------------
#********************************************
library(tidyverse)
library(cowplot)
library(cobalt)
library(DT)
library(reshape2)
library(MASS)
library(broom)
library(scales)
library(knitr)
library(kableExtra)
library(stargazer)
library(MatchIt)
library(grf)
library(survey)
library(gbm)
library(ggpubr)
library(latex2exp)
library(twang)
library(grid)
library(foreach)
library(doParallel)

#********************************************
#**********Code-Defined Flags****************
#********************************************
time_factor_flag <- 1 * min((time_value!=""), consider_time)
negative_drug_flag <- 1 * min((negative_exp_drug!=""), neg_exp_flag)
other_covariates_flag <- 1 * (length(other_covariates)>0)
other_regression_covariates_flag <- 1 * (length(other_regression_covariates)>0)
relevant_treatments <- c()

#********************************************
#----------Build the Full Cohort-------------
#********************************************
drug_incl_excl_cols <- c(drug_treatment_1_incl_cols, drug_treatment_1_excl_cols,
                         drug_treatment_2_incl_cols, drug_treatment_2_excl_cols,
                         drug_treatment_3_incl_cols, drug_treatment_3_excl_cols)

required_cohort_df <- readr::read_csv(cohort_filename) %>%
  dplyr::select(age = matches(paste0("^",age,"$")),
                sex = matches(paste0("^",sex,"$")), 
                one_of(outcomes),
                one_of(comorbidities), 
                bph = matches(paste0("^",bph,"$")),
                one_of(drug_incl_excl_cols))

if(run_target_cohort_1 == TRUE){
  newcol <- readr::read_csv(cohort_filename) %>%
    dplyr::select(drug_treatment_1 = matches(paste0("^",drug_treatment_1,"$")))
  required_cohort_df <- cbind(required_cohort_df, newcol)
  relevant_treatments <- c(relevant_treatments, 'drug_treatment_1')
}
if(run_target_cohort_2 == TRUE){
  newcol <- readr::read_csv(cohort_filename) %>%
    dplyr::select(drug_treatment_2 = matches(paste0("^",drug_treatment_2,"$"), ignore.case=FALSE))
  required_cohort_df <- cbind(required_cohort_df, newcol)
  relevant_treatments <- c(relevant_treatments, 'drug_treatment_2')
}
if(run_target_cohort_3 == TRUE){
  newcol <- readr::read_csv(cohort_filename) %>%
    dplyr::select(drug_treatment_3 = matches(paste0("^",drug_treatment_3,"$")))
  required_cohort_df <- cbind(required_cohort_df, newcol)
  relevant_treatments <- c(relevant_treatments, 'drug_treatment_3')
}
if(time_factor_flag == 1){
  newcol <- readr::read_csv(cohort_filename) %>%
    dplyr::select(time_value = matches(paste0("^",time_value,"$")))
  required_cohort_df <- cbind(required_cohort_df, newcol)
}
if(other_covariates_flag == 1){
  newcol <- readr::read_csv(cohort_filename) %>%
    dplyr::select(one_of(other_covariates))
  required_cohort_df <- cbind(required_cohort_df, newcol)
}
if(other_regression_covariates_flag == 1){
  newcol <- readr::read_csv(cohort_filename) %>%
    dplyr::select(one_of(other_regression_covariates))
  required_cohort_df <- cbind(required_cohort_df, newcol)
}
if(negative_drug_flag == TRUE){
  newcol <- readr::read_csv(cohort_filename) %>%
    dplyr::select(negative_treatment = matches(paste0("^",negative_exp_drug,"$")))
  required_cohort_df <- cbind(required_cohort_df, newcol)
  relevant_treatments <- c(relevant_treatments, 'negative_treatment')
}

# Ensure column types are mutated correctly
cohort_df <- na.omit(required_cohort_df) %>%
  mutate(male_flag = 1*(sex == male_char))
if(time_factor_flag == 1){
  cohort_df <- cohort_df %>%
    dplyr::filter(min_time_period <= time_value, time_value <= max_time_period) %>%
    mutate(time_factor = as.factor(time_value))
} 
cohort_df[c('sex')] <- lapply(cohort_df[c('sex')], factor) 

# It's time to build the full covariates list.  We'll get there incrementally.
relevant_covariates <- c(comorbidities, other_covariates, "age", "sex")
comorb_and_bph <- c(comorbidities, 'bph')

time.mtrx <- model.matrix( ~ time_factor - 1, data= cohort_df )[,c(2:nlevels(cohort_df$time_factor))]
cohort_df <- cbind(cohort_df, as.data.frame(time.mtrx))

if(time_factor_flag == 1){
  full_covariates <- c("patient_age_demeaned", "patient_age_demeaned_squared", "patient_age_demeaned_cubed", 
                       other_covariates, other_regression_covariates, comorbidities, colnames(time.mtrx))
  match_covariates <- c("patient_age_demeaned", "patient_age_demeaned_squared", "patient_age_demeaned_cubed", 
                        other_covariates, other_regression_covariates, comorbidities, "time_value")
} else{
  full_covariates <- c("patient_age_demeaned", "patient_age_demeaned_squared", "patient_age_demeaned_cubed", 
                       other_covariates, other_regression_covariates, comorbidities)
  match_covariates <- c("patient_age_demeaned", "patient_age_demeaned_squared", "patient_age_demeaned_cubed", 
                        other_covariates, other_regression_covariates, comorbidities)
}

# Define cohort restrictions
all_covar_to_demean <- union(eval(full_covariates), eval(match_covariates))
demean_col_list <- c(relevant_treatments, eval(outcomes), all_covar_to_demean, "male_flag")
renamed_demean_col_list <- paste("dm", demean_col_list, sep = "_")
renamed_scale_demean_col_list <- paste("dm_s", demean_col_list, sep = "_")

# rename covariate lists
demean_full_covar_list <- paste("dm", full_covariates, sep = "_")
scale_demean_full_covar_list <- paste("dm_s", full_covariates, sep = "_")

demean_match_covar_list <- paste("dm", match_covariates, sep = "_")
scale_demean_match_covar_list <- paste("dm_s", match_covariates, sep = "_")

# get regression input covariates
if (scale_flag) {
  match_covariates_demeaned_timenum <- scale_demean_match_covar_list
  full_covariates_demeaned_timenum <- scale_demean_full_covar_list
} else {
  match_covariates_demeaned_timenum <- demean_match_covar_list
  full_covariates_demeaned_timenum <- demean_full_covar_list
}

# we do not include age covariates in the matching formula in new_match runs (we already "exact match" on age)
match_covariates_demeaned_timenum <- match_covariates_demeaned_timenum[!str_detect(match_covariates_demeaned_timenum, "age")]

#********************************************
#---Cohort Transformations and Restrictions--
#********************************************

CenterColmeans <- function(x) {
  scale(x, scale = FALSE)
}

ScaleCenterColmeans <- function(x) {
  scale(x, center = TRUE, scale = TRUE)
}

# filter to older men and de-mean ages
CohortRestriction <- function(df){
  # age and gender restriction
  target_cohort_n <- df %>%
    dplyr::filter(sex %in% sex_to_consider,
                  age >= min_age_to_consider,
                  age <= max_age_to_consider) %>%
    # de-mean patient ages
    mutate(patient_age_demeaned = CenterColmeans(age),
           patient_age_demeaned_squared = patient_age_demeaned^2,
           patient_age_demeaned_cubed = patient_age_demeaned^3)
  dm_cols <- as.data.frame(CenterColmeans(target_cohort_n[demean_col_list]))
  colnames(dm_cols) <- renamed_demean_col_list
  target_out_centered <- cbind(target_cohort_n, dm_cols)
  # scale covariates
  dm_s_cols <- as.data.frame(ScaleCenterColmeans(target_out_centered[demean_col_list]))
  colnames(dm_s_cols) <- renamed_scale_demean_col_list
  target_out <- cbind(target_out_centered, dm_s_cols)
  
  return(target_out)
}

GenerateTargetCohorts <- function(cohort_df, treatment_incl_cols, treatment_excl_cols){
  df <- cohort_df
  if(!is.null(treatment_incl_cols)){
    df <- df %>%
      filter_at(vars(treatment_incl_cols), all_vars(. %in% 1))  
  }
  if(!is.null(treatment_excl_cols)){
    df <- df %>%
      filter_at(vars(treatment_excl_cols), all_vars(. %in% 0))  
  }
  target_cohort_n <- as.data.frame(CohortRestriction(df))
  return(target_cohort_n)
}

PropensityScoresAndWeights <- function(df, formula, covars, prop_treatment){
  if(twang_propensities){
    ps.df <- ps(formula,
                data = df,
                n.trees=10000,
                interaction.depth=2,
                shrinkage=0.01,
                perm.test.iters=0,
                stop.method=c("es.mean","ks.max"),
                estimand = "ATT",
                verbose=FALSE)
    propensity_scores <- ps.df$ps$es.mean.ATT
    weights <- get.weights(ps.df, stop.method="es.mean")
    # Normalize treated and control weights separately.
    weights[df[,prop_treatment]==0] <- weights[df[,prop_treatment]==0] / sum(weights[df[,prop_treatment]==0])
    weights[df[,prop_treatment]==1] <- weights[df[,prop_treatment]==1] / sum(weights[df[,prop_treatment]==1])
  }else{
    # use either outcome gives same propensity scores 
    outcome <- df$outcome_vent
    X <- as.matrix(df[, covars])
    drug <- df[, prop_treatment]
    
    cf <- causal_forest(X, outcome, drug, num.trees = 2000, seed = 123)
    propensity_scores = cf$W.hat
    weights <- ifelse(df$alpha_treatment == 1, 1, propensity_scores/(1-propensity_scores))
    # Normalize treated and control weights separately.
    weights[df[,prop_treatment]==0] <- weights[df[,prop_treatment]==0] / sum(weights[df[,prop_treatment]==0])
    weights[df[,prop_treatment]==1] <- weights[df[,prop_treatment]==1] / sum(weights[df[,prop_treatment]==1])
  }
  return(list(propensity_scores,weights))
}



MatchFunc <- function(m.out){
  matched_target_cohort_df <- match.data(m.out)
  matching_tables <- summary(m.out, standardize=T)
  bal_tab_before <- matching_tables[[3]]
  bal_tab_after <- matching_tables[[4]]
  love_plot_y_ticklabels <- bal_tab_after %>% rownames()
  if(max(nchar(love_plot_y_ticklabels)) > 20){
    love_plot_y_ticklabels <- love_plot_y_ticklabels %>%
      str_replace_all("[:digit:]", "") %>% 
      str_replace_all("a|e|i|o|u", "")
    time_factor_idx <- which(str_detect(love_plot_y_ticklabels, "time_factor"))
    love_plot_y_ticklabels[time_factor_idx] <- love_plot_y_ticklabels[time_factor_idx] %>%
      str_replace("time_factor", "tf") %>% 
      str_replace("_0.[:digit:]{5,}","")
  }
  
  df_to_plot <- tibble(feature = rownames(bal_tab_before), Before = bal_tab_before$`Std. Mean Diff.`, After = bal_tab_after$`Std. Mean Diff.`) %>% 
    gather(key = "Matching", value = "StdMnDff", Before, After) %>% 
    mutate(Matching = fct_relevel(Matching, levels = c("Before", "After")))
  
  target <- bal_tab_after %>% rownames()
  
  df_to_plot <- df_to_plot %>% 
    mutate(feature = fct_relevel(feature, levels = target))
  
  return(list(matched_target_cohort_df=matched_target_cohort_df,
              love_plot=love_plot,
              bal_tab_before=bal_tab_before,
              bal_tab_after=bal_tab_after,
              df_to_plot=df_to_plot))
}

#********************************************
#----------Build the Target Cohorts----------
#********************************************

# target cohort 1: all older men
if(run_target_cohort_1){
  tc_1 <- GenerateTargetCohorts(cohort_df, drug_treatment_1_incl_cols, drug_treatment_1_excl_cols)
  target_cohort_1 <- as.data.frame(tc_1)
  target_cohort_1$alpha_treatment <- target_cohort_1$drug_treatment_1
  target_cohort_1$dm_alpha_treatment <- target_cohort_1$dm_drug_treatment_1
  target_cohort_1$dm_s_alpha_treatment <- target_cohort_1$dm_s_drug_treatment_1
}

# target cohort 2: tamsulosin-alpha older men
if(run_target_cohort_2){
  tc_2 <- GenerateTargetCohorts(cohort_df, drug_treatment_2_incl_cols, drug_treatment_2_excl_cols)
  target_cohort_2 <- as.data.frame(tc_2)
  target_cohort_2$alpha_treatment <- target_cohort_2$drug_treatment_2
  target_cohort_2$dm_alpha_treatment <- target_cohort_2$dm_drug_treatment_2
  target_cohort_2$dm_s_alpha_treatment <- target_cohort_2$dm_s_drug_treatment_2
  target_cohort_2$dm_s_drug_treatment_3 <- 0
}

# target cohort 3: doxazosin-alpha older men
if(run_target_cohort_3){
  tc_3 <- GenerateTargetCohorts(cohort_df, drug_treatment_3_incl_cols, drug_treatment_3_excl_cols)
  target_cohort_3 <- as.data.frame(tc_3)
  target_cohort_3$alpha_treatment <- target_cohort_3$drug_treatment_3
  target_cohort_3$dm_alpha_treatment <- target_cohort_3$dm_drug_treatment_3
  target_cohort_3$dm_s_alpha_treatment <- target_cohort_3$dm_s_drug_treatment_3
  target_cohort_3$dm_s_drug_treatment_2 <- 0
}


####################### Parallel ############################

# define target cohort lists 
cohort_list = list()
if (run_target_cohort_1) {
  cohort_list[[1]] = list(target_cohort_1, full_covariates_demeaned_timenum, run_target_cohort_1, drug_treatment_1, match_covariates_demeaned_timenum)
}

if (run_target_cohort_2) {
  cohort_list[[2]] = list(target_cohort_2, full_covariates_demeaned_timenum, run_target_cohort_2, drug_treatment_2, match_covariates_demeaned_timenum)
}

if (run_target_cohort_3) {
  cohort_list[[3]] = list(target_cohort_3, full_covariates_demeaned_timenum, run_target_cohort_3, drug_treatment_3, match_covariates_demeaned_timenum)
}


registerDoParallel(3) 
prop_df = foreach(cohort=cohort_list) %dopar% {
  if (cohort[[3]]) {
    
    treatment_formula <- as.formula(paste('alpha_treatment ~ ', 
                                          paste(cohort[[2]], collapse = " + "), sep=""))
    prop_treatment <- "alpha_treatment"
    
    set.seed(123)
    propensities_and_weights <- PropensityScoresAndWeights(cohort[[1]], treatment_formula, 
                                                           cohort[[2]], prop_treatment)
    cohort[[1]]$propensity_scores <- propensities_and_weights[[1]]
    cohort[[1]]$prop_weights <- propensities_and_weights[[2]]
      
    saveRDS(cohort[[1]], file = str_c(out_dir, "target_cohort_df_", cohort[[4]], "_", cohort_data_source, "_", cohort_name, ".rds"))
    
    matching_formula <- as.formula(paste('alpha_treatment ~ ', 
                                         paste(cohort[[5]], collapse = " + "), sep=""))
    
    propensity_df <- cohort[[1]]
    
    treated_percentiles <- propensity_df %>% 
      dplyr::select(treatment = one_of(prop_treatment), propensity_scores) %>% 
      dplyr::filter(treatment == 1) %>% 
      pull(propensity_scores) %>% 
      quantile(c(propensity_threshold, 1 - propensity_threshold))
    
    control_percentiles <- propensity_df %>% 
      dplyr::select(treatment = one_of(prop_treatment), propensity_scores) %>% 
      dplyr::filter(treatment == 0) %>% 
      pull(propensity_scores) %>% 
      quantile(c(propensity_threshold, 1 - propensity_threshold))
    
    lower_propensity_threshold <- max(c(treated_percentiles[1], control_percentiles[1]))
    upper_propensity_threshold <- min(c(treated_percentiles[2], control_percentiles[2]))
    propensity_df <- propensity_df %>% dplyr::filter(propensity_scores >= lower_propensity_threshold & propensity_scores <=  upper_propensity_threshold)
    saveRDS(propensity_df, file = str_c(out_dir, "filtered_cohort_df_", cohort[[4]], "_", cohort_data_source, "_", cohort_name, ".rds"))
    
    propensities_and_weights <- PropensityScoresAndWeights(propensity_df, treatment_formula, 
                                                           cohort[[2]], prop_treatment)
    propensity_df$propensity_scores <- propensities_and_weights[[1]]
    propensity_df$prop_weights <- propensities_and_weights[[2]]
    propensity_df$dm_s_male_flag <- NULL
    propensity_df
  }
}

stopImplicitCluster()
saveRDS(prop_df, file = str_c(out_dir, "reestimate_propensity_df", "_", cohort_data_source, "_", cohort_name, ".rds"))

match_list = list()
if (run_target_cohort_1) {
  match_list[[1]] = list(prop_df[[1]], run_target_cohort_1)
}

if (run_target_cohort_2) {
  match_list[[2]] = list(prop_df[[2]], run_target_cohort_2)
}

if (run_target_cohort_3) {
  match_list[[3]] = list(prop_df[[3]], run_target_cohort_3)
}


registerDoParallel(3)
match_df = foreach(match=match_list) %dopar% {
  if (match[[2]]) {
    propensity_df <- match[[1]]
    propensity_df$dm_s_male_flag <- NULL
    
    matching_formula <- as.formula(paste('alpha_treatment ~ ', 
                                         paste(match_covariates_demeaned_timenum, collapse = " + "), sep=""))
    
    set.seed(123)
    rownames(propensity_df) <- 1:nrow(propensity_df)
    propensity_df$idx <- 1:nrow(propensity_df)
    m.out <- matchit(matching_formula, data=propensity_df, exact=c("dm_s_patient_age_demeaned"), 
                     method = "nearest", ratio=match.ratio, caliper = .2, distance = "mahalanobis")
    
    matched_target_cohort_df <- match.data(m.out)
    matched_target_cohort <- list(m.out, matched_target_cohort_df)
    matched_target_cohort
  }
}
stopImplicitCluster()
saveRDS(match_df, file = str_c(out_dir, "matchit_df", "_", cohort_data_source, "_", cohort_name, ".rds"))
