Source code for curepy.retrieval_methods.mcmc

"""Markov Chain Monte Carlo (MCMC) retrieval class"""

from curepy.retrieval_methods.base import BaseRetrieval
from curepy.container.retrieval_input import RetrievalInput
from curepy.container.retrieval_result import RetrievalResult


from multiprocessing import Pool
import emcee
import numpy as np
from typing import Optional, Union, Sequence
import comet_maths as cm


[docs] class MCMC(BaseRetrieval): """MCMC retrieval object.""" def __init__( self, nwalkers: int, steps: int, burn_in: int, progress: bool = True, parallel_cores: int = 1, ) -> None: """ Initialise the MCMC retrieval object. :param nwalkers: Number of ensemble walkers used by :class:`emcee.EnsembleSampler`. :param steps: Total number of MCMC steps per walker. :param burn_in: Number of initial samples to discard as burn-in. :param progress: If ``True``, display a progress bar during sampling. :param parallel_cores: Number of CPU cores for parallel sampling. Values greater than 1 use :class:`multiprocessing.Pool`. """ self.nwalkers = nwalkers self.steps = steps self.burn_in = burn_in self.progress = progress self.parallel_cores = parallel_cores def _run_retrieval( self, retrieval_input: RetrievalInput, return_samples: bool = False, return_corr: bool = False, return_b_samples: bool = False, reshape_results: bool = True, corr_dims: Optional[Union[int, Sequence[int]]] = -99, ) -> RetrievalResult: """ Run the MCMC retrieval and return the results. :param retrieval_input: Object containing all retrieval inputs. :param return_samples: If ``True``, the full sample array is stored in the returned :class:`~curepy.container.retrieval_result.RetrievalResult`. :param return_corr: If ``True``, the parameter correlation matrix is computed from the samples and stored in the result. :param return_b_samples: If ``True``, the ancillary parameter samples are stored in the result. :param reshape_results: If ``True``, reshape the flat output arrays to the initial-guess shape. :returns: Object containing retrieved values, uncertainties, and optionally samples and correlations. """ self.retrieval_input = retrieval_input self._check_retrieval_input() # define theta_0 theta_0 = self.generate_theta_0( self.retrieval_input.measurement_function_obj.initial_guess ) # generate b samples if ancillary data exists self.retrieval_input.ancillary_obj.generate_b_samples() b_samples = self.retrieval_input.ancillary_obj.b_samples # generate samples with MCMC if b_samples is None or self.retrieval_input.ancillary_obj.b_MC_steps == 1: samples = self.run_MCMC(theta_0, self.nwalkers, self.steps, self.burn_in) else: samples = np.zeros( ( (self.nwalkers * self.steps - self.burn_in) * self.retrieval_input.ancillary_obj.b_MC_steps, len(theta_0), ), dtype=np.float32, ) b = self.retrieval_input.ancillary_obj.b[:] for i in range(len(b_samples[0])): for ii in range(len(b_samples)): # if b_samples[ii].ndim == 1: self.retrieval_input.ancillary_obj.b[ii] = b_samples[ii][i] # elif b_samples[ii].ndim == 2: # self.retrieval_input.ancillary_obj.b[ii] = np.array( # [b_samples[ii][j][i] for j in range(len(b_samples[ii]))] # ) # else: # raise ValueError( # "MCMC_retrieval: the dimensionality of one of the parameters in b is not supported (currently the ancillary parameters in b can only be floats or 1d arrays)." # ) samples[ i * (self.nwalkers * self.steps - self.burn_in) : (i + 1) * (self.nwalkers * self.steps - self.burn_in), :, ] = self.run_MCMC(theta_0, self.nwalkers, self.steps, self.burn_in) self.retrieval_input.ancillary_obj.b = b[:] return self.analyse_samples( samples, b_samples, return_samples, return_corr, return_b_samples, reshape_results, corr_dims=corr_dims, )
[docs] def run_MCMC( self, theta_0: np.ndarray, nwalkers: int, steps: int, burn_in: int, ) -> np.ndarray: """ Run :class:`emcee.EnsembleSampler` and return the post-burn-in chain. :param theta_0: Initial state vector around which walkers are initialised. :param nwalkers: Number of ensemble walkers. :param steps: Total number of sampling steps. :param burn_in: Number of initial samples to discard. :returns: Array of post-burn-in samples with shape ``(nwalkers * steps - burn_in, ndim)``. """ ndimw = len(theta_0) pos = [self.generate_theta_i(theta_0) for i in range(nwalkers)] if self.parallel_cores > 1: p = Pool(self.parallel_cores) sampler = emcee.EnsembleSampler(nwalkers, ndimw, self.lnprob, pool=p) else: sampler = emcee.EnsembleSampler(nwalkers, ndimw, self.lnprob) sampler.run_mcmc(pos, steps, progress=self.progress) samples = sampler.get_chain()[:, :, :].reshape((-1, ndimw))[burn_in::] return samples
[docs] def generate_theta_i( self, theta_0: np.ndarray, factor_std: float = 0.1, ) -> np.ndarray: """ Generate a single walker starting position from ``theta_0``. Perturbs ``theta_0`` by a Gaussian factor and recursively reduces the perturbation magnitude until the resulting position lies within the support of the prior. :param theta_0: Initial state vector. :param factor_std: Standard deviation of the multiplicative Gaussian perturbation. :returns: Perturbed starting position that is within the prior support. """ theta_i = theta_0 * np.random.normal(1.0, factor_std, theta_0.shape) if all( np.isfinite( self.retrieval_input.prior_obj.lnprior( theta_i, )() ) ): return theta_i else: return self.generate_theta_i(theta_0, factor_std=factor_std * 0.9)
[docs] def analyse_samples( self, samples: np.ndarray, b_samples: Optional[np.ndarray], return_samples: bool, return_corr: bool, return_b_samples: bool, reshape_results: bool, corr_dims: Optional[Union[int, Sequence[int]]] = -99, ) -> RetrievalResult: """ Summarise MCMC samples into a :class:`~curepy.container.retrieval_result.RetrievalResult`. Computes the median, symmetric uncertainty (average of upper and lower 1-sigma percentiles), and optionally the correlation matrix. :param samples: Post-burn-in MCMC samples. :param b_samples: Ancillary parameter samples. :param return_samples: If ``True``, include the raw samples in the result. :param return_corr: If ``True``, compute and include the correlation matrix. :param return_b_samples: If ``True``, include ancillary samples. :param reshape_results: If ``True``, reshape outputs to the initial-guess shape. :param corr_dims: int or List of ints, axis to calculate correlation matrix along. :returns: Retrieved values, uncertainties, and optional extras. """ medians = np.median(samples, axis=0) unc_up = np.percentile(samples, 84, axis=0) - medians unc_down = -(np.percentile(samples, 16, axis=0) - medians) unc_avg = (unc_up + unc_down) / 2.0 if return_corr: if corr_dims != -99: corr = cm.calculate_corr( samples.reshape( (samples.shape[0],) + self.retrieval_input.measurement_function_obj.initial_guess.shape ), corr_dims, ) else: corr = cm.calculate_corr(samples, corr_dims) if reshape_results: medians, unc_avg, corr = self.reshape_outputs( medians, unc_avg, corr if return_corr else None ) outs = RetrievalResult( x=medians, u_x=unc_avg, corr_x=corr if return_corr else None, samples=samples if return_samples else None, b_samples=b_samples if return_b_samples else None, x_names=self.retrieval_input.measurement_function_obj._input_quantities_names, retrieval_object=self, ) return outs