using Graphs, Parameters, Random, Distributions, GLMakie, StaticArrays, Statistics, JLD

include("cryptfunctions.jl")

println("Run the following functions for the various simulations in paper")
println("fiveFUsim(tmax;dose,doseschedule)")
println("fiveFUvideo(tmax;dose,doseschedule)")
println("fiveFUki67(dose,doseschedule)")
println("ablationsim(tmax;celltype,duration,doseschedule)")
println("ablationvideo(tmax;celltype,duration,doseschedule)")
println("cdkinhibsim(tmax;duration,doseschedule)")
println("cdkinhibvideo(tmax;duration,doseschedule)")
println("brdusim()")
println("----Analysis Plots----")
println("countsplot(), ki67plot(), brduplot(), FUplot(), plot3d(), ki67plot3d(), detailsplot()")

const circum = 9.0
const rad = circum/(2*pi)
const nummy = 32
const cue = 1
const g1dur = 0.25882999999999967
const stemmylim = 3.5
const c = ((1/0.8)^(cue) + 1)/2
const Inc0 = 4
const dzero = (c/g1dur)*log(c/(c-1))
const kwnt = 64 * dzero / Inc0
const stemdiv = 21.5
const absdiv = 10
const dee = 2*(stemdiv - absdiv)
const p0 = 3.2
const pannydeff = 54.0
const kdi_p = 0.8
const Vatffac = 1.0
const sphasedur = 8.0
const massconstslow = 1.175
const Vsifacslow = 1.0
const kawee_pslow = 0.97
const massconstfast = 0.962
const Vsifacfast = 0.1
const kawee_pfast = 2.76

# signalling constants
const ijratio = 1
const Wntrange = 0.35
const Notchrange = 0.0
const knotch = 200
const wntbound = 64
const pannynotchold = 3
const notchold = 2
const overnotchold = 5
const DNAdamagethreshold = 0.25
const RNAdamagethreshold = 0.25
const apoptime = 12/24

# other constants
const movescale = 10000
const cyclescale = Int(100000/movescale)
const dt = 1 / (movescale)
const stdev = 40
const offsets = [CartesianIndex(-1, -1, -1), CartesianIndex(0, -1, -1), CartesianIndex(1, -1, -1), CartesianIndex(-1, 0, -1), CartesianIndex(0, 0, -1), CartesianIndex(1, 0, -1), CartesianIndex(-1, 1, -1), CartesianIndex(0, 1, -1), CartesianIndex(1, 1, -1), CartesianIndex(-1, -1, 0), CartesianIndex(0, -1, 0), CartesianIndex(1, -1, 0), CartesianIndex(-1, 0, 0)]
const bigoffsets = [CartesianIndex(-1, -1, -1), CartesianIndex(-1, -1, 0), CartesianIndex(-1, -1, 1), CartesianIndex(-1, 0, -1), CartesianIndex(-1, 0, 0), CartesianIndex(-1, 0, 1), CartesianIndex(-1, 1, -1), CartesianIndex(-1, 1, 0), CartesianIndex(-1, 1, 1), CartesianIndex(0, -1, -1), CartesianIndex(0, -1, 0), CartesianIndex(0, -1, 1),             CartesianIndex(0, 0, -1), CartesianIndex(0, 0, 0), CartesianIndex(0, 0, 1),             CartesianIndex(0, 1, -1), CartesianIndex(0, 1, 0), CartesianIndex(0, 1, 1), CartesianIndex(1, -1, -1), CartesianIndex(1, -1, 0), CartesianIndex(1, -1, 1), CartesianIndex(1, 0, -1), CartesianIndex(1, 0, 0), CartesianIndex(1, 0, 1), CartesianIndex(1, 1, -1), CartesianIndex(1, 1, 0), CartesianIndex(1, 1, 1)]
const spacing = 1.5

function fiveFUvideo(tmax; dose = 50, doseschedule = [1.0,1.5,2.0,2.5,3.0,3.5,4.0,4.5])
    println("---------LET'SA GO!---------")
    
    cells, villus, consts = setup("jeju")

    if dose == .0
        doseschedule = []
    end

    indoseschedule = round.(doseschedule/dt,digits=0)

    for i in eachindex(cells)
        cells[i].ID = [i]
    end

    fiveFU = initfiveFU(dose,doseschedule)
    cellcounts, villuscounts, fiveFUtrack, ki67, brdu, celldeets = fancyinitcounts(tmax,length(cells),consts.mintop)
    fiveFUtrack[:,1] = fiveFU[:]
    updatedacounts(cells,villus,cellcounts,villuscounts,0,tmax,consts.mintop)

    positions, markers, colours = plotstuff(cells)
    points = Observable(positions)
    colournode = Observable(colours)
    markernode = Observable(markers)

    figgy = Figure(resolution = (2000,1200))
    ax2 = figgy[1:2,4:6] = Axis(figgy, title = "ZNRF3", ylabel = "z", xlabel = "ZNRF3 [A.U.]")
    ax3 = figgy[1:2,7:9] = Axis(figgy, title = rich("Wnt"), xlabel = rich("log", subscript("2"),"Wnt [A.U.]"))
    ax4 = figgy[1:2,10:12] = Axis(figgy, title = "BMP", xlabel = "BMP [A.U.]")
    ax5 = figgy[1:2,13:15] = Axis(figgy, title = "Notch", xlabel = "Notch [A.U.]")

    ax6 = figgy[4,1:5] = Axis(figgy, title = "Crypt Cell Counts", xlabel = "Time [days]", ylabel = "Number of cells")
    ax7 = figgy[4,6:10] = Axis(figgy, title = "Villus Cell Counts", xlabel = "Time [days]", ylabel = "Number of cells")
    ax8 = figgy[4,11:15] = Axis(figgy, title = "Concentration of 5FU and metabolites", xlabel = "Time [days]", ylabel = "log(Concentration + 1) [ng/ml]",yticks=0:5)

    colaz = [:yellow,:yellow4,:orange,:red,:green,:purple,:blue,:turquoise,:darkgrey,:brown4,:black]

    meshscatter(figgy[1:2,1:3],points,markersize=markernode,color=colournode)
    
    trange = 0:dt*50:50*dt
    lins = [lines!(ax6,trange.-1,cellcounts[k,1:2],color=colaz[k],linewidth=3) for k = 1:length(cellcounts[:,1])]
    otherlins = [lines!(ax7,trange.-1,villuscounts[k,1:2],color=colaz[k],linewidth=3) for k = 1:length(cellcounts[:,1])]
    lines!(ax8,trange.-1, log10.(1 .+fiveFUtrack[1,1:2]),color=:dodgerblue2,linewidth=3,label="5FU plasma")
    lines!(ax8,trange.-1, log10.(1 .+fiveFUtrack[5,1:2]),color=:goldenrod2,linewidth=3,label="FdUTP")
    lines!(ax8,trange.-1, log10.(1 .+fiveFUtrack[7,1:2]),color=:mediumseagreen,linewidth=3,label="FUTP")
    axislegend(ax8,labelsize=10)
    Legend(figgy[3,1:15],lins,["Stem", "Uncom. Prog.", "Abs. Prog.", "Enterocyte", "Sec. Prog.", "Goblet", "Paneth", "Entero.", "Arrested", "Apoptotic", "Total"],orientation = :horizontal, nbanks=1)
    
    ylims!(ax2,(-rad, 17))
    ylims!(ax3,(-rad, 17))
    ylims!(ax4,(-rad, 17))
    ylims!(ax5,(-rad, 17))
    ylims!(ax6,(-15,370))
    ylims!(ax7,(-50,1050))
    ylims!(ax8,(-0.25,5))


    xlims!(ax2,(-0.1,3+stemmylim))
    xlims!(ax3,(-3.1,7.1))
    xlims!(ax4,(-0.1,32.1))
    xlims!(ax5,(-0.1,6.1))
    xlims!(ax6,(-1.1,tmax+0.1-1))
    xlims!(ax7,(-1.1,tmax+0.1-1))
    xlims!(ax8,(-1.1,tmax+0.1-1))

    display(figgy)

    record(figgy,"lastrun.mp4", 0:Int(tmax/(50*dt))-1; framerate=30, compression=20) do i
        for j in 1:50
            fiveFU = mainloop5FU(consts,cells,villus,cellcounts,villuscounts,fiveFU,fiveFUtrack,dose,indoseschedule,ki67,brdu,celldeets,tmax,50*i+j)
        end

        points[], markernode[], colournode[] = plotstuff(cells)

        empty!(ax2)
        empty!(ax3)
        empty!(ax4)
        empty!(ax5)
        empty!(ax6)
        empty!(ax7)
        empty!(ax8)

        zeds = Vector{Float64}(undef,length(cells))
        wntvec = Vector{Float64}(undef,length(cells))
        znrf3strength = Vector{Float64}(undef,length(cells))
        BMPvec = Vector{Float64}(undef,length(cells))
        notchvec = Vector{Float64}(undef,length(cells))
        cooolls = Vector{Symbol}(undef,length(cells))
        testvec = znrf3calc(cells)

        enteronum = count(x->typeof(x)==entero,villus)
        BMP = BMPcalc(enteronum,consts.nent,consts.p)
        top = topcalc(enteronum,consts.mintop,consts.nent)


        ylims!(ax2,(-rad, round(top)))
        ylims!(ax3,(-rad, round(top)))
        ylims!(ax4,(-rad, round(top)))
        ylims!(ax5,(-rad, round(top)))

        for k in eachindex(cells)
            znrf3strength[k] = znrf3field(cells[k],testvec)
            zeds[k] = cells[k].pos[3]
            BMPvec[k] = BMP * exp(-log(8)/(top - consts.zent)*(top - cells[k].pos[3]))
            wntvec[k] = log(2,cells[k].boundwnt)
            notchvec[k] = cells[k].notch
            cooolls[k] = colour(cells[k])
        end

        zs = -rad:0.01:24

        scatter!(ax2,znrf3strength,zeds,color=cooolls)
        scatter!(ax3,wntvec,zeds,color=cooolls)
        text!(ax3,(5.9,15),text="Paneth/Stem Wnt Threshold", rotation = pi/2, align = (:right, :baseline))
        scatter!(ax4,BMPvec,zeds,color=cooolls)
        scatter!(ax5,notchvec,zeds,color=cooolls)
        vlines!(ax3,6,color=:black)

        trange = 0:dt*50:(i+1)*50*dt
        for k in 1:length(cellcounts[:,1])
            lines!(ax6,trange.-1,cellcounts[k,1:i+2],color=colaz[k],linewidth=3)
            lines!(ax7,trange.-1,villuscounts[k,1:i+2],color=colaz[k],linewidth=3)
        end
        lines!(ax8,trange.-1, log10.(1 .+fiveFUtrack[1,1:i+2]),color=:dodgerblue2,linewidth=3,label="5FU plasma")
        lines!(ax8,trange.-1, log10.(1 .+fiveFUtrack[5,1:i+2]),color=:goldenrod2,linewidth=3,label="FdUTP")
        lines!(ax8,trange.-1, log10.(1 .+fiveFUtrack[7,1:i+2]),color=:mediumseagreen,linewidth=3,label="FUTP")
    end

    global crypttype = "jeju"
    postprocess(cells,villus,cellcounts,villuscounts,dose,doseschedule,ki67,brdu,celldeets,tmax)
end

function brdusim()
    println("May be slow as producing average of 8 crypts")
    println("---------LET'SA GO!---------")

    cvmat, consts = subset("ileum")

    cryptinds = rand(1:10,8)
    villusinds = rand(1:10,8)
    tmax = 7

    brdumat = zeros(Int,8,2,Int(consts.mintop*5),ceil(Int,tmax*24*3))
    cellcountsmat = Array{Int,3}(undef,8,11,ceil(Int,tmax*20)+1)
    villuscountsmat = Array{Int,3}(undef,8,11,ceil(Int,tmax*20)+1)

    Threads.@threads for k = 1:8
        brdumat[k,:,:,:], cellcountsmat[k,:,:], villuscountsmat[k,:,:] = brdusuperloop(cvmat[cryptinds[k],1],cvmat[villusinds[k],2],consts,tmax)
    end

    global crypttype = "ileum"
    global brdu = sumbrdu(brdumat)
    global cellcounts = avgcount(cellcountsmat)
    global villuscounts = avgcount(villuscountsmat)
    global tee = range(0, tmax, length=size(cellcounts,2))
    global doseschedule = []
    global dose = 0
    figure = brduplot()
    display(figure)
end

function fiveFUki67()
    println("May be slow as producing average of 8 crypts")
    println("---------LET'SA GO!---------")

    cvmat, consts = subset("jeju")

    cryptinds = rand(1:10,8)
    villusinds = rand(1:10,8)
    tmax = 10

    ki67mat = zeros(Int,8,2,Int(consts.mintop*5),ceil(Int,tmax*24*3))
    cellcountsmat = Array{Int,3}(undef,8,11,ceil(Int,tmax*20)+1)
    villuscountsmat = Array{Int,3}(undef,8,11,ceil(Int,tmax*20)+1)

    Threads.@threads for k = 1:8
        ki67mat[k,:,:,:], cellcountsmat[k,:,:], villuscountsmat[k,:,:] = ki67superloop(cvmat[cryptinds[k],1],cvmat[villusinds[k],2],consts,tmax)
    end

    global crypttype = "jeju"
    global ki67 = sumbrdu(ki67mat)
    global cellcounts = avgcount(cellcountsmat)
    global villuscounts = avgcount(villuscountsmat)
    global tee = range(0, tmax, length=size(cellcounts,2))
    global doseschedule = doseschedule
    global dose = dose
    figure = ki67plot()
    display(figure)
end

function fiveFUsim(tmax; dose = 50, doseschedule = [1.0,1.5,2.0,2.5,3.0,3.5,4.0,4.5])
    println("---------LET'SA GO!---------")

    cells, villus, consts = setup("jeju")

    if dose == .0
        doseschedule = []
    end

    indoseschedule = round.(doseschedule/dt,digits=0)

    for i in eachindex(cells)
        cells[i].ID = [i]
    end

    fiveFU = initfiveFU(dose,doseschedule)
    cellcounts, villuscounts, fiveFUtrack, ki67, brdu, celldeets = initcounts(tmax,length(cells),consts.mintop)
    fiveFUtrack[:,1] = fiveFU[:]
    updatedacounts(cells,villus,cellcounts,villuscounts,0,tmax,consts.mintop)

    for i in 1:ceil(Int,tmax/dt)+1
        fiveFU = mainloop5FU(consts,cells,villus,cellcounts,villuscounts,fiveFU,fiveFUtrack,dose,indoseschedule,ki67,brdu,celldeets,tmax,i)
    end

    global fiveFUtrack = fiveFUtrack
    global crypttype = "jeju"

    postprocess(cells,villus,cellcounts,villuscounts,dose,doseschedule,ki67,brdu,celldeets,tmax)
    figure = countsplot()
    display(figure)
end

function ablationsim(tmax; celltype = stem, duration = 12/24, doseschedule = [1.0,1.5,2.0,2.5,3.0,3.5,4.0,4.5])
    println("---------LET'SA GO!---------")

    cells, villus, consts = setup("ablation")

    for i in eachindex(cells)
        cells[i].ID = [i]
    end

    cellcounts, villuscounts, fiveFUtrack, ki67, brdu, celldeets = initcounts(tmax,length(cells),consts.mintop)
    updatedacounts(cells,villus,cellcounts,villuscounts,0,tmax,consts.mintop)

    for i in 1:ceil(Int,tmax/dt)+1
        mainloopablat(consts,cells,villus,cellcounts,villuscounts,fiveFUtrack,celltype,duration,doseschedule,ki67,brdu,celldeets,tmax,i)
    end

    global crypttype = "ablation"
    postprocess(cells,villus,cellcounts,villuscounts,0,doseschedule,ki67,brdu,celldeets,tmax)
    figure = countsplot()
    display(figure)
end

function ablationvideo(tmax; celltype = stem, duration = 12/24, doseschedule = [1.0,1.5,2.0,2.5,3.0,3.5,4.0,4.5])
    println("---------LET'SA GO!---------")

    cells, villus, consts = setup("ablation")

    for i in eachindex(cells)
        cells[i].ID = [i]
    end

    cellcounts, villuscounts, fiveFUtrack, ki67, brdu, celldeets = fancyinitcounts(tmax,length(cells),consts.mintop)
    updatedacounts(cells,villus,cellcounts,villuscounts,0,tmax,consts.mintop)

    positions, markers, colours = plotstuff(cells)
    points = Observable(positions)
    colournode = Observable(colours)
    markernode = Observable(markers)

    figgy = Figure(resolution = (1500,1000))
    lscene = LScene(figgy[1, 1], show_axis=false)
    meshscatter!(lscene,points,markersize=markernode,color=colournode)
    ax2 = figgy[1,2][1,1] = Axis(figgy, title = "Crypt Cell Counts", xlabel = "Time [days]", ylabel = "Number of cells")
    ax3 = figgy[1,2][2,1] = Axis(figgy, title = "Villus Cell Counts", xlabel = "Time [days]", ylabel = "Number of cells")
    colaz = [:yellow,:yellow4,:orange,:red,:green,:purple,:blue,:turquoise,:darkgrey,:brown4,:black]
    
    lins = [lines!(ax2,-1,cellcounts[k,1],color=colaz[k],linewidth=3) for k = 1:length(cellcounts[:,1])]
    Legend(figgy[1,2][3,1],lins,["Stem", "Uncom. Prog.", "Abs. Prog.", "Enterocyte", "Sec. Prog.", "Goblet", "Paneth", "Entero.", "Arrested", "Apoptotic", "Total"],orientation = :horizontal, nbanks=2)
       
    xlims!(ax2,(-0.1,tmax+0.1))
    xlims!(ax3,(-0.1,tmax+0.1))
    ylims!(ax2,(-15,300))
    ylims!(ax3,(-50,750))

    display(figgy)

    record(figgy,"lastrun.mp4", 0:Int(tmax/(50*dt))-1; framerate=30, compression=20) do i
        for j in 1:50
            mainloopablat(consts,cells,villus,cellcounts,villuscounts,fiveFUtrack,celltype,duration,doseschedule,ki67,brdu,celldeets,tmax,50*i+j)
        end
        points[], markernode[], colournode[] = plotstuff(cells)

        empty!(ax2)
        empty!(ax3)
        if length(doseschedule) > 0
            vlines!(ax2,doseschedule,color=:darkgrey,linewidth=:3,linestyle=:dot)
            vlines!(ax3,doseschedule,color=:darkgrey,linewidth=:3,linestyle=:dot)
        end

        trange = 0:dt*50:(i+1)*50*dt
        for k in 1:length(cellcounts[:,1])
            lines!(ax2,trange,cellcounts[k,1:i+2],color=colaz[k],linewidth=3)
            lines!(ax3,trange,villuscounts[k,1:i+2],color=colaz[k],linewidth=3)
        end
    end

    global crypttype = "ablation"
    postprocess(cells,villus,cellcounts,villuscounts,0,doseschedule,ki67,brdu,celldeets,tmax)
end

function cdkinhibsim(tmax; duration = 6/24, doseschedule = [1.0,1.5,2.0,2.5,3.0,3.5,4.0,4.5])
    println("---------LET'SA GO!---------")

    cells, villus, consts = setup("ileum")

    for i in eachindex(cells)
        cells[i].ID = [i]
    end

    cellcounts, villuscounts, fiveFUtrack, ki67, brdu, celldeets = initcounts(tmax,length(cells),consts.mintop)
    updatedacounts(cells,villus,cellcounts,villuscounts,0,tmax,consts.mintop)

    for i in 1:ceil(Int,tmax/dt)+1
        mainloopCDK(consts,cells,villus,cellcounts,villuscounts,fiveFUtrack,duration,doseschedule,ki67,brdu,celldeets,tmax,i)
    end

    global crypttype = "ileum"
    postprocess(cells,villus,cellcounts,villuscounts,0,doseschedule,ki67,brdu,celldeets,tmax)
    figure = countsplot()
    display(figure)
end

function cdkinhibvideo(tmax; duration = 6/24, doseschedule = [1.0,1.5,2.0,2.5,3.0,3.5,4.0,4.5])
    println("---------LET'SA GO!---------")

    cells, villus, consts = setup("ileum")

    for i in eachindex(cells)
        cells[i].ID = [i]
    end

    cellcounts, villuscounts, fiveFUtrack, ki67, brdu, celldeets = fancyinitcounts(tmax,length(cells),consts.mintop)
    updatedacounts(cells,villus,cellcounts,villuscounts,0,tmax,consts.mintop)

    positions, markers, colours = plotstuff(cells)
    points = Observable(positions)
    colournode = Observable(colours)
    markernode = Observable(markers)

    figgy = Figure(resolution = (1500,1000))
    lscene = LScene(figgy[1, 1], show_axis=false)
    meshscatter!(lscene,points,markersize=markernode,color=colournode)
    ax2 = figgy[1,2][1,1] = Axis(figgy, title = "Crypt Cell Counts", xlabel = "Time [days]", ylabel = "Number of cells")
    ax3 = figgy[1,2][2,1] = Axis(figgy, title = "Villus Cell Counts", xlabel = "Time [days]", ylabel = "Number of cells")
    colaz = [:yellow,:yellow4,:orange,:red,:green,:purple,:blue,:turquoise,:darkgrey,:brown4,:black]
    
    lins = [lines!(ax2,-1,cellcounts[k,1],color=colaz[k],linewidth=3) for k = 1:length(cellcounts[:,1])]
    Legend(figgy[1,2][3,1],lins,["Stem", "Uncom. Prog.", "Abs. Prog.", "Enterocyte", "Sec. Prog.", "Goblet", "Paneth", "Entero.", "Arrested", "Apoptotic", "Total"],orientation = :horizontal, nbanks=2)
     
    xlims!(ax2,(-0.1,tmax+0.1))
    xlims!(ax3,(-0.1,tmax+0.1))
    ylims!(ax2,(-15,300))
    ylims!(ax3,(-50,1100))

    display(figgy)

    record(figgy,"lastrun.mp4", 0:Int(tmax/(50*dt))-1; framerate=30, compression=20) do i
        for j in 1:50
            mainloopCDK(consts,cells,villus,cellcounts,villuscounts,fiveFUtrack,duration,doseschedule,ki67,brdu,celldeets,tmax,50*i+j)
        end
        points[], markernode[], colournode[] = plotstuff(cells)

        empty!(ax2)
        empty!(ax3)
        if length(doseschedule) > 0
            vlines!(ax2,doseschedule,color=:darkgrey,linewidth=:3,linestyle=:dot)
            vlines!(ax3,doseschedule,color=:darkgrey,linewidth=:3,linestyle=:dot)
        end

        trange = 0:dt*50:(i+1)*50*dt
        for k in 1:length(cellcounts[:,1])
            lines!(ax2,trange,cellcounts[k,1:i+2],color=colaz[k],linewidth=3)
            lines!(ax3,trange,villuscounts[k,1:i+2],color=colaz[k],linewidth=3)
        end
    end

    global crypttype = "ileum"
    postprocess(cells,villus,cellcounts,villuscounts,0,doseschedule,ki67,brdu,celldeets,tmax)
end
;