# coding=utf8
# Copyright (c) MPE/IR-Submm Group. See LICENSE.rst for license information.
#
# Utility functions for fitting
from __future__ import (absolute_import, division, print_function,
unicode_literals)
## Standard library
import logging
# Third party imports
import os
import numpy as np
import astropy.units as u
from scipy.stats import gaussian_kde
from scipy.optimize import fmin
# Dysmalpy imports:
from dysmalpy.instrument import DoubleBeam, Moffat, GaussianBeam
# LOGGER SETTINGS
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('DysmalPy')
logger.setLevel(logging.INFO)
_bayesian_fitting_methods = ['mcmc', 'nested']
def _chisq_generalized(gal, red_chisq=None):
if red_chisq is None:
raise ValueError("'red_chisq' must be True or False!")
chsq_general = 0.0
for obs_name in gal.observations:
obs = gal.observations[obs_name]
if obs.fit_options.fit:
# 3D observation
if obs.instrument.ndim == 3:
# Will have problem with vel shift: data, model won't match...
msk = obs.data.mask
dat = obs.data.data.unmasked_data[:].value[msk]
mod = obs.model_data.data.unmasked_data[:].value[msk]
err = obs.data.error.unmasked_data[:].value[msk]
# Weights:
wgt_data = 1.
if hasattr(obs.data, 'weight'):
if obs.data.weight is not None:
wgt_data = obs.data.weight[msk]
# Artificially mask zero errors which are masked
#err[((err==0) & (msk==0))] = 99.
chisq_arr_raw = (((dat - mod)/err)**2) * wgt_data
if red_chisq:
if gal.model.nparams_free > np.sum(msk) :
raise ValueError("More free parameters than data points!")
invnu = 1./ (1.*(np.sum(msk) - gal.model.nparams_free))
else:
invnu = 1.
chsq_general += chisq_arr_raw.sum() * invnu * obs.weight
elif ((obs.instrument.ndim == 1) or (obs.instrument.ndim ==2)):
if obs.fit_options.fit_velocity:
#msk = obs.data.mask
if hasattr(obs.data, 'mask_velocity'):
if obs.data.mask_velocity is not None:
msk = obs.data.mask_velocity
else:
msk = obs.data.mask
else:
msk = obs.data.mask
vel_dat = obs.data.data['velocity'][msk]
vel_mod = obs.model_data.data['velocity'][msk]
vel_err = obs.data.error['velocity'][msk]
if obs.fit_options.fit_dispersion:
if hasattr(obs.data, 'mask_vel_disp'):
if obs.data.mask_vel_disp is not None:
msk = obs.data.mask_vel_disp
else:
msk = obs.data.mask
else:
msk = obs.data.mask
disp_dat = obs.data.data['dispersion'][msk]
disp_mod = obs.model_data.data['dispersion'][msk]
disp_err = obs.data.error['dispersion'][msk]
# Correct model for instrument dispersion if the data is instrument corrected:
if 'inst_corr' in obs.data.data.keys():
if obs.data.data['inst_corr']:
disp_mod = np.sqrt(disp_mod**2 -
obs.instrument.lsf.dispersion.to(u.km/u.s).value**2)
disp_mod[~np.isfinite(disp_mod)] = 0 # Set the dispersion to zero when its below
# below the instrumental dispersion
if obs.fit_options.fit_flux:
msk = obs.data.mask
flux_dat = obs.data.data['flux'][msk]
flux_mod = obs.model_data.data['flux'][msk]
try:
flux_err = obs.data.error['flux'][msk]
except:
flux_err = 0.1*obs.data.data['flux'][msk] # PLACEHOLDER
# Weights:
wgt_data = 1.
if hasattr(obs.data, 'weight'):
if obs.data.weight is not None:
wgt_data = obs.data.weight[msk]
#####
fac_mask = 0
chisq_arr_sum = 0
if obs.fit_options.fit_velocity:
fac_mask += 1
### Data includes velocity
# Includes velocity shift
chisq_arr_raw_vel = (((vel_dat - vel_mod)/vel_err)**2) * wgt_data
chisq_arr_sum += chisq_arr_raw_vel.sum()
if obs.fit_options.fit_dispersion:
fac_mask += 1
chisq_arr_raw_disp = (((disp_dat - disp_mod)/disp_err)**2) * wgt_data
chisq_arr_sum += chisq_arr_raw_disp.sum()
if obs.fit_options.fit_flux:
fac_mask += 1
chisq_arr_raw_flux = (((flux_dat - flux_mod)/flux_err)**2) * wgt_data
chisq_arr_sum += chisq_arr_raw_flux.sum()
####
if red_chisq:
if gal.model.nparams_free > fac_mask*np.sum(msk) :
raise ValueError("More free parameters than data points!")
invnu = 1./ (1.*(fac_mask*np.sum(msk) - gal.model.nparams_free))
else:
invnu = 1.
####
chsq_general += (chisq_arr_sum) * invnu * obs.weight
elif obs.instrument.ndim == 0:
msk = obs.data.mask
data = obs.data.data
mod = obs.model_data.data
err = obs.data.error
# Weights:
wgt_data = 1.
if hasattr(obs.data, 'weight'):
if obs.data.weight is not None:
wgt_data = obs.data.weight
chisq_arr = (((data - mod)/err)**2) * wgt_data
if red_chisq:
if gal.model.nparams_free > np.sum(msk):
raise ValueError("More free parameters than data points!")
invnu = 1. / (1. * (np.sum(msk) - gal.model.nparams_free))
else:
invnu = 1.
chsq_general += chisq_arr.sum() * invnu * obs.weight
else:
logger.warning("ndim={} not supported!".format(obs.instrument.ndim))
raise ValueError
return chsq_general
def _chisq_general_per_type(obs, type=None, red_chisq=True, nparams_free=None, **kwargs):
"""
Evaluate reduced chi square of model for one specific map/profile
(i.e., flux/velocity/dispersion), relative to the data.
"""
# type = 'velocity', 'disperesion', or 'flux'
if ((obs.data.ndim != 1) & (obs.data.ndim != 2)):
msg = "_chisq_general_per_type() can only be called when\n"
msg += "obs.data.ndim = 1 or 2!"
raise ValueError(msg)
if (type.strip().lower() == 'velocity'):
#msk = obs.data.mask
if hasattr(obs.data, 'mask_velocity'):
if obs.data.mask_velocity is not None:
msk = obs.data.mask_velocity
else:
msk = obs.data.mask
else:
msk = obs.data.mask
vel_dat = obs.data.data['velocity'][msk]
vel_mod = obs.model_data.data['velocity'][msk]
vel_err = obs.data.error['velocity'][msk]
if (type.strip().lower() == 'dispersion'):
if hasattr(obs.data, 'mask_vel_disp'):
if obs.data.mask_vel_disp is not None:
msk = obs.data.mask_vel_disp
else:
msk = obs.data.mask
else:
msk = obs.data.mask
disp_dat = obs.data.data['dispersion'][msk]
disp_mod = obs.model_data.data['dispersion'][msk]
disp_err = obs.data.error['dispersion'][msk]
# Correct model for instrument dispersion if the data is instrument corrected:
if 'inst_corr' in obs.data.data.keys():
if obs.data.data['inst_corr']:
disp_mod = np.sqrt(disp_mod**2 -
obs.instrument.lsf.dispersion.to(u.km/u.s).value**2)
disp_mod[~np.isfinite(disp_mod)] = 0 # Set the dispersion to zero when its below
# below the instrumental dispersion
if (type.strip().lower() == 'flux'):
msk = obs.data.mask
flux_dat = obs.data.data['flux'][msk]
flux_mod = obs.model_data.data['flux'][msk]
try:
flux_err = obs.data.error['flux'][msk]
except:
flux_err = 0.1*obs.data.data['flux'][msk] # PLACEHOLDER
# Weights:
wgt_data = 1.
if hasattr(obs.data, 'weight'):
if obs.data.weight is not None:
wgt_data = obs.data.weight[msk]
#####
fac_mask = 0
chisq_arr_sum = 0
if (type.strip().lower() == 'velocity'):
fac_mask += 1
### Data includes velocity
# Includes velocity shift
chisq_arr_raw_vel = (((vel_dat - vel_mod)/vel_err)**2) * wgt_data
chisq_arr_sum += chisq_arr_raw_vel.sum()
if (type.strip().lower() == 'dispersion'):
fac_mask += 1
chisq_arr_raw_disp = (((disp_dat - disp_mod)/disp_err)**2) * wgt_data
chisq_arr_sum += chisq_arr_raw_disp.sum()
if (type.strip().lower() == 'flux'):
fac_mask += 1
chisq_arr_raw_flux = (((flux_dat - flux_mod)/flux_err)**2) * wgt_data
chisq_arr_sum += chisq_arr_raw_flux.sum()
####
if red_chisq:
if nparams_free is None:
raise ValueError("If `red_chisq` = TRUE, then must set `nparams_free`.")
if nparams_free > fac_mask*np.sum(msk) :
raise ValueError("More free parameters than data points!")
invnu = 1./ (1.*(fac_mask*np.sum(msk) - nparams_free)) * obs.weight
else:
invnu = 1. * obs.weight
####
chsq_general = (chisq_arr_sum) * invnu
return chsq_general
def find_peak_gaussian_KDE(flatchain, initval, weights=None):
"""
Return chain parameters that give peak of the posterior PDF, using KDE.
"""
try:
nparams = flatchain.shape[1]
nrows = nparams
except:
nparams = 1
nrows = 0
if nrows > 0:
peakvals = np.zeros(nparams)
for i in range(nparams):
kern = gaussian_kde(flatchain[:,i], weights=weights)
peakvals[i] = fmin(lambda x: -kern(x), initval[i],disp=False)
return peakvals
else:
kern = gaussian_kde(flatchain, weights=weights)
peakval = fmin(lambda x: -kern(x), initval,disp=False)
try:
return peakval[0]
except:
return peakval
def find_peak_gaussian_KDE_multiD(flatchain, linked_inds, initval, weights=None):
"""
Return chain parameters that give peak of the posterior PDF *FOR LINKED PARAMETERS*, using KDE.
"""
if weights is not None:
raise ValueError("TEST")
# nparams = len(linked_inds)
kern = gaussian_kde(flatchain[:,linked_inds].T, weights=weights)
peakvals = fmin(lambda x: -kern(x), initval,disp=False)
return peakvals
def find_multiD_pk_hist(flatchain, linked_inds, nPostBins=50):
H2, edges = np.histogramdd(flatchain[:,linked_inds], bins=nPostBins)
wh_pk = np.where(H2 == H2.max())[0][0]
pk_vals = np.zeros(len(linked_inds))
for k in range(len(linked_inds)):
pk_vals[k] = np.average([edges[k][wh_pk], edges[k][wh_pk+1]])
return pk_vals
def get_linked_posterior_peak_values(flatchain,
guess = None,
linked_posterior_ind_arr=None,
nPostBins=50):
"""
Get linked posterior best-fit values using a multi-D histogram for the
given linked parameter indices.
Input:
flatchain: sampler flatchain, shape (Nwalkers, Nparams)
linked_posterior_inds_arr: array of arrays of parameters to be analyzed together
eg: analyze ind1+ind2 together, and then ind3+ind4 together
linked_posterior_inds_arr = [ [ind1, ind2], [ind3, ind4] ]
nPostBins: number of bins on each parameter "edge" of the multi-D histogram
Output:
bestfit_theta_linked: array of the linked bestfit paramter values from multiD param space
eg:
bestfit_theta_linked = [ [best1, best2], [best3, best4] ]
"""
# Use gaussian KDE to get bestfit linked:
bestfit_theta_linked = np.array([])
for k in range(len(linked_posterior_ind_arr)):
bestfit_thetas = find_peak_gaussian_KDE_multiD(flatchain, linked_posterior_ind_arr[k],
guess[linked_posterior_ind_arr[k]])
if len(bestfit_theta_linked) >= 1:
bestfit_theta_linked = np.stack(bestfit_theta_linked, np.array([bestfit_thetas]) )
else:
bestfit_theta_linked = np.array([bestfit_thetas])
return bestfit_theta_linked
def get_linked_posterior_indices(nestedResults):
"""
Convert the input set of linked posterior names to set of indices:
Input:
(example structure)
To analyze all parameters together:
linked_posterior_names = 'all'
Alternative: only link some parameters:
linked_posterior_names = [ joint_param_bundle1, joint_param_bundle2 ]
with
join_param_bundle1 = [ [cmp1, par1], [cmp2, par2] ]
jont_param_bundle2 = [ [cmp3, par3], [cmp4, par4] ]
for a full array of:
linked_posterior_names =
[ [ [cmp1, par1], [cmp2, par2] ], [ [cmp3, par3], [cmp4, par4] ] ]
Also if doing single bundle must have:
linked_posterior_names = [ [ [cmp1, par1], [cmp2, par2] ] ]
Output:
linked_posterior_inds = [ joint_bundle1_inds, joint_bundle2_inds ]
with joint_bundle1_inds = [ ind1, ind2 ], etc
ex:
output = [ [ind1, ind2], [ind3, ind4] ]
"""
linked_posterior_ind_arr = None
try:
if nestedResults.linked_posterior_names.strip().lower() == 'all':
linked_posterior_ind_arr = [range(len(nestedResults.free_param_names))]
except:
pass
if linked_posterior_ind_arr is None:
free_cmp_param_arr = make_arr_cmp_params(nestedResults)
linked_posterior_ind_arr = []
for k in range(len(nestedResults.linked_posterior_names)):
# Loop over *sets* of linked posteriors:
# This is an array of len-2 arrays/tuples with cmp, param names
linked_post_inds = []
for j in range(len(nestedResults.linked_posterior_names[k])):
indp = get_param_index(nestedResults, nestedResults.linked_posterior_names[k][j],
free_cmp_param_arr=free_cmp_param_arr)
linked_post_inds.append(indp)
linked_posterior_ind_arr.append(linked_post_inds)
return linked_posterior_ind_arr
def get_param_index(nestedResults, param_name, free_cmp_param_arr=None):
if free_cmp_param_arr is None:
free_cmp_param_arr = make_arr_cmp_params(nestedResults)
cmp_param = param_name[0].strip().lower()+':'+param_name[1].strip().lower()
try:
whmatch = np.where(free_cmp_param_arr == cmp_param)[0][0]
except:
raise ValueError(cmp_param+' component+parameter not found in free parameters of nestedResults')
return whmatch
############################################################
# UTILITY FUNCTIONS
####################
def make_arr_cmp_params(results):
arr = np.array([])
for cmp in results.free_param_names.keys():
for i in range(len(results.free_param_names[cmp])):
param = results.free_param_names[cmp][i]
arr = np.append( arr, cmp.strip().lower()+':'+param.strip().lower() )
return arr
[docs]
def setup_oversampled_chisq(gal):
# Setup for oversampled_chisq:
for obs_name in gal.observations:
obs = gal.observations[obs_name]
if isinstance(obs.instrument.beam, GaussianBeam):
try:
PSF_FWHM = obs.instrument.beam.major.value
except:
PSF_FWHM = obs.instrument.beam.major
elif isinstance(obs.instrument.beam, Moffat):
try:
PSF_FWHM = obs.instrument.beam.major_fwhm.value
except:
PSF_FWHM = obs.instrument.beam.major_fwhm
elif isinstance(obs.instrument.beam, DoubleBeam):
try:
PSF_FWHM = np.max([obs.instrument.beam.beam1.major.value, obs.instrument.beam.beam2.major.value])
except:
PSF_FWHM = np.max([obs.instrument.beam.beam1.major, obs.instrument.beam.beam2.major])
if obs.instrument.ndim == 1:
rarrtmp = obs.data.rarr.copy()
rarrtmp.sort()
spacing_avg = np.abs(np.average(rarrtmp[1:]-rarrtmp[:-1]))
obs.data.oversample_factor_chisq = PSF_FWHM /spacing_avg
elif obs.instrument.ndim == 2:
obs.data.oversample_factor_chisq = (PSF_FWHM / obs.instrument.pixscale.value)**2
elif obs.instrument.ndim == 3:
spec_step = obs.instrument.spec_step.to(u.km/u.s).value
LSF_FWHM = obs.instrument.lsf.dispersion.to(u.km/u.s).value * (2.*np.sqrt(2.*np.log(2.)))
obs.data.oversample_factor_chisq = (LSF_FWHM / spec_step) * (PSF_FWHM / obs.instrument.pixscale.value)**2
return gal
def find_shortest_conf_interval(xarr, percentile_frac):
# Canonical 1sigma: 0.6827
xsort = np.sort(xarr)
N = len(xarr)
i_max = np.int64(np.round(percentile_frac*N))
len_arr = xsort[i_max:] - xsort[0:N-i_max]
argmin = np.argmin(len_arr)
l_val, u_val = xsort[argmin], xsort[argmin+i_max-1]
return l_val, u_val
def shortest_span_bounds(arr, percentile=0.6827):
if len(arr.shape) == 1:
limits = find_shortest_conf_interval(arr, percentile)
else:
limits = np.ones((2, arr.shape[1]))
for j in range(arr.shape[1]):
limits[:, j] = find_shortest_conf_interval(arr[:,j], percentile)
return limits
def _check_existing_files_overwrite(output_options, fit_type=None, fitter=None):
# ---------------------------------------------------
# Check for existing files if overwrite=False:
if (not output_options.overwrite):
fnames = []
fnames_opt = [ output_options.f_model, output_options.f_vcirc_ascii,
output_options.f_mass_ascii, output_options.f_results,
output_options.f_sampler_results,
output_options.f_plot_bestfit ]
if (fit_type.lower() in _bayesian_fitting_methods):
fnames_ext = [output_options.f_plot_trace,
output_options.f_plot_param_corner,
output_options.f_chain_ascii]
for fn in fnames_ext:
fnames_opt.append(fn)
if (fit_type.lower() == 'mcmc'):
fnames_ext = [output_options.f_plot_trace_burnin]
for fn in fnames_ext:
fnames_opt.append(fn)
elif (fit_type.lower() == 'nested'):
fnames_ext = [output_options.f_checkpoint,
output_options.f_plot_run]
for fn in fnames_ext:
fnames_opt.append(fn)
for fname in fnames_opt:
if fname is not None:
fnames.append(fname)
file_bundle_names = ['f_model_bestfit', 'f_vel_ascii', 'f_bestfit_cube']
for fbunname in file_bundle_names:
for obsn in output_options.__dict__[fbunname]:
if output_options.__dict__[fbunname][obsn] is not None:
fnames.append(output_options.__dict__[fbunname][obsn])
for fname in fnames:
if fname is not None:
if os.path.isfile(fname):
logger.warning("overwrite={} & File already exists! Will not save file. \n {}".format(output_options.overwrite, fname))
# Return early if it won't save the results, sampler:
if output_options.f_results is not None:
if os.path.isfile(output_options.f_results):
msg = "overwrite={}, and 'f_results' won't be saved,".format(output_options.overwrite)
msg += " so the fit will not be saved.\n Specify new outfile or delete old files."
logger.warning(msg)
return None
else:
# Overwrite=True: remove old file versions
if (fit_type.lower() in _bayesian_fitting_methods):
fnames_ext = [output_options.f_plot_trace,
output_options.f_plot_param_corner,
output_options.f_chain_ascii]
for fn in fnames_ext:
if os.path.isfile(fn): os.remove(fn)
if (fit_type.lower() == 'mcmc'):
fnames_ext = [output_options.f_plot_trace_burnin]
for fn in fnames_ext:
if os.path.isfile(fn): os.remove(fn)
elif (fit_type.lower() == 'nested'):
fnames_ext = [output_options.f_checkpoint,
output_options.f_plot_run]
for fn in fnames_ext:
if os.path.isfile(fn): os.remove(fn)
# ---------------------------------------------------