Source code for OCD_modeling.analysis.rww_dst_analysis

####    OCD Modeling: Reduced Wong-Wang Model Analysis using Dynamical Systems Toolbox
###      
##      Author: Sebastien Naze
#       QIMR 2023

from abc import ABC
import argparse
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, wait
import copy
import dill

import itertools
import joblib
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
import multiprocessing
import mpmath
import os
import pandas as pd
#from pathos.multiprocessing import ProcessingPool
import pickle
import PyDSTool as dst
from PyDSTool.Toolbox import phaseplane as pp
import scipy
import sympy as sp
import time


from OCD_modeling.utils.utils import *
from OCD_modeling.models import ReducedWongWang as RWW


rng = np.random.default_rng()

[docs] def create_model(params, args=None): """ Create the Dynamical System in PyDSTool .. math:: \dot{S_i} = - \cfrac{S_i}{\\tau_S} + (1 - S_i) \gamma H(x_i) + \sigma v_i with .. math:: H(x) = \cfrac{ax-b}{1-\exp{(-d(ax-b))}} and .. math:: x_i = w J_N S_i + G J_N \sum_j{C_{ij} S_j} + I_0 Parameters ---------- params: dict Parameters of the model args: Argparse.Namespace Optional arguments Returns ------- rww: PyDSTool.Vode_ODEsystem PyDSTool object containing the dynamical system """ icdict = {'S1': rng.uniform(), 'S2': rng.uniform()} S1eq = '-S1/tau_S + (1-S1)*gam*(a*(w*J_N*S1+G*J_N*C_12*S2+I_0-I_1)-b)/(1-exp(-d*(a*(w*J_N*S1+G*J_N*C_12*S2+I_0-I_1)-b)))' S2eq = '-S2/tau_S + (1-S2)*gam*(a*(w*J_N*S2+G*J_N*C_21*S1+I_0+I_1)-b)/(1-exp(-d*(a*(w*J_N*S2+G*J_N*C_21*S1+I_0+I_1)-b)))' DSargs = dst.args(name='rww') # struct-like data DSargs.pars = params DSargs.tdata = [0, 100] DSargs.algparams = {'max_pts': 100000, 'init_step': 0.01, 'stiff': True} DSargs.varspecs = {'S1': S1eq, 'S2': S2eq} DSargs.xdomain = {'S1': [-1, 1], 'S2': [-1, 1]} DSargs.fnspecs = {'Jacobian': (['t','S1','S2'], """[[-J_N*a*d*gam*w*(1 - S1)*(a*(C_12*G*J_N*S2 + I_0 - I_1 + J_N*S1*w) - b)*exp(-d*(a*(C_12*G*J_N*S2 + I_0 - I_1 + J_N*S1*w) - b))/(1 - exp(-d*(a*(C_12*G*J_N*S2 + I_0 - I_1 + J_N*S1*w) - b)))**2 + J_N*a*gam*w*(1 - S1)/(1 - exp(-d*(a*(C_12*G*J_N*S2 + I_0 - I_1 + J_N*S1*w) - b))) - gam*(a*(C_12*G*J_N*S2 + I_0 - I_1 + J_N*S1*w) - b)/(1 - exp(-d*(a*(C_12*G*J_N*S2 + I_0 - I_1 + J_N*S1*w) - b))) - 1/tau_S, -C_12*G*J_N*a*d*gam*(1 - S1)*(a*(C_12*G*J_N*S2 + I_0 - I_1 + J_N*S1*w) - b)*exp(-d*(a*(C_12*G*J_N*S2 + I_0 - I_1 + J_N*S1*w) - b))/(1 - exp(-d*(a*(C_12*G*J_N*S2 + I_0 - I_1 + J_N*S1*w) - b)))**2 + C_12*G*J_N*a*gam*(1 - S1)/(1 - exp(-d*(a*(C_12*G*J_N*S2 + I_0 - I_1 + J_N*S1*w) - b)))], [-C_21*G*J_N*a*d*gam*(1 - S2)*(a*(C_21*G*J_N*S1 + I_0 + I_1 + J_N*S2*w) - b)*exp(-d*(a*(C_21*G*J_N*S1 + I_0 + I_1 + J_N*S2*w) - b))/(1 - exp(-d*(a*(C_21*G*J_N*S1 + I_0 + I_1 + J_N*S2*w) - b)))**2 + C_21*G*J_N*a*gam*(1 - S2)/(1 - exp(-d*(a*(C_21*G*J_N*S1 + I_0 + I_1 + J_N*S2*w) - b))), -J_N*a*d*gam*w*(1 - S2)*(a*(C_21*G*J_N*S1 + I_0 + I_1 + J_N*S2*w) - b)*exp(-d*(a*(C_21*G*J_N*S1 + I_0 + I_1 + J_N*S2*w) - b))/(1 - exp(-d*(a*(C_21*G*J_N*S1 + I_0 + I_1 + J_N*S2*w) - b)))**2 + J_N*a*gam*w*(1 - S2)/(1 - exp(-d*(a*(C_21*G*J_N*S1 + I_0 + I_1 + J_N*S2*w) - b))) - gam*(a*(C_21*G*J_N*S1 + I_0 + I_1 + J_N*S2*w) - b)/(1 - exp(-d*(a*(C_21*G*J_N*S1 + I_0 + I_1 + J_N*S2*w) - b))) - 1/tau_S]]""")} DSargs.ics = icdict rww = dst.Vode_ODEsystem(DSargs) return rww
[docs] def get_fixed_points(model, params, xdomain={'S1':[0,1], 'S2':[0,1]}, args=None): """ Get model's nullclines :math:`\\frac{dS_1}{dt}=0` and :math:`\\frac{dS_2}{dt}=0` and fixed points :math:`\\frac{dS_1}{dt}=\\frac{dS_2}{dt}=0` for a given set of parameters. Parameters ---------- model: PyDSTool.Vode_ODEsystem Model object in PyDSTool. params: dict Model parameters. xdomain: dict Variable and lower/upper bounds. args: Argparse.Namespace Optional extra arguments. Returns ------- model: PyDSTool.Vode_ODEsystem Updated model object with given parameters. fps: PyDSTool.Toolbox.phaseplane.fixedpoint_2D Fixed points of the system. (nulls_x, nulls_y): tuple Tuple containing nullclines (arrays of paired xs and ys). """ model.set(pars=params, xdomain=xdomain) # fixed points (using n starting points along the domain) fp_coords = pp.find_fixedpoints(model, n=20, eps=1e-8) fps = [] for fp_coord in fp_coords: fp = pp.fixedpoint_2D(model, dst.Point(fp_coord), eps=1e-8) fps.append(fp) # nullclines (using n starting points along the domain) nulls_x, nulls_y = pp.find_nullclines(model, 'S1', 'S2', n=5, eps=1e-8, max_step=0.01, fps=fp_coords) return model, fps, (nulls_x, nulls_y)
[docs] def compute_trajectories(model, n, tdata=[0,1000]): """ Compute n trajectories from model, each with different initial conditions. Parameters ---------- model: PyDSTool.Vode_ODEsystem PyDSTool model object. n: int Number of trajectories to compute. tdata: list Time interval of the saved trajectories. """ points = [] for i in range(n): model.set(ics={'S1':rng.uniform(), 'S2':rng.uniform()}, tdata=tdata) traj = model.compute(f"traj {i}") pts = traj.sample() points.append(pts) return points
[docs] def compute_equilibrium_point_curve(model, fps, pdomain): """ Find equilibrium point curve(s) of the system, starting from each fixed point (if exist). Parameters ---------- model: PyDSTool.Vode_ODEsystem Model object in PyDSTool. fps: PyDSTool.Toolbox.phaseplane.fixedpoint_2D Fixed points of the system. pdomain: dict Free variable (or order parameter) to perform the bifurcation analyis from. Returns ------- cont: PyDSTool.ContClass PyDSTool Continuation Class object populated with equilibrium point curves. """ model.set(pdomain=pdomain) cont = dst.ContClass(model) free_params = list(pdomain.keys()) for i,fp in enumerate(fps): epc = 'EQ'+str(i) PCargs = dst.args(name=epc, type='EP-C') PCargs.initpoint = fp.point PCargs.freepars = free_params PCargs.StepSize = 1e-5 PCargs.MaxNumPoints = 1000000 PCargs.MaxStepSize = 1e-4 PCargs.LocBifPoints = 'all' PCargs.StopAtPoints = 'B' PCargs.SaveEigen = True PCargs.verbosity = 0 cont.newCurve(PCargs) print('Computing curve...') start = dst.perf_counter() cont[epc].forward() cont[epc].backward() print('done in %.3f seconds!' % (dst.perf_counter()-start)) return cont
[docs] def stability_analysis(order_params, default_params, out_queue, args, pdomain={'C_12':[-0.5, 1.5]}): """ Create model and analyse dynamics using PyDSTool. Parameters ---------- order_params: dict Fixed parameters, for which to analyse the system using discretized values, for example ``{'C_21': np.linspace(0.2,0.8,4)}``. default_params: dict Default model's parameters. out_queue: Queue Queue to put results in (used for parallel computation) for each values of order parameter. args: Argparse.Namespace Structure of necessary options. For example, must include ``args.compute_epc = True`` to compute equilibrium point curves. pdomain: dict Free variable (or order parameter) to perform the bifurcation analyis from. Returns ------- None A dict with model, nullclines (ncs), fixed points (fps), trajectories (trajs), and a pickled (dilled) continuation object (if equilibrium curves are asked in args), is appended to the queue. """ for k,v in order_params.items(): default_params[k] = v out = dict() # fixed point and nullclines model = create_model(default_params, args) model, fps, nullclines = get_fixed_points(model, default_params) points = compute_trajectories(model, args.n_trajs) out['model'], out['fps'], out['ncs'], out['trajs'] = model, fps, nullclines, points # EP-C if args.compute_epc: # takes a bit of ressource and memory try: cont = compute_equilibrium_point_curve(model, fps, pdomain) out['dilled_cont'] = dill.dumps(cont, byref=True) except: print("Error in computing Equilibrium Point Curve") # put output to output queue output = copy.deepcopy(order_params) output['output'] = out out_queue.put(dill.dumps(output))
[docs] def launch_stability_analysis(order_params, default_params, out_queue, args): """ Ghost process that launches the stability analysis for a set of defined order parameter, creating a child process with a set timeout per child process, such that the stbaility analysis does not hang waiting for the continuation to terminate if it does not converge. Parameters ---------- order_params: dict Fixed parameters, for which to analyse the system using discretized values, for example ``{'C_21': np.linspace(0.2,0.8,4)}``. default_params: dict Default model's parameters. out_queue: Queue Output queue on which to append the results. args: Argparse.Namespace Structure of necessary options. For example, must include ``args.compute_epc = True`` to compute equilibrium point curves. """ proc = multiprocessing.Process(target=stability_analysis, args=(order_params, default_params, out_queue, args)) proc.start() # wait for the process until timeout proc.join(args.timeout) # if process is still running after timeout, force terminating it if proc.is_alive(): print(f"Stability analysis for {order_params} took too long, aborted after {args.timeout}s") proc.terminate()
[docs] def run_stability_analysis(order_params, default_params, args): """ Starts a pool of parallel processes to run the stability analysis. Parameters ---------- order_params: dict Fixed parameters, for which to analyse the system using discretized values, for example ``{'C_21': np.linspace(0.2,0.8,4)}``. default_params: dict Default model's parameters. args: Argparse.Namespace Structure of necessary options. For example, must include ``args.compute_epc = True`` to compute equilibrium point curves. Returns ------- outputs: list of dict Stability analysis of each values (or combination of values) given in order_params. futures: list of concurrent.futures.Future (deprecated) if using futures.concurrent library for parallel process, return the list of Future objects (https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.Future). """ # debug #outputs = [] #for vals in itertools.product(*order_params.values()): # out = stability_analysis(dict(zip(order_params.keys(), vals)), copy.deepcopy(default_params), out_queue=None, args=args) # outputs.append(out) #return outputs, [] #outputs = joblib.Parallel(n_jobs=32, verbose=10, timeout=20)(joblib.delayed(perform_stability_analysis)(copy.deepcopy(model), dict(zip(order_params.keys(), vals)), copy.deepcopy(default_params)) # for vals in itertools.product(*order_params.values())) o_pars = list([dict(zip(order_params.keys(), vals)) for vals in itertools.product(*order_params.values())]) n_pars = len(o_pars) print("Run stability analysis...") #out_queue = multiprocessing.Queue(maxsize=n_pars) # /!\ Standard queues seems (at least) 10x slower than managed queues, resulting in timeouts.. out_queue = multiprocessing.Manager().Queue() # if performance is primordial, it could be changed to Pipes with locks (even faster) futures = [] with ProcessPoolExecutor(max_workers=args.n_jobs) as pool: #future = [pool.map(launch_stability_analysis, (copy.deepcopy(order_param), copy.deepcopy(default_params), out_queue, args)) for order_param in o_pars] for order_param in o_pars: future = pool.submit(launch_stability_analysis, copy.deepcopy(order_param), copy.deepcopy(default_params), out_queue, args) futures.append(future) #with ProcessingPool(nodes=args.n_jobs) as pool: # params = [(copy.deepcopy(order_param), copy.deepcopy(default_params), out_queue, args) for order_param in o_pars] # futures = pool.uimap(launch_stability_analysis, params) outputs = [] while not out_queue.empty(): outputs.append(out_queue.get_nowait()) return outputs, futures
[docs] def plot_phasespace(model, fps, nullclines, trajs, ax=None, args=None): """ Plot vector field, nullclines and fixed points of a model previously set with parameters. Parameters ---------- model: PyDSTool.Vode_ODEsystem PyDSTool model object. fps: PyDSTool.Toolbox.phaseplane.fixedpoint_2D Fixed points of he system nullclines: list list of PyDSTool nullcline objects, containing arrays of xs and ys of nullclines. trajs: list of dict Simulated trajectories. ax: matplotlib.Axis Axis in which to plot the phasespace. """ plt.sca(ax) pp.plot_PP_vf(model, 'S1', 'S2', scale_exp=-2) pp.plot_PP_fps(fps, do_evecs=True, markersize=6) plt.plot(nullclines[0][:,0], nullclines[0][:,1], 'b', lw=3) plt.plot(nullclines[1][:,0], nullclines[1][:,1], 'g', lw=3) for traj in trajs: plt.plot(traj['S1'], traj['S2'], linewidth=1)
#plt.show(block=False) def get_grid_inds(output, order_params): """ get the indices (i,j) of the order parameters from output """ o_pars = np.sort([k for k in output.keys() if k!='output']) i = [x for x,val in enumerate(order_params[o_pars[0]]) if val==output[o_pars[0]]] j = [y for y,val in enumerate(order_params[o_pars[1]]) if val==output[o_pars[1]]] return o_pars, i[0],j[0]
[docs] def plot_phasespace_grid(outputs, order_params, args=None): """ Plot a grid of phasespaces from stability analysis outputs. Parameters ---------- outputs: list Outputs from stability analysis. order_params: dict Fixed parameters, for which to analyse the system using discretized values, for example ``{'C_21': np.linspace(0.2,0.8,4)}``. args: Argparse.Namespace Optional extra arguments in argprase Namespace, such as ``args.save_figs=True`` to save figure and ``args.plot_figs=True`` to plot figures. """ plt.rcParams['svg.fonttype'] = 'none' plt.rcParams.update({'font.size':11, 'axes.titlesize':'medium', 'mathtext.default': 'regular'}) fig = plt.figure(figsize=[16,16]) p1s, p2s = list(order_params.values())[0], list(order_params.values())[1] gs = plt.GridSpec(nrows=len(p1s), ncols=len(p2s)) for res in outputs: out = dill.loads(res) output = out['output'] o_pars,i,j = get_grid_inds(out, order_params) ax = fig.add_subplot(gs[i,j]) plot_phasespace(output['model'], output['fps'], output['ncs'], output['trajs'], ax=ax, args=args) plt.axis('tight') #plt.title("{}={:.3f} {}={:.3f}".format(o_pars[0], p1s[i], o_pars[1], p2s[j])) ttl = "$C_{12}$=%s $C_{21}$=%s" %("{:.2f}".format(p1s[i]),"{:.2f}".format(p2s[j])) plt.title(ttl, fontdict={'fontsize': 13} )#.format(p1s[i], p2s[j])) if i==len(p1s)-1: plt.xlabel('$S_1$', fontsize=13) if j==0: plt.ylabel('$S_2$', fontsize=13) plt.tight_layout() # makes quiver a bit more transparent and arrows a bit larger plt.getp(ax, 'children')[0].set(alpha=0.7) #for path in plt.getp(plt.getp(ax,'children')[0], 'paths'): # path.vertices = path.vertices*10 #pdb.set_trace() if args.save_figs: plt.savefig(os.path.join(proj_dir, 'img', 'phase_space'+today()+'.svg'), format='svg', transparent=True) if args.plot_figs: plt.show()
def plot_phasespace_row(outputs, order_params, rww=None, t_range=None, args=None): """ Plot a row of graphs from outputs """ plt.rcParams['svg.fonttype'] = 'none' plt.rcParams.update({'font.size':11, 'axes.titlesize':'medium', 'mathtext.default': 'regular'}) fig = plt.figure(figsize=[12,3]) p1s = list(order_params.values())[0] gs = plt.GridSpec(nrows=1, ncols=len(p1s)) #o_pars = np.sort([k for k in output.keys() if k!='output']) #i = [x for x,val in enumerate(order_params[o_pars[0]]) if val==output[o_pars[0]]] for i,res in enumerate(outputs): out = dill.loads(res) output = out['output'] ax = fig.add_subplot(gs[0,i]) plot_phasespace(output['model'], output['fps'], output['ncs'], output['trajs'], ax=ax, args=args) plt.axis('tight') #plt.title("{}={:.3f} {}={:.3f}".format(o_pars[0], p1s[i], o_pars[1], p2s[j])) ttl = "$C_{2 \leftarrow 1}$=%s" %("{:.3f}".format(p1s[i])) plt.title(ttl, fontdict={'fontsize': 11} )#.format(p1s[i], p2s[j])) plt.xlabel('$S_1$') if i==0: plt.ylabel('$S_2$') plt.tight_layout() # makes quiver a bit more transparent and arrows a bit larger plt.getp(ax, 'children')[0].set(alpha=0.7) #for path in plt.getp(plt.getp(ax,'children')[0], 'paths'): # path.vertices = path.vertices*10 if rww!=None: if t_range!=None: start = t_range[0]*rww.sf stop = t_range[1]*rww.sf else: start = 0 stop = rww.S_rec.shape[0] plt.plot(rww.rec_C_12[start:stop], rww.S_rec[start:stop,0], lw=0.5, color='blue', alpha=0.4) plt.plot(rww.rec_C_12[start:stop], rww.S_rec[start:stop,1], lw=0.5, color='red', alpha=0.4) #pdb.set_trace() if args.save_figs: today = datetime.datetime.now().strftime("%Y%m%d") plt.savefig(os.path.join(proj_dir, 'img', 'phase_space'+today+'.svg'), format='svg', transparent=True) if args.plot_figs: plt.show()
[docs] def plot_bifurcation_grid(outputs, order_params, args=None): """ Plot a grid of bifurcation diagrams """ plt.rcParams.update({'font.size':6, 'axes.titlesize':'medium'}) fig = plt.figure(figsize=[20,20]) p1s, p2s = list(order_params.values())[0], list(order_params.values())[1] gs = plt.GridSpec(nrows=len(p1s), ncols=len(p2s)) for output in outputs: o_pars,i,j = get_grid_inds(output, order_params) ax = fig.add_subplot(gs[i,j]) if 'output' in output.keys(): cont = dill.loads(output['output']['dilled_cont']) try: cont.display(axes=ax, coords=['C_12', 'S2'], stability=True, color='blue') cont.display(axes=ax, coords=['C_12', 'S1'], stability=True, color='orange') except: plt.xlabel("") plt.xticks([]) plt.ylabel("") plt.yticks([]) plt.title("{}={:.3f} {}={:.3f}".format(o_pars[0], p1s[i], o_pars[1], p2s[j])) continue plt.sca(ax) plt.axis('tight') plt.title("{}={:.3f} {}={:.3f}".format(o_pars[0], p1s[i], o_pars[1], p2s[j])) if i<len(p1s)-1: plt.xlabel("") plt.xticks([]) if j > 0: plt.ylabel("") plt.yticks([]) plt.xlim([-1,1]) plt.ylim([0,1]) else: plt.xlabel("") plt.xticks([]) plt.ylabel("") plt.yticks([]) plt.title("{}={:.3f} {}={:.3f}".format(o_pars[0], p1s[i], o_pars[1], p2s[j])) plt.show(block=False)
[docs] def plot_bifurcation_row(outputs, order_params, rww=None, t_range=None, args=None): """ Plot a row of bifurcation diagrams (ie. a 1 by n grid). Parameters ---------- outputs: list Outputs from stability analysis. order_params: dict Order parameters of the analysis in `{'param_name': np.array}` format where `np.array` is the list of order parameters `param_name`. rww: OCD_modeling.models.ReducedWongWangOU Model instance that ran. t_range: list [start, stop] timestamp values of the RWW model traces to plot in the diagrams. args: Argparse.Namespace Extra options. """ plt.rcParams.update({'font.size':10, 'axes.titlesize':'medium'}) fig = plt.figure(figsize=[15,3]) p1s = list(order_params.values())[0] gs = plt.GridSpec(nrows=1, ncols=len(p1s)) for output in outputs: output = dill.loads(output) o_par = list(order_params.keys())[0] i = [x for x,val in enumerate(order_params[o_par]) if val==output[o_par]][0] ax = fig.add_subplot(gs[0,i]) if 'output' in output.keys(): cont = dill.loads(output['output']['dilled_cont']) try: cont.display(axes=ax, coords=['C_12', 'S1'], stability=True, color='blue') cont.display(axes=ax, coords=['C_12', 'S2'], stability=True, color='red') except: plt.xlabel("") plt.xticks([]) plt.ylabel("") plt.yticks([]) plt.title("{}={:.3f}".format(o_par, p1s[i])) continue plt.sca(ax) plt.axis('tight') plt.title("{}={:.3f}".format(o_par, p1s[i])) if i > 0: plt.ylabel("") plt.yticks([]) plt.xlim([-0.5,0.5]) plt.ylim([0,1]) else: plt.xlabel("") plt.xticks([]) plt.ylabel("") plt.yticks([]) plt.title("{}={:.3f}".format(o_par, p1s[i])) if rww!=None: if t_range!=None: start = t_range[0]*rww.sf stop = t_range[1].sf else: start = 0 stop = rww.S_rec.shape[0] plt.plot(rww.rec_C_12[start:stop], rww.S_rec[start:stop,0], lw=0.5, color='blue', alpha=0.4) plt.plot(rww.rec_C_12[start:stop], rww.S_rec[start:stop,1], lw=0.5, color='red', alpha=0.4) plt.tight_layout() if args.save_figs: fname = os.path.join(proj_dir, 'img', 'bifurcation_diagram_023_'+today()+'.svg') plt.savefig(fname) if args.plot_figs: plt.show(block=False) else: fig.close()
[docs] def plot_timeseries_phasespace_bif(outputs, rww, df_eta_sigma, args): """ Show :math:`S1`, :math:`S2`, :math:`C_{12}` timeseries, :math:`S1-S2` phase space with trajectories and :math:`C_{12} - S1|S2` phase space. with bifurcation diagram. Parameters ---------- outputs: list Outputs from stability analysis. rww: OCD_modeling.models.ReducedWongWangOU Model instance that ran. df_eta_sigma: pandas.DataFrame Data from the simulations varying eta and sigma parameters. args: Argparse.Namespace Extra options. """ t_range = [2800,6000] ticks = [0 ,0.3, 0.6, 0.9] start = t_range[0]*rww.sf stop = t_range[1]*rww.sf fig = plt.figure(figsize=[10,5]) gs = plt.GridSpec(nrows=2, ncols=4, height_ratios=[1,0.8]) sub_gs = gs[0,1:4].subgridspec(3,1) # time series sub_gs_ = gs[1,:].subgridspec(1,4, width_ratios=[1,1,1.2,1.2]) # transitions & FC # Time series #------------ ax = fig.add_subplot(sub_gs[0,0]) ax.plot(rww.t[start:stop], rww.S_rec[start:stop,0], color='dodgerblue') ax.spines.top.set_visible(False) ax.spines.bottom.set_visible(False) ax.spines.right.set_visible(False) plt.xticks([], label='') plt.yticks([0,1]) plt.ylabel('$S_1$', rotation=0, labelpad=15, fontsize=12) ax = fig.add_subplot(sub_gs[1,0]) ax.plot(rww.t[start:stop], rww.S_rec[start:stop,1], color='forestgreen') ax.spines.top.set_visible(False) ax.spines.bottom.set_visible(False) ax.spines.right.set_visible(False) plt.xticks([], label='') plt.yticks([0,1]) plt.ylabel('$S_2$', rotation=0, labelpad=15, fontsize=12) ax = fig.add_subplot(sub_gs[2,0]) ax.hlines(y=0, xmin=t_range[0], xmax=t_range[1], lw=0.75, linestyle='--', color='gray') ax.plot(rww.t[start:stop], rww.rec_C_12[start:stop], lw=0.25, color='orange') ax.spines.top.set_visible(False) ax.spines.right.set_visible(False) #plt.yticks([-0.5,0.75]) plt.yticks([-1.2,1.6]) plt.ylabel('$C_{12}$', rotation=0, labelpad=0, fontsize=12) xticks = np.arange(t_range[0], t_range[1], 600) # 10min ticks plt.xticks(xticks, labels=np.array((xticks-t_range[0])/60, dtype=int)) ax.set_xlabel('time (min)', fontsize=11) # Transitions and FC against eta & sigma # -------------------------------------- ax1 = fig.add_subplot(sub_gs_[0,2]) ax2 = fig.add_subplot(sub_gs_[0,3]) plot_transitions_FC(df_eta_sigma, ax1, ax2, args) # S1-S2 State space #------------------ output = dill.loads(outputs[5])['output'] ax = fig.add_subplot(sub_gs_[0,0]) ax.plot(output['ncs'][1][:,0], output['ncs'][1][:,1], color='forestgreen', lw=3) ax.plot(output['ncs'][0][:,0], output['ncs'][0][:,1], color='dodgerblue', lw=3) ax.plot(rww.S_rec[start:stop,0], rww.S_rec[start:stop,1], lw=1, color='gray', alpha=1) for fp in output['fps']: x,y = fp.toarray() ax.scatter(x, y, color='black', s=40) ax.spines.top.set_visible(False) ax.spines.right.set_visible(False) ax.set_xlabel('$S_1$', fontsize=12) ax.set_ylabel('$S_2$', rotation=0, labelpad=15, fontsize=12) ax.set_xticks(ticks) ax.set_xticklabels(ticks) ax.set_yticks(ticks) ax.set_yticklabels(ticks) ax.set_xlim([-0.05,1]) ax.set_ylim([-0.05,1]) ax.text(x=0, y=0.8, s='$\\frac{dS_2}{dt}$', color='forestgreen', fontsize=16) ax.text(x=0.25, y=0.8, s='$=0$', color='forestgreen', fontsize=12) ax.text(x=0.7, y=0.05, s='$\\frac{dS_1}{dt}$', color='dodgerblue', fontsize=16) ax.text(x=0.95, y=0.05, s='$=0$', color='dodgerblue', fontsize=12) #plt.tight_layout() # S - C_12 state space ax = fig.add_subplot(sub_gs_[0,1]) cont = dill.loads(output['dilled_cont']) #cont.display(axes=ax, coords=['C_12', 'S1'], stability=True, color='blue', linewidth=1.6, points=False) #cont.display(axes=ax, coords=['C_12', 'S2'], stability=True, color='red', linewidth=1.6, points=False) cont.curves['EQ0'].display(coords=['C_12', 'S1'], stability=True, color='blue', linewidth=1.6, points=True) cont.curves['EQ0'].display(coords=['C_12', 'S2'], stability=True, color='red', linewidth=1.6, points=True) plt.plot(rww.rec_C_12[start:stop], rww.S_rec[start:stop,0], lw=0.25, color='blue', alpha=0.2, label='$S_1$') plt.plot(rww.rec_C_12[start:stop], rww.S_rec[start:stop,1], lw=0.25, color='red', alpha=0.2, label='$S_2$') ax.spines.top.set_visible(False) ax.spines.left.set_color('blue') ax.spines.right.set_color('red') ax.set_yticks(ticks) ax.set_yticklabels(ticks) ax.set_ylim([-0.05,1]) ax.set_xlim([-0.6,1.6]) ax.tick_params(axis='y', which='both', labelleft='on', left=True) ax.tick_params(axis='y', which='both', labelright='on', right=True) plt.xlabel('$C_{12}$', fontsize=12) plt.ylabel('$S_1$', rotation=0, fontsize=12) plt.title('') #plt.tight_layout(pad=0) plt.subplots_adjust(wspace=0.5, hspace=0.4) # adjust FC and transitions #plt.sca(ax1) #plt.subplots_adjust(left=0.2, right=0.8) if args.save_figs: fname = os.path.join(proj_dir, 'img', 'single_pathway_model'+today()+'.svg') plt.savefig(fname)
def plot_transitions_FC(df_eta_sigma, ax1=None, ax2=None, args=None): etas = np.sort(np.unique(df_eta_sigma['eta'])) sigmas = np.sort(np.unique(df_eta_sigma['sigma'])) fcs = np.zeros((len(etas), len(sigmas))) trs = np.zeros((len(etas), len(sigmas))) for i,eta in enumerate(etas): for j,sigma in enumerate(sigmas): fcs[i,j] = df_eta_sigma[(df_eta_sigma.eta==eta) & (df_eta_sigma.sigma==sigma)].fc.mean() trs[i,j] = df_eta_sigma[(df_eta_sigma.eta==eta) & (df_eta_sigma.sigma==sigma)].transitions_per_minute.mean() if ((ax1==None) | (ax2==None)): fig = plt.figure(figsize=[8,3]) ax1 = plt.subplot(1,2,1) ax2 = plt.subplot(1,2,2) X,Y = np.meshgrid(etas,sigmas) cntr = ax1.contourf(X,Y,fcs, levels=20, cmap='Greens')#, vmin=-0.5, vmax=1) #X=etas, Y=sigmas, plt.sca(ax1) plt.xlabel('$\eta_{12}$', fontsize=12) plt.xticks([0, 0.025, 0.05, 0.075, 0.1], labels=['0', '', '0.05', '', '0.1']) plt.yticks(np.arange(0,0.5,0.1)) plt.ylabel('$\sigma_{12}$', fontsize=12) #plt.title('FC') ax1.spines['top'].set_visible(False) ax1.spines['right'].set_visible(False) cbar = plt.colorbar(cntr, fraction=0.1, label='R', pad=0.025, drawedges=False) cbar.set_ticks([0, 0.2, 0.4, 0.6]) cntr = ax2.contourf(X,Y,trs, levels=20, cmap='Oranges') #X=etas, Y=sigmas, plt.sca(ax2) plt.xticks([0, 0.025, 0.05, 0.075, 0.1], labels=['0', '', '0.05', '', '0.1']) plt.xlabel('$\eta_{12}$', fontsize=12) plt.yticks(np.arange(0,0.5,0.1), labels=[]) #plt.ylabel('$\sigma_{12}$', fontsize=12) #plt.title('Transitions rate') ax2.spines['top'].set_visible(False) ax2.spines['right'].set_visible(False) cbar = plt.colorbar(cntr, fraction=0.1, label='$min^{-1}$', pad=0.025, drawedges=False) cbar.set_ticks(np.arange(0,50,20)) #plt.tight_layout(pad=0) #if args.save_figs: # plt.savefig(os.path.join(proj_dir, 'img', 'eta_sigma'+today()+'.svg')) def load_df_eta_sigma(): """ Load pre-computed transitions rates and functional connecticity of the model along the :math:`\\eta` and :math:`\\sigma`parameter space. Returns ------- df_eta_sigma: pandas.DataFrame Preprocessed simulated data. """ pth = os.path.join(proj_dir, 'postprocessing', 'eta_sigma_20241029/') files = os.listdir(pth) lines = [] for file in files: with open(pth+file, 'rb') as f: outs = pickle.load(f) lines.append(outs) df_eta_sigma = pd.DataFrame(list(np.array(lines).flatten())) df_eta_sigma['transitions_per_minute'] = df_eta_sigma['n_transitions']/120 # 120 because 2h simulations, i.e. 120 minutes return df_eta_sigma
[docs] def get_parser(): """ parsing global script argument """ parser = argparse.ArgumentParser() parser.add_argument('--create_model', default=False, action='store_true', help='create symbolic model') parser.add_argument('--compute_nullclines', default=False, action='store_true', help='compute nullclines numerically') parser.add_argument('--compute_epc', default=False, action='store_true', help='compute equilibrium point curves numerically') parser.add_argument('--save_model', default=False, action='store_true', help='save symbolic model with its computed attributes') parser.add_argument('--run_stability_analysis', default=False, action='store_true', help='run stability analysis: find fixed points (semi-analytically) and perform linear stability analysis around them') parser.add_argument('--load_stability_analysis', default=False, action='store_true', help='load previously completed stability analysis') parser.add_argument('--plot_figs', default=False, action='store_true', help='plot figures') parser.add_argument('--plot_phasespace_grid', default=False, action='store_true', help='plot grid of phase spaces using discretized order parameters') parser.add_argument('--plot_bifurcation_diagrams', default=False, action='store_true', help='plot grid of bifurcation diagrams using discretized order parameters') parser.add_argument('--save_figs', default=False, action='store_true', help='save figures') parser.add_argument('--save_outputs', default=False, action='store_true', help='save analysis outputs') parser.add_argument('--timeout', type=int, default=3000, action='store', help='timeout of the stability analysis (per parameter combination invoked)') parser.add_argument('--n_jobs', type=int, default=20, action='store', help='number of processes used in parallelization') parser.add_argument('--n_trajs', type=int, default=10, action='store', help='number of trajectories (traces) to compute for phase space projections') parser.add_argument('--n_op', type=int, default=5, action='store', help='number of values taken for each order parameters') parser.add_argument('--load_sample_rww', default=False, action='store_true', help='load a sample of ReducedWongWang model object previously ran for illustration') parser.add_argument('--plot_timeseries_phasespace_bif', default=False, action='store_true', help='plot neat figure of timeseries, phase space and bifurcations with trajectories (paper quality)') return parser
if __name__=='__main__': # This is to be able to run multiple stability analysis in parallel. # In short, PyDSTool objects need to be serialized to be passed across processes through a Queue. # but pyCont objects (used for continuation) are not pickable. We need to set multiprocessing library to # use dill to serialize the objects. dill.Pickler.dumps, dill.Pickler.loads = dill.dumps, dill.loads multiprocessing.reduction.ForkingPickler = dill.Pickler multiprocessing.reduction.dump = dill.dump multiprocessing.queues._ForkingPickler = dill.Pickler args = get_parser().parse_args() default_params = {'a':270, 'b': 108, 'd': 0.154, 'C_12': 0.25, 'G':2.5, 'J_N':0.2609, 'I_0':0.3, 'I_1':0.0, 'tau_S':100, 'w':0.9, 'gam':0.000641} #order_params = {'C_12': np.linspace(-1,1,args.n_op), 'I_0': np.linspace(0.2,0.5,args.n_op)} #, 'C_21': np.linspace(-1,1,args.n_op)} #order_params = {'C_12': np.linspace(-0.5,0.5,args.n_op), 'C_21': np.linspace(-0.5,0.5,args.n_op)} order_params = {'C_21': np.linspace(0.2,0.8,args.n_op)} if args.load_stability_analysis: #fname = os.path.join(proj_dir, 'postprocessing', 'outputs_dst__20230925_op_C_12_fix025_C21_var023.pkl') #fname = os.path.join(proj_dir, 'postprocessing', 'outputs_dst__20230925_op_C_12_fix025_C21_var0_1_10.pkl') fname = os.path.join(proj_dir, 'postprocessing', 'outputs_dst_20241207.pkl') with open(fname, 'rb') as f: outputs = pickle.load(f) elif args.run_stability_analysis: outputs, futures = run_stability_analysis(order_params, default_params, args) if args.save_outputs: fname = 'outputs_dst'+today()+'.pkl' with open(os.path.join(proj_dir, 'postprocessing', fname), 'wb') as f: pickle.dump(outputs, f) if args.plot_bifurcation_diagrams: if len(order_params.keys())==1: plot_bifurcation_row(outputs, order_params, args) elif len(order_params.keys())==2: plot_bifurcation_grid(outputs, order_params, args) else: print("can't plot for the number of order parameters") if args.plot_phasespace_grid: #if ('outputs' not in globals()) & ('outputs' not in locals()): # outputs = pickle.load(os.path.join(proj_dir, 'postprocessing', 'outputs_dst_20230227.pkl')) if len(order_params.keys())==1: plot_phasespace_row(outputs, order_params, args) elif len(order_params.keys())==2: plot_phasespace_grid(outputs, order_params, args) if args.load_sample_rww: #with open(os.path.join(proj_dir, 'postprocessing', 'sample_rww.pkl'), 'rb') as f: with open(os.path.join(proj_dir, 'postprocessing', 'rww_sample001.pkl'), 'rb') as f: rww = pickle.load(f) if args.plot_timeseries_phasespace_bif: df_eta_sigma = load_df_eta_sigma() plot_timeseries_phasespace_bif(outputs, rww, df_eta_sigma, args)