"""Module containing classes and functions for posterior analysis of parameters using the ML models from the BO optimization.
This module provides functionality to visualize the posterior distributions of parameters
using various plots, including 1D and 2D posteriors, devil's plots, and density plots."""
######### Package Imports #########################################################################
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import itertools
from scipy.special import logsumexp
from itertools import combinations
import ax
from ax import *
from ax.core.observation import ObservationFeatures
# from ax.core.base_trial import TrialStatus as T
from optimpv.general.general import inv_loss_function
from optimpv.axBOtorch.axUtils import get_df_from_ax
######### Function Definitions ####################################################################
[docs]
def get_MSE_grid(params, Nres, objective_name, model, loss, optimizer_type = 'ax'):
""" Calculate the Mean Squared Error (MSE) grid for the given parameters and model.
Parameters
----------
params : list of FitParam() objects
List of parameters to explore.
Nres : int
Number of points to evaluate for each parameter.
objective_name : str
Name of the objective to evaluate.
model : torch model
Model to evaluate the objective.
loss : str
Loss function used, see optimpv/general/general.py for available loss functions.
optimizer_type : str, optional
Type of optimizer used, by default 'ax'
Returns
-------
np.ndarray
Grid of MSE values.
Raises
------
ValueError
If the optimizer type is not recognized.
"""
dims_GP, dims = [], []
dic_fixed = {}
for idx, param in enumerate(params):
if param.type != 'fixed':
if param.axis_type == 'log':
if param.force_log:
parax = np.linspace(np.log10(param.bounds[0]),np.log10(param.bounds[1]),Nres)
else:
parax = np.logspace(np.log10(param.bounds[0]/param.fscale),np.log10(param.bounds[1]/param.fscale),Nres)
parax_rescaled = np.logspace(np.log10(param.bounds[0]),np.log10(param.bounds[1]),Nres)
else:
parax = np.linspace(param.bounds[0]/param.fscale,param.bounds[1]/param.fscale,Nres)
parax_rescaled = np.linspace(param.bounds[0],param.bounds[1],Nres)
dims_GP.append(parax)
dims.append(parax_rescaled)
else:
dic_fixed[param.name] = param.value
Xc = np.array(list(itertools.product(*dims_GP)))
mean_predictions = np.zeros(len(Xc))
observation_features = []
if optimizer_type.lower() == 'ax':
for i,line in enumerate(Xc):
dum_dic = {}
for idx, param in enumerate(params):
dum_dic[param.name] = line[idx]
dum_dic.update(dic_fixed)
observation_features.append(ObservationFeatures(parameters=dum_dic))
predictions = model.predict(observation_features)
mean_predictions = np.array(predictions[1][objective_name][objective_name])
else:
raise ValueError('Optimizer type not recognized')
# invert the loss
mean_predictions = inv_loss_function(mean_predictions, loss)
return mean_predictions.reshape(*[Nres for i in range(len(params))]) , dims_GP, dims
# grid_MSE, dims_GP = get_MSE_grid(params, Nres, objective_name, model, loss, optimizer_type = 'ax')
[docs]
def calculate_1d_posteriors(mse_array):
"""Calculate 1D posterior distributions over each parameter axis from an n x n-dimensional MSE array.
Parameters
----------
mse_array : np.ndarray
An n-dimensional array of MSE values.
Returns
-------
list
A list of 1D posterior distributions for each parameter.
"""
# Convert MSE to negative log-likelihood
negative_log_likelihood = -0.5 * mse_array
# Compute the log posterior by normalizing using logsumexp
log_posterior = negative_log_likelihood - logsumexp(negative_log_likelihood)
# Calculate marginal posteriors for each parameter by summing over all other axes
marginal_posteriors = []
for axis in range(mse_array.ndim):
# Use logsumexp to marginalize efficiently
marginal_log = logsumexp(log_posterior, axis=tuple(i for i in range(mse_array.ndim) if i != axis))
marginal_posteriors.append(np.exp(marginal_log))
return marginal_posteriors
[docs]
def calculate_2d_posteriors(mse_array):
"""Calculate 2D posterior distributions over each pair of parameter axes from an n x n-dimensional MSE array.
Parameters
----------
mse_array : np.ndarray
An n-dimensional array of MSE values.
Returns
-------
list
A list of 2D posterior distributions for each pair of parameters.
"""
# Convert MSE to negative log-likelihood
negative_log_likelihood = -0.5 * mse_array
# Compute the log posterior by normalizing using logsumexp
log_posterior = negative_log_likelihood - logsumexp(negative_log_likelihood)
# Calculate pairwise marginal posteriors for each pair of parameters
pairwise_posteriors = []
ndim = mse_array.ndim
for axis1 in range(ndim):
for axis2 in range(axis1 + 1, ndim):
# Marginalize over all other axes except axis1 and axis2
marginal_log = logsumexp(log_posterior, axis=tuple(i for i in range(ndim) if i != axis1 and i != axis2))
pairwise_posteriors.append(np.exp(marginal_log))
return pairwise_posteriors
[docs]
def devils_plot(params, Nres, objective_name, model, loss, best_parameters = None, params_orig = None, grid_MSE = None, dims_GP = None, optimizer_type = 'ax', **kwargs):
"""Generate a devil's plot to visualize the posterior distributions of parameters.
Parameters
----------
params : list of FitParam() objects
List of parameters to explore.
Nres : int
Number of points to evaluate for each parameter.
objective_name : str
Name of the objective to evaluate.
model : torch model
Model to evaluate the objective.
loss : str
Loss function used, see optimpv/general/general.py for available loss functions.
best_parameters : dict, optional
Dictionary of the best parameters, by default None.
params_orig : dict, optional
Dictionary of the original parameters, by default None.
optimizer_type : str, optional
Type of optimizer used, by default 'ax'.
**kwargs : dict
Additional keyword arguments.
Returns
-------
matplotlib.figure.Figure
The figure object containing the plot.
matplotlib.axes._subplots.AxesSubplot
The axes object containing the plot.
"""
fig_size = kwargs.get('fig_size', (15, 15))
marker_size = kwargs.get('marker_size', 200)
if grid_MSE is None or dims_GP is None:
grid_MSE, dims_GP, dims = get_MSE_grid(params, Nres, objective_name, model, loss, optimizer_type = optimizer_type)
marginal_posteriors = calculate_1d_posteriors(grid_MSE)
pairwise_posteriors = calculate_2d_posteriors(grid_MSE)
n = len(params)
names = [ param.name for param in params if param.type != 'fixed']
comb = list(itertools.combinations(names, 2))
dims_GP = dims
fig, ax = plt.subplots(n, n, figsize=fig_size)
for i in range(n):
for j in range(n):
if i == j:
ax[i, j].plot(dims_GP[j], marginal_posteriors[j])
if params_orig is not None:
ax[i, j].axvline(x=params_orig[params[j].name], color='k', linestyle='-')
if best_parameters is not None:
ax[i, j].axvline(x=best_parameters[params[j].name], color='tab:red', linestyle='--')
if params[j].axis_type == 'log':
ax[i, j].set_xscale('log')
ax[i, j].set_xlabel(params[j].display_name + ' [' +params[j].unit+']')
ax[i, j].set_ylabel("Posterior probability")
elif i > j:
ax[i, j].contourf(dims_GP[j], dims_GP[i], pairwise_posteriors[comb.index((params[j].name, params[i].name))].reshape(Nres, Nres).T)
if params_orig is not None:
ax[i,j].axhline(y=params_orig[params[i].name], color='k', linestyle='-')
ax[i,j].axvline(x=params_orig[params[j].name], color='k', linestyle='-')
# ax[i, j].scatter(params_orig[params[j].name], params_orig[params[i].name], c='tab:red', marker='*', s=marker_size, zorder=10)
if best_parameters is not None:
print(best_parameters[params[i].name], best_parameters[params[j].name])
ax[i,j].axhline(y=best_parameters[params[i].name], color='tab:red', linestyle='--')
ax[i,j].axvline(x=best_parameters[params[j].name], color='tab:red', linestyle='--')
# ax[i, j].scatter(best_parameters[params[j].name], best_parameters[params[i].name], c='tab:orange', marker='*', s=marker_size, zorder=10)
if params[j].axis_type == 'log':
ax[i, j].set_xscale('log')
if params[i].axis_type == 'log':
ax[i, j].set_yscale('log')
ax[i, j].set_xlabel(params[j].display_name + ' [' +params[j].unit+']')
ax[i, j].set_ylabel(params[i].display_name + ' [' +params[i].unit+']')
else:
ax[i, j].set_visible(False)
#xlim
ax[i,j].set_xlim(params[j].bounds[0], params[j].bounds[1])
#ylim
if i != j:
ax[i,j].set_ylim(params[i].bounds[0], params[i].bounds[1])
if j > 0:
if i != j:
ax[i, j].set_yticklabels([])
ax[i, j].set_yticklabels([],minor=True)
# remove the y axis label
ax[i, j].set_ylabel('')
if i < n - 1:
ax[i, j].set_xticklabels([])
ax[i, j].set_xticklabels([],minor=True)
# remove the x axis label
ax[i, j].set_xlabel('')
if i == n - 1:
ax[i, j].set_xlabel(params[j].display_name + ' [' +params[j].unit+']')
# rotate x axis label
ax[i, j].tick_params(axis='x', rotation=45, which='both')
if j != 0:
ax[i, j].set_yticklabels([])
ax[i, j].set_yticklabels([],minor=True)
ax[i, j].set_ylabel('')
if j == 0:
ax[i, j].set_ylabel(params[i].display_name + ' [' +params[i].unit+']')
if i == j:
# if i == 0:
ax[i, j].set_title(params[i].display_name + ' [' +params[i].unit+']')
# remove y axis label
ax[i, j].set_ylabel('P('+params[i].display_name + '|Data)')
# put y tick labels on the right, only move the label
ax[i, j].yaxis.set_label_position('right')
ax[i, j].yaxis.tick_right()
ax[i, j].yaxis.set_tick_params(which='both', direction='in', left=True, right=True)
ax[i, j].tick_params(axis='y', labelleft=False, labelright=True)
# ax[i, j].spines['right'].set_visible(True)
# ax[i, j].spines['left'].set_visible(True)
#custim legend
# add star for the original parameters
legend_elements = []
if params_orig is not None:
legend_elements.append(plt.Line2D([0], [0], color='k', label='Original parameters', linestyle='-'))
if best_parameters is not None:
legend_elements.append(plt.Line2D([0], [0], color='tab:red', label='Best parameters', linestyle='--'))
if len(legend_elements) > 0:
fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(0.9, 0.5), ncol=1)
# change spacing between subplots
plt.tight_layout()
fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1, hspace=0.2, wspace=0.2)
return fig, ax
[docs]
def plot_1d_posteriors(params, Nres, objective_name, model, loss, best_parameters= None, params_orig = None, optimizer_type = 'ax',**kwargs):
"""Generate 1D posterior plots for each parameter.
Parameters
----------
params : list of FitParam() objects
List of parameters to explore.
Nres : int
Number of points to evaluate for each parameter.
objective_name : str
Name of the objective to evaluate.
model : torch model
Model to evaluate the objective.
loss : str
Loss function used, see optimpv/general/general.py for available loss functions.
best_parameters : dict, optional
Dictionary of the best parameters, by default None.
params_orig : dict, optional
Dictionary of the original parameters, by default None.
optimizer_type : str, optional
Type of optimizer used, by default 'ax'.
**kwargs : dict
Additional keyword arguments.
Returns
-------
matplotlib.figure.Figure
The figure object containing the plot.
matplotlib.axes._subplots.AxesSubplot
The axes object containing the plot.
"""
n = kwargs.get('n', len(params))
l = kwargs.get('l', 1)
ylim = kwargs.get('ylim', None)
fig_size = kwargs.get('fig_size', (16, 9))
if n > 3:
l = int(np.ceil(n/3))
grid_MSE, dims_GP, dims = get_MSE_grid(params, Nres, objective_name, model, loss, optimizer_type = optimizer_type)
marginal_posteriors = calculate_1d_posteriors(grid_MSE)
dims_GP = dims
fig, axes = plt.subplots(l, 3, figsize=fig_size)
for i, ax in enumerate(axes.flat):
if i >= n:
ax.set_visible(False)
continue
ax.plot(dims_GP[i], marginal_posteriors[i])
if params_orig is not None:
ax.axvline(x=params_orig[params[i].name], color='k', linestyle='-')
if best_parameters is not None:
ax.axvline(x=best_parameters[params[i].name], color='tab:red', linestyle='--')
if params[i].axis_type == 'log':
ax.set_xscale('log')
ax.set_xlabel(params[i].display_name + ' [' +params[i].unit+']')
ax.set_ylabel('P('+params[i].display_name + '|Data)')
if ylim is not None:
ax.set_ylim(ylim)
legend_elements = []
if params_orig is not None:
legend_elements.append(plt.Line2D([0], [0], color='k', label='Original parameters', linestyle='-'))
if best_parameters is not None:
legend_elements.append(plt.Line2D([0], [0], color='tab:red', label='Best parameters', linestyle='--'))
if len(legend_elements) > 0:
# put in the center
fig.legend(handles=legend_elements, loc='center', bbox_to_anchor=(0.5, 1.02), ncol=2)
plt.tight_layout()
return fig, ax
# def get_df_from_ax(params, optimizer):
# """Get the dataframe from the ax client and rescale the parameters to their true scale.
# The dataframe contains the parameters and the objective values.
# The parameters are rescaled to their true scale.
# The objective values are the mean of the objective values.
# The dataframe is returned as a pandas dataframe.
# Parameters
# ----------
# params : list of FitParam() objects
# List of Fitparam() objects.
# optimizer : object
# Optimizer object from optimpv.axBOtorch.axBOtorch
# The optimizer object contains the ax client and the experiment.
# Returns
# -------
# pd.DataFrame
# Dataframe containing the parameters and the objective values.
# Raises
# ------
# ValueError
# trying to rescale a parameter that is not int or float
# """
# ax_client = optimizer.ax_client
# objective_names = optimizer.all_metrics
# df = get_df_ax_client_metrics(params, ax_client, objective_names)
# return df
# def get_df_ax_client_metrics(params, ax_client, all_metrics):
# """Get the dataframe from the ax client and rescale the parameters to their true scale.
# The dataframe contains the parameters and the objective values.
# The parameters are rescaled to their true scale.
# The objective values are the mean of the objective values.
# The dataframe is returned as a pandas dataframe.
# Parameters
# ----------
# params : list of FitParam() objects
# List of Fitparam() objects.
# ax_client : object
# Ax client object.
# all_metrics : list of str
# List of objective names.
# Returns
# -------
# pd.DataFrame
# Dataframe containing the parameters and the objective values.
# Raises
# ------
# ValueError
# trying to rescale a parameter that is not int or float
# """
# data = ax_client.experiment.fetch_data().df
# objective_names = all_metrics
# dumdic = {}
# # create a dic with the keys of the parameters
# if isinstance(ax_client.experiment.trials[0], BatchTrial):# check if trial is a BatchTrial
# for key in ax_client.experiment.trials[0].arms[0].parameters.keys():
# dumdic[key] = []
# # fill the dic with the values of the parameters
# for i in range(len(ax_client.experiment.trials)):
# if ax_client.experiment.trials[i].status == T.COMPLETED:
# for arm in ax_client.experiment.trials[i].arms:
# if arm.name in data['arm_name'].values: # only add the arm if it is in the data i.e. if it was completed
# for key in arm.parameters.keys():
# dumdic[key].append(arm.parameters[key])
# else:
# for key in ax_client.experiment.trials[0].arm.parameters.keys():
# dumdic[key] = []
# # fill the dic with the values of the parameters
# for i in range(len(ax_client.experiment.trials)):
# if ax_client.experiment.trials[i].status == T.COMPLETED:
# for key in ax_client.experiment.trials[i].arm.parameters.keys():
# dumdic[key].append(ax_client.experiment.trials[i].arm.parameters[key])
# for objective_name in objective_names:
# dumdic[objective_name] = list(data[data['metric_name'] == objective_name]['mean'])
# dumdic['iteration'] = list(data[data['metric_name'] == objective_name]['trial_index'])
# df = pd.DataFrame(dumdic)
# # add iteration column with
# for par in params:
# if par.name in df.columns:
# if par.rescale or par.force_log:
# if par.value_type == 'int':
# df[par.name] = df[par.name] * par.stepsize
# elif par.value_type == 'float':
# if par.force_log:
# df[par.name] = 10 ** df[par.name]
# else:
# df[par.name] = df[par.name] * par.fscale
# else:
# raise ValueError('Trying to rescale a parameter that is not int or float')
# return df
[docs]
def plot_density_exploration(params, optimizer = None, best_parameters = None, params_orig = None, optimizer_type = 'ax', **kwargs):
"""Generate density plots to visualize the exploration of parameter space.
Parameters
----------
params : list of FitParam() objects
List of parameters to explore.
optimizer : object, optional
Optimizer object, by default None.
best_parameters : dict, optional
Dictionary of the best parameters, by default None.
params_orig : dict, optional
Dictionary of the original parameters, by default None.
optimizer_type : str, optional
Type of optimizer used, by default 'ax'.
**kwargs : dict
Additional keyword arguments.
Returns
-------
matplotlib.figure.Figure
The figure object containing the plot.
matplotlib.axes._subplots.AxesSubplot
The axes object containing the plot.
Raises
------
ValueError
If the optimizer type is not supported.
"""
fig_size = kwargs.get('fig_size', (15, 15))
levels = kwargs.get('levels', 100)
if optimizer_type == 'ax':
df = get_df_from_ax(params, optimizer)
elif optimizer_type == 'pymoo':
resall = optimizer.all_evaluations
dum_dic = {}
for key in resall[0]['params'].keys():
dum_dic[key] = []
# for key in resall[0]['results'].keys():
# dum_dic[key] = []
for i in range(len(resall)):
for key in resall[i]['params'].keys():
dum_dic[key].append(resall[i]['params'][key])
# for key in resall[i]['results'].keys():
# dum_dic[key].append(resall[i]['results'][key])
df = pd.DataFrame(dum_dic)
else:
raise ValueError('This optimizer type is not supported')
names = []
display_names = []
log_scale = []
axis_limits = []
for p in params:
if p.type != 'fixed':
names.append(p.name)
display_names.append(p.display_name + ' [' + p.unit + ']')
log_scale.append(p.axis_type == 'log')
axis_limits.append(p.bounds)
# Get all combinations of names
comb = list(combinations(names, 2))
# Determine the grid size
n = len(names)
fig, axes = plt.subplots(n, n, figsize=fig_size)
# Plot each combination in the grid
for i, xx in enumerate(names):
for j, yy in enumerate(names):
xval = np.nan
yval = np.nan
if params_orig is not None:
xval = params_orig[xx]
yval = params_orig[yy]
ax = axes[i, j]
if i == j:
# kde plot on the diagonal
try:
sns.kdeplot(x=yy, data=df, ax=ax, fill=True, thresh=0, levels=levels, cmap="rocket", color="#03051A", log_scale=log_scale[names.index(xx)])
except:
# hystogram if kdeplot fails
sns.histplot(x=yy, data=df, ax=ax, color="#03051A", log_scale=log_scale[names.index(xx)])
if params_orig is not None:
ax.axvline(x=yval, color='yellow', linestyle='-')
if best_parameters is not None:
ax.axvline(x=best_parameters[yy], color='r', linestyle='--')
# put point at the best value top of the axis
if log_scale[names.index(yy)]:
ax.set_xscale('log')
ax.set_xlim(axis_limits[names.index(yy)])
else:
ax.set_xlim(axis_limits[names.index(yy)])
# put x label on the top
# except for the last one
if i < n - 1:
ax.xaxis.set_label_position('top')
ax.xaxis.tick_top()
elif i > j:
kind = 'kde'
if kind == 'scatter':
sns.scatterplot(x=yy, y=xx, data=df, ax=ax, color="#03051A")
ax.set_xscale('log')
ax.set_yscale('log')
else:
try:
sns.kdeplot(x=yy, y=xx, data=df, ax=ax, fill=True, thresh=0, levels=levels, cmap="rocket", color="#03051A", log_scale=(log_scale[names.index(yy)], log_scale[names.index(xx)]))
except Exception as e:
print(f"Error in kdeplot: {e}")
sns.scatterplot(x=yy, y=xx, data=df, ax=ax, color="#03051A")
# Plot as line over the full axis
if params_orig is not None:
ax.axhline(y=params_orig[xx], color='yellow', linestyle='-')
ax.axvline(x=params_orig[yy], color='yellow', linestyle='-')
if best_parameters is not None:
ax.axhline(y=best_parameters[xx], color='r', linestyle='--')
ax.axvline(x=best_parameters[yy], color='r', linestyle='--')
ax.set_xlim(axis_limits[names.index(yy)])
ax.set_ylim(axis_limits[names.index(xx)])
else:
ax.set_visible(False)
if j > 0:
if i != j:
ax.set_yticklabels([])
ax.set_yticklabels([],minor=True)
# remove the y axis label
ax.set_ylabel('')
if i < n - 1:
ax.set_xticklabels([])
ax.set_xticklabels([],minor=True)
# remove the x axis label
ax.set_xlabel('')
if i == n - 1:
ax.set_xlabel(display_names[j])
# for p in params:
# if p.name == yy:
# ax.set_xlabel(p.display_name + ' [' + p.unit + ']')
ax.tick_params(axis='x', rotation=45, which='both')
if j == 0:
ax.set_ylabel(display_names[i])
# for p in params:
# if p.name == xx:
# ax.set_ylabel(p.display_name + ' [' + p.unit + ']')
if i == j:
ax.set_title(display_names[i])
# ax.set_title(params[i].display_name + ' [' +params[i].unit+']')
ax.set_ylabel('Density')
ax.yaxis.set_label_position('right')
ax.yaxis.tick_right()
ax.yaxis.set_tick_params(which='both', direction='in', left=True, right=True)
ax.tick_params(axis='y', labelleft=False, labelright=True)
#custom legend
legend_elements = []
if params_orig is not None:
legend_elements.append(plt.Line2D([0], [0], color='yellow', label='Original parameters', linestyle='-'))
if best_parameters is not None:
legend_elements.append(plt.Line2D([0], [0], color='r', label='Best parameters', linestyle='--'))
if len(legend_elements) > 0:
fig.legend(handles=legend_elements, loc='center right', bbox_to_anchor=(0.9, 0.5), ncol=1)
# change spacing between subplots
plt.tight_layout()
fig.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1, hspace=0.2, wspace=0.2)
return fig, axes
[docs]
def plot_1D_2D_posterior(params, param_x, param_y, Nres, objective_name, model, loss, best_parameters=None, params_orig=None, optimizer_type='ax', **kwargs):
"""Generate a combined 2D and 1D posterior plot for a specific combination of 2 parameters.
Parameters
----------
params : list of FitParam() objects
List of parameters to explore.
param_x : str
Name of the parameter to plot on the x-axis.
param_y : str
Name of the parameter to plot on the y-axis.
Nres : int
Number of points to evaluate for each parameter.
objective_name : str
Name of the objective to evaluate.
model : torch model
Model to evaluate the objective.
loss : str
Loss function used, see optimpv/general/general.py for available loss functions.
best_parameters : dict, optional
Dictionary of the best parameters, by default None.
params_orig : dict, optional
Dictionary of the original parameters, by default None.
optimizer_type : str, optional
Type of optimizer used, by default 'ax'.
**kwargs : dict
Additional keyword arguments.
Returns
-------
matplotlib.figure.Figure
The figure object containing the plot.
matplotlib.axes._subplots.AxesSubplot
The axes object containing the plot.
"""
fig_size = kwargs.get('fig_size', (12, 12))
marker_size = kwargs.get('marker_size', 200)
levels = kwargs.get('levels', Nres)
grid_MSE, dims_GP, dims = get_MSE_grid(params, Nres, objective_name, model, loss, optimizer_type=optimizer_type)
marginal_posteriors = calculate_1d_posteriors(grid_MSE)
pairwise_posteriors = calculate_2d_posteriors(grid_MSE)
param_x_idx = [i for i, param in enumerate(params) if param.name == param_x][0]
param_y_idx = [i for i, param in enumerate(params) if param.name == param_y][0]
dims_GP = dims
fig, ax = plt.subplots(2, 2, figsize=fig_size, gridspec_kw={'height_ratios': [1, 4], 'width_ratios': [4, 1], 'hspace': 0.05, 'wspace': 0.05})
# 2D posterior plot
ax[1, 0].contourf(dims_GP[param_x_idx], dims_GP[param_y_idx], pairwise_posteriors[param_x_idx * (len(params) - 1) + param_y_idx - 1].reshape(Nres, Nres).T,levels=levels)
# set x and y limits
ax[1, 0].set_xlim([params[param_x_idx].bounds[0], params[param_x_idx].bounds[1]])
ax[1, 0].set_ylim([params[param_y_idx].bounds[0], params[param_y_idx].bounds[1]])
if params_orig is not None:
ax[1,0].axhline(y=params_orig[param_y], color='tab:red', linestyle='-')
ax[1,0].axvline(x=params_orig[param_x], color='tab:red', linestyle='-')
if best_parameters is not None:
ax[1,0].axhline(y=best_parameters[param_y], color='tab:orange', linestyle='--')
ax[1,0].axvline(x=best_parameters[param_x], color='tab:orange', linestyle='--')
# ax[1, 0].scatter(best_parameters[param_x], best_parameters[param_y], c='tab:orange', marker='*', s=marker_size, zorder=10)
if params[param_x_idx].axis_type == 'log':
ax[1, 0].set_xscale('log')
if params[param_y_idx].axis_type == 'log':
ax[1, 0].set_yscale('log')
ax[1, 0].set_xlabel(params[param_x_idx].display_name + ' [' + params[param_x_idx].unit + ']')
ax[1, 0].set_ylabel(params[param_y_idx].display_name + ' [' + params[param_y_idx].unit + ']')
# 1D posterior plot for param_x
ax[0, 0].plot(dims_GP[param_x_idx], marginal_posteriors[param_x_idx])
# rotate the x-axis labels
ax[0, 0].tick_params(axis='x', rotation=45, which='both')
if params_orig is not None:
ax[0, 0].axvline(x=params_orig[param_x], color='tab:red', linestyle='-')
if best_parameters is not None:
ax[0, 0].axvline(x=best_parameters[param_x], color='tab:orange', linestyle='--')
if params[param_x_idx].axis_type == 'log':
ax[0, 0].set_xscale('log')
# set x lim
ax[0, 0].set_xlim([params[param_x_idx].bounds[0], params[param_x_idx].bounds[1]])
ax[0, 0].set_xticklabels([])
ax[0, 0].set_xticklabels([],minor=True)
ax[0, 0].set_ylabel('P('+params[param_x_idx].display_name + '|Data)')
# 1D posterior plot for param_y
ax[1, 1].plot(marginal_posteriors[param_y_idx], dims_GP[param_y_idx])
# set y lim
ax[1, 1].set_ylim([params[param_y_idx].bounds[0], params[param_y_idx].bounds[1]])
if params_orig is not None:
ax[1, 1].axhline(y=params_orig[param_y], color='tab:red', linestyle='-')
if best_parameters is not None:
ax[1, 1].axhline(y=best_parameters[param_y], color='tab:orange', linestyle='--')
if params[param_y_idx].axis_type == 'log':
ax[1, 1].set_yscale('log')
ax[1, 1].set_yticklabels([])
ax[1, 1].set_yticklabels([],minor=True)
ax[1, 1].set_xlabel('P('+params[param_y_idx].display_name + '|Data)')
# rotate the x-axis labels
ax[1, 1].tick_params(axis='x', rotation=45, which='both')
ax[0, 1].axis('off')
# legend customisation
legend_elements = []
if params_orig is not None:
legend_elements.append(plt.Line2D([0], [0], color='tab:red', linestyle='-', label='Original value'))
if best_parameters is not None:
legend_elements.append(plt.Line2D([0], [0], color='tab:orange', linestyle='--', label='Best value'))
fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.95, 0.85))
plt.tight_layout()
return fig, ax