This is an R code used to create Figure 5 for the paper “A neural network model of hippocampal contributions to category learning”.
https://www.biorxiv.org/content/10.1101/2022.01.12.476051v1
# Load libraries
library(readxl)
library(dplyr)
library(ggplot2)
library(reshape2)
library(RColorBrewer)
library(ggpubr)
library(grid)
library(forcats)
# Load data
data_fig5cd <- as.data.frame(read.csv("Figure 5 - Source Data 1.csv"))
data_fig5ef <- as.data.frame(read.csv("Figure 5 - Source Data 2.csv"))
# Define order of levels for the factor "atypical_features"
data_fig5ef <- data_fig5ef %>%
mutate(atypical_features = fct_relevel(atypical_features,
"4 atypical", "3 atypical", "2 atypical", "1 atypical"))
# Subset accuracy data for recognition test
rec <- data_fig5ef %>%
filter(test_type == "rec")
# Subset accuracy data for generalization test
gen <- data_fig5ef %>%
filter(test_type == "gen")
# define plot theme specs
plot_specs <-
theme_bw() +
theme(panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
axis.line = element_line(colour = "black"),
axis.text = element_text(size = 15),
axis.title = element_text(size = 20),
plot.title = element_text(size = 25),
legend.text = element_text(size = 20),
legend.title = element_blank())
Intact model generalization across learning, with 10 trials prior to each interim test.
m_gen_1 <- data_fig5cd %>%
filter(test_counter %in% c(5, 15, 25, 35))
# re-order factor levels
m_gen_1$AtypFeat_r <- factor(m_gen_1$AtypFeat_r, levels=c( "Prototypes",
"Distance 1",
"Distance 2",
"Distance 3",
"Distance 4"))
# plot
ggplot(m_gen_1, aes(x = test_counter, y = Accuracy, group = AtypFeat_r, color = AtypFeat_r)) +
geom_line(size = 3) +
geom_errorbar(aes(ymin = Accuracy - se,
ymax = Accuracy + se),
width = .2,
position = position_dodge(0.05)) +
scale_x_continuous(breaks = c(5, 15, 25, 35),
n.breaks = 4,
limits = c(0, 40),
labels = c(1, 2, 3, 4)) +
scale_color_manual(values = c('#252525', '#525252','#737373','#969696','#bdbdbd')) +
labs(x = "Interim test number",
y = "Proportion correct",
title = "Model generalization across learning") +
coord_cartesian(ylim = c(0.5, 1)) +
plot_specs
Intact model generalization at the end of learning.
# select final test
m_gen_2 <- data_fig5cd %>%
filter(test_counter == 40)
# plot
ggplot(m_gen_2, aes(x = AtypFeat, y = Accuracy)) +
geom_bar(stat = "identity", width = 0.5, color = "black", fill = "#737373") +
geom_errorbar(aes(ymin = Accuracy-se,
ymax = Accuracy+se),
width = .2,
position = position_dodge(.9)) +
labs(x = "Distance from prototypes",
y = "Proportion correct",
title = "Model generalization") +
coord_cartesian(ylim = c(0.5, 1)) +
plot_specs
Model generalization broken down by typicality and model type.
ggplot(gen, aes(x = trial, y = accuracy, group = condition, color = condition)) +
geom_line(size = 2) +
geom_errorbar(aes(ymin = accuracy - standard_error,
ymax = accuracy + standard_error),
width = .2,
position = position_dodge(0.05)) +
scale_x_continuous(breaks = seq(0, 100, 20)) +
scale_color_manual(values = c('#1b9e77','#d95f02','#7570b3')) +
labs(x = "Trial",
y = "Accuracy",
title = "Model generalization") +
geom_hline(yintercept = 0.5, linetype = "dashed", colour = "grey") +
coord_cartesian(ylim = c(0.3, 1)) +
facet_grid(~ atypical_features) +
theme_bw() +
theme(panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
axis.line = element_line(colour = "black"),
axis.text = element_text(size=20),
axis.title = element_text(size=30),
legend.position = "none",
legend.title = element_blank(),
strip.text.x = element_text(size = 30),
plot.title = element_text(size = 40, hjust = 0.5),
strip.background = element_blank())
Model atypical feature recognition broken down by typicality and model type.
ggplot(rec, aes(x = trial, y= accuracy, group = condition, color = condition)) +
geom_line(size = 2) +
geom_errorbar(aes(ymin = accuracy - standard_error,
ymax = accuracy + standard_error),
width = .2,
position = position_dodge(0.05)) +
scale_x_continuous(breaks = seq(0, 100, 20)) +
labs(x = "Trial",
y = "Accuracy",
title = "Model atypical feature recognition") +
geom_hline(yintercept = 0.5, linetype = "dashed", colour = "grey") +
scale_color_manual(values = c('#1b9e77', '#d95f02', '#7570b3', '#bcbddc')) +
coord_cartesian(ylim = c(0.3, 1)) +
facet_wrap(~ atypical_features, nrow = 1) +
theme_bw() +
theme(panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
panel.border = element_blank(),
axis.line = element_line(colour = "black"),
axis.text = element_text(size = 20),
axis.title = element_text(size = 30),
plot.subtitle = element_text(hjust = 0.5, size=30),
plot.title = element_text(size = 40, hjust = 0.5),
legend.text = element_text(size = 20),
legend.position = c("right"),
legend.title = element_blank(),
legend.key.width = unit(3, "line"),
strip.background = element_blank(),
strip.text = element_text(size = 30),
panel.spacing = unit(2, "lines"))
Representational similarity for the initial and settled response of the intact network. Each item appears in the rows and columns of the heatmap, organized by most prototypical members of one category to most prototypical members of the other.The diagonals are always 1, as this reflects items correlated to themselves, and the off-diagonals are symmetric. Black boxes delineate categories.
# Clear the environment
remove(list = ls())
# Define a function that converts a data frame to a matrix
matrix.please <- function(x) {
m <- as.matrix(x[,-1])
rownames(m) <- x[,1]
m }
# Specify colors for heatmaps
dd <- (c('#a50026','#d73027','#f46d43','#fdae61','#fee090','#e0f3f8','#abd9e9','#74add1','#4575b4','#313695'))
pp <- rev(dd)
# Load the data
DG_initial <- read_excel("Figure 5 - Source Data 3.xlsx", sheet = "DG_MeanCorMatrix_test_init")
DG_settled <- read_excel("Figure 5 - Source Data 3.xlsx", sheet = "DG_MeanCorMatrix_test_set")
CA3_initial <- read_excel("Figure 5 - Source Data 3.xlsx", sheet = "CA3_MeanCorMatrix_test_init")
CA3_settled <- read_excel("Figure 5 - Source Data 3.xlsx", sheet = "CA3_MeanCorMatrix_test_set")
CA1_initial <- read_excel("Figure 5 - Source Data 3.xlsx", sheet = "CA1_MeanCorMatrix_test_init")
CA1_settled <- read_excel("Figure 5 - Source Data 3.xlsx", sheet = "CA1_MeanCorMatrix_test_set")
# Create a list containing all files
all_files <- list(DG_initial = DG_initial, DG_settled = DG_settled,
CA3_initial = CA3_initial, CA3_settled = CA3_settled,
CA1_initial = CA1_initial, CA1_settled = CA1_settled)
# Initiate an empty list that will store plots
plots_list <- list()
# Loop through the data for the initial and settled phase of each subfield, and visualize correlation matrices
for (i in 1:length(all_files)) {
# load one dataset
dat_temp = all_files[[i]]
# extract information about the subfield (DG, CA3 and CA1) and phase (initial and settled)
current_subfield <- gsub("_.*$","", names(all_files)[i])
phase <- substr((names(all_files)[i]), 5, 7)
# dynamically create a name for each plot
temp_name <- paste("heatmap_", current_subfield, "_", phase, sep = "")
# define as data frame
dat_temp <- as.data.frame(dat_temp)
# convert to a matrix
dat_temp_M <- matrix.please(dat_temp)
# produce a plot
p <- ggplot(melt(dat_temp_M), aes(Var1, ordered(Var2, levels = rev(sort(unique(Var2)))), fill = value)) +
geom_tile() +
coord_equal() +
scale_fill_gradientn(colours = pp,
values = scales::rescale(c(-0.04, 1)),
breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1),
labels = c(0, 0.2, 0.4, 0.6, 0.8, 1),
limits = c(-0.04,1)) +
ggtitle(paste(current_subfield)) +
annotate("rect", xmin = 0.5, xmax = 42.5, ymin = 0.5, ymax = 42.5, alpha = 0, color = "black", size = 0.3) +
annotate("rect", xmin = 0.5, xmax = 21.5, ymin = 21.5, ymax = 42.5, alpha = 0, color = "black", size = 1.2) +
annotate("rect", xmin = 21.5, xmax = 42.5, ymin = 0.5, ymax = 21.5, alpha = 0, color = "black", size = 1.2) +
theme_void() +
theme(legend.title = element_blank(),
legend.text = element_text(size = 35),
legend.position = c("none"),
plot.title = element_text(size = 30, hjust = 0.5),
panel.border = element_blank(),
panel.background = element_blank(),
panel.grid.major = element_blank(),
panel.grid.minor = element_blank(),
axis.title = element_blank(),
axis.text = element_blank(),
axis.ticks = element_blank())
plots_list[[i]] <- p
}
# combine the heatmaps into one figure
combined_heatmaps <- annotate_figure(ggarrange(
annotate_figure(ggarrange(plots_list[[1]], plots_list[[3]], plots_list[[5]],
ncol = 3, nrow = 1),
left = text_grob("Initial response",
size = 30, rot = 90, hjust = 0.5)),
annotate_figure(ggarrange(plots_list[[2]], plots_list[[4]], plots_list[[6]],
ncol = 3, nrow = 1),
left = text_grob("Settled response",
size = 30, rot = 90, hjust = 0.5)),
ncol = 1, nrow = 2,
widths = c(1, 1)),
top = text_grob("Model representations after learning", size = 40))
combined_heatmaps
# Plot legend
as_ggplot(get_legend(
ggplot(melt(dat_temp_M), aes(Var1, ordered(Var2, levels = rev(sort(unique(Var2)))), fill = value)) +
geom_tile() +
scale_fill_gradientn(colours = pp,
values = scales::rescale(c(-0.04, 1)),
breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1),
labels = c(0, 0.2, 0.4, 0.6, 0.8, 1),
limits = c(-0.04,1)) +
theme(legend.title = element_blank(),
legend.text = element_text(size = 16),
legend.position = c("left"))
))