//                                                                   -*- C++ -*-
#include <cmath>
#include <cstdio>
#include <cassert>

#include "util.h"
#include "cuda.h"
#include "array.h"
#include "struct.h"
#include "dims.h"
#include "statevar.h"
#include "tabulation.cuh"
#include "tabulate.cuh"
#include "parameters.cuh"
#include "step.h"
#include "math.h"

#include "functions.tcc"

__device__
real current_reaction(statevar_t *v, int idx, real volt)
{
    real current;

    real index = TAB_INDEX(M_inf, volt);
    real ca_total = STRUCT_GET_ADDRESS(v, Cai_total, idx);
    real ca = free_of_total(ca_total, CA_BUF_C, K_BUF_C);

    assert(TAB_WITHIN(H_inf, volt));
    assert(TAB_WITHIN(Rev_potential_ca, ca));

    real cur_kr = G_KR * STRUCT_GET_ADDRESS(v, Xr1, idx) *
	STRUCT_GET_ADDRESS(v, Xr2, idx) *
	(volt - const_E_K);

    real cur_ks = G_KS * SQR(STRUCT_GET_ADDRESS(v, Xs, idx)) *
	(volt - const_E_Ks_log_by_RTtoF);

    real cur_k1 = xk1_abs(volt) * (volt - const_E_K);

    real cur_k_to = G_TRAN_OUT * STRUCT_GET_ADDRESS(v, R, idx) *
	STRUCT_GET_ADDRESS(v, S, idx) * (volt - const_E_K);

    real cur_na = G_NA * CUBE(STRUCT_GET_ADDRESS(v, M, idx)) *
	STRUCT_GET_ADDRESS(v, H, idx) *
	STRUCT_GET_ADDRESS(v, J, idx) * (volt - const_E_Na);

    real cur_na_bg = G_NA_BACKGROUND * (volt - const_E_Na);

    real ca_ss_total = STRUCT_GET_ADDRESS(v, CaSS_total, idx);
    real ca_ss = free_of_total(ca_ss_total, CA_BUF_SS, K_BUF_SS);

    real cur_ca_ltype = STRUCT_GET_ADDRESS(v, D, idx) *
	STRUCT_GET_ADDRESS(v, F, idx) *
    	STRUCT_GET_ADDRESS(v, F2, idx) *
	STRUCT_GET_ADDRESS(v, FCaSS, idx) *
    	(TAB_VALUE(G_cal_first_abs, index) * ca_ss -
    	 TAB_VALUE(G_cal_second_abs, index));

    real cur_ca_bg = G_CA_BACKGROUND * (volt - rev_potential_ca(ca));

    real cur_nacax =
	TAB_VALUE(P_nacax_first_abs, index) -
	TAB_VALUE(P_nacax_second_abs, index) * ca;

    real cur_pump = TAB_VALUE(Kna_pump_abs, index);

    real cur_ca_plateau = G_PLATEAU_CA * ca / (K_PLATEAU_CA + ca);

    real cur_k_plateau = k_plateau_abs(volt) * (volt - const_E_K);

    real cur_defib = G_BIODEF * pow(STRUCT_GET_ADDRESS(v, Defib_a, idx),8) * (volt - E_BIODEF);

    current = cur_kr + cur_ks + cur_k1 + cur_k_to + cur_na + cur_na_bg
    	+ cur_ca_ltype + cur_ca_bg + cur_pump + cur_nacax + cur_ca_plateau +
    	cur_k_plateau + cur_defib;

    real ca_sr_total = STRUCT_GET_ADDRESS(v, CaSR_total, idx);

    real ca_sr = free_of_total(ca_sr_total, CA_BUF_SR, K_BUF_SR);

    real k_ca_sr = CA_MAX_SR - (CA_MAX_SR - CA_MIN_SR) /
	(REAL(1.0) + SQR(CA_EC / ca_sr));

    real r_bar = STRUCT_GET_ADDRESS(v, Rbar, idx) + TIMESTEP *
	(CA_K4 * (REAL(1.0) - STRUCT_GET_ADDRESS(v, Rbar, idx)) -
	CA_K2 * k_ca_sr * ca_ss * STRUCT_GET_ADDRESS(v, Rbar, idx));
    STRUCT_SET_ADDRESS(v, Rbar, idx, r_bar);

    real o_gate = (CA_K1 / k_ca_sr) * SQR(ca_ss) * r_bar /
	(CA_K3 + (CA_K1 / k_ca_sr) * SQR(ca_ss));

    real cur_rel = VOLUME_REL * o_gate * (ca_sr - ca_ss);
    real cur_leak = VOLUME_LEAK * (ca_sr - ca);
    real serca = V_MAX_UP / (REAL(1.0) + SQR(K_UP / ca));

    ca_sr_total += TIMESTEP * (serca - cur_rel - cur_leak);

    STRUCT_SET_ADDRESS(v, CaSR_total, idx, ca_sr_total);

    real j_xfer = VOLUME_XFER * (ca_ss - ca);

    ca_ss_total += TIMESTEP * (- j_xfer * VOLUME_C / VOLUME_SS +
	       cur_rel * VOLUME_SR / VOLUME_SS +
		       (- cur_ca_ltype  / (REAL(2.0) * VOLUME_SS * FARADEY) *
                              CAPACITANCE));

    STRUCT_SET_ADDRESS(v, CaSS_total, idx, ca_ss_total);

    ca_total += TIMESTEP * (
	(- (cur_ca_bg + cur_ca_plateau - REAL(2.0) * cur_nacax) *
	 (REAL(1.0) / (REAL(2.0) * VOLUME_C * FARADEY)) * CAPACITANCE) +
	(cur_leak - serca) * (VOLUME_SR / VOLUME_C) + j_xfer);

    STRUCT_SET_ADDRESS(v, Cai_total, idx, ca_total);

#define UPDATE(A, a)							\
     real a##_inf_val = a##_inf(volt);					\
 	STRUCT_SET_ADDRESS(v, A, idx, a##_inf_val -		\
 		(a##_inf_val - STRUCT_GET_ADDRESS(v, A, idx)) *	\
 			TAB_VALUE(A##_tau_er, index))

    UPDATE(M, m);
    UPDATE(H, h);
    UPDATE(J, j);
    UPDATE(Xr1, xr1);
    UPDATE(Xr2, xr2);
    UPDATE(Xs, xs);
    UPDATE(S, s);
    UPDATE(R, r);
    UPDATE(D, d);
    UPDATE(F, f);
    UPDATE(F2, f2);
    /*UPDATE(Clock, clock);*/
#undef UPDATE

    real ca_ss_f_inf_val = ca_ss_f_inf(ca_ss);
    STRUCT_SET_ADDRESS(v, FCaSS, idx,
    	ca_ss_f_inf_val -
	     (ca_ss_f_inf_val - STRUCT_GET_ADDRESS(v, FCaSS, idx)) *
		    ca_ss_tau_er(ca_ss));

    STRUCT_SET_ADDRESS(v, Defib_a, idx,
        STRUCT_GET_ADDRESS(v, Defib_a, idx) + 
        TIMESTEP / 700.0 * ((volt < -60.0 ? 0.0 : 0.5) * (1 - STRUCT_GET_ADDRESS(v, Defib_a, idx)) - ((volt < -60.0 ? 3.5: 0.0) + (1 ? 10000.0 : 0.0)*pow(STRUCT_GET_ADDRESS(v, Defib_c, idx),5) ) * STRUCT_GET_ADDRESS(v, Defib_a, idx)
        );

    STRUCT_SET_ADDRESS(v, Defib_c, idx,
        STRUCT_GET_ADDRESS(v, Defib_c, idx) + 
        TIMESTEP / 700.0 * (((-55< volt && volt < 0.0) ? 0.5 : 0.0) * (1 - STRUCT_GET_ADDRESS(v, Defib_c, idx)) - ((volt <= -55.0 || volt>=0.0) ? 100.0 : 0.0) * STRUCT_GET_ADDRESS(v, Defib_c, idx))
        );

    return current;
}


__global__
void step_kernel(real *u, real *u_next, statevar_t *vars, weights_t *wts)
{
    int j = blockIdx.x * blockDim.x + threadIdx.x;
    int i = blockIdx.y * blockDim.y + threadIdx.y;
    int tj = threadIdx.x + 1;
    int ti = threadIdx.y + 1;
    int idx = ADDRESS(i, j);
    real volt = ARRAY_GET_ADDRESS(u, idx);

    __shared__ real tile[BLOCK_DIMY + 2][BLOCK_DIMX + 2];

    if (threadIdx.x == 0) {
    	tile[ti][0] = ARRAY_GET(u, i, j - 1);
    	tile[ti][BLOCK_DIMX + 1] = ARRAY_GET(u, i, j + BLOCK_DIMX);
    }

    if (threadIdx.y == 0) {
    	tile[0][tj] = ARRAY_GET(u, i - 1, j);
    	tile[BLOCK_DIMY + 1][tj] = ARRAY_GET(u, i + BLOCK_DIMY, j);
    }

    tile[ti][tj] = volt;
    __syncthreads();

    real current = current_reaction(vars, idx, volt);

    real cur_ax = const_Diff_to_Step_2 *
    	(STRUCT_GET_ADDRESS(wts, next_i, idx) * (tile[ti + 1][tj] - volt) +
    	 STRUCT_GET_ADDRESS(wts, prev_i, idx) * (tile[ti - 1][tj] - volt) +
    	 STRUCT_GET_ADDRESS(wts, next_j, idx) * (tile[ti][tj + 1] - volt) +
    	 STRUCT_GET_ADDRESS(wts, prev_j, idx) * (tile[ti][tj - 1] - volt));

    ARRAY_SET_ADDRESS(u_next, idx, volt + TIMESTEP * (cur_ax - current));
}

extern "C"
void step(real *volt, real *volt_next, statevar_t *vars, weights_t *weights)
{
    dim3 block(BLOCK_DIMX, BLOCK_DIMY);
    dim3 grid(DIMX / BLOCK_DIMX, DIMY / BLOCK_DIMY);

    step_kernel <<<grid, block, 0, Cuda_Stream[COMPUTATION]>>>
	(volt, volt_next, vars, weights);
}
