#### OCD Modeling: Symbolic Reduced Wong-Wang Model Analysis
###
## Author: Sebastien Naze
# QIMR 2023
from abc import ABC
import argparse
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, wait
import copy
import datetime
import importlib
import itertools
import joblib
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import cm
import multiprocessing
import mpmath
import os
import pandas as pd
import pickle
import PyDSTool as dst
import scipy
import sympy as sp
import time
from OCD_modeling.utils import utils
from OCD_modeling.models import ReducedWongWang as RWW
#sp.init_session()
#sp.init_printing()
proj_dir = os.path.join(utils.get_working_dir(), 'lab_lucac/sebastiN/projects/OCD_modeling')
params = {'a':270, 'b': 108, 'd': 0.154, 'C_12':1, 'C_21':1, 'G':2.5, 'J_N':0.2609, 'I_0':0.3, 'tau_S':100, 'w':0.9, 'gamma':0.000641}
# increase numerical precision when using mpmath
mpmath.mp.dps = 125
[docs]
class SymbolicModel(ABC):
""" Abstract class that implement common functions to all symbolic Reduced Wong Wang models """
def print_dS(self):
sp.pprint(self.dS)
print(sp.latex(self.dS))
def compute_nullclines(self):
# using SOLVE
# -----------
# nullclines
self.n1 = sp.solve(self.dS[0], self.S2)
self.n2 = sp.solve(self.dS[1], self.S1)
return self.n1,self.n2
def compute_characteristic_eq(self):
# substitute S1 into S2 nullcline equation (characteristic equation)
if (not hasattr(self, 'n1') or not hasattr(self, 'n2')):
n1,n2 = self.compute_nullclines()
self.charEq = self.dS[0].subs({self.S1:self.n2[0]})
return self.charEq
def plot_nullclines(self, params):
# plot nullclines and characteritic equation
sp.plot(self.n1[0].subs(params), (self.S1,-3,3), ylim=(-3,3), n=1000, title='dS_1=0 curve\n')
sp.plot(self.n2[0].subs(params), (self.S2, 3,3), ylim=(-3,3), n=1000, title='dS_2=0 curve\n')
sp.plot(self.charEq.subs(params), (self.S2, -3, 3), ylim=(-3,3), n=1000, title='dS=0 curve\n')
[docs]
class symRWW_2D(SymbolicModel):
""" Original Reduced Wong Wang model """
[docs]
def __init__(self):
# model definition
self.x1, self.x2,self.S1,self.S2,self.C_12,self.C_21 = sp.symbols('x1 x2 S1 S2 C_12 C_21')
self.S = sp.Matrix([self.S1, self.S2])
self.x = sp.Matrix([self.x1, self.x2])
self.C = sp.Matrix([[0, self.C_12],[self.C_21, 0]])
self.w, self.J_N, self.I_0, self.G = sp.symbols('w J_N I_0 G')
self.X = self.w*self.J_N*self.S + self.G*self.J_N*self.C*self.S + sp.ones(2,1)*self.I_0
self.a,self.b,self.d, = sp.symbols('a b d')
self.H = sp.Matrix([a1/a2 for a1,a2 in zip((self.a*self.X-self.b*sp.ones(2,1)),
sp.ones(2,1)-(-self.d*(self.a*self.X-self.b*sp.ones(2,1))).applyfunc(sp.exp))])
self.tau_S, self.gamma = sp.symbols('tau_S gamma')
self.dS = (-self.S/self.tau_S) + sp.matrices.dense.matrix_multiply_elementwise((sp.ones(2,1)-self.S), self.gamma*self.H)
self.jacobian = [[self.dS[0].diff(self.S1), self.dS[0].diff(self.S2)], [self.dS[1].diff(self.S1), self.dS[1].diff(self.S2)]]
[docs]
class pw_RWW_2D(SymbolicModel):
""" Piecewise linear Reduced Wong Wang symbolic model """
[docs]
def __init__(self):
# model definition
self.x1, self.x2,self.S1,self.S2,self.C_12,self.C_21 = sp.symbols('x1 x2 S1 S2 C_12 C_21')
self.S = sp.Matrix([self.S1, self.S2])
self.x = sp.Matrix([self.x1, self.x2])
self.C = sp.Matrix([[0, self.C_12],[self.C_21, 0]])
self.w, self.J_N, self.I_0, self.G = sp.symbols('w J_N I_0 G')
self.X = self.w*self.J_N*self.S + self.G*self.J_N*self.C@self.S + sp.ones(2,1)*self.I_0
self.a,self.b,self.d,self.theta = sp.symbols('a b d theta')
self.H = sp.Matrix([sp.Piecewise((0, self.X[0] <= self.theta), (self.a*(self.X[0]-self.theta), self.X[0]>self.theta)),
sp.Piecewise((0, self.X[1] <= self.theta), (self.a*(self.X[1]-self.theta), self.X[1]>self.theta))])
self.tau_S, self.gamma = sp.symbols('tau_S gamma')
self.dS = (-self.S/self.tau_S) + sp.matrices.dense.matrix_multiply_elementwise((sp.ones(2,1)-self.S), self.gamma*self.H)
self.jacobian = [[self.dS[0].diff(self.S1), self.dS[0].diff(self.S2)], [self.dS[1].diff(self.S1), self.dS[1].diff(self.S2)]]
## Stability analysis
#
[docs]
def find_roots(f,x,itv=None, slope_thr=0.05):
""" Find zero crossings of function f (numerical roots).
:param slope_thr" defines the slopes above which the zero crossing is an artifact
e.g. in case of hyperbolic function
"""
if itv==None:
itv = [x.min(), x.max()]
roots = []
fxi = np.real(f(x[0]))
for i,x_i in enumerate(x[1:]):
#fxii = np.real(complex(f(x[i+1]).evalf()))
fxii = np.real(f(x_i))
if ( ((fxi > 0) & (fxii < 0)) | ((fxi < 0) & (fxii > 0)) ):
diff = fxii - fxi
if ((diff < 0) & (np.abs(diff) < slope_thr)):
roots.append({'x':(x_i+x[i+1])/2, 'slope':diff})
if ((diff > 0) & (np.abs(diff) < slope_thr)):
roots.append({'x':(x_i+x[i+1])/2, 'slope':diff})
fxi = copy.deepcopy(fxii)
return roots
def find_roots_subs(model, x, default_params):
""" find roots of equation numerically using sympy substitution instead of lambdified """
def f(x):
default_params['S2'] = x
return model.charEq.subs(default_params)
roots = find_roots(f,x)
return roots
[docs]
def get_fixed_point_stability(model, fp, params):
""" Derive stability based on eigenvalues of the jacobian around the fixed point """
fp['tau'] = sp.trace(sp.Matrix(model.jacobian).subs(params))
fp['delta'] = sp.det(sp.Matrix(model.jacobian).subs(params))
fp['lambda1'] = (fp['tau'] - sp.sqrt(fp['tau']**2 - 4*fp['delta']))/2
fp['lambda2'] = (fp['tau'] + sp.sqrt(fp['tau']**2 - 4*fp['delta']))/2
l1_re, l1_im = fp['lambda1'].as_real_imag()
l2_re, l2_im = fp['lambda2'].as_real_imag()
# handle case of division per 0
if ((fp['tau']==sp.nan) | (fp['delta']==sp.nan)):
fp['type'] = None
# special cases
elif fp['tau']==0:
fp['type'] = 'center'
elif ((fp['tau']**2 - 4*fp['delta'])==0):
if fp['tau'] > 0:
fp['type'] = 'degenerate source'
elif fp['tau'] < 0:
fp['type'] = 'degenerate sink'
else:
fp['type'] = 'uniform motion'
# real eigenvalues
elif ((l1_im==0) & (l2_im==0)):
# saddle
if (((l1_re>0) & (l2_re<0)) | ((l1_re<0) & (l2_re>0))):
fp['type'] = 'saddle'
# unstable
elif ((l1_re>0) & (l2_re>0)):
fp['type'] = 'unstable node'
elif ((l1_re<0) & (l2_re<0)):
fp['type'] = 'stable node'
# complex eigenvalues
else:
if ((l1_re>0) & (l2_re>0)):
fp['type'] = 'unstable focus'
elif ((l1_re<0) & (l2_re<0)):
fp['type'] = 'stable focus'
def launch_stability_analysis(model, order_params, default_params, out_queue, args):
""" Ghost process that launches the stability analysis for a set of defined order parameter,
creating a child process with a set timeout per child process """
proc = multiprocessing.Process(target=perform_stability_analysis, args=(model, order_params, default_params, out_queue))
proc.start()
# wait for the process until timeout
proc.join(args.timeout)
# if process is still running after timeout, force terminating it
if proc.is_alive():
print(f"Stability analysis for {order_params} took too long, likely a numerical issue, aborted after {args.timeout}s")
proc.terminate()
def run_stability_analysis(model, order_params, default_params, args):
# lambdify characterstic equation based on variables of interest
#charPars = dict((k,v) for k,v in default_params.items() if k not in order_params.keys())
#variables = (model.S2, *(getattr(model,var) for var in order_params.keys()))
#lambdified = sp.lambdify(variables, model.charEq.subs(charPars), modules=['mpmath'])
# debug
#outputs = []
#for vals in itertools.product(*order_params.values()):
# out = perform_stability_analysis(copy.deepcopy(model), dict(zip(order_params.keys(), vals)), copy.deepcopy(default_params))
# outputs.append(out)
#outputs = joblib.Parallel(n_jobs=32, verbose=10, timeout=20)(joblib.delayed(perform_stability_analysis)(copy.deepcopy(model), dict(zip(order_params.keys(), vals)), copy.deepcopy(default_params))
# for vals in itertools.product(*order_params.values()))
o_pars = list([dict(zip(order_params.keys(), vals)) for vals in itertools.product(*order_params.values())])
n_pars = len(o_pars)
print("Run stability analysis...")
out_queue = multiprocessing.Queue(maxsize=n_pars)
futures = []
with ThreadPoolExecutor(max_workers=args.n_jobs) as pool:
for order_param in o_pars:
future = pool.submit(launch_stability_analysis, copy.deepcopy(model), copy.deepcopy(order_param), copy.deepcopy(default_params), out_queue, args)
futures.append(future)
#print(futures)
outputs = []
while not out_queue.empty():
outputs.append(out_queue.get_nowait())
return outputs, futures
[docs]
def plot_3d_bifurcations(outputs, azim=0, elev=0):
""" plot bifurcation diagram in 3D """
node_colors = {'saddle': 'purple', 'unstable node': 'red', 'stable node':'blue', 'unstable focus':'magenta', 'stable focus': 'green',
'degenerate source':'black', 'degenerate sink':'black', 'uniform motion':'black'}
fig = plt.figure(figsize=[20,8])
ax1 = fig.add_subplot(1,2,1, projection='3d')
ax1.view_init(azim=azim,elev=elev)
ax2 = fig.add_subplot(1,2,2, projection='3d')
ax2.view_init(azim=azim,elev=elev)
for output in outputs:
o_pars = np.sort([k for k in output.keys() if k!='fps'])
for fp in output['fps']:
if fp['type']!=None:
marker = matplotlib.markers.MarkerStyle('o', fillstyle='full')
opts = {'marker':marker, 'color':node_colors[fp['type']], 'alpha':0.2}
ax1.scatter(output[o_pars[0]], output[o_pars[1]], fp['S1'], **opts)
ax2.scatter(output[o_pars[0]], output[o_pars[1]], fp['S2'], **opts)
ax1.set_xlim(-1,1)
ax1.set_ylim(-1,1)
ax1.set_xlabel(o_pars[0])
ax1.set_ylabel(o_pars[1])
ax2.set_xlim(-1,1)
ax2.set_ylim(-1,1)
ax2.set_xlabel(o_pars[0])
ax2.set_ylabel(o_pars[1])
plt.show()
def get_model(args):
""" create or load symbolic model """
if args.create_model:
print("Creating model..")
#sym_rww = symRWW_2D()
sym_rww = pw_RWW_2D()
if args.compute_nullclines:
print("Computing nullclines.. (takes a few minutes)")
t0 = time.time()
sym_rww.compute_characteristic_eq()
print('Done in {:0.1f}s'.format(time.time()-t0))
if args.save_model:
print("Saving model...")
with open(os.path.join(proj_dir, 'postprocessing', 'sym_pw_rww.pkl'), 'wb') as f:
pickle.dump(sym_rww, f)
else:
with open(os.path.join(proj_dir, 'postprocessing', 'sym_rww.pkl'), 'rb') as f:
sym_rww = pickle.load(f)
return sym_rww
[docs]
def lambdify_model(model, default_params, order_params):
""" evaluate model for parameters using sympa lambdify.
:param model: symbolic model (object)
:param params: dictionnary of parameters
"""
params = copy.deepcopy(default_params)
for k,v in order_params.items():
params[k] = v
try:
f = sp.lambdify([model.S1, model.S2], model.dS.subs(params), 'numpy')
n1_f = sp.lambdify(model.S1, model.n1[0].subs(params), 'scipy') # need scipy because uses LambertW
n2_f = sp.lambdify(model.S2, model.n2[0].subs(params), 'scipy')
except:
print("{} did encounter problems, discarding")
return None
return (f, n1_f, n2_f)
def evaluate_params(model, default_params, order_params, args):
""" evaluate model of parameters using sympa lambdify """
o_pars = list([dict(zip(order_params.keys(), vals)) for vals in itertools.product(*order_params.values())]) #
evals = joblib.Parallel(n_jobs=args.n_jobs)(joblib.delayed(lambdify_model)(model,default_params,o_par) for o_par in o_pars)
return evals, o_pars
def plot_quiver(f, n1_f, n2_f, o_par, smin=-3, smax=3, n=30, ax=None, scale=8, args=None):
""" plot a single quiver.
:param n: number of arrows per line
"""
s = np.linspace(smin,smax,n*10) # number of data point for curves
s_q = np.linspace(smin,smax,n) # number of arrow per row/column
s1,s2 = np.meshgrid(s_q,s_q)
u,v = f(s1,s2)
if ax==None:
plt.figure()
else:
plt.sca(ax)
plt.quiver(s1,s2,u.squeeze(),v.squeeze(), scale=scale)
plt.plot(s, n2_f(s))
plt.plot(n1_f(s), s)
plt.xlim(smin,smax)
plt.ylim(smin,smax)
plt.xlabel('S2')
plt.ylabel('S1')
ttl = " ".join(['{} = {:.2f}'.format(k,v) for k,v in o_par.items()])
plt.title(ttl, fontsize=8)
plt.grid()
def plot_quivers_grid(evals, o_pars, order_params, args):
""" plot a grid of quiver graphs for each pair of order parameter """
xs = list(order_params.values())[0]
ys = list(order_params.values())[1]
evs = np.reshape(evals, (len(xs), len(ys), -1))
opars = np.reshape(o_pars, (len(xs), len(ys)))
plt.rcParams.update({'font.size':8})
fig = plt.figure(figsize=[30,30])
gs = plt.GridSpec(len(xs), len(ys))
for i,x in enumerate(xs):
for j,y in enumerate(ys):
ev = evs[i,j]
if (ev==None).any():
continue
else:
f ,n1_f, n2_f = ev
ax = fig.add_subplot(gs[i,j])
plot_quiver(f, n1_f, n2_f, opars[i,j], ax=ax, args=args)
plt.tight_layout()
plt.show()
def get_parser():
""" parsing global argument """
parser = argparse.ArgumentParser()
parser.add_argument('--create_model', default=False, action='store_true', help='create symbolic model')
parser.add_argument('--compute_nullclines', default=False, action='store_true', help='create symbolic model')
parser.add_argument('--save_model', default=False, action='store_true', help='save symbolic model with its ciomputed attributes')
parser.add_argument('--run_stability_analysis', default=False, action='store_true', help='run stability analysis: find fixed point (semi-analytically) and perform linear stability analysis around them')
parser.add_argument('--plot_quivers', default=False, action='store_true', help='plot figures')
parser.add_argument('--plot_figs', default=False, action='store_true', help='plot figures')
parser.add_argument('--save_outputs', default=False, action='store_true', help='save analysis outputs')
parser.add_argument('--timeout', type=int, default=30, action='store', help='timeout of the stability analysis (per parameter combination invoked)')
parser.add_argument('--n_jobs', type=int, default=12, action='store', help='numper of processes used in parallelization')
return parser
if __name__=='__main__':
args = get_parser().parse_args()
sym_rww = get_model(args)
order_params = {'C_12': np.linspace(-0.3,0.3,6), 'C_21': np.linspace(-0.3,0.3,6)}
if args.run_stability_analysis:
outputs, futures = run_stability_analysis(sym_rww, order_params, params, args)
if args.save_outputs:
today = datetime.datetime.now().strftime("%Y%m%d")
fname = 'outputs_'+today+'.pkl'
with open(os.path.join(proj_dir, 'postprocessing', fname), 'wb') as f:
pickle.dump(outputs, f)
if args.plot_figs:
plot_3d_bifurcations(outputs, azim=0, elev=0) # yz plan
plot_3d_bifurcations(outputs, -90, 0) # xz plan
plot_3d_bifurcations(outputs, 90, -90) # xy plan
if args.plot_quivers:
evals, o_pars = evaluate_params(sym_rww, params, order_params, args)
plot_quivers_grid(evals, o_pars, order_params, args)