Source code for optimpv.BayesInfEmcee.EmceeOptimizer

import numpy as np
import matplotlib.pyplot as plt
import emcee
import corner

import copy, os, warnings
from collections import defaultdict
from multiprocessing import Pool
from functools import partial
from logging import Logger
import time
import logging
# logging.basicConfig(format='[%(levelname)s %(asctime)s] %(name)s: %(message)s', level=logging.INFO)

# logger = logging.getLogger(__name__)
from optimpv.general.logger import get_logger, _round_floats_for_logging

logger: Logger = get_logger('EmceeOptimizer')
ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES: int = 6
round_floats_for_logging = partial(
    _round_floats_for_logging,
    decimal_places=ROUND_FLOATS_IN_LOGS_TO_DECIMAL_PLACES,
)
# Assuming other necessary imports like the base optimizer class or model interface exist
from optimpv.general.BaseAgent import BaseAgent # Import BaseAgent

[docs] class EmceeOptimizer(BaseAgent): # Inherit from BaseAgent """ Optimizer using the emcee library for MCMC Bayesian inference. Inherits from BaseAgent and interacts with Agent objects. """ def __init__(self, params=None, agents=None, nwalkers=20, nsteps=1000, burn_in=100, progress=True, name='emcee_opti', **kwargs): """_summary_ Parameters ---------- params : list of Fitparam() objects, optional List of Fitparam() objects, by default None agents : list of Agent() objects, optional List of Agent() objects see optimpv/general/BaseAgent.py for a base class definition, by default None nwalkers : int, optional Number of walkers in the MCMC ensemble, by default 20 nsteps : int, optional Number of MCMC steps per walker, by default 1000 burn_in : int, optional Number of steps to discard as burn-in, by default 100 progress : bool, optional Whether to display the progress bar during sampling, by default True name : str, optional Name for the optimization process, by default 'emcee_opti' **kwargs : dict, optional Additional keyword arguments (e.g., parallel processing settings), by default None Raises ------ ValueError Agents must minimize all targets. Please set minimize=True for all targets. ValueError Parameter must be of type 'float'. Please set value_type='float' for all parameters. ValueError Number of dimensions (parameters) cannot be determined. """ # super().__init__() # Call BaseAgent init if needed self.params = params if not isinstance(agents, list): agents = [agents] self.agents = agents self.nwalkers = nwalkers self.nsteps = nsteps self.burn_in = burn_in self.progress = progress self.name = name self.kwargs = kwargs # make sure all agents target are minimize for agent in self.agents: if hasattr(agent, 'minimize'): for i in range(len(agent.minimize)): if not agent.minimize[i]: raise ValueError(f"Agent {agent.name} must minimize all targets. Please set minimize=True for all targets.") # make sure all of the params val_type are floats for param in self.params: if param.value_type != 'float': raise ValueError(f"Parameter {param.name} must be of type 'float'. Please set value_type='float' for all parameters.") # Extract settings from kwargs self.use_pool = kwargs.get('use_pool', True) # Control whether to use multiprocessing Pool self.max_parallelism = kwargs.get('max_parallelism', os.cpu_count() - 1) # Process parameters to get dimensions, bounds, initial guess, names self.x0, self.bounds, self.param_mapping, self.log_params_indices = self.create_search_space(self.params) self.ndim = len(self.x0) self.param_names = [p.display_name if hasattr(p,'display_name') else p.name for p in self.params if p.name in self.param_mapping] # Use display names if available if self.ndim == 0: raise ValueError("Number of dimensions (parameters) cannot be determined.") self.sampler = None self.chain = None self.flat_samples = None self.results = None self.all_metrics = self.create_metrics_list() # Helper to get metric names # def _get_all_metrics(self): # """ Get a list of all metric names from the agents. """ # metrics = [] # for agent in self.agents: # for i in range(len(agent.metric)): # if hasattr(agent, 'exp_format'): # metrics.append(f"{agent.name}_{agent.exp_format[i]}_{agent.metric[i]}") # else: # metrics.append(f"{agent.name}_{agent.metric[i]}") # return metrics
[docs] def create_metrics_list(self): """ Create a list of all metrics from all agents. Returns ------- list List of metric names """ metrics = [] for agent in self.agents: for i in range(len(agent.all_agent_metrics)): metrics.append(agent.all_agent_metrics[i]) return metrics
[docs] def create_search_space(self, params): """Create search space details (initial vector, bounds, mapping) from FitParam list. Parameters ---------- params : list of FitParam List of FitParam objects defining the parameters to optimize. Returns ------- tuple x0 : array Initial parameter vector for optimization (potentially log-transformed). bounds : list of tuples List of (lower, upper) bound tuples for optimization vector. param_mapping : list List of parameter names corresponding to x0 elements. log_params_indices : list Indices of parameters optimized in log10 space. Raises ------ ValueError If a parameter type is unsupported (not 'float'). """ # Initialize empty lists for x0, bounds, and parameter mapping x0 = [] bounds = [] param_mapping = [] log_params_indices = [] for i, param in enumerate(params): if param.type == 'fixed': continue param_mapping.append(param.name) current_index = len(x0) # Index in the optimization vector 'x' if param.value_type == 'float': if param.force_log: log_params_indices.append(current_index) x0.append(np.log10(param.value)) # Ensure bounds are positive before log10 lower_bound = np.log10(param.bounds[0]) if param.bounds[0] > 0 else -np.inf upper_bound = np.log10(param.bounds[1]) if param.bounds[1] > 0 else np.inf bounds.append((lower_bound, upper_bound)) else: scale_factor = param.fscale if hasattr(param, 'fscale') and param.fscale is not None else 1.0 x0.append(param.value / scale_factor) bounds.append((param.bounds[0] / scale_factor, param.bounds[1] / scale_factor)) else: raise ValueError(f"Unsupported parameter type: {param.value_type}. Only 'float' is supported.") return np.array(x0), bounds, param_mapping, log_params_indices
[docs] def reconstruct_params(self, x_opt): """Reconstruct a full parameter dictionary from an optimization vector x_opt. Parameters ---------- x_opt : array-like Parameter vector from the optimizer (potentially log-transformed). Returns ------- dict Dictionary mapping full parameter names to their values. Raises ------ ValueError If a parameter type is unsupported (not 'float'). """ # Initialize empty dictionary for reconstructed parameters param_dict = {} opt_idx = 0 for param in self.params: if param.type == 'fixed': param_dict[param.name] = param.value else: # Find the corresponding value in x_opt current_val = x_opt[opt_idx] if param.value_type == 'float': if opt_idx in self.log_params_indices: if param.force_log: # If log10 transformed, convert back to original scale param_dict[param.name] = 10**current_val else: param_dict[param.name] = current_val else: scale_factor = param.fscale if hasattr(param, 'fscale') and param.fscale is not None else 1.0 param_dict[param.name] = current_val * scale_factor else: raise ValueError(f"Unsupported parameter type: {param.value_type}. Only 'float' is supported.") opt_idx += 1 return param_dict
def _log_likelihood(self, theta, agents=None): """Calculate the log-likelihood based on agent evaluations. Assumes agent.run_Ax returns a dictionary where keys match self.all_metrics and values are loss/metric values (e.g., sum of squared errors). Converts loss to log-likelihood assuming Gaussian noise. Parameters ---------- theta : array-like contains the parameters to evaluate agents : list of Agent() objects, optional List of Agent() objects to evaluate the likelihood, by default None Returns ------- float Log-likelihood value. Returns -np.inf for invalid evaluations (e.g., NaN, Inf). """ # param_dict = self.reconstruct_params(theta) param_dict = {} idx = 0 for i, param in enumerate(self.params): if param.type == 'fixed': param_dict[param.name] = param.value else: param_dict[param.name] = theta[idx] idx += 1 total_log_like = 0.0 all_results = {} # Evaluate all agents for the given parameter set # Note: This part is sequential. Parallelism happens at the walker level in emcee. try: for agent in agents: # Assuming run_Ax needs the parameter dictionary agent_results = agent.run_Ax(param_dict) all_results.update(agent_results) # Combine results into a single log-likelihood value # Simple approach: sum of negative losses (assuming loss ~ -2*logL) for metric_name in self.all_metrics: if metric_name in all_results: loss_val = all_results[metric_name] if np.isnan(loss_val) or not np.isfinite(loss_val): return -np.inf # Penalize NaNs or Infs heavily log_like_contribution = -0.5 * loss_val total_log_like += log_like_contribution else: # Metric not found in results, indicates an issue warnings.warn(f"Metric {metric_name} not found in agent results for params {param_dict}, something went wrong.") return -np.inf if not np.isfinite(total_log_like): return -np.inf return total_log_like except Exception as e: # Handle potential errors during agent evaluation (e.g., simulation crashes) # print(f"Error during agent evaluation: {e}") # Optional: log error return -np.inf # Penalize parameters causing errors def _log_prior(self, theta): """ Calculate the log-prior probability of the parameters (in optimization space). Uses bounds defined during initialization. Assumes uniform prior within bounds. Parameters ---------- theta : array-like Parameter vector in optimization space. Returns ------- float Log-prior value. Returns -np.inf for invalid evaluations (e.g., outside bounds). """ for i in range(self.ndim): min_val, max_val = self.bounds[i] if not (min_val <= theta[i] <= max_val): return -np.inf # Add other priors if necessary (e.g., Gaussian priors on specific parameters) # Remember theta is potentially log10 transformed for some parameters. # Priors should be defined on the space you are sampling (theta). return 0.0 # Flat prior within bounds def _log_probability(self, theta, agents=None): """ Calculate the total log-probability (log-prior + log-likelihood). This is the function called by emcee. It combines the prior and likelihood evaluations. Handles potential issues with likelihood evaluation and prior violations. Parameters ---------- theta : array-like Parameter vector in optimization space. agents : list of Agent() objects, optional List of Agent() objects to evaluate the likelihood, by default None Returns ------- float Total log-probability value. Returns -np.inf for invalid evaluations (e.g., NaN, Inf). """ lp = self._log_prior(theta) if not np.isfinite(lp): return -np.inf # Prior is violated # Likelihood calculation might fail, handle potential errors try: ll = self._log_likelihood(theta, agents=agents) if not np.isfinite(ll): return -np.inf # Likelihood calculation failed or returned non-finite value except Exception as e: # Catch unexpected errors during likelihood calculation warnings.warn(f"Unexpected error in log_likelihood: {e}") return -np.inf return lp + ll
[docs] def initialize_walkers(self): # use LHS to initialize walkers # LHS is a sampling method that ensures the samples are evenly distributed in the parameter space from scipy.stats import qmc sampler = qmc.LatinHypercube(d=self.ndim) unit_samples = sampler.random(n=self.nwalkers) lower_bounds, upper_bounds = np.array(self.bounds).T scaled_samples = qmc.scale(unit_samples, lower_bounds, upper_bounds) return scaled_samples
[docs] def optimize(self): """ Run the MCMC optimization using emcee. """ verbose_logging = self.kwargs.get('verbose_logging',True) # Initialize walkers # Start walkers in a small ball around the initial guess x0 # pos = self.x0 + 1e-4 * np.random.randn(self.nwalkers, self.ndim) # # Ensure initial positions respect bounds # for i in range(self.nwalkers): # for j in range(self.ndim): # pos[i, j] = np.clip(pos[i, j], self.bounds[j][0], self.bounds[j][1]) # # Optional: Re-check if any clipped position violates prior (e.g., if bounds were -inf/inf) # while not np.isfinite(self._log_prior(pos[i])): # # Resample if prior is still violated (should be rare with clipping if bounds are finite) # pos[i] = self.x0 + 1e-3 * np.random.randn(self.ndim) # for j in range(self.ndim): # pos[i, j] = np.clip(pos[i, j], self.bounds[j][0], self.bounds[j][1]) pos = self.initialize_walkers() if verbose_logging: # Log initial positions print("----------------------------------------------------\n") logger.info(f"Running MCMC with {self.nwalkers} walkers for {self.nsteps} steps...") # Setup multiprocessing pool if enabled pool = None map_fn = map if self.use_pool: pool = Pool(processes=self.max_parallelism) map_fn = pool.map # Use pool's map for parallelization # Create the sampler # Pass the pool to EnsembleSampler for parallel likelihood evaluations njobs = min(self.nwalkers, self.max_parallelism) if self.use_pool else 1 # with Pool() as pool: sampler = emcee.EnsembleSampler( self.nwalkers, self.ndim, self._log_probability, pool=pool, kwargs={'agents': self.agents} ) # Burn-in phase state = sampler.run_mcmc(pos, self.burn_in, progress=True) sampler.reset() # Run MCMC sampler.run_mcmc(state, self.nsteps, progress=self.progress) # Close the pool if it was used if pool is not None: pool.close() pool.join() self.sampler = sampler if verbose_logging: logger.info("MCMC run complete.") # Process results self.chain = sampler.get_chain() # Adjust thin parameter as needed, based on autocorrelation time analysis if performed autocorr_time = sampler.get_autocorr_time(tol=0) # Basic estimate thin_factor = int(np.mean(autocorr_time) / 2) if np.all(np.isfinite(autocorr_time)) else 15 thin_factor = max(1, thin_factor) # Ensure thin >= 1 self.log_prob_samples = self.sampler.get_log_prob(discard=self.burn_in, thin=thin_factor, flat=True) # Use same thinning self.flat_samples = sampler.get_chain(discard=self.burn_in, thin=thin_factor, flat=True) # Store results (e.g., median parameters and uncertainties in original parameter space) self.results = {} # Get median parameters in optimization space median_opt_params = np.median(self.flat_samples, axis=0) # Convert median parameters back to original space median_params_dict = self.reconstruct_params(median_opt_params) if verbose_logging: # Log median parameters logger.info("MCMC Results (Median & 16th/84th Percentiles)") for i, name in enumerate(self.param_mapping): # Get samples for this parameter in optimization space param_samples_opt = self.flat_samples[:, i] # Transform samples back to original space if necessary if i in self.log_params_indices: param_samples_orig = 10**param_samples_opt else: # Check if scaling was applied original_param = next(p for p in self.params if p.name == name) scale_factor = original_param.fscale if hasattr(original_param, 'fscale') and original_param.fscale is not None else 1.0 param_samples_orig = param_samples_opt * scale_factor if original_param.value_type == 'int': # Keep as float for distribution analysis, or round if needed pass # param_samples_orig = np.round(param_samples_orig) # Calculate percentiles in original space mcmc = np.percentile(param_samples_orig, [16, 50, 84]) q = np.diff(mcmc) self.results[name] = {'median': mcmc[1], '16th': mcmc[0], '84th': mcmc[2], 'lower_err': q[0], 'upper_err': q[1]} # Find the display name for printing display_name = next((p.display_name for p in self.params if p.name == name and hasattr(p,'display_name')), name) if verbose_logging: # Log results logger.info(f"{display_name} ({name}): {mcmc[1]:.4g} (+{q[1]:.3g} / -{q[0]:.3g})") # Update self.params with median values self.update_params_with_best_balance() # Use max likelihood by default if verbose_logging: print("----------------------------------------------------\n") return self.results
[docs] def get_best_params(self, method='max_likelihood'): """Return the 'best' parameters based on the MCMC samples. This method allows the user to specify how to determine the 'best' parameters Parameters ---------- method : str How to determine 'best' params ('median', 'mean', 'max_likelihood'). 'median' - median of the samples 'mean' - mean of the samples 'max_likelihood' - maximum likelihood estimate (MLE) based on the log-probability samples 'max_likelihood' is the default method. Returns ------- dict Dictionary of best parameter values in original space. Raises ------ ValueError If the method is not one of 'median', 'mean', or 'max_likelihood'. """ # Check if optimization has been run if self.flat_samples is None: print("Optimization has not been run yet.") return None if method == 'median': best_opt_params = np.median(self.flat_samples, axis=0) elif method == 'mean': best_opt_params = np.mean(self.flat_samples, axis=0) elif method == 'max_likelihood': max_prob_index = np.argmax(self.log_prob_samples) best_opt_params = self.flat_samples[max_prob_index] else: raise ValueError("Method must be 'median', 'mean', or 'max_likelihood'") # Reconstruct to original parameter space best_params_dict = self.reconstruct_params(best_opt_params) return best_params_dict
[docs] def update_params_with_best_balance(self, method='max_likelihood', return_best_balance=False): """Update the parameters with the best balance based on MCMC results. This method updates the parameters in self.params with the best values determined by the specified method. It can also return the best parameters dictionary if requested. Parameters ---------- method : str, optional method to determine 'best' params ('median', 'mean', 'max_likelihood'), by default 'max_likelihood' return_best_balance : bool, optional If True, return the best parameters dictionary, by default False """ if self.results is None: raise ValueError("Optimization has not run or results not processed.") best_params_dict = self.get_best_params(method=method) # Update the FitParam objects in self.params for param in self.params: if param.name in best_params_dict: param.value = best_params_dict[param.name] if return_best_balance: return best_params_dict # Return the dictionary used for updating
[docs] def get_chain(self, **kwargs): """Return the MCMC chain. kwargs passed to sampler.get_chain()""" if self.sampler: return self.sampler.get_chain(**kwargs) return None
[docs] def get_flat_samples(self): """Return the flattened samples after burn-in and thinning.""" return self.flat_samples
# Add plotting methods if desired (e.g., corner plots, walker traces)
[docs] def plot_corner(self, **kwargs): """Generate a corner plot of the posterior distribution.""" title_fmt = kwargs.get('title_fmt', ".4e") if self.flat_samples is None: print("Optimization has not been run yet.") return None # Get samples in original parameter space for plotting samples_orig = [] labels_orig = [] truths_orig = kwargs.get('True_params', None) # Dictionary to hold true values for parameters for i, name in enumerate(self.param_mapping): param_samples_opt = self.flat_samples[:, i] original_param = next(p for p in self.params if p.name == name) labels_orig.append(original_param.display_name if hasattr(original_param,'display_name') else name) if i in self.log_params_indices: samples_orig.append(10**param_samples_opt) else: scale_factor = original_param.fscale if hasattr(original_param, 'fscale') and original_param.fscale is not None else 1.0 samples_orig.append(param_samples_opt * scale_factor) samples_orig_array = np.vstack(samples_orig).T # Prepare truths list in the correct order for corner if truths_orig is None: truths_list = [None] * len(self.param_mapping) else: truths_list = [truths_orig.get(name, None) for name in self.param_mapping] # Default corner plot settings corner_kwargs = { 'labels': labels_orig, 'show_titles': True, 'title_kwargs': {"fontsize": 10}, 'quantiles': [0.16, 0.5, 0.84], 'truths': truths_list, 'truth_color': 'red', 'color': 'darkblue', 'hist2d_kwargs': { 'cmap': plt.get_cmap('Blues'), }, 'hist_kwargs': { 'color': 'darkblue', }, } corner_kwargs.update(kwargs) # Allow user to override defaults params_axis_type = [] for param in self.params: if hasattr(param, 'axis_type'): params_axis_type.append(param.axis_type) else: params_axis_type.append('linear') fig = corner.corner(samples_orig_array, axes_scale=params_axis_type,title_fmt=title_fmt,**corner_kwargs) return fig
[docs] def plot_traces(self, **kwargs): """Plot the MCMC traces for each parameter.""" import matplotlib.pyplot as plt if self.chain is None: print("Optimization has not been run yet or chain not stored.") return None n_steps, n_walkers, n_dim = self.chain.shape labels = [p.display_name if hasattr(p,'display_name') else p.name for p in self.params if p.name in self.param_mapping] fig, axes = plt.subplots(n_dim, figsize=(10, 2 * n_dim), sharex=True) if n_dim == 1: # Handle case with only one parameter axes = [axes] for i,param in enumerate(self.params): ax = axes[i] # Plot traces for all walkers if param.force_log: # If log10 transformed, plot in original space ax.plot(10**self.chain[:, :, i], "k", alpha=0.2) ax.set_yscale('log') else: ax.plot(self.chain[:, :, i], "k", alpha=0.2) ax.set_xlim(0, n_steps) ax.set_ylabel(labels[i]) ax.yaxis.set_label_coords(-0.1, 0.5) # Add burn-in line ax.axvline(self.burn_in, color='blue', linestyle='--', lw=1, label=f'Burn-in ({self.burn_in})') if i == 0: ax.legend(loc='upper right') axes[-1].set_xlabel("Step number") plt.tight_layout() return fig