import numpy as np
import pandas as pd
import math
import os
import json
from os.path import join
from tifffile import tifffile as tif
from scipy.stats import norm
from scipy.interpolate import interp1d
from scipy.optimize import curve_fit, minimize
from scipy.integrate import solve_ivp
from scipy.stats import rv_continuous
from scipy.stats import lognorm
from scipy.signal import convolve
import matplotlib.pyplot as plt
import numba
# local modules
from stepfilled import hard_edge
from convolve import *

DATA = os.path.join('..', 'data')
FIGS = os.path.join('..', 'figs')
if not os.path.exists(FIGS):
    os.mkdir(FIGS)
SHOW = False

UB = 260000
LB = 0
cutoff = 45000
p0 = [0.33, 14000, 0.5]  # inital guess for monomer peak
# radius used in detection, used for filtering out nearby spots
RAD = 3  # > 1 

files = {}
join = os.path.join
with open(join(DATA, 'names.json')) as jf:
    fls = json.load(jf)
for k in fls:
    files[k] = join(DATA, k)
print(files)

namesm = list(files.keys())
namesm = sorted(namesm, key=lambda s: ' '.join(s.split(' ')[1:]))
names = namesm.copy()
names.remove('Monomers')

N = len(namesm)

# Matplotlib parameters
plotstyle = {'figure.figsize': (6, 4),
             'font.size': 10,
             'lines.linewidth': 2}
plt.rcParams.update(plotstyle)

heat = plt.matplotlib.cm.gist_heat
earth = plt.matplotlib.cm.gist_earth

clu = {}
for i, k in enumerate(files):
    clu[k] = heat(1/N*i) if 'Act' in k else earth(1/N*i)


@numba.jit(nopython=True)
def sparsify(xy, dist):
    """
    Only keep points greater than dist from each other
    """
    n = xy.shape[0]
    keep = np.ones(n, dtype=np.int8)
    for i in range(n):
        for j in range(n):
            if i == j:
                continue
            x0, y0 = xy[i,0], xy[i,1]
            x1, y1 = xy[j,0], xy[j,1]
            r = np.sqrt((x1-x0)**2 + (y1-y0)**2)
            if r < dist:
                keep[i] = 0
                keep[j] = 0
    return keep

def keep(grp, dist):
    xy = grp[['position_x', 'position_y']].values
    keep = sparsify(xy, dist)
    return grp[keep.astype(bool)]

def get_spots(path):
    spots = pd.read_csv(path + '.csv', index_col=0)
    spots.columns = map(str.lower, spots.columns)
    # intensity = spots['total_intensity'] - bg[name] * spots['pixel_count']
    spots['net_intensity'] = spots['total_intensity'] - spots['outer_intensity'] * spots['pixel_count']
    spots = spots[(spots.net_intensity > LB) & (spots.net_intensity < UB)]
    spots = spots.groupby('frame').apply(keep, 2*(RAD + 1))
    return spots

# # Test filtering out nearby neighbors
# df = spots['Monomers']
# dff = spots['Monomers'].groupby('frame').apply(keep, 10)
# print(df.shape), print(dff.shape)
# fig, ax = plt.subplots(ncols=2)
# fr = 1
# ax[0].set_aspect('equal')
# ax[1].set_aspect('equal')
# df[df.frame == fr].plot(ax=ax[0], kind='scatter', x='position_x', y='position_y', s=5)
# dff[dff.frame == fr].plot(ax=ax[1], kind='scatter', x='position_x', y='position_y', s=5)
# plt.show()

spots = {k: get_spots(files[k]) for k in files}
nets = {k: spots[k].net_intensity.values for k in files}
# FA repeats under different imagine conditions, looking for doublets
# old camera is 12 bit. Convert to 16
faa = ['Activated FA 2','Activated FA 3']
fau = ['Unactivated FA 2','Unactivated FA 3']
for f in faa:
    nets[f] *= 4
    nets[f] = nets[f][nets[f] < UB]
for f in fau:
    nets[f] *= 4
    nets[f] = nets[f][nets[f] < UB]

pool = np.concatenate(list(nets.values()))
print('Pool max', pool.max())

# All data
plt.close('all')
_, bins = np.histogram(pool, 120)
xmp = (bins[1:] + bins[:-1])/2
subareas = {}
top = (abs(bins - cutoff)).argmin()
vals = {}
subbins = bins[:top+1]
histstyle = {'histtype': 'stepfilled', 'density': True, 'bins': bins}
for n, s in nets.items():
    v, _, p = plt.hist(s, label=n, **histstyle)
    vals[n] = v
    subareas[n] = sum(np.diff(subbins) * v[:top])
    hard_edge(p, 0.3)
plt.xlabel('Integrated intensity')
plt.ylabel('Probability')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(FIGS, 'all_data.pdf'))
plt.show() if SHOW else plt.close()

# check FA constructs
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(5, 6), sharex=True)
axs = axes.ravel()
for ax, n0, n1 in zip(axs, ['Activated FA 1'] + faa, ['Unactivated FA 2'] + fau):
    s0, s1 = nets[n0], nets[n1]
    _, _, p = ax.hist(s0, label=n0, color=clu[n0], **histstyle)
    hard_edge(p, 0.3)
    _, _, p = ax.hist(s1, label=n1, color=clu[n1], **histstyle)
    hard_edge(p, 0.3)
    ax.legend()
if ax != axs[-1]:
    ax, n0, s0 = axs[-1], names[-1], nets[names[-1]]
    _, _, p = ax.hist(s0, label=n0, color=clu[n0], **histstyle)
    hard_edge(p, 0.3)
    ax.legend()
plt.xlabel('Integrated intensity')
plt.ylabel('Probability')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(FIGS, 'all FA.pdf'))
plt.show() if SHOW else plt.close()



# All data
fig, axes = plt.subplots(nrows=N//2, ncols=1, figsize=(5, N//2*2.5), sharex=True)
axs = axes.ravel()
for ax, n0, n1 in zip(axs, names[:-1:2], names[1::2]):
    s0, s1 = nets[n0], nets[n1]
    _, _, p = ax.hist(s0, label=n0, color=clu[n0], **histstyle)
    hard_edge(p, 0.3)
    _, _, p = ax.hist(s1, label=n1, color=clu[n1], **histstyle)
    hard_edge(p, 0.3)
    ax.legend()
if ax != axs[-1]:
    ax, n0, s0 = axs[-1], names[-1], nets[names[-1]]
    _, _, p = ax.hist(s0, label=n0, color=clu[n0], **histstyle)
    hard_edge(p, 0.3)
    ax.legend()
plt.xlabel('Integrated intensity')
plt.ylabel('Probability')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(FIGS, 'all pairs.pdf'))
plt.show() if SHOW else plt.close()


# WT only
fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(5, 3*2.5), sharex=True, sharey=True)
style = {'histtype': 'stepfilled', 'density': True, 'bins': bins}
wta = [f for f in names if 'WT' in f and 'Act' in f]
wtu = [f for f in names if 'WT' in f and 'Unact' in f]
for ax, n0, n1 in zip(axes.ravel(), wta, wtu):
    s0, s1 = nets[n0], nets[n1]
    _, _, p = ax.hist(s0, label=n0, color=clu[n0], **histstyle)
    hard_edge(p, 0.3)
    _, _, p = ax.hist(s1, label=n1, color=clu[n1], **histstyle)
    hard_edge(p, 0.3)
    ax.legend()
plt.xlabel('Integrated intensity')
plt.ylabel('Probability')
plt.legend()
plt.tight_layout()
plt.savefig(os.path.join(FIGS, 'all WT.pdf'))
plt.show() if SHOW else plt.close()


# Get dimer peaks
fig, axes = plt.subplots(nrows=N, ncols=1, figsize=(5, N*2.5), sharex=True, sharey=True)
vals_minor = {}
for ax, n in zip(axes.ravel(), namesm):
    s = nets[n]
    if n is not 'Monomers':
        subset = s[s < cutoff]
    else:
        subset = s[s < cutoff/2]  # extended tail skews fit
    w = np.ones_like(subset)/len(subset)/np.diff(subbins)[0]*subareas[n]
    v, _, p = ax.hist(subset, bins=subbins, weights=w,
                         histtype='stepfilled', color=clu[n])
    A = sum(np.diff(subbins) * v)
    print(f'{n} area: {A}')
    vals_minor[n] = v/A
    hard_edge(p, 0.3)
    ax.set_title(n + ' minor peak')
ax.relim()
plt.tight_layout()
plt.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
plt.savefig(os.path.join(FIGS, 'all dimer2fit.pdf'))
plt.show() if SHOW else plt.close()

def get_lognorm_support(s, loc, scale, xn=1000):
    # top will extend exponentially with more 9's.
    # better to get double 0.99
    top = lognorm(s, loc, scale).ppf(0.99)
    return np.linspace(loc, 2*top, xn)

# fit small peak to "dimer", with dark
def dimer(xs, s, scale, nu, dx=5, mx=None):
    d = lognorm(loc=0, s=s, scale=scale)
    xa = get_lognorm_support(s, 0, scale, 5000)
    m1 = PMF(xa, d.pdf(xa))
    m1 = renorm_pmf(m1, 0, (1-nu))
    m2 = convolvePMF(m1, 1, mx=mx)[0]
    d2 = pmf2pdf(renorm_pmf(m2, 0, 0))
    return d2(xs)

def dimer_error(xs, xdata=None, ydata=None):
    y = dimer(xdata, *xs) # minimization chooses wild paramters
    return ((1e5*(y-ydata))**2).sum()

# test dimer function
# n = names[0]
# plt.hist(nets[n][nets[n] < cutoff], bins=subbins, density=True)
# plt.plot(subx, dimer(subx, *p0))
# plt.show()


# Fit dimer peaks
# fitting is strongly sensitive to initial conditions
subx = (subbins[1:] + subbins[:-1])/2
popts = {}
# bounds = ((0,0,0), (5,cutoff,1))
# Note using minimize with bounds will test edge cases, make sure functions work
# scale parameters (p0[0]) larger than 1.5 create extremely sharp delta peaks
bounds = [(p0[0]/1e2, 1.5), (p0[1]/1e2, cutoff), (0, 1)]
for n in vals_minor:
    print(n)
    if n is not 'Monomers':
        res = minimize(dimer_error, p0, args=(subx, vals_minor[n]), bounds=bounds)
        print(res)
        print('------------\n')
        popts[n] = res.x
    else:
        popts[n] = lognorm.fit(nets[n], floc=0)
        print(popts[n])



# Fitted dimer peaks (each)
fig, axes = plt.subplots(nrows=N, ncols=1, figsize=(5, N*2.5), sharex=True, sharey=True)
xs = np.linspace(subbins.min(), subbins.max(), 500)
for ax, n in zip(axes.ravel(), namesm):
    s = nets[n]
    subset = s[s < cutoff]
    # w = np.ones_like(subset)/len(subset)/np.diff(subbins)[0]*subareas[n]
    _, _, p = ax.hist(subset, bins=subbins, density=1, histtype='stepfilled', color=clu[n])
    hard_edge(p, 0.3)
    if n is not 'Monomers':
        ax.plot(xs, dimer(xs, *popts[n]), color=clu[n], lw=3, ls='--')
    else:
        ax.plot(xs, lognorm(*popts[n]).pdf(xs), color=clu[n], lw=3, ls='--')
    ax.set_title(n + ' minor peak')
ax.relim()
plt.tight_layout()
plt.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
plt.savefig(os.path.join(FIGS, 'all dimer fit.pdf'))
plt.show() if SHOW else plt.close()


# Fitted dimer peaks (single model)
params = list(np.vstack(list(popts.values())).mean(axis=0)[:2])
fixed_dimer = lambda x, f: dimer(x, params[0], params[1], f)
fig, axes = plt.subplots(nrows=N, ncols=1, figsize=(5, N*2.5), sharex=True, sharey=True)
xs = np.linspace(subbins.min(), subbins.max(), 500)
for ax, n in zip(axes.ravel(), namesm):
    s = nets[n]
    subset = s[s < cutoff]
    # w = np.ones_like(subset)/len(subset)/np.diff(subbins)[0]*subareas[n]
    _, _, p = ax.hist(subset, bins=subbins, density=1, histtype='stepfilled', color=clu[n])
    hard_edge(p, 0.3)
    p, _ = curve_fit(fixed_dimer, subx, vals_minor[n], p0=[0.5])
    if n is not 'Monomers':
        ax.plot(xs, dimer(xs, *params, p[0]))
    else:
        ax.plot(xs, lognorm(params[0], 0, params[1]).pdf(xs))
    ax.set_title(n + ' minor peak')
ax.relim()
plt.tight_layout()
plt.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
plt.savefig(os.path.join(FIGS, 'all dimer fit global.pdf'))
plt.show() if SHOW else plt.close()

# Fitted dimer peaks (each) with monomer
fig, axes = plt.subplots(nrows=N, ncols=1, figsize=(5, N*2.8), sharex=True, sharey=True)
xs = np.linspace(subbins.min(), subbins.max(), 500)
for ax, n in zip(axes.ravel(), namesm):
    s = nets[n]
    subset = s[s < cutoff]
    # w = np.ones_like(subset)/len(subset)/np.diff(subbins)[0]*subareas[n]
    _, _, p = ax.hist(subset, bins=subbins, density=1, histtype='stepfilled', color=clu[n], label=n + ' minor peak')
    hard_edge(p, 0.3)
    popt = popts[n]
    nu = popt[-1]
    nu1 = 2*nu*(1-nu)
    nu2 = nu**2
    monf = nu1/(nu1+nu2)
    if n is not "Monomers":
        d2 = interpolate(xs, dimer(xs, *popt))
        ax.plot(xs, d2(xs), color='black', label='Fitted Dimer', lw=3, ls='--', alpha=0.5)
        ax.plot(xs, monf*lognorm(loc=0, s=popt[0], scale=popt[1]).pdf(xs), label='Inferred Monomer', color=clu[n], lw=3, ls='--', alpha=0.5)
    else:
        monf = 1.0
        ax.plot(xs, lognorm(*popt).pdf(xs), label='Inferred Monomer', color=clu[n], lw=3, ls='--', alpha=0.5)
    ax.xaxis.set_tick_params(which='both', labelbottom=True)
    ax.yaxis.set_tick_params(which='both', labelleft=True)
    ax.set_xlabel('Net Intensity (b16)')
    ax.set_ylabel('Probability')
    # ax.text(0.6, 0.5, f"% bright monomer {monf*100:.02f}", transform=ax.transAxes)
    # ax.set_title(n + ' minor peak')
    ax.legend()
ax.relim()
plt.tight_layout()
plt.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
plt.savefig(os.path.join(FIGS, 'all dimer fit with monomer.pdf'))
plt.show() if SHOW else plt.close()

# Do simple convolution of monomer peak
dists = {}
xs = np.linspace(0, 100000, 5000)
for n in names:
    ds = {}
    d = lognorm(loc=0, s=popts[n][0], scale=popts[n][1])
    ds[1] = d.pdf
    pmf = PMF(xs, d.pdf(xs))
    ms = convolvePMF(pmf, 13, mx=bins.max(), keep=[6, 8, 10, 12])
    ds[8], ds[10], ds[12], ds[14] = [pmf2pdf(m) for m in ms]
    dists[n] = ds


# Full distributions
fig, axes = plt.subplots(nrows=N, ncols=1, figsize=(5, N*2.5), sharex=True, sharey=True)
xs = np.linspace(bins.min(), bins.max(), 1000)
styles = {'ls': '--', 'lw': 3, 'alpha': 0.5}
for ax, n in zip(axes.ravel(), names):
    _, _, p = ax.hist(nets[n], bins=bins, density=1, histtype='stepfilled',
                      color=clu[n], label=n)
    hard_edge(p, 0.3)
    A = subareas[n]
    B = (1-A)/2
    ds = dists[n]
    ax.plot(xs, A*ds[1](xs), label='Monomer', **styles)
    for i in [8, 10, 12, 14]:
        ax.plot(xs, B*ds[i](xs), label=str(i), **styles)
plt.tight_layout()
plt.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
plt.savefig(os.path.join(FIGS, 'all full distributions.pdf'))
plt.show() if SHOW else plt.close()


# Now on fitting the major peak
fig, axes = plt.subplots(nrows=N, ncols=1, figsize=(5, N*2.5), sharex=True, sharey=True)
vals_major = {}
majorareas = {}
major_bins = bins[top+1:]
major_mp = (major_bins[:-1] + major_bins[1:])/2
for ax, n in zip(axes.ravel(), names):
    s = nets[n]
    subset = s[s > cutoff]
    w = np.ones_like(subset)/len(subset)/np.diff(major_bins)[0]*(1-subareas[n])
    v, _, p = ax.hist(subset, bins=major_bins, weights=w,
                         histtype='stepfilled', color=clu[n])
    A = sum(np.diff(major_bins) * v)
    print(f'{n} area: {A}')
    majorareas[n] = A
    vals_major[n] = v/A
    hard_edge(p, 0.3)
    ax.set_title(n + ' major peak')
ax.relim()
plt.tight_layout()
plt.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
plt.savefig(os.path.join(FIGS, 'all major2fit.pdf'))
# plt.show() if SHOW else plt.close()
plt.show()


# Fit NU parameter to all data
def major_peak(x, nu, params=(0.26, 12600), dx=50):
    xs = np.arange(-dx/2, 30000, dx)
    d1 = lognorm(s=params[0], loc=0, scale=params[1])
    m1 = PMF(xs, d1.pdf(xs))
    # create delta peak
    m1.y *= nu
    m1.y[0] = (1-nu)
    m12 = convolvePMF(m1, 11, mx=bins.max())[0]
    # convert back to PDF
    return pmf2pdf(m12)(x)

nus = {}
for n in names:
    if 'WT' in n:
        popt, pcov = curve_fit(major_peak, major_mp, vals_major[n], p0=(0.65))
        nus[n] = popt[0]
        print('Estimated NU from major peak', n, popt[0])
    if 'FA' in n:
        nus[n] = popts[n][2]
        print('Estimated NU from FA dimer', n, popts[n][2])
sr = pd.Series(nus, name='NUs')
sr.to_csv('nus.csv')
NU = np.median(list(nus.values()))
print('Global NU estimate:', NU)
print()


def norm(a):
    # for inequality
    return 1 - sum(a)

def memoize(f):
    memo = {}
    def helper(x):
        if x not in memo:
            memo[x] = f(x)
        return memo[x]
    return helper

@memoize
def get_x(lim):
    return np.linspace(0, lim, 1000)

def mixture(x, pm, pd, po, prms=[0.27, 12500, NU], dx=50):
    m = lognorm(s=prms[0], loc=0, scale=prms[1])
    xs = get_x(int(4*prms[1])) # make hashable for memoization
    mm = PMF(xs, m.pdf(xs))
    nu = prms[2]
    mm.y *= nu
    mm.y[0] = (1-nu)
    d, o = convolvePMF(mm, 11, mx=bins.max(), keep=[0, 10])
    # renormalize PMFs to exclude completely dark
    d.y[1:] /= (1-d.y[0])
    d.y[0] = 0
    o.y[1:] /= (1-o.y[0])
    o.y[0] = 0
    d = pmf2pdf(d)
    o = pmf2pdf(o)
    return pm*m.pdf(x) + pd*d(x) + po*o(x)


def mixture_error(ps, x=None, y=None, prms=[0.27, 12500, NU], dx=50):
    p = mixture(x, ps[0], ps[1], ps[2], prms=prms, dx=dx)
    # need to *1e9 to get error within reasonable levels
    return sum(1e9*(p - y)**2)


mix_prms = {}
for n in names:
    x, y = xmp, vals[n]
    prms = popts[n]
    prms[2] = NU # fix nu across samples
    x0 = [subareas[n]/2, subareas[n]/2, majorareas[n]]
    res = minimize(mixture_error, x0=x0, args=(x, y, prms, 30), constraints={'type': 'ineq', 'fun': norm}, bounds=[(0,1),(0,1),(0,1)])
    print(n)
    print(res, '\n----------\n')
    mix_prms[n] = res.x


# Full distributions
fig, axes = plt.subplots(nrows=math.ceil(N/2), ncols=2, figsize=(8, 22))
styles = {'ls': '--', 'lw': 3, 'alpha': 0.6}
xa = np.linspace(0, bins.max(), 1000)
namesm.remove('Monomers')
namesm.append('Monomers')
ylims = []
for i in range(0,N,2):
    if i+1 < len(names):
        ymx = np.concatenate((vals[namesm[i]], vals[names[i+1]])).max()
        ylims.extend([ymx, ymx])
    else:
        ylims.append(max(vals[namesm[i]]))
for ax, n, y in zip(axes.ravel(), namesm, ylims):
    prms = popts[n]
    if n != 'Monomers':
        prms[2] = NU
        ax.plot(xa, mixture(xa, *mix_prms[n], prms=prms), label='1, 2, 12-mer', **styles, color='black')
        ax.text(0.6, 0.5, f"% dimer {mix_prms[n][1]*100:.02f}", transform=ax.transAxes)
    else:
        ax.plot(xa, lognorm(0.32, 0, 16800).pdf(xa), label='Monomer', **styles, color='black')
    _, _, p = ax.hist(nets[n], bins=bins, density=1, histtype='stepfilled',
                      color=clu[n], label=n)
    hard_edge(p, 0.3)
    ax.legend()
    ax.set_xlabel('Net Intensity (b16)')
    ax.set_ylabel('Probability')
    ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
    ax.set_ylim(0, y*1.2)
    ax.set_xlim(0, UB)
    print(n)
plt.tight_layout()
plt.savefig(os.path.join(FIGS, 'all fit mixtures.pdf'))
plt.show() if SHOW else plt.close()

import matplotlib.gridspec as gs
from math import pi, sqrt, exp

def gauss(n=11,sigma=1):
    r = range(-int(n/2),int(n/2)+1)
    return [1 / (sigma * sqrt(2*pi)) * exp(-float(x)**2/(2*sigma**2)) for x in r]

# WT fit
styles = {'ls': '--', 'lw': 2.5, 'alpha': 0.6}
xa = np.linspace(0, bins.max(), 1000)
wta = [f for f in names if 'WT' in f and 'Act' in f]
wtu = [f for f in names if 'WT' in f and 'Unact' in f]
N = len(wta) + len(wtu)
assert len(wta) == len(wtu)
fig = plt.figure(figsize=(12, 4 * N//2))
widths = [1, 1]
heights = [1, 1, 1]
spec = gs.GridSpec(ncols=2, nrows=N//2, width_ratios=widths, height_ratios=heights, hspace=0.3, wspace=0.5, left=0.05, bottom=0.05, top=0.95, right=0.88)
ymx = max(np.concatenate([vals[n] for n in wta + wtu]))
axu, axa = axes.ravel()[::2], axes.ravel()[1::2]
axs = {}
for i, k in enumerate(wtu):
    subspec = gs.GridSpecFromSubplotSpec(2, 1, height_ratios=[1, 3], subplot_spec=spec[i, 0], hspace=0.2)
    axs[k] = (fig.add_subplot(subspec[0]), fig.add_subplot(subspec[1]))
for i, k in enumerate(wta):
    subspec = gs.GridSpecFromSubplotSpec(2, 1, height_ratios=[1, 3], subplot_spec=spec[i, 1], hspace=0.2)
    axs[k] = (fig.add_subplot(subspec[0]), fig.add_subplot(subspec[1]))
clu['fit0'] = list(heat(0.5))
clu['fit0a'] = clu['fit0'].copy()
clu['fit0a'][-1] = 0.5
clu['fit1'] = list(heat(0.8))
clu['fit1a'] = clu['fit1'].copy()
clu['fit1a'][-1] = 0.5
for n in axs:
    print(n)
    ax = axs[n][1]
    prms = popts[n]
    prms[2] = NU
    pdf = mixture(xa, *mix_prms[n], prms=prms)
    pdfonbins = mixture(xmp, *mix_prms[n], prms=prms)
    res = vals[n] - pdfonbins
    ma = np.convolve(res, gauss(17, 6), mode='same')
    signs = np.where(np.diff(np.sign(ma)))[0]
    x0 = abs(xmp - cutoff).argmin() - 1
    x1 = signs[-2]
    x2 = signs[-1]
    # fig0, ax0 = plt.subplots()
    # ax0.axvspan(0, x0, facecolor='red', alpha=0.1)
    # ax0.axvspan(x0, x1, facecolor='orange', alpha=0.1)
    # ax0.axvspan(x1, x2, facecolor='green', alpha=0.1)
    # ax0.axvspan(x2, len(xmp), facecolor='blue', alpha=0.1)
    # ax0.plot(res)
    # ax0.plot(ma)
    percents = {}
    i0 = range(x0)
    e0 = np.trapz(ma[i0], xmp[i0])
    y0 = mix_prms[n][0] + mix_prms[n][1]
    percents['Monomer/Dimer'] = round(100 * (y0 + e0))
    i1 = range(x0, x1)
    e1 = np.trapz(ma[i1], xmp[i1])
    percents['Intermediate'] = round(100 * e1)
    i2 = range(x1, x2)
    e2 = np.trapz(ma[i2], xmp[i2])
    y2 = mix_prms[n][2]
    percents['Dodecamer'] = round(100 * (y2 + e2))
    i3 = range(x2, len(xmp))
    e3 = np.trapz(ma[i3], xmp[i3])
    percents['Aggregate'] = round(100 * e3)
    total = sum(list(percents.values()))
    lbls = ['Monomer/Dimer', 'Intermediate', 'Dodecamer', 'Aggregate']
    if total < 94:
        print("Warning: large deviation in accounting")
    i = 0
    while total < 100 and i < 7:
        percents[lbls[3 - i % 4]] += 1
        total = sum(list(percents.values()))
        i += 1
        print(total)
    idxminor = xa < cutoff
    ax.plot(xa[idxminor], pdf[idxminor], label='Monomer/Dimer Fit', lw=2.5, color=clu['fit0'])
    ax.plot(xa[~idxminor], pdf[~idxminor], label='Predicted Dodecamer', lw=2.5, color=clu['fit1'])
    mpct = subareas[n] * 100
    _, _, p = ax.hist(nets[n], bins=bins, density=1, histtype='stepfilled', label='Data', color=clu[n])
    hard_edge(p, 0.3)
    ax.legend()
    ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3), useMathText=True)
    ax.set_xlabel('Net Intensity (b16)')
    ax.set_ylabel('Probability')
    ax.set_ylim(0, ymx*1.02)
    ax.set_xlim(0, UB)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    pcts = [[f'{percents[k]:.0f}'] for k in lbls]
    clrs = [clu['fit0a'], (1,1,1,1), clu['fit1a'], (1,1,1,1)]
    ax.table(pcts, rowLabels=lbls, colLabels=['%'], loc='right', colWidths=[0.12]*2, bbox=[1.2, 0.1, 0.12, 0.6], rowColours=clrs)
    # residuals
    axr = axs[n][0]
    axr.bar(xmp, res, width=np.diff(bins)[0], color='k', alpha=0.5, label='Other Oligomers')
    hard_edge(p, 0.3)
    axr.get_xaxis().set_ticks([])
    axr.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3), useMathText=True)
    axr.spines['bottom'].set_visible(False)
    axr.spines['top'].set_visible(False)
    axr.spines['right'].set_visible(False)
    axr.set_ylabel('Residual')
    # axr.text(0.6, 0.3, f"% other oligomer:\n{other:.02f}", transform=ax.transAxes, color='k')
    axr.set_ylim(-5e-6, 5e-6)
    axr.set_title(n)
# plt.tight_layout()
plt.savefig(os.path.join(FIGS, 'residuals WT.pdf'))
plt.show() if SHOW else plt.close()

# just FA
xa = np.linspace(0, bins.max(), 1000)
wta = ['Activated FA 1']
wtu = ['Unactivated FA 1']
titles = {}
titles['Activated FA 1'] = 'Activated F397A'
titles['Unactivated FA 1'] = 'Unactivated F397A'
N = len(wta) + len(wtu)
assert len(wta) == len(wtu)
fig = plt.figure(figsize=(12, 4 * N//2))
widths = [1] * 2
heights = [1] * (N//2)
spec = gs.GridSpec(ncols=2, nrows=N//2, width_ratios=widths, height_ratios=heights, hspace=0.3, wspace=0.5, left=0.05, bottom=0.15, top=0.92, right=0.85)
ymx = max(np.concatenate([vals[n] for n in wta + wtu]))
axu, axa = axes.ravel()[::2], axes.ravel()[1::2]
axs = {}
for i, k in enumerate(wtu):
    subspec = gs.GridSpecFromSubplotSpec(2, 1, height_ratios=[1, 3], subplot_spec=spec[i, 0], hspace=0.2)
    axs[k] = (fig.add_subplot(subspec[0]), fig.add_subplot(subspec[1]))
for i, k in enumerate(wta):
    subspec = gs.GridSpecFromSubplotSpec(2, 1, height_ratios=[1, 3], subplot_spec=spec[i, 1], hspace=0.2)
    axs[k] = (fig.add_subplot(subspec[0]), fig.add_subplot(subspec[1]))
for n in axs:
    ax = axs[n][1]
    prms = popts[n]
    prms[2] = NU
    pdf = mixture(xa, *mix_prms[n], prms=prms)
    pdfonbins = mixture(xmp, *mix_prms[n], prms=prms)
    res = vals[n] - pdfonbins
    ma = np.convolve(res, gauss(17, 6), mode='same')
    signs = np.where(np.diff(np.sign(ma)))[0]
    x0 = abs(xmp - cutoff).argmin() - 1
    x1 = signs[-2]
    x2 = signs[-1]
    # fig0, ax0 = plt.subplots()
    # ax0.axvspan(0, x0, facecolor='red', alpha=0.1)
    # ax0.axvspan(x0, x1, facecolor='orange', alpha=0.1)
    # ax0.axvspan(x1, x2, facecolor='green', alpha=0.1)
    # ax0.axvspan(x2, len(xmp), facecolor='blue', alpha=0.1)
    # ax0.plot(res)
    # ax0.plot(ma)
    percents = {}
    i0 = range(x0)
    e0 = np.trapz(ma[i0], xmp[i0])
    y0 = mix_prms[n][0] + mix_prms[n][1]
    percents['Monomer/Dimer'] = round(100 * (y0 + e0))
    i1 = range(x0, x1)
    e1 = np.trapz(ma[i1], xmp[i1])
    percents['Intermediate'] = round(100 * e1)
    i2 = range(x1, x2)
    e2 = np.trapz(ma[i2], xmp[i2])
    y2 = mix_prms[n][2]
    percents['Dodecamer'] = round(100 * (y2 + e2))
    i3 = range(x2, len(xmp))
    e3 = np.trapz(ma[i3], xmp[i3])
    percents['Aggregate'] = round(100 * e3)
    total = sum(list(percents.values()))
    lbls = ['Monomer/Dimer', 'Intermediate', 'Dodecamer', 'Aggregate']
    if total < 94:
        print("Warning: large deviation in accounting")
    i = 0
    while total < 100 and i < 7:
        percents[lbls[3 - i % 4]] += 1
        total = sum(list(percents.values()))
        i += 1
        print(total)
    idxminor = xa < cutoff
    ax.plot(xa[idxminor], pdf[idxminor], label='Monomer/Dimer Fit', lw=2.5, color=clu['fit0'])
    ax.plot(xa[~idxminor], pdf[~idxminor], label='Predicted Dodecamer', lw=2.5, color=clu['fit1'])
    _, _, p = ax.hist(nets[n], bins=bins, density=1, histtype='stepfilled', label='Data', color=clu[n])
    hard_edge(p, 0.3)
    ax.legend()
    ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3), useMathText=True)
    print(n)
    ax.set_xlabel('Net Intensity (b16)')
    ax.set_ylabel('Probability')
    ax.set_ylim(0, ymx*1.02)
    ax.set_xlim(0, UB)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    pcts = [[f'{percents[k]:.0f}'] for k in lbls]
    clrs = [clu['fit0a'], (1,1,1,1), clu['fit1a'], (1,1,1,1)]
    ax.table(pcts, rowLabels=lbls, colLabels=['%'], loc='right', colWidths=[0.12]*2, bbox=[1.2, 0.1, 0.12, 0.6], rowColours=clrs)
    # residuals
    axr = axs[n][0]
    res = vals[n] - mixture(xmp, *mix_prms[n], prms=prms)
    axr.bar(xmp, res, width=np.diff(bins)[0], color='k', alpha=0.5, label='Other Oligomers')
    hard_edge(p, 0.3)
    axr.get_xaxis().set_ticks([])
    axr.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3), useMathText=True)
    axr.spines['bottom'].set_visible(False)
    axr.spines['top'].set_visible(False)
    axr.spines['right'].set_visible(False)
    axr.set_ylabel('Residual')
    # axr.text(0.6, 0.3, f"% other oligomer:\n{other:.02f}", transform=ax.transAxes, color='k')
    axr.set_ylim(-5e-6, 5e-6)
    axr.set_title(titles[n])
# plt.tight_layout()
plt.savefig(os.path.join(FIGS, 'residuals FA.pdf'))
plt.show() if SHOW else plt.close()


# Full distributions, no fit
fig, axes = plt.subplots(nrows=math.ceil(N/2), ncols=2, figsize=(8, 16))
styles = {'ls': '--', 'lw': 3, 'alpha': 0.6}
xa = np.linspace(0, bins.max(), 1000)
namesm.remove('Monomers')
namesm.append('Monomers')
ylims = []
for i in range(0,N,2):
    if i+1 < len(names):
        ymx = np.concatenate((vals[namesm[i]], vals[names[i+1]])).max()
        ylims.extend([ymx, ymx])
    else:
        ylims.append(max(vals[namesm[i]]))
for ax, n, y in zip(axes.ravel(), namesm, ylims):
    prms = popts[n]
    _, _, p = ax.hist(nets[n], bins=bins, density=1, histtype='stepfilled',
                      color=clu[n], label=n)
    hard_edge(p, 0.3)
    ax.legend()
    ax.set_xlabel('Net Intensity (b16)')
    ax.set_ylabel('Probability')
    ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
    ax.set_ylim(0, y*1.2)
    ax.set_xlim(0, UB)
    print(n)
plt.tight_layout()
plt.savefig(os.path.join(FIGS, 'all full dists.pdf'))
plt.show() if SHOW else plt.close()


# Three distributions
fig, ax = plt.subplots(figsize=(5,3))
trim = lambda s: s[:-2] if 'WT' in s else s
styles = {'ls': '--', 'lw': 3, 'alpha': 0.6}
xa = np.linspace(0, bins.max(), 1000)
ns = ['Activated WT 1', 'Activated FA 1', 'Monomers']
ix = [namesm.index(n) for n in ns]
cls = [heat(0.2), heat(0.6), earth(0.5)]
for n, c in zip(ns, cls):
    prms = popts[n]
    # if n is not 'Monomers':
    #     prms[2] = NU
    #     ax.plot(xa, mixture(xa, *mix_prms[n], prms=prms), **styles, color='black')
    # else:
    #     ax.plot(xa, lognorm(0.32, 0, 16800).pdf(xa), **styles, color='black')
    _, _, p = ax.hist(nets[n], bins=bins, density=1, histtype='stepfilled',
                      color=c, label=trim(n))
    hard_edge(p, 0.3)
    ax.legend()
    ax.set_xlabel('Net Intensity (b16)')
    ax.set_ylabel('Probability')
    ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
    ax.set_xlim(0, UB)
    print(n)
plt.tight_layout()
plt.savefig(os.path.join(FIGS, 'three fit mixtures.pdf'))
plt.show() if SHOW else plt.close()
prms


import pandas as pd
pd.options.display.float_format = '{:,.2f}'.format
df = pd.DataFrame(mix_prms, index=[1, 2, 12]).T
df.to_csv('decomposition.csv')


