Source code for OCD_modeling.models.ReducedWongWang

import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import os
import pandas as pd
import pickle
import platform
import scipy
import scipy.stats
import sklearn
from sklearn import preprocessing
import time

from OCD_modeling.utils.utils import get_working_dir
#from OCD_modeling.models.HemodynamicResponseModeling.BalloonWindkessel import balloonWindkessel
#from OCD_modeling.utils.neurolib.neurolib.models.bold.timeIntegration import simulateBOLD
#from OCD_modeling.utils import simulateBOLD
from neurolib.models.bold import simulateBOLD

# get computer name to set paths
working_dir = get_working_dir()

# general paths
proj_dir = working_dir+'lab_lucac/sebastiN/projects/OCD_modeling'

class ReducedWongWang:
    """ Reduced Wong Wang model (1-dimensional) """
    def __init__(self, a=270., b=108., d=0.154,
                 I_0=0.3, J_N=0.2609, w=0.9, G=1., C=0., S_j=0.,
                 tau_S=100., gamma=0.000641, sigma=0.001, v_i=0.):

        # synaptic gating params
        self.a = a             # slope (n/C); default=270
        self.b = b             # offset (Hz); default=108
        self.d = d             # decay (s);   default=0.154

        # firing rate params
        self.I_0 = I_0         # external input (nA); default=0.3
        self.J_N = J_N         # synaptic coupling (nA); default=0.2609
        self.w = w             # local exc recurrence (n/a); default=0.9
        self.G = G             # global scaling factor (n/a); default=1
        self.C = C             # connectivity matrix (n/a); default=0
        self.S_j = S_j         # coupled population firing rates (Hz); default=0

        # ODE params
        self.tau_S = tau_S      # kinetic parameter of local population (ms); default=100
        self.gamma = gamma     # kinetic parameter of coupled population (ms); default=0.000641
        self.sigma = sigma     # noise amplitude (in node) (nA); default=0.001
        self.v_i = v_i         # gaussian noise (n/a); default=0

    def H(self, x):
        """ Average synaptic gating.
    
        a: slope (n/C); default=270
        b: offset (Hz); default=108
        d: decay (s);   default=0.154
        
        """
        return (self.a * x - self.b) / (1. - np.exp(-self.d * (self.a * x - self.b)))

    def S_i(self, x, S_j=0):
        """ Firing rate.
    
        I_0: external input (nA); default=0.3
        J_N: synaptic coupling (nA); default=0.2609
        w: local exc recurrence (n/a); default=0.9
        G: global scaling factor (n/a); default=1
        C: connectivity matrix (n/a); default=0
        S_j: coupled population firing rates (Hz); default=0
    
            """
        return (x - self.I_0 - self.G * self.J_N * self.C * S_j) / (self.w * self.J_N)

    def dS_i(self, x, v_i=0):
        """ ODE of firing rate.
        
        tau_S: kinetic parameter of local population (ms); default=100
        gamma: kinetic parameter of coupled population (ms); default=0.000641
        sigma: noise amplitude (nA); default=0.001
        v_i: gaussian noise (n/a); default=0
        
        """
        return (-S_i(x)/self.tau_S + (1 - S_i(x)) * self.gamma * H(x) + self.sigma*v_i)

    def S_nc(self, x):
        " Nullcline (dS/dt=0) "
        return (-self.tau_S*self.gamma*H(x) / (-1.-self.tau_S*self.gamma*H(x)))





[docs] class ReducedWongWangND: """ Reduced Wong Wang model (N-dimensional) """
[docs] def __init__(self, a=270., b=108., d=0.154, I_0=0.3, J_N=0.2609, w=0.9, G=1., tau_S=100., gamma=0.000641, sigma=0.001, N=2, dt=0.01, C=None, S=None, x=None, *args, **kwargs): # synaptic gating params self.a = a # slope (n/C); default=270 self.b = b # offset (Hz); default=108 self.d = d # decay (s); default=0.154 # firing rate params if type(I_0)==np.ndarray: self.I_0 = I_0 else: self.I_0 = np.ones(N,)*I_0 #self.I_0 = I_0 # external input (nA); default=0.3 self.J_N = J_N # synaptic coupling (nA); default=0.2609 self.w = w # local exc recurrence (n/a); default=0.9 self.G = G # global scaling factor (n/a); default=1 if S is None: self.S = np.array(np.random.rand(N,)*0.1+0.05, dtype=np.float64) # coupled population firing rates (Hz); else: self.S = np.array(S).squeeze().astype(np.float64) self.x = np.array(np.random.rand(N,)*0.1+0.05, dtype=np.float64) # coupled population activity (Hz); default = uniformly distributed # ODE params self.tau_S = tau_S # kinetic parameter of local population (ms); default=100 self.gamma = gamma # kinetic parameter of coupled population (ms); default=0.000641 self.sigma = sigma # noise amplitude (nA); default=0.001 #self.v = v_0 # gaussian noise (n/a); default=0 self.dt = dt self.N = N # number of units if C is None: self.C = np.random.randn(N,N) * (1-np.eye(N,N)) # connectivity matrix (n/a); default= 0 self, 1 to others elif type(C)==list: self.C = np.array(C) else: self.C = C self.control_params = {} # initialize empty control parameters, they will be set after the model is instanciated
[docs] def H(self, x): """ Average synaptic gating. a: slope (n/C); default=270 b: offset (Hz); default=108 d: decay (s); default=0.154 """ return (self.a * x - self.b) / (1. - np.exp(-self.d * (self.a * x - self.b)))
[docs] def v(self): """ Implementing the Gaussian white noise process. """ return np.array(np.random.randn(self.N,), dtype=np.float64)
[docs] def dS(self): """ ODE of firing rate. tau_S: kinetic parameter of local population (ms); default=100 gamma: kinetic parameter of coupled population (ms); default=0.000641 sigma: noise amplitude (nA); default=0.001 v_i: gaussian noise (n/a); default=0 """ return (-self.S/self.tau_S + (1 - self.S) @ (self.gamma * self.H(x=self.x)) + self.sigma*self.v())
[docs] def integrate(self): """ Euler(-Maruyama) integration of the ODE """ self.x = self.w * self.J_N * self.S + self.G * self.J_N * self.C @ self.S + self.I_0 H = self.H(x=self.x) v = self.v() dS = (-self.S/self.tau_S + (1 - self.S) * (self.gamma * H) + self.sigma*v) self.S = self.S + dS*self.dt
[docs] def run(self, t_tot=1000, sf=100, t_rec=None, rec_vars=[]): """ Runs the model. Parameters ---------- t_tot: int Total simulation time (s). sf: int Sampling frequency of the reccording (Hz). t_rec: list Interval of recording (s) in the form ``[start, stop]``. rec_vars: list Variables to records (note that S is always recorded). """ n_ts = int(t_tot/self.dt) sf_dt = 1./(sf*self.dt) # fix sf if needed based on dt if sf_dt.is_integer(): self.sf = sf else: # case of sf higher than dt permits if sf_dt<1: sf_dt=1 # case of sf not multiple of dt else: sf_dt = np.floor(sf_dt) self.sf = 1. / self.dt / sf_dt print("Sampling frequency not a multiple of dt, used sf={} instead".format(self.sf)) # set recording time if unspecified if t_rec==None: t_rec = [0,t_tot] n_rec = int((t_rec[1]-t_rec[0])*self.sf) # number of steps recorded # prepare variables to record self.t = np.arange(t_rec[0],t_rec[1], 1./self.sf) self.S_rec = np.zeros((n_rec,self.N)) self.prepare_auxiliary_variables(rec_vars, n_rec) self.prepare_control_params() # run the model rec_idx = 0 for i in range(n_ts): self.integrate() # record and update control parameters if ((not i%sf_dt) & ((i*self.dt)>=t_rec[0]) & ((i*self.dt)<=t_rec[1])): self.S_rec[rec_idx,:] = self.S.copy() self.record_auxiliary_variables(rec_vars, rec_idx) self.update_control_params(rec_idx) rec_idx += 1
[docs] def set_control_params(self, params:dict): """ Set the parameters to be updated during the simulation (e.g. a slow control parameter). Parameters ---------- params: dict Parameters to be udpated, keys of this dict must match parameters names of the model. values of the dictionary are list of tuple indicating times and values of the parameter to be updated, i.e.: ``params = {I_0: [ (t0,v0), (t1,v1), (t2,v2), ... ]}`` Note that the update is linear monotonic between referenced points, and the update frequency used is the sampling frequency (SF). Control parameter can only be changed during recording period. """ for k,v in params.items(): if not hasattr(self,k): if not k.startswith('C_'): print(f"Control parameter {k} does not exist in model, it won't be updated") params.pop(k) self.control_params = params
[docs] def load_C_param(self, par, vals): """ Special case for connectivity control parameters as it is an array """ # load control param if already exists (i.e. other C indices have been set), otherwise create it if hasattr(self, 'update_C'): update_par = self.update_C else: update_par = np.array([self.C.copy() for _ in range(self.t.shape[0])]) #initialize control param to default value ij = par.split('_')[1] i,j = int(ij[0])-1,int(ij[1])-1 ind_0 = 0 val_0 = update_par[0][i,j] for t,v in vals: if ( (self.t.min()>t) or (self.t.max()<(t-1)) ): print(f"Control parameter {par} cannot be changed if simulation is not recorded during those set times, change recording times or parameter update time ") break ind_t = np.abs(self.t - t).argmin() # get closest t index n = ind_t - ind_0 # number of values to fill between times new_vals = np.linspace(val_0,v,n) for k,C in enumerate(update_par[ind_0:ind_t]): C[i,j] = new_vals[k] ind_0,val_0 = ind_t,v self.update_C = update_par
[docs] def prepare_control_params(self): """ Prepare the control parameters to be updated during the simulation """ for par,vals in self.control_params.items(): vals.sort() # makes sure the times are set in increasing ordered if par.startswith('C_'): self.load_C_param(par,vals) else: update_par = np.ones(self.t.shape) * getattr(self, par) #initialize control param to default value ind_0 = 0 val_0 = update_par[0] for t,v in vals: if ( (self.t.min()>t) or (self.t.max()<t) ): print(f"Control parameter {par} cannot be changed if simulation is not recorded during those set times, change recording times or parameter update time ") break ind_t = np.abs(self.t - t).argmin() # get closest t index n = ind_t - ind_0 # number of values to fill between times update_par[ind_0:ind_t] = np.linspace(val_0,v,n) ind_0,val_0 = ind_t,v setattr(self, 'update_'+par, update_par)
[docs] def prepare_auxiliary_variables(self, rec_vars, n_rec): """ creates dtata structures to recorded supplementary variables """ for var in rec_vars: if hasattr(self, var): setattr(self, 'rec_'+var, np.zeros((n_rec,))) # <-- assumes single variable, if vector or matrice it needs different shape elif var.startswith('C_'): setattr(self, 'rec_'+var, np.zeros((n_rec,))) else: print(f"Variable {var} does not seem to exist, it cannot be recorded.") rec_vars.remove(var)
[docs] def record_auxiliary_variables(self, rec_vars, rec_idx): """ Record variables other than S """ for var in rec_vars: if hasattr(self, var): val = getattr(self, var) elif var.startswith('C_'): ij = var.split('_')[1] i,j = int(ij[0])-1,int(ij[1])-1 val = self.C[i,j] else: continue rec_var = getattr(self, 'rec_'+var) rec_var[rec_idx] = val.copy()
[docs] def update_control_params(self, index): """ Update the control parameters during the simulation """ for par in self.control_params.keys(): # drawback of connectivity control parameters: C is loaded n times if n C_ij's are modidfied at the same time (not a big deal) if par.startswith('C_'): par = 'C' update_par = getattr(self, 'update_'+par) setattr(self, par, update_par[index])
[docs] class ReducedWongWangOU(ReducedWongWangND): """ Reduced Wong-Wang model with Ornsetin-Uhlenbeck process for coupling (n dimensions) """
[docs] def __init__(self, N=4, sigma_C=[], eta_C=[], *args, **kwargs): super().__init__(N=N, *args, **kwargs) self.vC = self.C.copy() # variable connectivity variables if len(sigma_C)==0: self.sigma_C = np.zeros((N,N)) else: self.sigma_C = np.array(sigma_C) if len(eta_C)==0: self.eta_C = np.zeros((N,N)) else: self.eta_C = np.array(eta_C)
[docs] def integrate(self): """ Euler(-Maruyama) integration of the ODE """ self.x = self.w * self.J_N * self.S + self.G * self.J_N * self.vC @ self.S + self.I_0 H = self.H(x=self.x) v = self.v() dS = (-self.S/self.tau_S + (1 - self.S) * (self.gamma * H) + self.sigma*v) dvC = -self.eta_C*(self.vC - self.C) + self.sigma_C*np.random.randn(self.N, self.N) self.S = self.S + dS*self.dt self.vC = self.vC + dvC*self.dt
[docs] def record_auxiliary_variables(self, rec_vars, rec_idx): """ Record auxiliary variables other than S """ for var in rec_vars: if hasattr(self, var): val = getattr(self, var) elif var.startswith('C_'): ij = var.split('_')[1] i,j = int(ij[0])-1,int(ij[1])-1 val = self.vC[i,j] else: continue rec_var = getattr(self, 'rec_'+var) rec_var[rec_idx] = val.copy()
# POST PROCESSING FUNCTIONS # # --------------------------- #
[docs] def compute_bold(model, t_range=None, transient=30): """ BOLD timeseries and functional connectivity between regions. Parameters ---------- model: OCD_modeling.models.ReducedWongWang. Model object. t_range: list Times of interest (in sec), in the form ``[start, stop]``. Default: all recorded time. transient: int Time discarded at the beginning of t_range due to BOLD transient (in sec). Default: 30s. """ inds = get_inds(model, t_range) #bold_ts, s, f, v, q = balloonWindkessel(model.S_rec[inds,:].T, 1./model.sf) #scaler = sklearn.preprocessing.MinMaxScaler() #ts = scaler.fit_transform(model.S_rec[inds,:]) ts = model.S_rec[inds,:] bold_ts, x, f, q, v = simulateBOLD(ts.T, 1./model.sf, voxelCounts=None) model.bold_ts = bold_ts[:,int(model.sf*transient):] # discard first 10 sec due to transient model.bold_t = model.t[inds[int(model.sf*transient):]] model.bold_fc = np.corrcoef(model.bold_ts)
def compute_transitions(model, threshold=0.3, min_diff=3, t_range=None): """ Compute the number of transitions between low and high activity states """ inds = get_inds(model, t_range) ts = model.S_rec[inds,:] ts_thr = ts > threshold transitions = dict() for i in range(model.N): bin_ts = ts_thr[:,i] trans = np.where((np.roll(bin_ts,1) != bin_ts))[0] / model.sf # handle supra-threshold initial conditions, need to remove first element if len(trans)>1: if trans[0]==0: trans = trans[1:] trans_inds = np.where(np.diff(trans)>min_diff)[0]+1 trans_inds = np.concatenate([[0], trans_inds]) elif len(trans)==1: if trans[0]==0: trans = [] trans_inds = [] else: trans_inds = [0] elif len(trans)==0: trans_inds = [] transitions['S'+str(i+1)] = np.array(trans[trans_inds], dtype=int) model.transitions = transitions def compute_strFr_stats(model, t_range=None, thr=0, rec_vars=['C_12']): """ Compute transition times, dwell times, and number of transitions in striato-frontal circuit(s) """ inds = get_inds(model, t_range) output = dict() for var in rec_vars: ts = getattr(model, 'rec_'+var) ts_thr = ts > thr transitions = np.diff(ts_thr) # 1 = indirect to direct; -1: direct to indirect transitions_inds, = np.where(transitions!=0) transitions_times = model.t[transitions_inds] dwell_time = np.diff(transitions_times) dwell_times = dict() if ts_thr[0]>0: # starts direct dwell_times['direct'] = dwell_time[0::2] dwell_times['indirect'] = dwell_time[1::2] else: # starts indirect dwell_times['indirect'] = dwell_time[0::2] dwell_times['direct'] = dwell_time[1::2] output[var] = {'n_transitions':len(transitions_inds), 'transitions_times':transitions_times, 'dwell_times':dwell_times} model.strFr_stats = output
[docs] def create_sim_df(sim_objs, sim_type = 'sim-con', offset=0): """ Make a pandas DataFrame from list of simulation outputs objects """ if sim_objs[0].N == 4: var_names = ['OFC', 'PFC', 'NAcc', 'Put'] pathway_map = {'OFC-PFC': 'OFC_PFC', 'OFC-NAcc': 'Acc_OFC', 'OFC-Put':'dPut_OFC', 'PFC-NAcc':'Acc_PFC', 'PFC-Put':'dPut_PFC', 'NAcc-Put':'Acc_dPut'} elif sim_objs[0].N == 6: var_names = ['OFC', 'PFC', 'NAcc', 'Put', 'DP', 'VA'] pathway_map = {'OFC-PFC': 'OFC_PFC', 'OFC-NAcc': 'Acc_OFC', 'OFC-Put':'dPut_OFC', 'OFC-DP':'dpThal_OFC', 'OFC-VA':'vaThal_OFC', 'PFC-NAcc':'Acc_PFC', 'PFC-Put':'dPut_PFC', 'PFC-DP':'dpThal_PFC', 'PFC-VA':'vaThal_PFC', 'NAcc-Put':'Acc_dPut', 'NAcc-DP':'Acc_dpThal', 'NAcc-VA':'Acc_vaThal', 'Put-DP':'dPut_dpThal', 'Put-VA':'dPut_vaThal', 'DP-VA': 'vaThal_dpThal'} else: print('Cannot create sim_df if N!=4 or N!=6') return lines = [] for i,sim in enumerate(sim_objs): fc = sim.bold_fc for j in np.arange(sim.N): for k in np.arange(j+1,sim.N): val = fc[j,k] c = '-'.join([var_names[j], var_names[k]]) pathway = pathway_map[c] line = dict() line['subj'] = sim_type+'{:06d}'.format(offset+i+1) line['cohort'] = sim_type line['pathway'] = pathway line['corr'] = val lines.append(line) df_sim_fc = pd.DataFrame(lines) return df_sim_fc
def distance(x,y): """ distance to minimize based on score """ #return 1 - x['r'] + x['corr_diff'] return x['RMSE'] def get_inds(model, t_range=None): """ extract time series indices of interest beased on t_range (in sec) """ if t_range==None: t_range=[model.t.min(), model.t.max()] inds, = np.where((model.t>=t_range[0]) & (model.t<=t_range[1])) return inds
[docs] def score_model(rww, coh='con'): """ Score single model against empirical FC (only considering mean). :param rww: instance of model to score. :param coh: cohort to be scored against ('con' or 'pat'). (that is used when optimizing single models, not populations of models) """ # load empirical FC with open(os.path.join(proj_dir, 'postprocessing', 'R.pkl'), 'rb') as f: R = pickle.load(f) # compute score output = dict() fix_inds = [2,3,0,1] # NAcc Put OFC PFC -> OFC PFC NAcc Put fixed_inds = np.ix_(fix_inds,fix_inds) triu_inds = np.triu_indices(rww.N,k=1) corr = rww.bold_fc corrData, corrModel = R[coh][fixed_inds][triu_inds].flatten(), corr[triu_inds].flatten() corr_MAE = np.sum(np.abs(corrData - corrModel))/len(corrData) corr_RMSE = np.sqrt(np.sum((corrData - corrModel)**2)/len(corrData)) r,pval = scipy.stats.pearsonr(corrData, corrModel) return r, corr_MAE, corr_RMSE
[docs] def score_population_models(sim_objs, cohort='controls'): """ Score a population of simulated model (using a parameter set) against experimental observations. Here, the whole distribution of models outputs is scored against the distributions of observations. """ # load empirical FC if sim_objs[0].N==4: with open(os.path.join(proj_dir, 'postprocessing', 'df_roi_corr_avg_2023.pkl'), 'rb') as f: df_roi_corr = pickle.load(f) if sim_objs[0].N==6: with open(os.path.join(proj_dir, 'postprocessing', 'df_roi_corr_avg_2024_Thal.pkl'), 'rb') as f: df_roi_corr = pickle.load(f) # create simulated FC dataframe sim_type = 'sim-'+cohort df_sim_fc = create_sim_df(sim_objs, sim_type=sim_type) df = df_roi_corr[df_roi_corr.cohort==cohort].merge(df_sim_fc, how='outer') # compute root mean square error RMSE = [] for pathway in df.pathway.unique(): obs = df[(df.pathway==pathway) & (df.cohort==cohort)]['corr'] sim = df[(df.pathway==pathway) & (df.cohort==sim_type)]['corr'] RMSE.append((np.mean(obs)-np.mean(sim))**2) RMSE.append((np.std(obs)-np.std(sim))**2) RMSE = np.sqrt(np.sum(RMSE)/len(RMSE)) return RMSE
# PLOTTING FUNCTIONS # #----------------------#
[docs] def plot_timeseries(model, t_range=None, labels=['OFC', 'PFC', 'NAcc', 'Put']): """ visualize time serie generated by model :param model: ReducedWangWang object. """ plt.figure(figsize=[16,4]) inds = get_inds(model, t_range) plt.plot(model.t[inds],model.S_rec[inds,:]) #plt.legend([str(i) for i in range(model.N)]) plt.legend(labels, loc='upper left', bbox_to_anchor=[1,1]) #plt.title("sigma={:.4f} G={:.1f} C={}".format(model.sigma, model.G, model.C)) plt.show()
[docs] def plot_control_params(model, t_range=None, labels=[]): """ Visualize time serie of control parameters. :param model: ReducedWangWang object. """ n_pars = len(list(model.control_params.keys())) inds = get_inds(model, t_range) plt.figure(figsize=[16,2*n_pars]) for k,(par,vals) in enumerate(model.control_params.items()): if par.startswith('C_'): ij = par.split('_')[1] i,j = int(ij[0])-1,int(ij[1])-1 ts = model.update_C[inds][:,i,j] else: update_par = getattr(model, 'update_'+par) ts = update_par[inds] plt.subplot(n_pars,1, k+1) plt.plot(model.t[inds], ts) #plt.legend([str(i) for i in range(model.N)]) plt.legend([par], loc='upper right') plt.show()
[docs] def plot_auxiliary_variables(model, t_range=None, rec_vars=[]): """ Visualize time serie generated by model :param rww: ReducedWangWang object. """ n = len(rec_vars) if n>0: plt.figure(figsize=[16,2*n]) inds = get_inds(model, t_range) for var in rec_vars: ts = getattr(model, 'rec_'+var) plt.plot(model.t[inds], ts) plt.legend(rec_vars) plt.title("eta_C={} sigma_C={}".format(model.eta_C, model.sigma_C)) plt.show()
[docs] def plot_bold(model, labels=['OFC', 'PFC', 'NAcc', 'Putamen'], colors=['blue', 'green', 'red', 'magenta']): """ plot BOLD timeseries and FC """ fig = plt.figure(figsize=[16,4]) gs = plt.GridSpec(1,2, width_ratios=[4,1]) ax1 = fig.add_subplot(gs[0,0]) for i in range(model.N): ts = model.bold_ts.T[:,i] ts = (ts-np.mean(ts)) / np.std(ts) - 3*i ax1.plot(model.bold_t, ts, color=colors[i], alpha=0.3) ax1.set_yticks(-np.linspace(0,4*model.N, model.N)[::-1]) #([-12, -8, -4, 0]) ax1.set_yticklabels(labels[::-1]) ax1.legend(labels, loc='upper right') ax1.spines.top.set_visible(False) ax1.spines.right.set_visible(False) ax1.set_xlabel('time (s)') ax1.set_ylabel('Normalized BOLD signal') ax2 = fig.add_subplot(gs[0,1]) img = ax2.imshow(model.bold_fc, vmin=-1, vmax=1, cmap='RdBu_r') #plt.colorbar(img) plt.xticks(np.arange(len(labels)), labels, rotation=60) plt.yticks(np.arange(len(labels)), labels) ax2.set_title('Functional Connectivity') plt.show()
[docs] def plot_correlations(rww, t_range=None): """ Visualize correlation between timeseries generated by model (S_rec, not BOLD). :param rww: ReducedWangWang object. """ plt.figure(figsize=[4,4]) inds = get_inds(rww, t_range) corr = np.corrcoef(rww.S_rec[inds,:].T) plt.imshow(corr, vmin=-1, vmax=1, cmap='RdBu_r') plt.xticks([0,1,2,3], ['OFC', 'PFC', 'NAcc', 'Put']) plt.yticks([0,1,2,3], ['OFC', 'PFC', 'NAcc', 'Put']) #plt.colorbar() plt.show()