from os import listdir, makedirs
import numpy as np
from scipy.optimize import least_squares
# import matplotlib.pyplot as plt
import sys
import pickle
from collections import defaultdict


# If plotting is desired, uncomment the matplotlib import statement above before running.
# If running from the command line is desired, uncomment 'if __name__' block at
# bottom of file.


class Dataset:
    """This class stores the dataset that is parsed by the Fit class below.
    One could swap out the equations frac_x and frac_a with any other equations
    (still named frac_x, frac_a) of any number of variables and the fitting
    should still work."""

    def __init__(self, **kwargs):
        self.name = kwargs.get('name')
        self.group = kwargs.get('group')
        self.global_vars = kwargs.get('global_vars')
        self.local_vars = kwargs.get('local_vars')
        self.group_global_vars = kwargs.get('group_global_vars')
        self.data_activity = kwargs.get('data_activity')
        self.data_crosslink = kwargs.get('data_crosslink')
        self.data_mg_conc = kwargs.get('data_mg_conc')
        self.upper_bounds = kwargs.get('upper_bounds')
        self.lower_bounds = kwargs.get('lower_bounds')
        self.initial_guess = kwargs.get('initial_guess')
        self.fit_activity = None  # This is where the fit values (to the exact datapoints) are stored.
        self.fit_crosslink = None
        self.fit_activity_interpolate = None  # This is where interpolated values of the fit are stored.
        self.fit_crosslink_interpolate = None
        self.fit_interpolate_mg_conc = None
        self._param_indicies = dict()
        self.var_order_frac_x = None
        self.var_order_frac_a = None
        self.fit_activity_bootstrap = None
        self.fit_crosslink_bootstrap = None

    def frac_x(self, kd, k1, k2, k3, kss, a1, a2):
        """Fraction cross-linked."""
        return ((((1+self.data_mg_conc/kss)**2)*(k1+(k1*k2*a1)+(k1*k2*k3*a1*a2)+(k1*k3)))/(((1+self.data_mg_conc/kd)**2)*(k3+(k2*k3*a2+k2+1))+((1+self.data_mg_conc/kss)**2)*((k1*k2*k3*a1*a2)+(k1*k3)+k1+(k1*k2*a1))))

    def frac_a(self, s, kd, k1, k2, k3, kss, a1, a2):
        """Fraction active. Note the scaling parameter s has been added."""
        return (s*(((1+self.data_mg_conc/kd)**2)*(k3+(k2*k3*a2))+((1+self.data_mg_conc/kss)**2)*((k1*k2*k3*a1*a2)+(k1*k3)))/(((1+self.data_mg_conc/kd)**2)*(k3+(k2*k3*a2)+k2+1)+((1+self.data_mg_conc/kss)**2)*((k1*k2*k3*a1*a2)+(k1*k3)+k1+(k1*k2*a1))))

    @staticmethod
    def get_variable_order(func):
        """Returns the order of variables in a function (func), e.g.,
        for a func defined 'def func(a,b,c): return whatever', get_variable_order(func)
        returns ('a', 'b', 'c').
        """
        return func.__code__.co_varnames

    def set_variable_orders(self):
        self.var_order_frac_x = self.get_variable_order(self.frac_x)
        self.var_order_frac_a = self.get_variable_order(self.frac_a)


class FitData:
    """FitData can run a single fit calculation (.fit()), using the initial
    conditions from the input files, or it can run a multi-start fit calculation
    (.fit_rand_search()) using random starting points chosen on a log scale
    between the variable bounds.  Bootstrapped confidence intervals of the fit
    can be obtained with method .bootstrap_ci().  An example of how to run a fit
    can be seen in the main() function below."""

    def __init__(self, path_to_datasets):
        self.path_to_datasets = path_to_datasets
        self.datasets = list()
        self.globals = dict()
        self.upper_bounds = list()
        self.lower_bounds = list()
        self.initial_cond = list()
        self.mean_data_activity = None
        self.mean_data_crosslink = None
        self.scale = None
        self.lsq_res = None
        self._variable_map_dict = dict()
        self.params_fit = dict()
        self.params_fit_min = dict()
        self.lsq_res_rand = list()
        self._min_index = None
        self.lsq_res_bootstrap = list()
        self.params_mean_bootstrap = None
        self.params_std_bootstrap = None
        self.params_ci_low_bootstrap = None
        self.params_ci_high_bootstrap = None
        self.params_fit_min_ci = dict()

    @staticmethod
    def parse_data(file):
        """This function parses the dataset text file and stores the info
        in the appropriate variables."""
        print(file)
        kwargs = dict()
        data_mg_conc = list()
        data_activity = list()
        data_crosslink = list()
        initial_guess = dict()
        upper_bounds = dict()
        lower_bounds = dict()
        data = False
        with open(file, 'r') as infile:
            for line in infile:
                line = line.strip()
                if data:
                    if not line:
                        break
                    datapoint = line.split(',')
                    data_mg_conc.append(float(datapoint[0]))
                    data_activity.append(float(datapoint[1]))
                    data_crosslink.append(float(datapoint[2]))
                else:
                    if line[0] == '#':
                        continue
                    spl = line.split('=')
                    if spl[0] == 'name':
                        kwargs['name'] = spl[1]
                    if spl[0] == 'global_vars':
                        if spl[1]:
                            kwargs['global_vars'] = set([g.strip() for g in spl[1].lower().split(',')])
                        else:
                            kwargs['global_vars'] = None
                    if spl[0] == 'group_global_vars':
                        if spl[1]:
                            kwargs['group_global_vars'] = set([g.strip() for g in spl[1].lower().split(',')])
                        else:
                            kwargs['group_global_vars'] = None
                    if spl[0] == 'group':
                        if spl[1]:
                            kwargs['group'] = tuple([g.strip() for g in spl[1].split(',')])
                        else:
                            kwargs['group'] = None
                    if spl[0] == 'local_vars':
                        if spl[1]:
                            kwargs['local_vars'] = set([l.strip() for l in spl[1].lower().split(',')])
                        else:
                            kwargs['local_vars'] = None
                    if spl[0] == 'initial_guess':
                        guess_split = spl[1].split(',')
                        for guess in guess_split:
                            guess = guess.strip()
                            var = guess.split(':')[0]
                            initial_guess[var] = float(guess.split(':')[1])
                        kwargs['initial_guess'] = initial_guess
                    if spl[0] == 'bounds':
                        bounds_split = spl[1].split(',')
                        for bound in bounds_split:
                            bound = bound.strip()
                            var = bound.split(':')[0]
                            bound = bound.split(':')[1].replace('[', '').replace(']', '').split()
                            if bound[0] == '-inf':
                                lower_b = -np.inf  # Note that inf bounds will likely not work
                                                   # with the random initial condition sampling code below.
                            else:
                                lower_b = float(bound[0])
                            if bound[1] == 'inf':
                                upper_b = np.inf
                            else:
                                upper_b = float(bound[1])
                            upper_bounds[var] = upper_b
                            lower_bounds[var] = lower_b
                        kwargs['upper_bounds'] = upper_bounds
                        kwargs['lower_bounds'] = lower_bounds
                    if len(spl) == 1:
                        data = True
            kwargs['data_mg_conc'] = np.array(data_mg_conc)
            kwargs['data_activity'] = np.array(data_activity)
            kwargs['data_crosslink'] = np.array(data_crosslink)
        return kwargs

    def setup(self):
        """This parses all datasets in the dataset directory and initializes bounds,
         initial conditions, and variable index mappings.  It also determines a scaling
         constant by which to scale all the activity data before global fitting.  (If the
         activity data are not scaled, then its residuals will dominate the global fit.)
         The scaling constant ensures the mean value of all activity and crosslinking datasets
         is equal before the fitting procedure begins.  After the fit, the activity data
         are rescaled back to their original scale."""

        for file in [f for f in listdir(self.path_to_datasets) if f[0] != '.']:
            kwargs = self.parse_data(self.path_to_datasets + file)
            self.datasets.append(Dataset(**kwargs))

        i = 0
        self.globals = dict()
        group_globals = defaultdict(dict)
        for dataset in self.datasets:
            dataset.set_variable_orders()
            self._variable_map_dict[dataset.name] = dict()
            if dataset.global_vars is not None:
                for glob in dataset.global_vars:
                    if glob in self.globals:
                        self._variable_map_dict[dataset.name][glob] = self.globals[glob]
                    else:
                        self._variable_map_dict[dataset.name][glob] = i
                        self.globals[glob] = i
                        i += 1
            has_key = False
            if dataset.group_global_vars is not None:
                for key in group_globals.keys():
                    if dataset.name in key:
                        has_key = True
                        break

            if dataset.group_global_vars is None:
                pass

            elif has_key:
                for var in dataset.group_global_vars:
                    if var in group_globals[key]:
                        self._variable_map_dict[dataset.name][var] = group_globals[key][var]
                    else:
                        self._variable_map_dict[dataset.name][var] = i
                        group_globals[key][var] = i
                        i += 1
            else:
                for var in dataset.group_global_vars:
                    self._variable_map_dict[dataset.name][var] = i
                    group_globals[dataset.group][var] = i
                    i += 1

            if dataset.local_vars is not None:
                for loc in dataset.local_vars:
                    self._variable_map_dict[dataset.name][loc] = i
                    i += 1

        init_cond = list(range(i))
        upper_bounds = list(range(i))
        lower_bounds = list(range(i))
        for dataset in self.datasets:
            for var, index in self._variable_map_dict[dataset.name].items():
                init_cond[index] = dataset.initial_guess[var]
                upper_bounds[index] = dataset.upper_bounds[var]
                lower_bounds[index] = dataset.lower_bounds[var]
        self.initial_cond = init_cond
        self.upper_bounds = upper_bounds
        self.lower_bounds = lower_bounds

        for dataset in self.datasets:
            print(dataset.data_activity, dataset.name)

        self.mean_data_activity = np.mean([dataset.data_activity for dataset in self.datasets])
        self.mean_data_crosslink = np.mean([dataset.data_crosslink for dataset in self.datasets])
        self.scale = self.mean_data_activity / self.mean_data_crosslink

    def residual(self, x):
        """This is the function the least_squares minimizer is minimizing.  It is the
        residuals of the frac_a, frac_x functions and the respective datasets."""

        res = list()
        for dataset in self.datasets:
            args_a = [x[self._variable_map_dict[dataset.name][var]] for var in dataset.var_order_frac_a[1:]]
            args_x = [x[self._variable_map_dict[dataset.name][var]] for var in dataset.var_order_frac_x[1:]]
            res.append(dataset.frac_a(*args_a) - dataset.data_activity / self.scale)  # Activity is scaled.
            res.append(dataset.frac_x(*args_x) - dataset.data_crosslink)
        return np.array(res).flatten()

    def residual_bootstrap(self, x):
        """This is the function the least_squares minimizer is minimizing during
        bootstrapping for parameter confidence intervals.  The new datasets are generated directly
        from the best fit parameters + randomly chosen residuals, so the activity dataset does not
        need to be scaled in this case."""

        res = list()
        for dataset in self.datasets:
            args_a = [x[self._variable_map_dict[dataset.name][var]] for var in dataset.var_order_frac_a[1:]]
            args_x = [x[self._variable_map_dict[dataset.name][var]] for var in dataset.var_order_frac_x[1:]]
            res.append(dataset.frac_a(*args_a) - dataset.data_activity_bootstrap)
            res.append(dataset.frac_x(*args_x) - dataset.data_crosslink_bootstrap)
        return np.array(res).flatten()

    def fit(self):
        """Performs a single fitting and stores the result in *lsq_res*.  Best fit params are
        extracted to *params_fit*."""
        if not self.datasets:
            self.setup()

        self.lsq_res = least_squares(self.residual, self.initial_cond, bounds=(self.lower_bounds, self.upper_bounds))
        self._extract_params_fit()

    @staticmethod
    def log10uniform(low=0, high=1, size=None):
        """low and high should be the result of np.log10(low), np.log10(high).
        Note that all returned values are positive numbers, so bounds can not
        be negative."""
        return np.power(10, np.random.uniform(low, high, size))

    def fit_rand_search(self, num_trials=100, seed=None):
        """Performs multiple (*num_trials*) fits from initial conditions chosen
        randomly on a uniform log scale to be between the bounds set in the dataset
        text file.  The random number generator can be initialized with *seed* (integer).
        Note that bounds can not be negative for this uniform log scale
        sampling (The code would need to amended for that.).  The results of all fits are
        stored in *lsq_res_rand* and parameters of the best fit are stored in
         *params_fit_min*."""

        if seed:
            np.random.seed(seed)

        if not self.datasets:
            self.setup()

        size = (num_trials, len(self.initial_cond))
        initial_conds = np.zeros(size)
        for i, (lb, ub) in enumerate(zip(self.lower_bounds, self.upper_bounds)):
            # if np.sign(lb) == -1 or np.sign(ub) == -1 or lb == 0 or ub == 0:
            #     initial_conds[:, i] = np.random.uniform(lb, ub, num_trials)
            if lb == 0:
                high = np.log10(ub)
                initial_conds[:, i] = self.log10uniform(high*10**-6, high, num_trials)
                inds = np.random.choice(num_trials, int(np.ceil(0.06 * num_trials)))
                initial_conds[inds, i] = 0
            else:
                low = np.log10(lb)
                high = np.log10(ub)
                initial_conds[:, i] = self.log10uniform(low, high, num_trials)

        for initial_cond in initial_conds:
            lsq_res = least_squares(self.residual, initial_cond, bounds=(self.lower_bounds, self.upper_bounds))
            self.lsq_res_rand.append(lsq_res)

        self._extract_best_params_fit()

    def _extract_best_params_fit(self):
        """sets *params_fit_min* after the datasets are fit by .fit_rand_search()."""

        costs = sorted([(lsq_res.cost, i) for i, lsq_res in enumerate(self.lsq_res_rand)])
        min_cost = costs[0]
        self._min_index = min_cost[1]
        lsq_res = self.lsq_res_rand[self._min_index]

        for dataset in self.datasets:
            self.params_fit_min[dataset.name] = dict()
            for var in dataset.var_order_frac_a[1:]:
                self.params_fit_min[dataset.name][var] = lsq_res.x[self._variable_map_dict[dataset.name][var]]
            for var in dataset.var_order_frac_x[1:]:
                self.params_fit_min[dataset.name][var] = lsq_res.x[self._variable_map_dict[dataset.name][var]]

    def _extract_best_params_fit_bootstrap(self):
        """sets *params_fit_min_ci* after the bootstrapping of the best fit is performed."""

        low = self.params_ci_low_bootstrap
        high = self.params_ci_high_bootstrap

        for dataset in self.datasets:
            self.params_fit_min_ci[dataset.name] = dict()
            for var in dataset.var_order_frac_a[1:]:
                self.params_fit_min_ci[dataset.name][var] = (low[self._variable_map_dict[dataset.name][var]],
                                                             high[self._variable_map_dict[dataset.name][var]])
            for var in dataset.var_order_frac_x[1:]:
                self.params_fit_min_ci[dataset.name][var] = (low[self._variable_map_dict[dataset.name][var]],
                                                             high[self._variable_map_dict[dataset.name][var]])

    def _extract_params_fit(self):
        """sets *params_fit* after the datasets are fit by .fit."""

        for dataset in self.datasets:
            self.params_fit[dataset.name] = dict()
            for var in dataset.var_order_frac_a[1:]:
                self.params_fit[dataset.name][var] = self.lsq_res.x[self._variable_map_dict[dataset.name][var]]
            for var in dataset.var_order_frac_x[1:]:
                self.params_fit[dataset.name][var] = self.lsq_res.x[self._variable_map_dict[dataset.name][var]]

    def _bootstrap(self, num_trials=100, fit='min', seed=None):
        """Performs bootstrapping of the best fit variables.
        Data is refit *num_trials* times. fit='min' denotes the best fit variables
        are derived from fit_rand_search. *seed* controls the random number generator
        for generating randomly chosen residuals from the best fit.  The bootstrapped fits
        are stored in lsq_res_bootstrap."""

        if seed:
            np.random.seed(seed)

        if fit == 'min':
            fit = self.lsq_res_rand[self._min_index]
        else:
            fit = self.lsq_res

        residuals = fit.fun
        initial_cond = fit.x

        for trial in range(num_trials):

            for dataset in self.datasets:
                if self.params_fit_min:
                    param_dict = self.params_fit_min[dataset.name]
                else:
                    param_dict = self.params_fit[dataset.name]
                args_a = [param_dict[var] for var in dataset.var_order_frac_a[1:]]
                args_x = [param_dict[var] for var in dataset.var_order_frac_x[1:]]

                pick = np.random.randint(0, len(residuals), len(dataset.data_activity))
                residuals_boot_a = residuals[pick]
                pick = np.random.randint(0, len(residuals), len(dataset.data_crosslink))
                residuals_boot_x = residuals[pick]

                dataset.data_activity_bootstrap = dataset.frac_a(*args_a) + residuals_boot_a
                dataset.data_crosslink_bootstrap = dataset.frac_x(*args_x) + residuals_boot_x

            lsq_res_boot = least_squares(self.residual_bootstrap, initial_cond,
                                         bounds=(self.lower_bounds, self.upper_bounds))
            self.lsq_res_bootstrap.append(lsq_res_boot)

    def bootstrap_ci(self, num_trials=100, fit='min', ci_low=25, ci_high=75, status=(2,)):
        """Performs bootstrapping of the best fit variables and calculate confidence intervals.
           Defaults are lower bound = 25%, upper bound = 75%, so a 50% confidence interval.
           Data is refit *num_trials* times. fit='min' denotes the best fit variables
           are derived from fit_rand_search. *seed* controls the random number generator
           for generating randomly chosen residuals from the best fit.  The bootstrapped fits
           are stored in lsq_res_bootstrap.  The best params lower and upper ci bounds are stored
           in .params_fit_min_ci."""

        if not self.lsq_res_bootstrap:
            self._bootstrap(num_trials, fit)

        fits = [lsq_res.x for lsq_res in self.lsq_res_bootstrap if lsq_res.status in status]
        self.params_mean_bootstrap = np.mean(fits, axis=0)
        self.params_std_bootstrap = np.std(fits, axis=0)
        self.params_ci_low_bootstrap = np.percentile(fits, ci_low, axis=0)
        self.params_ci_high_bootstrap = np.percentile(fits, ci_high, axis=0)
        self._extract_best_params_fit_bootstrap()

    def _plot_fit(self, dataset):
        """Plots the fit for *dataset*."""

        if dataset.data_mg_conc[0] == 0:
            lb = np.log10(0.1 * dataset.data_mg_conc[1])
        else:
            lb = np.log10(0.1 * dataset.data_mg_conc[0])
        ub = np.log10(dataset.data_mg_conc[-1] * 10)
        mg_conc_interp = np.logspace(lb, ub, 100)

        param_dict = self.params_fit[dataset.name]
        args_a = [param_dict[var] for var in dataset.var_order_frac_a[1:]]
        args_x = [param_dict[var] for var in dataset.var_order_frac_x[1:]]

        dataset.fit_activity = self.scale * dataset.frac_a(*args_a)
        dataset.fit_crosslink = dataset.frac_x(*args_x)

        _data_mg_conc = dataset.data_mg_conc
        dataset.data_mg_conc = mg_conc_interp
        dataset.fit_activity_interpolate = self.scale * dataset.frac_a(*args_a)
        dataset.fit_crosslink_interpolate = dataset.frac_x(*args_x)
        dataset.data_mg_conc = _data_mg_conc

        plt.figure()
        plt.title('Activity of ' + dataset.name)
        plt.semilogx(mg_conc_interp, dataset.fit_activity_interpolate, 'k',
                 _data_mg_conc, dataset.fit_activity, 'go',
                 _data_mg_conc, dataset.data_activity, 'bx')
        plt.xlabel('[Mg] (M)')
        plt.ylabel('Activity (Miller Units)')

        plt.figure()
        plt.title('Fraction crosslinked of ' + dataset.name)
        plt.semilogx(mg_conc_interp, dataset.fit_crosslink_interpolate, 'k',
                 _data_mg_conc, dataset.fit_crosslink, 'go',
                 _data_mg_conc, dataset.data_crosslink, 'bx')

        plt.xlabel('[Mg] (M)')
        plt.ylabel('Fraction crosslinked')

    def plot_fit(self):
        """Plots fits separately for all datasets."""

        for dataset in self.datasets:
            self._plot_fit(dataset)

    def grid_plot(self, chunk_size=8, nrows=4, ncols=4):
        """Plots the fits in 4 rows by 4 column grids for all datasets."""

        for i in range(0, len(self.datasets), chunk_size):
            self._grid_plot(self.datasets[i:i + chunk_size], nrows, ncols)

    def _grid_plot(self, datasets, nrows=4, ncols=4):
        """Plots fits for *datasets* on a 4 row by 4 col grid."""

        plt.figure()
        j = 1
        for i in range(len(datasets)):
            dataset = datasets[i]

            if dataset.data_mg_conc[0] == 0:
                lb = np.log10(0.1 * dataset.data_mg_conc[1])
            else:
                lb = np.log10(0.1 * dataset.data_mg_conc[0])
            ub = np.log10(dataset.data_mg_conc[-1] * 10)
            mg_conc_interp = np.logspace(lb, ub, 100)

            param_dict = self.params_fit[dataset.name]
            args_a = [param_dict[var] for var in dataset.var_order_frac_a[1:]]
            args_x = [param_dict[var] for var in dataset.var_order_frac_x[1:]]

            dataset.fit_activity = self.scale * dataset.frac_a(*args_a)
            dataset.fit_crosslink = dataset.frac_x(*args_x)

            _data_mg_conc = dataset.data_mg_conc
            dataset.data_mg_conc = mg_conc_interp
            dataset.fit_activity_interpolate = self.scale * dataset.frac_a(*args_a)
            dataset.fit_crosslink_interpolate = dataset.frac_x(*args_x)
            dataset.data_mg_conc = _data_mg_conc

            plt.subplot(nrows, ncols, j)
            plt.title('Act: ' + dataset.name, fontsize=10, fontweight='bold')
            plt.semilogx(mg_conc_interp, dataset.fit_activity_interpolate, 'k',
                         _data_mg_conc, dataset.fit_activity, 'go',
                         _data_mg_conc, dataset.data_activity, 'bx')
            plt.ylim(-10, 100)

            if j in list(range(nrows * ncols - ncols, nrows * ncols + 1)):
                plt.xlabel('[Mg] (M)', fontsize=10)
            plt.ylabel('Activity (Miller Units)', fontsize=10)

            plt.subplot(nrows, ncols, j + 1)
            plt.title('X-link: ' + dataset.name, fontsize=10, fontweight='bold')
            plt.semilogx(mg_conc_interp, dataset.fit_crosslink_interpolate, 'k',
                         _data_mg_conc, dataset.fit_crosslink, 'go',
                         _data_mg_conc, dataset.data_crosslink, 'bx')
            plt.ylim(-0.1, 1.1)

            if j in list(range(nrows * ncols - ncols, nrows * ncols + 1)):
                plt.xlabel('[Mg] (M)', fontsize=10)
            plt.ylabel('Fraction crosslinked', fontsize=10)
            j += 2


path = '/wynton/home/degradolab/nick.polizzi/phoq/012721/data/'
def main():
    """Performs fitting of datasets in *path*.
        Does not perform bootstrapping. Stores the
        fit in outpath/fit.pkl"""
    seed = sys.argv[1]  # Must be an integer.
    fit = FitData(path)
    fit.fit_rand_search(num_trials=250, seed=int(seed))
    fit.lsq_res_rand = [fit.lsq_res_rand[fit._min_index]]

    outpath = '/wynton/scratch/nick.polizzi/phoq/fits/' + str(seed) + '/'

    try:
       makedirs(outpath)
    except FileExistsError:
       pass

    with open(outpath + 'fit.pkl', 'wb') as outfile:
       pickle.dump(fit, outfile)


# uncomment this block to run on command line.
if __name__ == '__main__':
    # pass
    main()


# example to run from command line:
#       $ python phoq_fit.py 123
# '123' is the seed for random number generation.


# example to run from command line:
#       $ python phoq_fit.py 123
# '123' is the seed for random number generation.


# Fitting was run on the UCSF Wynton HPC with the following command:
# $  qsub phoq_fit.job
# The best fit was found via:
# import os
# import pickle
# costs = []
# for d in os.listdir('.'):  # Current directory is output directory of fitting.
#     with open(d + '/fit.pkl', 'rb') as infile:
#         fit = pickle.load(infile)
#     lsq_res = fit.lsq_res_rand[0]
#     costs.append((lsq_res.cost, lsq_res.status, d))
# costs_sorted = sorted(costs)
# best_fit = costs_sorted[0]  # Best fit was that with seed 372, which was then
#                             # used for bootstrapping.

# Plot best fit in an iPython session:
# run phoq_fit_local_global_ipython.py
# %matplotlib
# import pickle
# with open('/Volumes/disk1/phoq/fits/372/fit.pkl', 'rb') as infile:
#     fit = pickle.load(infile)
# fit.params_fit = fit.params_fit_min
# fit.grid_plot()
# plt.tight_layout()

# # In python session, print fit parameters and boundaries using in fitting:
# run phoq_fit_local_global_ipython.py
# from collections import defaultdict
# import pickle
# with open('/Volumes/disk1/phoq/fits/372/fit.pkl', 'rb') as infile:
#     fit = pickle.load(infile)
# inv = defaultdict(list)
# num_to_var = dict()
# for k, d in fit._variable_map_dict.items():
#     for v, n in d.items():
#         inv[n].append(k)
#         num_to_var[n] = v
# for i, (u, l), in enumerate(zip(fit.upper_bounds, fit.lower_bounds)):
#     print(i, num_to_var[i], fit.lsq_res_rand[0].x[i], l, u)
#     # parameter number, parameter name, fit value, lower boundary used in fitting, upper boundary uised in fitting
# for k, v in inv.items():
#     print(k, num_to_var[k], v)
#     # parameter number, parameter name, associated datasets.
#     print()

# Bootstrapping for estimation of parameter confidence intervals was run
# on UCSF Wynton HPC with the following command:
# $  qsub phoq_fit_ci.job

# #Setup parameters for plotting confidence intervals in an iPython session:
# cd ~/Projects/phoq/src/022521/
# run phoq_fit_local_global_ipython.py
# %matplotlib
# import pickle
# from collections import defaultdict
# fits = []
# for i in range(1,101):
#     with open('/Volumes/disk1/phoq/fits/372/372/fit_ci_' + str(i) + '.pkl', 'rb') as infile:
#         fit_ci = pickle.load(infile)
#         fits.extend([lsq_res.x for lsq_res in fit_ci.lsq_res_bootstrap if lsq_res.status in (2,)])
# fits_ = np.array(fits)
#
# inv = defaultdict(list)
# num_to_var = dict()
# for k, d in fit_ci._variable_map_dict.items():
#     for v, n in d.items():
#         inv[n].append(k)
#         num_to_var[n] = v
# inds_fixed = np.where(np.array(fit_ci.upper_bounds) - np.array(fit_ci.lower_bounds) < 1)[0]
#
# with open('/Volumes/disk1/phoq/fits/372/fit.pkl', 'rb') as infile:
#     fit = pickle.load(infile)

# #Plotting confidence intervals:
# for s in [0, 24, 48]:
#     plt.figure();
#     for i in range(0 + s, 24 + s):
#         if i == 63:
#             break
#         if i not in inds_fixed:
#             plt.subplot(6, 4, i + 1 - s);
#             n_, b_, p_ = plt.hist(np.log10(fits_[:, i]), bins='scott')
#             plt.xlim([np.log10(fit.lower_bounds[i]) - 0.1, np.log10(fit.upper_bounds[i]) + 0.1])
#             max_ = max(n_)
#             # plt.ylim([-0.1 * max_, 1.1 * max_])
#             plt.plot(np.log10(fit.lsq_res_rand[0].x[i]), 50, '|', markersize=6, color='r', mew=1.5)
#             plt.plot(np.log10(fit.lsq_res_rand[0].x[i]), 25, '|', markersize=6, color='r', mew=1.5)
#             plt.plot(np.log10(fit.lsq_res_rand[0].x[i]), 0, '|', markersize=6, color='r', mew=1.5)
#         else:
#             plt.subplot(6, 4, i + 1 - s)
#         title = 'p' + str(i) + ', ' + num_to_var[i]
#         if len(inv[i]) > 1:
#             title += '*'
#         if i in inds_fixed:
#             title = 'Fixed ' + title + ' = 1'
#         plt.xlabel(r'$\log_{10}$' + '(' + title + ')', fontsize=10)
#         if i - s in list(range(0, 6 * 4, 4)):
#             plt.ylabel('Counts', fontsize=10)
#     plt.tight_layout()

# #Plotting residual sweeps:
# for s in [0, 24, 48]:
#     plt.figure();
#     for i in range(0 + s, 24 + s):
#         fixed_vals = np.logspace(np.log10(fit.lower_bounds[i]), np.log10(fit.upper_bounds[i]), 100)
#         costs = []
#         for p in fixed_vals:
#             x = fit.lsq_res_rand[0].x.copy()
#             x[i] = p
#             r = fit.residual(x)
#             costs.append(np.sum(r ** 2))
#         if i not in inds_fixed:
#             plt.subplot(6, 4, i + 1 - s);
#             plt.plot(np.log10(fixed_vals), np.log10(costs), '-')
#             plt.plot(np.log10(fit.lsq_res_rand[0].x[i]), np.log10(fit.lsq_res_rand[0].cost * 2), 'x', markersize=6,
#                      color='r', mew=1.5)
#         else:
#             plt.subplot(6, 4, i + 1 - s)
#         title = 'p' + str(i) + ', ' + num_to_var[i]
#         if len(inv[i]) > 1:
#             title += '*'
#         if i in inds_fixed:
#             title = 'Fixed ' + title + ' = 1'
#         plt.xlabel(r'$\log_{10}$' + '(' + title + ')', fontsize=10)
#         if i - s in list(range(0, 6 * 4, 4)):
#             plt.ylabel('$\log_{10}(\sum R^2 )$', fontsize=10)
#     plt.tight_layout()

# Calculating R^2 value:
# var_a = np.sum(np.array([(dataset.data_activity - fit.mean_data_activity)/fit.scale for dataset in fit.datasets])**2)
# var_c = np.sum(np.array([(dataset.data_crosslink - fit.mean_data_crosslink) for dataset in fit.datasets])**2)
# rss = np.sum(fit.lsq_res_rand[0].fun**2)
# tss = var_a + var_c
# R2 = 1 - rss/tss










