Source code for OCD_modeling.mcmc.simulate_inference

### Create new simulations from inference of posterior distributions
##  Author: Sebastien Naze
#   QIMR 2023

import argparse
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from datetime import datetime 
#import duckdb
import itertools
from matplotlib import pyplot as plt
import multiprocessing
import numpy as np
import os 
import pandas as pd
import pickle
import sqlite3
#import sqlalchemy as sl
from time import time, sleep

#import OCD_modeling
# import most relevant environment and project variable
from OCD_modeling.utils.utils import proj_dir, today
from OCD_modeling.mcmc.history_analysis import import_results, compute_kdes
from OCD_modeling.mcmc.abc_hpc import unpack_params, get_prior, get_prior_Thal, get_prior_Thal_fc_weak, get_prior_Thal_hc_weak
from OCD_modeling.models.ReducedWongWang import create_sim_df
from OCD_modeling.hpc.parallel_launcher import run_sim

#from OCD_modeling.utils import proj_dir, today
#from OCD_modeling.mcmc import import_results, compute_kdes, unpack_params, get_prior
#from OCD_modeling.models import create_sim_df
#from OCD_modeling.hpc import run_sim

def batched(iterable, n):
    "Batch data into lists of length n. The last batch may be shorter."
    # batched('ABCDEFG', 3) --> ABC DEF G
    it = iter(iterable)
    while True:
        batch = list(itertools.islice(it, n))
        if not batch:
            return
        yield batch


priorfunc = {4:get_prior, 6:get_prior_Thal_hc_weak} #6:get_prior_Thal_fc_weak  #get_prior_Thal

[docs] def create_params(kdes, cols, test_param, args): """ creates n_sims new parameters from posetrior distribution of base cohort """ model_params, sim_params, control_params, bold_params, ids = [],[],[],[],[] # not most elegant but fine for now params = [] # to keep formatted parameters easier to convert to DataFrame _,bounds = priorfunc[args.N]() # create a new dict of params to unpack for i in range(args.n_sims): param = dict() # original param for col in cols: param[col] = kdes[args.base_cohort][col]['kde'].sample().squeeze() # make sure it's within bounds while( (param[col]<bounds[col][0]) | (param[col]>bounds[col][1])): param[col] = kdes[args.base_cohort][col]['kde'].sample().squeeze() # test parameters if col in test_param.split(' '): param[col] = kdes[args.test_cohort][col]['kde'].sample().squeeze() # make sure it's within bounds while( (param[col]<bounds[col][0]) | (param[col]>bounds[col][1])): param[col] = kdes[args.test_cohort][col]['kde'].sample().squeeze() model_pars, sim_pars, ctl_pars, bold_pars = unpack_params(param) # put those params in format digestable by pool.map model_params.append(model_pars) sim_params.append(sim_pars) control_params.append(ctl_pars) bold_params.append(bold_pars) params.append(param) ids.append(i) return model_params, sim_params, control_params, bold_params, params, ids
def use_optim_params(cols, test_param, args): """ Re-use parameter sets from outputs of the optimization """ histories = import_results(args) model_params, sim_params, control_params, bold_params, paired_ids = [],[],[],[],[] # not most elegant but fine for now params = [] # to keep formatted parameters easier to convert to DataFrame base_particles = histories[args.base_cohort].get_population(t=9).particles test_particles = histories[args.test_cohort].get_population(t=9).particles n_particles = len(base_particles) # create a new dict of params to unpack for i in range(args.n_sims): param = dict() j = i % n_particles # original param for col in cols: param[col] = base_particles[j].parameter[col] # test params if col in test_param.split(' '): param[col] = test_particles[j].parameter[col] model_pars, sim_pars, ctl_pars, bold_pars = unpack_params(param) # put those params in format digestable by pool.map model_params.append(model_pars) sim_params.append(sim_pars) control_params.append(ctl_pars) bold_params.append(bold_pars) params.append(param) paired_ids.append(j) return model_params, sim_params, control_params, bold_params, params, paired_ids
[docs] def write_outputs_to_db(params, cols, test_param, outputs, paired_ids, args): """ Write output of simulation inference into SQLite database """ t = time() dbpath = os.path.join(proj_dir, 'postprocessing', args.db_name+'.db') with sqlite3.connect(dbpath, isolation_level='EXCLUSIVE', timeout=args.timeout) as conn: #engine = sl.create_engine("duckdb:///"+dbpath) #sl.MetaData().create_all(engine) #with engine.connect() as conn: # may need this to acquire exclusive lock from beginning #conn.execute("BEGIN IMMEDIATE") # if table not exitsing in db, it will create it at he end of the function cursor = conn.execute("SELECT count(*) FROM sqlite_master WHERE type='table' AND name='SIMTEST'") if cursor.fetchall()[0][0]==0: cnt=0 idx=0 else: # get how many simulations in db cursor = conn.execute("SELECT COUNT(*) FROM SIMTEST WHERE base_cohort='{}'".format(args.base_cohort)) cnt = cursor.fetchall()[0][0] cursor = conn.execute(" SELECT COUNT(*) FROM SIMTEST ") idx = cursor.fetchall()[0][0] # get FC from simulations df_sims = create_sim_df(outputs, sim_type='sim-'+args.base_cohort[:3], offset=cnt) df_sims = df_sims.pivot(index=['subj', 'cohort'], columns='pathway', values='corr').reset_index() df_sims['base_cohort'] = args.base_cohort df_sims['test_cohort'] = args.test_cohort df_sims['test_param'] = test_param df_sims['paired_ids'] = paired_ids # get associated parameters df_params = pd.DataFrame.from_dict(params, dtype=float) df_sims = df_sims.join(df_params) #pars = dict((col,[]) for col in cols) #for param in params: # for par,val in param.items(): # pars[par].append(val) #for col,vals in pars.items(): # df_sims[col] = np.repeat(vals,6) # create_sim_df return a dataframe with 6 rows per simulation # save in DB df_sims.to_sql('SIMTEST', conn, if_exists='append', index=False, method='multi') conn.close() print("took {:.3f}s".format(time()-t))
[docs] def launch_sims_parallel(kdes, cols, test_param, args=None): """ Run batched simulations from posterior inference in parallel: Parameters ---------- kdes: dict Kernel Density Estimates from optimization (structured as kdes[group][param]) cols: list Parameters (columns) which draw samples from KDEs (otherwise default values are used) test_param: list Parameters for which the posterior is permuted for the virtual interventions. args: argparse.Namespace Extra arguments with options. Returns ------- None. Output of the simulations are written into the local SQLite database. """ if args.use_optim_params: model_params, sim_params, control_params, bold_params, params, paired_ids = use_optim_params(cols, test_param, args) else: model_params, sim_params, control_params, bold_params, params, paired_ids = create_params(kdes, cols, test_param, args) tot_batches = int(np.ceil(args.n_sims/args.n_batch)) batch_number = 1 print("Starting a pool of {} workers to run {} simulation(s) by batches of {} with test parameters {}".format(args.n_jobs, args.n_sims, args.n_batch, test_param)) with ProcessPoolExecutor(max_workers=args.n_jobs, mp_context=multiprocessing.get_context('spawn')) as pool: #with ProcessPoolExecutor(max_workers=args.n_jobs, mp_context=multiprocessing.get_context('fork')) as pool: #with multiprocessing.pool.Pool(processes=args.n_jobs) as pool: for model_pars, sim_pars,control_pars, bold_pars, pars, p_ids in zip(batched(model_params, args.n_batch), batched(sim_params, args.n_batch), batched(control_params, args.n_batch), batched(bold_params, args.n_batch), batched(params, args.n_batch), batched(paired_ids, args.n_batch)): t = time() #print("Starting a pool of {} workers to run a batch of {} simulation(s) with test parameters {}".format(args.n_jobs, args.n_batch, test_params)) #with ProcessPoolExecutor(max_workers=args.n_jobs, mp_context=multiprocessing.get_context('spawn')) as pool: sim_objs = pool.map(run_sim, model_pars, sim_pars, control_pars, bold_pars) #sim_objs = joblib.Parallel(n_jobs=args.n_jobs)(joblib.delayed(OCD_modeling.hpc.parallel_launcher.run_sim) # (model_pars[i], sim_pars[i], control_pars[i], bold_pars[i]) for i in range(args.n_batch)) outputs = list(sim_objs) write_outputs_to_db(pars, cols, test_param, outputs, p_ids, args) print(" Batch {}/{} done for test_param {} in {:.2f}s.".format(batch_number, tot_batches, test_param, time()-t)) if args.save_outputs: fname = "sim_objs"+today()+"_batch{:04d}".format(batch_number)+".pkl" with open(os.path.join(proj_dir, 'traces/tmp', fname), 'wb') as f: pickle.dump(outputs,f) print("Batch saved in traces/tmp/"+fname) batch_number += 1
def get_test_param(args): """ either get test params from command line arguments or from file """ if args.test_param_index==None: if args.test_param!=[]: test_param = ' '.join(args.test_param) else: test_param = 'None' else: with open(os.path.join(proj_dir, 'postprocessing', 'params_combinations.pkl'), 'rb') as f: combinations = pickle.load(f) test_param = list(combinations[args.test_param_index]) test_param = ' '.join(test_param) return test_param def parse_arguments(): " Script arguments when ran as main " parser = argparse.ArgumentParser() parser.add_argument('--save_figs', default=False, action='store_true', help='save figures') parser.add_argument('--save_outputs', default=False, action='store_true', help='save outputs') parser.add_argument('--n_jobs', type=int, default=10, action='store', help="number of parallel processes launched") parser.add_argument('--n_sims', type=int, default=50, action='store', help="number of simulations in intervention (e.g. to get distribution that can be campared to clinical observations)") parser.add_argument('--n_batch', type=int, default=10, action='store', help="number of sumluation per batch between saving to DB (to reduce i/o overhead)") parser.add_argument('--gens', nargs='+', default=[], action='store', help="generation of the optimization (list, must be same length as histories)") parser.add_argument('--plot_figs', default=False, action='store_true', help='plot figures') parser.add_argument('--histories', nargs='+', default=['rww4D_OU_HPC_20230510', 'rww4D_OU_HPC_20230605'], action='store', help="optimizations to analyse and compare") parser.add_argument('--history_names', type=list, default=['controls', 'patients'], action='store', help="names given to each otpimization loaded") parser.add_argument('--save_kdes', default=False, action='store_true', help='save KDEs') parser.add_argument('--db_name', type=str, default=None, help="identifier of the sqlite3 database") parser.add_argument('--base_cohort', type=str, default='controls', help="Cohort from which to infer posterior as default") parser.add_argument('--test_cohort', type=str, default='patients', help="Cohort from which to infer posterior of individual params") parser.add_argument('--test_param', nargs='+', default=[], help="posterior parameter to swap between base and test cohort, if empty list then all params are tested") parser.add_argument('--test_param_index', type=int, default=None, action='store', help='if using a external file for test parameters, use line index ') parser.add_argument('--timeout', type=int, default=3600, help="timeout for DB writting in parallel") parser.add_argument('--use_optim_params', default=False, action='store_true', help="if flag is on, use parameters from accepted particles of optimization, other draw new params from posterior") parser.add_argument('--N', type=int, default=4, action='store', help="Number of regions in simulations (default=4, could also be 6.") args = parser.parse_args() args.gens = np.array(args.gens, dtype=int) return args if __name__=='__main__': args = parse_arguments() # load histories and KDEs histories = import_results(args) kdes,cols = compute_kdes(histories, args=args) test_param = get_test_param(args) # delay start randomly by up to 3 min to avoid large batches of simulations writting # concurrently to DB sleep(np.random.randint(0,180)) launch_sims_parallel(kdes, cols, test_param, args=args)