Source code for dysmalpy.fitting.mcmc

# coding=utf8
# Copyright (c) MPE/IR-Submm Group. See LICENSE.rst for license information. 
#
# Classes and functions for fitting DYSMALPY kinematic models
#   to the observed data using MCMC
#
# Some handling of MCMC / posterior distribution analysis inspired by speclens,
#    with thanks to Matt George:
#    https://github.com/mrgeorge/speclens/blob/master/speclens/fit.py

from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

## Standard library
import logging
try:
    from multiprocess import cpu_count, Pool
except:
    # Old python versions:
    from multiprocessing import cpu_count, Pool

# DYSMALPY code
from dysmalpy.data_io import load_pickle, dump_pickle
from dysmalpy import plotting
from dysmalpy import galaxy
from dysmalpy import utils as dpy_utils
from dysmalpy.fitting import base
from dysmalpy.fitting import utils as fit_utils


# Third party imports
import os
import numpy as np
from collections import OrderedDict
import astropy.units as u
import copy

import time, datetime


__all__ = ['MCMCFitter', 'MCMCResults']


# LOGGER SETTINGS
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('DysmalPy')
logger.setLevel(logging.INFO)

try:
    import emcee
    _emcee_loaded = True
    _emcee_version = int(emcee.__version__[0])
except:
    _emcee_loaded = False
    logger.warn("emcee installation not found!")



[docs] class MCMCFitter(base.Fitter): """ Class to hold the MCMC fitter attributes + methods """ def __init__(self, **kwargs): if not _emcee_loaded: raise ValueError("emcee was not loaded!") self._emcee_version = _emcee_version self._set_defaults() super(MCMCFitter, self).__init__(fit_method='MCMC', **kwargs) def _set_defaults(self): # MCMC specific defaults self.nWalkers = 10 self.nCPUs = 1 self.cpuFrac = None self.scale_param_a = 3. self.nBurn = 2. self.nSteps = 10. self.minAF = 0.2 self.maxAF = 0.5 self.nEff = 10 self.oversampled_chisq = True # self.red_chisq = False # Option not used self.save_burn = False self.save_intermediate_sampler_results_chain = True self.nStep_intermediate_save = 5 self.continue_steps = False self.nPostBins = 50 self.linked_posterior_names = None self.input_sampler_results = None # ACOR SETTINGS # Force it to run for at least N steps, otherwise acor times might be completely wrong. self.acor_force_min = 49
[docs] def fit(self, gal, output_options): """ Fit observed kinematics using MCMC and a DYSMALPY model set. Parameters ---------- gal : `Galaxy` instance observed galaxy, including kinematics. also contains instrument the galaxy was observed with (gal.instrument) and the DYSMALPY model set, with the parameters to be fit (gal.model) output_options : `config.OutputOptions` instance instance holding ouptut options for MCMC fitting. Returns ------- mcmcResults : `MCMCResults` instance MCMCResults class instance containing the bestfit parameters, sampler_results information, etc. """ # -------------------------------- # Check option validity: if self.blob_name is not None: valid_blobnames = ['fdm', 'mvirial', 'alpha', 'rb'] if isinstance(self.blob_name, str): # Single blob blob_arr = [self.blob_name] else: # Array of blobs blob_arr = self.blob_name[:] for blobn in blob_arr: if blobn.lower().strip() not in valid_blobnames: raise ValueError("blob_name={} not recognized as option!".format(blobn)) # # Temporary: testing: # if self.red_chisq: # raise ValueError("red_chisq=True is currently *DISABLED* to test lnlike impact vs lnprior") # Check the FOV is large enough to cover the data output: dpy_utils._check_data_inst_FOV_compatibility(gal) # Pre-calculate instrument kernels: gal = dpy_utils._set_instrument_kernels(gal) # -------------------------------- # Basic setup: # For compatibility with Python 2.7: mod_in = copy.deepcopy(gal.model) gal.model = mod_in #if nCPUs is None: if self.cpuFrac is not None: self.nCPUs = int(np.floor(cpu_count()*self.cpuFrac)) # +++++++++++++++++++++++ # Setup for oversampled_chisq: if self.oversampled_chisq: gal = fit_utils.setup_oversampled_chisq(gal) # +++++++++++++++++++++++ # Set output options: filenames / which to save, etc output_options.set_output_options(gal, self) # MUST INCLUDE MCMC-SPECIFICS NOW! fit_utils._check_existing_files_overwrite(output_options, fit_type='mcmc', fitter=self) # -------------------------------- # Setup file redirect logging: if output_options.f_log is not None: loggerfile = logging.FileHandler(output_options.f_log) loggerfile.setLevel(logging.INFO) logger.addHandler(loggerfile) # -------------------------------- # Split by emcee version: if self._emcee_version >= 3: mcmcResults = self._fit_emcee_3(gal, output_options) else: mcmcResults = self._fit_emcee_221(gal, output_options) # Clean up logger: if output_options.f_log is not None: logger.removeHandler(loggerfile) loggerfile.close() return mcmcResults
def _fit_emcee_221(self, gal, output_options): # -------------------------------- # Initialize emcee sampler_results kwargs_dict = {'fitter': self} nBurn_orig = output_options['nBurn'] nDim = gal.model.nparams_free if (not self.continue_steps) & ((not self.save_intermediate_sampler_results_chain) \ | (not os.path.isfile(output_options.f_sampler_results_tmp))): sampler_results = emcee.EnsembleSampler(self.nWalkers, nDim, base.log_prob, args=[gal], kwargs=kwargs_dict, a = self.scale_param_a, threads = self.nCPUs) # -------------------------------- # Initialize walker starting positions initial_pos = initialize_walkers(gal.model, nWalkers=self.nWalkers) elif self.continue_steps: self.nBurn = 0 if self.input_sampler_results is None: try: self.input_sampler_results = load_pickle(output_options.f_sampler_results) except: message = "Couldn't find existing sampler_results in {}.".format(output_options.f_sampler_results) message += '\n' message += "Must set input_sampler_results if you will restart the sampler_results." raise ValueError(message) sampler_results = reinitialize_emcee_sampler_results(self.input_sampler_results, gal=gal, kwargs_dict=kwargs_dict, scale_param_a=self.scale_param_a) initial_pos = self.input_sampler_results['chain'][:,-1,:] if self.blob_name is not None: blob = self.input_sampler_results['blobs'] # Close things self.input_sampler_results = None elif self.save_intermediate_sampler_results_chain & (os.path.isfile(output_options.f_sampler_results_tmp)): self.input_sampler_results = load_pickle(output_options.f_sampler_results_tmp) sampler_results = reinitialize_emcee_sampler_results(self.input_sampler_results, gal=gal, fitter=self) self.nBurn = nBurn_orig - (self.input_sampler_results['burn_step_cur'] + 1) initial_pos = self.input_sampler_results['chain'][:,-1,:] if self.blob_name is not None: blob = self.input_sampler_results['blobs'] # If it saved after burn finished, but hasn't saved any of the normal steps: reset sampler_results if ((self.nBurn == 0) & (self.input_sampler_results['step_cur'] < 0)): blob = None sampler_results.reset() if self.blob_name is not None: sampler_results.clear_blobs() # Close things input_sampler_results = None # -------------------------------- # Output some fitting info to logger: logger.info("*************************************") logger.info(" Fitting: {} with MCMC".format(gal.name)) for obs_name in gal.observations: obs = gal.observations[obs_name] logger.info(" obs: {}".format(obs.name)) if obs.data.filename_velocity is not None: logger.info(" velocity file: {}".format(obs.data.filename_velocity)) if obs.data.filename_dispersion is not None: logger.info(" dispers. file: {}".format(obs.data.filename_dispersion)) logger.info(' nSubpixels: {}'.format(obs.mod_options.oversample)) logger.info('\n'+'nCPUs: {}'.format(self.nCPUs)) logger.info('nWalkers: {}'.format(self.nWalkers)) # logger.info('lnlike: red_chisq={}'.format(self.red_chisq)) logger.info('lnlike: oversampled_chisq={}'.format(self.oversampled_chisq)) logger.info('\n'+'blobs: {}'.format(self.blob_name)) if ('halo' in gal.model.components.keys()): logger.info('\n'+'mvirial_tied: {}'.format(gal.model.components['halo'].mvirial.tied)) if ('disk+bulge' in gal.model.components.keys()): if 'mhalo_relation' in gal.model.components['disk+bulge'].__dict__.keys(): logger.info('mhalo_relation: {}'.format(gal.model.components['disk+bulge'].mhalo_relation)) if 'truncate_lmstar_halo' in gal.model.components['disk+bulge'].__dict__.keys(): logger.info('truncate_lmstar_halo: {}'.format(gal.model.components['disk+bulge'].truncate_lmstar_halo)) ################################################################ # -------------------------------- # Run burn-in if self.nBurn > 0: logger.info('\nBurn-in:'+'\n' 'Start: {}\n'.format(datetime.datetime.now())) start = time.time() #### pos = initial_pos prob = None state = None blob = None for k in range(nBurn_orig): # -------------------------------- # If recovering intermediate save, only start past existing chain length: if self.save_intermediate_sampler_results_chain: if k < sampler_results.chain.shape[1]: continue logger.info(" k={}, time.time={}, a_frac={}".format( k, datetime.datetime.now(), np.mean(sampler_results.acceptance_fraction) ) ) ### pos_cur = pos.copy() # copy just in case things are set strangely # Run one sample step: if self.blob_name is not None: pos, prob, state, blob = sampler_results.run_mcmc(pos_cur, 1, lnprob0=prob, rstate0=state, blobs0 = blob) else: pos, prob, state = sampler_results.run_mcmc(pos_cur, 1, lnprob0=prob, rstate0=state) # -------------------------------- # Save intermediate steps if set: if self.save_intermediate_sampler_results_chain: if ((k+1) % self.nStep_intermediate_save == 0): sampler_results_dict_tmp = make_emcee_sampler_results_dict(sampler_results, nBurn=0, emcee_vers=2) sampler_results_dict_tmp['burn_step_cur'] = k sampler_results_dict_tmp['step_cur'] = -99 if output_options.f_sampler_results_tmp is not None: # Save stuff to file, for future use: dump_pickle(sampler_results_dict_tmp, filename=output_options.f_sampler_results_tmp, overwrite=True) # -------------------------------- ##### end = time.time() elapsed = end-start try: acor_time = sampler_results.get_autocorr_time(low=5, c=10) except: acor_time = "Undefined, chain did not converge" ####################################################################################### # Return Burn-in info # **** endtime = str(datetime.datetime.now()) nthingsmsg = 'nCPU, nParam, nWalker, nBurn = {}, {}, {}, {}'.format(self.nCPUs, nDim, self.nWalkers, self.nBurn) scaleparammsg = 'Scale param a= {}'.format(self.scale_param_a) timemsg = 'Time= {:3.2f} (sec), {:3.0f}:{:3.2f} (m:s)'.format( elapsed, np.floor(elapsed/60.), (elapsed/60.-np.floor(elapsed/60.))*60. ) macfracmsg = "Mean acceptance fraction: {:0.3f}".format(np.mean(sampler_results.acceptance_fraction)) acortimemsg = "Autocorr est: "+str(acor_time) logger.info('\nEnd: '+endtime+'\n' '\n******************\n' ''+nthingsmsg+'\n' ''+scaleparammsg+'\n' ''+timemsg+'\n' ''+macfracmsg+'\n' "Ideal acceptance frac: 0.2 - 0.5\n" ''+acortimemsg+'\n' '******************') nBurn_nEff = 2 try: if self.nBurn < np.max(acor_time) * nBurn_nEff: nburntimemsg = 'nBurn is less than {}*acorr time'.format(nBurn_nEff) logger.info('\n#################\n' ''+nburntimemsg+'\n' '#################\n') # Give warning if the burn-in is less than say 2-3 times the autocorr time except: logger.info('\n#################\n' "acorr time undefined -> can't check convergence\n" '#################\n') # -------------------------------- # Save burn-in sampler_results, if desired if (self.save_burn) & (output_options.f_burn_sampler_results is not None): sampler_results_burn = make_emcee_sampler_results_dict(sampler_results, nBurn=0, emcee_vers=2) # Save stuff to file, for future use: dump_pickle(sampler_results_burn, filename=output_options.f_burn_sampler_results, overwrite=output_options.overwrite) # -------------------------------- # Plot burn-in trace, if output file set if (output_options.do_plotting) & (output_options.f_plot_trace_burnin is not None): sampler_results_burn = make_emcee_sampler_results_dict(sampler_results, nBurn=0, emcee_vers=2) mcmcResultsburn = MCMCResults(model=gal.model, sampler_results=sampler_results_burn) plotting.plot_trace(mcmcResultsburn, fileout=output_options.f_plot_trace_burnin, overwrite=output_options.overwrite) # Reset sampler_results after burn-in: sampler_results.reset() if self.blob_name is not None: sampler_results.clear_blobs() else: # -------------------------------- # No burn-in: set initial position: if nBurn_orig > 0: logger.info('\nUsing previously completed burn-in'+'\n') pos = np.array(initial_pos) prob = None state = None if (not self.continue_steps) | (not self.save_intermediate_sampler_results_chain): blob = None ####################################################################################### # **** # -------------------------------- # Run sampler_results: Get start time logger.info('\nEnsemble sampling:\n' 'Start: {}\n'.format(datetime.datetime.now())) start = time.time() if sampler_results.chain.shape[1] > 0: logger.info('\n Resuming with existing sampler_results chain at iteration ' + str(sampler_results.iteration) + '\n') pos = sampler_results['chain'][:,-1,:] # -------------------------------- # Run sampler_results: output info at each step for ii in range(self.nSteps): # -------------------------------- # If continuing chain, only start past existing chain length: if self.continue_steps | self.save_intermediate_sampler_results_chain: if ii < sampler_results.chain.shape[1]: continue pos_cur = pos.copy() # copy just in case things are set strangely # -------------------------------- # Only do one step at a time: if self.blob_name is not None: pos, prob, state, blob = sampler_results.run_mcmc(pos_cur, 1, lnprob0=prob, rstate0=state, blobs0 = blob) else: pos, prob, state = sampler_results.run_mcmc(pos_cur, 1, lnprob0=prob, rstate0=state) # -------------------------------- # -------------------------------- # Give output info about this step: nowtime = str(datetime.datetime.now()) stepinfomsg = "ii={}, a_frac={}".format( ii, np.mean(sampler_results.acceptance_fraction) ) timemsg = " time.time()={}".format(nowtime) logger.info( stepinfomsg+timemsg ) try: acor_time = sampler_results.get_autocorr_time(low=5, c=10) logger.info( "{}: acor_time ={}".format(ii, np.array(acor_time) ) ) except: acor_time = "Undefined, chain did not converge" logger.info(" {}: Chain too short for acor to run".format(ii) ) # -------------------------------- # Case: test for convergence and truncate early: # Criteria checked: whether acceptance fraction within (minAF, maxAF), # and whether total number of steps > nEff * average autocorrelation time: # to make sure the paramter space is well explored. if ((self.minAF is not None) & (self.maxAF is not None) & \ (self.nEff is not None) & (acor_time is not None)): if ((self.minAF < np.mean(sampler_results.acceptance_fraction) < self.maxAF) & \ ( ii > np.max(acor_time) * self.nEff )): if ii == self.acor_force_min: logger.info(" Enforced min step limit: {}.".format(ii+1)) if ii >= self.acor_force_min: logger.info(" Finishing calculations early at step {}.".format(ii+1)) break # -------------------------------- # Save intermediate steps if set: if self.save_intermediate_sampler_results_chain: if ((ii+1) % self.nStep_intermediate_save == 0): sampler_results_dict_tmp = make_emcee_sampler_results_dict(sampler_results, nBurn=0, emcee_vers=2) sampler_results_dict_tmp['burn_step_cur'] = nBurn_orig - 1 sampler_results_dict_tmp['step_cur'] = ii if output_options.f_sampler_results_tmp is not None: # Save stuff to file, for future use: dump_pickle(sampler_results_dict_tmp, filename=output_options.f_sampler_results_tmp, overwrite=True) # -------------------------------- # -------------------------------- # Check if it failed to converge before the max number of steps, if doing convergence testing finishedSteps= ii+1 if (finishedSteps == self.nSteps) & ((self.minAF is not None) & \ (self.maxAF is not None) & (self.nEff is not None)): logger.info(" Caution: no convergence within nSteps={}.".format(self.nSteps)) # -------------------------------- # Finishing info for fitting: end = time.time() elapsed = end-start logger.info("Finished {} steps".format(finishedSteps)+"\n") try: acor_time = sampler_results.get_autocorr_time(low=5, c=10) except: acor_time = "Undefined, chain did not converge" ####################################################################################### # *********** # Consider overall acceptance fraction endtime = str(datetime.datetime.now()) nthingsmsg = 'nCPU, nParam, nWalker, nSteps = {}, {}, {}, {}'.format(self.nCPUs, nDim, self.nWalkers, self.nSteps) scaleparammsg = 'Scale param a= {}'.format(self.scale_param_a) timemsg = 'Time= {:3.2f} (sec), {:3.0f}:{:3.2f} (m:s)'.format(elapsed, np.floor(elapsed/60.), (elapsed/60.-np.floor(elapsed/60.))*60. ) macfracmsg = "Mean acceptance fraction: {:0.3f}".format(np.mean(sampler_results.acceptance_fraction)) acortimemsg = "Autocorr est: "+str(acor_time) logger.info('\nEnd: '+endtime+'\n' '\n******************\n' ''+nthingsmsg+'\n' ''+scaleparammsg+'\n' ''+timemsg+'\n' ''+macfracmsg+'\n' "Ideal acceptance frac: 0.2 - 0.5\n" ''+acortimemsg+'\n' '******************') # -------------------------------- # Save sampler_results, if output file set: # Burn-in is already cut by resetting the sampler_results at the beginning. # Get pickleable format: # _fit_io.make_emcee_sampler_results_dict sampler_results_dict = make_emcee_sampler_results_dict(sampler_results, nBurn=0, emcee_vers=2) if output_options.f_sampler_results is not None: # Save stuff to file, for future use: dump_pickle(sampler_results_dict, filename=output_options.f_sampler_results, overwrite=output_options.overwrite) # -------------------------------- # Cleanup intermediate saves: if self.save_intermediate_sampler_results_chain & (output_options.f_sampler_results_tmp is not None): if os.path.isfile(output_options.f_sampler_results_tmp): os.remove(output_options.f_sampler_results_tmp) # -------------------------------- if self.nCPUs > 1: sampler_results.pool.close() ########################################## ########################################## ########################################## # -------------------------------- # Bundle the results up into a results class: mcmcResults = MCMCResults(model=gal.model, sampler_results=sampler_results_dict, linked_posterior_names=self.linked_posterior_names, blob_name=self.blob_name, nPostBins=self.nPostBins) if self.oversampled_chisq: mcmcResults.oversample_factor_chisq = OrderedDict() for obs_name in gal.observations: obs = gal.observations[obs_name] mcmcResults.oversample_factor_chisq[obs_name] = obs.data.oversample_factor_chisq # Do all analysis, plotting, saving: mcmcResults.analyze_plot_save_results(gal, output_options=output_options) return mcmcResults def _fit_emcee_3(self, gal, output_options): # Check length of sampler_results if not overwriting: if (not output_options.overwrite): if os.path.isfile(output_options.f_sampler_results): backend = emcee.backends.HDFBackend(output_options.f_sampler_results, name='mcmc') try: if backend.get_chain().shape[0] >= self.nSteps: if output_options.f_results is not None: if os.path.isfile(output_options.f_results): msg = "overwrite={}, and 'f_sampler_results' already contains {} steps,".format(output_options.overwrite, backend.get_chain().shape[0]) msg += " so the fit will not be saved.\n Specify new outfile or delete old files." logger.warning(msg) return None else: pass except: pass # -------------------------------- # Initialize emcee sampler_results nBurn_orig = self.nBurn nDim = gal.model.nparams_free kwargs_dict = {'fitter': self} # -------------------------------- # Start pool, moves, backend: if (self.nCPUs > 1): pool = Pool(self.nCPUs) else: pool = None moves = emcee.moves.StretchMove(a=self.scale_param_a) backend_burn = emcee.backends.HDFBackend(output_options.f_sampler_results, name="burnin_mcmc") if output_options.overwrite: backend_burn.reset(self.nWalkers, nDim) sampler_results_burn = emcee.EnsembleSampler(self.nWalkers, nDim, base.log_prob, backend=backend_burn, pool=pool, moves=moves, args=[gal], kwargs=kwargs_dict) nBurnCur = sampler_results_burn.iteration self.nBurn = nBurn_orig - nBurnCur # -------------------------------- # Initialize walker starting positions if sampler_results_burn.iteration == 0: initial_pos = initialize_walkers(gal.model, nWalkers=self.nWalkers) else: initial_pos = sampler_results_burn.get_last_sample() # -------------------------------- # Output some fitting info to logger: logger.info("*************************************") logger.info(" Fitting: {} with MCMC".format(gal.name)) for obs_name in gal.observations: obs = gal.observations[obs_name] logger.info(" obs: {}".format(obs.name)) if obs.data.filename_velocity is not None: logger.info(" velocity file: {}".format(obs.data.filename_velocity)) if obs.data.filename_dispersion is not None: logger.info(" dispers. file: {}".format(obs.data.filename_dispersion)) logger.info(' nSubpixels: {}'.format(obs.mod_options.oversample)) logger.info('\n'+'nCPUs: {}'.format(self.nCPUs)) logger.info('nWalkers: {}'.format(self.nWalkers)) # logger.info('lnlike: red_chisq={}'.format(self.red_chisq)) logger.info('lnlike: oversampled_chisq={}'.format(self.oversampled_chisq)) logger.info('\n'+'blobs: {}'.format(self.blob_name)) if ('halo' in gal.model.components.keys()): logger.info('\n'+'mvirial_tied: {}'.format(gal.model.components['halo'].mvirial.tied)) if ('disk+bulge' in gal.model.components.keys()): if 'mhalo_relation' in gal.model.components['disk+bulge'].__dict__.keys(): logger.info('mhalo_relation: {}'.format(gal.model.components['disk+bulge'].mhalo_relation)) if 'truncate_lmstar_halo' in gal.model.components['disk+bulge'].__dict__.keys(): logger.info('truncate_lmstar_halo: {}'.format(gal.model.components['disk+bulge'].truncate_lmstar_halo)) ################################################################ # -------------------------------- # Run burn-in if self.nBurn > 0: logger.info('\nBurn-in:'+'\n' 'Start: {}\n'.format(datetime.datetime.now())) start = time.time() #### pos = initial_pos for k in range(nBurn_orig): # -------------------------------- # If recovering intermediate save, only start past existing chain length: if k < sampler_results_burn.iteration: continue logger.info(" k={}, time.time={}, a_frac={}".format( k, datetime.datetime.now(), np.mean(sampler_results_burn.acceptance_fraction) ) ) ### # Run one sample step: pos = sampler_results_burn.run_mcmc(pos, 1) ##### end = time.time() elapsed = end-start acor_time = sampler_results_burn.get_autocorr_time(tol=10, quiet=True) ####################################################################################### # Return Burn-in info # **** endtime = str(datetime.datetime.now()) nthingsmsg = 'nCPU, nParam, nWalker, nBurn = {}, {}, {}, {}'.format(self.nCPUs, nDim, self.nWalkers, self.nBurn) scaleparammsg = 'Scale param a= {}'.format(self.scale_param_a) timemsg = 'Time= {:3.2f} (sec), {:3.0f}:{:3.2f} (m:s)'.format( elapsed, np.floor(elapsed/60.), (elapsed/60.-np.floor(elapsed/60.))*60. ) macfracmsg = "Mean acceptance fraction: {:0.3f}".format(np.mean(sampler_results_burn.acceptance_fraction)) acortimemsg = "Autocorr est: "+str(acor_time) logger.info('\nEnd: '+endtime+'\n' '\n******************\n' ''+nthingsmsg+'\n' ''+scaleparammsg+'\n' ''+timemsg+'\n' ''+macfracmsg+'\n' "Ideal acceptance frac: 0.2 - 0.5\n" ''+acortimemsg+'\n' '******************') nBurn_nEff = 2 try: if self.nBurn < np.max(acor_time) * nBurn_nEff: nburntimemsg = 'nBurn is less than {}*acorr time'.format(nBurn_nEff) logger.info('\n#################\n' ''+nburntimemsg+'\n' '#################\n') # Give warning if the burn-in is less than say 2-3 times the autocorr time except: logger.info('\n#################\n' "acorr time undefined -> can't check convergence\n" '#################\n') # -------------------------------- # Plot burn-in trace, if output file set if (output_options.do_plotting) & (output_options.f_plot_trace_burnin is not None): sampler_results_burn_dict = make_emcee_sampler_results_dict(sampler_results_burn, nBurn=0) mcmcResults_burn = MCMCResults(model=gal.model, sampler_results=sampler_results_burn_dict) plotting.plot_trace(mcmcResults_burn, fileout=output_options.f_plot_trace_burnin, overwrite=output_options.overwrite) else: # -------------------------------- # No burn-in: set initial position: if nBurn_orig > 0: logger.info('\nUsing previously completed burn-in'+'\n') pos = initial_pos ####################################################################################### # Setup sampler_results: # -------------------------------- # Start backend: backend = emcee.backends.HDFBackend(output_options.f_sampler_results, name="mcmc") if output_options.overwrite: backend.reset(self.nWalkers, nDim) sampler_results = emcee.EnsembleSampler(self.nWalkers, nDim, base.log_prob, backend=backend, pool=pool, moves=moves, args=[gal], kwargs=kwargs_dict) ####################################################################################### # ************************************************************************************* # -------------------------------- # Run sampler_results: Get start time logger.info('\nEnsemble sampling:\n' 'Start: {}\n'.format(datetime.datetime.now())) start = time.time() if sampler_results.iteration > 0: logger.info('\n Resuming with existing sampler_results chain at iteration ' + str(sampler_results.iteration) + '\n') pos = sampler_results.get_last_sample() # -------------------------------- # Run sampler_results: output info at each step for ii in range(self.nSteps): # -------------------------------- # If continuing chain, only start past existing chain length: if ii < sampler_results.iteration: continue # -------------------------------- # Only do one step at a time: pos = sampler_results.run_mcmc(pos, 1) # -------------------------------- # -------------------------------- # Give output info about this step: nowtime = str(datetime.datetime.now()) stepinfomsg = "ii={}, a_frac={}".format( ii, np.mean(sampler_results.acceptance_fraction) ) timemsg = " time.time()={}".format(nowtime) logger.info( stepinfomsg+timemsg ) acor_time = sampler_results.get_autocorr_time(tol=10, quiet=True) #acor_time = sampler_results.get_autocorr_time(quiet=True) logger.info( "{}: acor_time ={}".format(ii, np.array(acor_time) ) ) # -------------------------------- # Case: test for convergence and truncate early: # Criteria checked: whether acceptance fraction within (minAF, maxAF), # and whether total number of steps > nEff * average autocorrelation time: # to make sure the paramter space is well explored. if ((self.minAF is not None) & (self.maxAF is not None) & \ (self.nEff is not None) & (acor_time is not None)): if ((self.minAF < np.mean(sampler_results.acceptance_fraction) < self.maxAF) & \ ( ii > np.max(acor_time) * self.nEff )): if ii == self.acor_force_min: logger.info(" Enforced min step limit: {}.".format(ii+1)) if ii >= self.acor_force_min: logger.info(" Finishing calculations early at step {}.".format(ii+1)) break # -------------------------------- # Check if it failed to converge before the max number of steps, if doing convergence testing finishedSteps= ii+1 if (finishedSteps == self.nSteps) & ((self.minAF is not None) & \ (self.maxAF is not None) & (self.nEff is not None)): logger.info(" Caution: no convergence within nSteps={}.".format(self.nSteps)) # -------------------------------- # Finishing info for fitting: end = time.time() elapsed = end-start logger.info("Finished {} steps".format(finishedSteps)+"\n") acor_time = sampler_results.get_autocorr_time(tol=10, quiet=True) ####################################################################################### # *********** # Consider overall acceptance fraction endtime = str(datetime.datetime.now()) nthingsmsg = 'nCPU, nParam, nWalker, nSteps = {}, {}, {}, {}'.format(self.nCPUs, nDim, self.nWalkers, self.nSteps) scaleparammsg = 'Scale param a= {}'.format(self.scale_param_a) timemsg = 'Time= {:3.2f} (sec), {:3.0f}:{:3.2f} (m:s)'.format(elapsed, np.floor(elapsed/60.), (elapsed/60.-np.floor(elapsed/60.))*60. ) macfracmsg = "Mean acceptance fraction: {:0.3f}".format(np.mean(sampler_results.acceptance_fraction)) acortimemsg = "Autocorr est: "+str(acor_time) logger.info('\nEnd: '+endtime+'\n' '\n******************\n' ''+nthingsmsg+'\n' ''+scaleparammsg+'\n' ''+timemsg+'\n' ''+macfracmsg+'\n' "Ideal acceptance frac: 0.2 - 0.5\n" ''+acortimemsg+'\n' '******************') if self.nCPUs > 1: pool.close() sampler_results.pool.close() sampler_results_burn.pool.close() ########################################## ########################################## ########################################## # -------------------------------- # Setup sampler_results dict: sampler_results_dict = make_emcee_sampler_results_dict(sampler_results, nBurn=0) # -------------------------------- # Bundle the results up into a results class: mcmcResults = MCMCResults(model=gal.model, sampler_results=sampler_results_dict, linked_posterior_names=self.linked_posterior_names, blob_name=self.blob_name, nPostBins=self.nPostBins) if self.oversampled_chisq: mcmcResults.oversample_factor_chisq = OrderedDict() for obs_name in gal.observations: obs = gal.observations[obs_name] mcmcResults.oversample_factor_chisq[obs_name] = obs.data.oversample_factor_chisq # Do all analysis, plotting, saving: mcmcResults.analyze_plot_save_results(gal, output_options=output_options) return mcmcResults
[docs] class MCMCResults(base.BayesianFitResults, base.FitResults): """ Class to hold results of MCMC fitting to DYSMALPY models. Notes: ------ `emcee` sampler_results object is ported to a dictionary in `mcmcResults.sampler_results` The name of the free parameters in the chain are accessed through `mcmcResults.chain_param_names`, or more generally (separate model + parameter names) through `mcmcResults.free_param_names` Optional Attribute: ---------------------- `linked_posterior_names` Indicates if best-fit parameters should be measured in multi-dimensional histogram space. It takes a list of linked parameter sets, where each set consists of len-2 tuples/lists of the component + parameter names. Structure Explanation: ---------------------- #. To analyze component + param 1 and 2 together, and then 3 and 4 together: `linked_posterior_names = [joint_param_bundle1, joint_param_bundle2]` with `joint_param_bundle1 = [[cmp1, par1], [cmp2, par2]]` and `joint_param_bundle2 = [[cmp3, par3], [cmp4, par4]]`, for a full array of: `linked_posterior_names = [[[cmp1, par1], [cmp2, par2]],[[cmp3, par3], [cmp4, par4]]]`. #. To analyze component + param 1 and 2 together: `linked_posterior_names = [joint_param_bundle1]` with `joint_param_bundle1 = [[cmp1, par1], [cmp2, par2]]`, for a full array of `linked_posterior_names = [[[cmp1, par1], [cmp2, par2]]]`. Example: Look at halo: mvirial and disk+bulge: total_mass together `linked_posterior_names = [[['halo', 'mvirial'], ['disk+bulge', 'total_mass']]]` """ def __init__(self, model=None, sampler_results=None, linked_posterior_names=None, blob_name=None, nPostBins=50): super(MCMCResults, self).__init__(model=model, blob_name=blob_name, fit_method='MCMC', linked_posterior_names=linked_posterior_names, sampler_results=sampler_results, nPostBins=nPostBins) def __setstate__(self, state): # Compatibility hacks super(MCMCResults, self).__setstate__(state) # # --------- # if ('sampler' not in state.keys()) & ('sampler_results' in state.keys()): # self._setup_samples_blobs() def _setup_samples_blobs(self): # Note: # self.sampler.samples replaces self.sampler_results['flatchain'], and # self.sampler.blobs replaces self.sampler_results['flatblobs'] if 'blobs' in self.sampler_results.keys(): blobset = True else: blobset = False if ('flatblobs' not in self.sampler_results.keys()) & (blobset): if len(self.sampler_results['blobs'].shape) == 2: # Only 1 blob: nSteps, nWalkers: flatblobs = self.sampler_results['blobs'].reshape(-1) elif len(self.sampler_results['blobs'].shape) == 3: # Multiblobs; nSteps, nWalkers, nBlobs flatblobs = self.sampler_results['blobs'].reshape(-1,self.sampler_results['blobs'].shape[2]) else: raise ValueError("sampler_results blob length not recognized") elif (not blobset): flatblobs = None else: flatblobs = self.sampler_results['flatblobs'] self.sampler = base.BayesianSampler(samples=self.sampler_results['flatchain'], blobs=flatblobs)
[docs] def reload_sampler_results(self, filename=None): """Reload the MCMC sampler_results saved earlier""" if filename is None: #filename = self.f_sampler_results raise ValueError hdf5_aliases = ['h5', 'hdf5'] pickle_aliases = ['pickle', 'pkl', 'pcl'] if (filename.split('.')[-1].lower() in hdf5_aliases): self.sampler_results = _reload_sampler_results_hdf5(filename=filename) elif (filename.split('.')[-1].lower() in pickle_aliases): self.sampler_results = _reload_sampler_results_pickle(filename=filename)
def initialize_walkers(model, nWalkers=None): """ Initialize a set of MCMC walkers by randomly drawing from the model set parameter priors """ stack_rand = [] pfree_dict = model.get_free_parameter_keys() comps_names = pfree_dict.keys() for compn in comps_names: comp = model.components.__getitem__(compn) params_names = pfree_dict[compn].keys() for paramn in params_names: if (pfree_dict[compn][paramn] >= 0) : # Free parameter: randomly sample from prior nWalker times: param_rand = comp.__getattribute__(paramn).prior.sample_prior(comp.__getattribute__(paramn), modelset=model, N=nWalkers) stack_rand.append(param_rand) pos = np.array(list(zip(*stack_rand))) # should have shape: (nWalkers, nDim) return pos def make_emcee_sampler_results_dict(sampler_results, nBurn=0, emcee_vers=3): """ Save chain + key results from emcee sampler_results instance to a dict, as the emcee sampler_resultss aren't pickleable. """ if emcee_vers == 3: return _make_emcee_sampler_results_dict_v3(sampler_results, nBurn=nBurn) elif emcee_vers == 2: return _make_emcee_sampler_results_dict_v2(sampler_results, nBurn=nBurn) else: raise ValueError("Emcee version {} not supported!".format(emcee_vers)) def _make_emcee_sampler_results_dict_v2(sampler_results, nBurn=0): """ Syntax for emcee v2.2.1 """ # Cut first nBurn steps, to avoid the edge cases that are rarely explored. chain = sampler_results.chain[:, nBurn:, :] flatchain = chain.reshape((-1, sampler_results.dim)) # Walkers, iterations probs = sampler_results.lnprobability[:, nBurn:] flatprobs = probs.reshape((-1)) try: acor_time = sampler_results.get_autocorr_time(low=5, c=10) except: acor_time = None # Make a dictionary: sampler_results_dict = { 'chain': chain, 'flatchain': flatchain, 'lnprobability': probs, 'flatlnprobability': flatprobs, 'nIter': sampler_results.iterations, 'nParam': sampler_results.dim, 'nCPU': sampler_results.threads, 'nWalkers': len(sampler_results.chain), 'acceptance_fraction': sampler_results.acceptance_fraction, 'acor_time': acor_time } if sampler_results.blobs is not None: if len(sampler_results.blobs) > 0: sampler_results_dict['blobs'] = np.array(sampler_results.blobs[nBurn:]) if len(np.shape(sampler_results.blobs)) == 2: # Only 1 blob: nSteps, nWalkers: sampler_results_dict['flatblobs'] = np.array(sampler_results_dict['blobs']).reshape(-1) elif len(np.shape(sampler_results.blobs)) == 3: # Multiblobs; nSteps, nWalkers, nBlobs sampler_results_dict['flatblobs'] = np.array(sampler_results_dict['blobs']).reshape(-1,np.shape(sampler_results.blobs)[2]) else: raise ValueError("sampler_results blob length not recognized") return sampler_results_dict def _make_emcee_sampler_results_dict_v3(sampler_results, nBurn=0): """ Syntax for emcee v3 """ # Cut first nBurn steps, to avoid the edge cases that are rarely explored. samples = np.swapaxes( sampler_results.get_chain(),0,1 )[:, nBurn:, :].reshape((-1, sampler_results.ndim)) # Walkers, iterations probs = sampler_results.get_log_prob()[:, nBurn:].reshape((-1)) acor_time = sampler_results.get_autocorr_time(tol=10, quiet=True) try: nCPUs = sampler_results.pool._processes # sampler_results.threads except: nCPUs = 1 # Make a dictionary: sampler_results_dict = { 'chain': np.swapaxes(sampler_results.get_chain(),0,1)[:, nBurn:, :], 'lnprobability': sampler_results.get_log_prob()[:, nBurn:], 'flatchain': samples, 'flatlnprobability': probs, 'nIter': sampler_results.iteration, 'nParam': sampler_results.ndim, 'nCPU': nCPUs, 'nWalkers': sampler_results.nwalkers, 'acceptance_fraction': sampler_results.acceptance_fraction, 'acor_time': acor_time } if sampler_results.get_blobs() is not None: if len(sampler_results.get_blobs()) > 0: if len(np.shape(sampler_results.get_blobs())) == 2: # Only 1 blob: nSteps, nWalkers: sampler_results_dict['blobs'] = sampler_results.get_blobs()[nBurn:, :] flatblobs = np.array(sampler_results_dict['blobs']).reshape(-1) elif len(np.shape(sampler_results.get_blobs())) == 3: # Multiblobs; nSteps, nWalkers, nBlobs sampler_results_dict['blobs'] = sampler_results.get_blobs()[nBurn:, :, :] flatblobs = np.array(sampler_results_dict['blobs']).reshape(-1,np.shape(sampler_results.get_blobs())[2]) else: raise ValueError("sampler_results blob shape not recognized") sampler_results_dict['flatblobs'] = flatblobs return sampler_results_dict def _reload_sampler_results_hdf5(filename=None, backend_name='mcmc'): # Load backend from file backend = emcee.backends.HDFBackend(filename, name=backend_name) return _make_sampler_results_dict_from_hdf5(backend) def _make_sampler_results_dict_from_hdf5(b): """ Construct a dysmalpy 'sampler_results_dict' out of the chain info stored in the emcee v3 HDF5 file """ nwalkers = b.shape[0] ndim = b.shape[1] chain = np.swapaxes(b.get_chain(), 0, 1) flatchain = chain.reshape((-1, ndim)) # Walkers, iterations probs = np.swapaxes(b.get_log_prob(), 0, 1) flatprobs = probs.reshape(-1) acor_time = b.get_autocorr_time(tol=10, quiet=True) # Make a dictionary: sampler_results_dict = { 'chain': chain, 'flatchain': flatchain, 'lnprobability': probs, 'flatlnprobability': flatprobs, 'nIter': b.iteration, 'nParam': ndim, 'nCPU': None, 'nWalkers': nwalkers, 'acceptance_fraction': b.accepted / float(b.iteration), 'acor_time': acor_time } if b.has_blobs() : sampler_results_dict['blobs'] = b.get_blobs() if len(b.get_blobs().shape) == 2: # Only 1 blob: nSteps, nWalkers: flatblobs = np.array(sampler_results_dict['blobs']).reshape(-1) elif len(b.get_blobs().shape) == 3: # Multiblobs; nSteps, nWalkers, nBlobs flatblobs = np.array(sampler_results_dict['blobs']).reshape(-1,np.shape(sampler_results_dict['blobs'])[2]) else: raise ValueError("sampler_results blob shape not recognized") sampler_results_dict['flatblobs'] = flatblobs return sampler_results_dict def _reload_sampler_results_pickle(filename=None): return load_pickle(filename) def reinitialize_emcee_sampler_results(sampler_results_dict, gal=None, fitter=None): """ Re-setup emcee sampler_results, using existing chain / etc, so more steps can be run. """ kwargs_dict = {'fitter': fitter} # This will break for updated version of emcee # works for emcee v2.2.1 if emcee.__version__ == '2.2.1': sampler_results = emcee.EnsembleSampler(fitter.nWalkers, fitter.nParam, base.log_prob, args=[gal], kwargs=kwargs_dict, a=fitter.scale_param_a, threads=sampler_results_dict['nCPU']) sampler_results._chain = copy.deepcopy(sampler_results_dict['chain']) sampler_results._blobs = list(copy.deepcopy(sampler_results_dict['blobs'])) sampler_results._lnprob = copy.deepcopy(sampler_results_dict['lnprobability']) sampler_results.iterations = sampler_results_dict['nIter'] sampler_results.naccepted = np.array(sampler_results_dict['nIter']*copy.deepcopy(sampler_results_dict['acceptance_fraction']), dtype=np.int64) ### elif int(emcee.__version__[0]) >= 3: # This is based off of HDF5 files, which automatically makes it easy to reload + resetup the sampler_results raise ValueError("emcee >=3 uses HDF5 files, so re-initialization not necessary!") ### else: try: backend = emcee.Backend() backend.nwalkers = sampler_results_dict['nWalkers'] backend.ndim = sampler_results_dict['nParam'] backend.iteration = sampler_results_dict['nIter'] backend.accepted = np.array(sampler_results_dict['nIter']*sampler_results_dict['acceptance_fraction'], dtype=np.int64) backend.chain = sampler_results_dict['chain'] backend.log_prob = sampler_results_dict['lnprobability'] backend.blobs = sampler_results_dict['blobs'] backend.initialized = True sampler_results = emcee.EnsembleSampler(sampler_results_dict['nWalkers'], sampler_results_dict['nParam'], base.log_prob, args=[gal], kwargs=kwargs_dict, backend=backend, a=fitter.scale_param_a, threads=sampler_results_dict['nCPU']) except: raise ValueError return sampler_results def _reload_all_fitting_mcmc(filename_galmodel=None, filename_results=None): gal = galaxy.load_galaxy_object(filename=filename_galmodel) results = MCMCResults() results.reload_results(filename=filename_results) return gal, results