Source code for OCD_modeling.hpc.parallel_launcher

###     Parallel simulations launched on HPC or local machine 
##
##      Author: Sebastien Naze
#
#       QIMR Berghofer 2023

import argparse 
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, wait, as_completed
from datetime import datetime
import joblib
import json
import multiprocessing
import numpy as np
import os
import pandas as pd
import pickle

from OCD_modeling.models import ReducedWongWang as RWW
from OCD_modeling.utils.utils import *

[docs] def run_sim(model_pars, sim_pars, control_pars={}, bold_pars={}): """ Run a single simulation. Parameters ---------- model_pars: dict Model parameters (e.g. couplings, noise, external inputs, etc.) sim_pars: dict Simulation parameters (e.g. simulation time, reccording intervals, sampling frequency, etc.) control_pars: dict Control parameters. In case parameters needs to be manually updated throughout the simulation, provide the parameters' timeseries here. bold_pars: dict BOLD parameters. Dictionary with keys t_range (the time range on which to compute the BOLD signal), and transient (the time to discard due to transient at hthe beginning on the convolution). Returns ------- rww_sim: OCD_modeling.models.ReducedWongWangOU Simulation object, containing raw and processed data (BOLD timeseries, functional connectivity, transitions). """ # run simulation rww_sim = RWW.ReducedWongWangOU(**model_pars) rww_sim.set_control_params(control_pars) rww_sim.run(**sim_pars) # analyze traces RWW.compute_bold(rww_sim, **bold_pars) #RWW.compute_transitions(rww_sim) #RWW.compute_strFr_stats(rww_sim) return rww_sim
def launch_simulations(args): """ Launch N simulations in parallel """ sim_objs = joblib.Parallel(n_jobs=args.n_jobs)(joblib.delayed(run_sim)(args.model_pars, args.sim_pars, args.control_pars, args.bold_pars) for _ in range(args.n_sims)) return sim_objs
[docs] def launch_pool_simulations(args): """ Launch N simulations in parallel using a Pool Executor. Parameters ---------- args: Argparse.Namespace Structure containing the information of the siumulation to be performed: model, simulation, control and BOLD parameters. Returns ------- sim_objs: list of OCD_modeling.models.ReducedWongWangOU Simulation objects after processing. """ #sim_objs = [] with ProcessPoolExecutor(max_workers=args.n_jobs, mp_context=multiprocessing.get_context('spawn')) as pool: #with ProcessPoolExecutor(max_workers=args.n_jobs) as pool: #print(pool._mp_context) #futures = {pool.submit(run_sim, args.model_pars, args.sim_pars, args.control_pars, args.bold_pars) : i for i in range(args.n_sims)} #for future in as_completed(futures): # sim_objs.append(future.result()) sim_objs = pool.map(run_sim, np.repeat(args.model_pars, args.n_sims), np.repeat(args.sim_pars, args.n_sims), np.repeat(args.control_pars, args.n_sims), np.repeat(args.bold_pars, args.n_sims)) #for out in outs: # sim_objs.append(*out) return list(sim_objs)
def save_batch(sim_objs, args): """ Save simulation runs as objects list in pickle file """ if args.save_outputs: today = datetime.today().strftime('%Y%m%d') with open(os.path.join(proj_dir, 'postprocessing', 'sim_objs_'+today+'.pkl'), 'wb') as f: pickle.dump(sim_objs, f) def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument('--save_figs', default=False, action='store_true', help='save figures') parser.add_argument('--save_outputs', default=False, action='store_true', help='save outputs') parser.add_argument('--save_scores', default=False, action='store_true', help='save scores') parser.add_argument('--n_jobs', type=int, default=10, action='store', help="number of parallel processes launched") parser.add_argument('--n_sims', type=int, default=50, action='store', help="number of simulations ran with the same parameters (e.g. to get distribution that can be campared to clinical observations)") parser.add_argument('--params_idx', type=int, default=0, action='store', help="index of the parameter set") parser.add_argument('--plot_figs', default=False, action='store_true', help='plot figures') parser.add_argument('--batch_id', type=str, default="test", action='store', help='batch number unique to each batched job launched on cluster, typically YYYYMMDDD_hhmmss') parser.add_argument('--model_pars', type=json.loads, default={}, action='store', help="dictionary of model parameters") parser.add_argument('--sim_pars', type=json.loads, default={}, action='store', help="dictionary of simulation (run) parameters") parser.add_argument('--control_pars', type=json.loads, default={}, action='store', help="dictionary of control parameters") parser.add_argument('--bold_pars', type=json.loads, default={}, action='store', help="dictionary of BOLD recording parameters") args = parser.parse_args() return args if __name__=='__main__': args = parse_arguments() sim_objs = launch_pool_simulations(args) save_batch(sim_objs, args)