""" This module provides several useful utilities: - Use central differencing to calculate the curvature of a function at a point in a large number of dimensions (the approximate Hessian). - Perform approximate inverse transform sampling in multiple dimensions (does not scale well in dimensionality). - Creating scatterplot matrixes and simliar array of plots for contours of marginals when provided a gridded distribution. Created in June-Oct. 2019, author: Sean T. Smith """ from numpy import empty, zeros, linspace, meshgrid, interp, searchsorted, sqrt from numpy.random import rand, normal import matplotlib.pyplot as plt def curvature(func, x0, *args, **kwargs): # Approximate the curvature of the -ln(posterior) at the mode: n = x0.shape[0] δi, δj = zeros(n), zeros(n) Σinv = empty((n, n)) fmid = func(x0, *args, **kwargs) for i in range(n): δi[i] = max(1e-6 * abs(x0[i]), 1e-12) # step size for finite diff. fplus = func(x0 + δi, *args, **kwargs) fminus = func(x0 - δi, *args, **kwargs) Σinv[i, i] = (fplus - 2 * fmid + fminus) / δi[i]**2 for j in range(i): δj[j] = max(1e-6 * abs(x0[j]), 1e-12) # step for the mixed diff. fpp = func(x0 + δi + δj, *args, **kwargs) fpm = func(x0 + δi - δj, *args, **kwargs) fmp = func(x0 - δi + δj, *args, **kwargs) fmm = func(x0 - δi - δj, *args, **kwargs) Σinv[i, j] = Σinv[j, i] = (fpp-fpm-fmp+fmm) / (4 * δi[i] * δj[j]) δj[j] = 0 δi[i] = 0 return Σinv def inverse_transform(pdf, x_grid, U=None, ns=100, fast=False): n_dim = pdf.ndim if U is None: U = rand(ns, n_dim) else: ns = U.shape[0] # Calculate the marginal for the 1st dimension: marg_x0 = pdf.copy() for i in range(1, n_dim): shape = (1, -1,) + (1,) * (n_dim - (i + 1)) Δxi = (x_grid[i][1:] - x_grid[i][:-1]).reshape(shape) # trapezoid rule... marg_x0 = 0.5 * (Δxi*marg_x0[:,:-1] + Δxi*marg_x0[:,1:]).sum(axis=1) # Calculate the cumulative across the 1st dimension: Δx0 = x_grid[0][+1:] - x_grid[0][:-1] cum_x0 = empty(pdf.shape[0]) cum_x0[0] = 0 cum_x0[1:] = 0.5 * (Δx0 * marg_x0[:-1] + Δx0 * marg_x0[1:]).cumsum() cum_x0 /= cum_x0[-1] # Perform inverse transform sampling on the marginal: X = empty((ns, n_dim)) X[:, 0] = interp(U[:, 0], cum_x0, x_grid[0]) if n_dim > 1: # TODO: Optionally parallelize this loop in dask. for i in range(ns): # Condition on sample: ind = searchsorted(x_grid[0], X[i, 0]) α = ((X[i, 0] - x_grid[0][ind-1]) / (x_grid[0][ind] - x_grid[0][ind-1])) # incorrect when ind==0 if fast or ind == 0: # Nearest neighbor interpolation: if α <= 0.5 and ind > 0: cond_pdf = pdf[ind-1] else: cond_pdf = pdf[ind] else: # Linear interpolation: cond_pdf = (1 - α) * pdf[ind - 1] + α * pdf[ind] # This is the bottleneck for high-dims. with many samples. # Recurse: X[i, 1:] = inverse_transform(cond_pdf, x_grid[1:], U[i:i+1, 1:]) return X def scatterplot_matrix(x, labels, ax_label_font=14, plot_type='scatter', fig_options={}, marginal_options={}, joint_options={}): ndim, nsamples = x.shape if type(fig_options) is tuple: fig, axes = fig_options else: fig, axes = plt.subplots(ndim, ndim, sharex='col', sharey='row', gridspec_kw=dict(wspace=0, hspace=0), **fig_options) # Row & column formatting for i in range(ndim): axes[i][0].set_ylabel(labels[i], fontsize=ax_label_font) axes[i][0].set_ylim([x[i].min(), x[i].max()]) for j in range(ndim): axes[-1][j].set_xlabel(labels[j], fontsize=ax_label_font) axes[-1][j].set_xlim([x[j].min(), x[j].max()]) # Remove unwanted frames & ticks from the upper triangle for i in range(ndim-1): for j in range(i+1, ndim): axes[i][j].spines['top'].set_visible(False) axes[i][j].spines['bottom'].set_visible(False) axes[i][j].spines['left'].set_visible(False) axes[i][j].spines['right'].set_visible(False) axes[i][j].tick_params(axis='both', which='both', left=False, bottom=False) # Marginals nbins = max(min(nsamples // 75, 75), 10) for i in range(ndim): ax = axes[i][i].twinx() ax.hist(x[i], bins=nbins, density=True, **marginal_options) ax.set_ylim([0, None]) ax.get_yaxis().set_ticks([]) axes[0][0].tick_params(axis='y', which='both', left=False, right=False, labelleft=False) # Pairwise plots: nbins = max(min(int(sqrt(nsamples / 25)), 50), 10) for i in range(ndim): for j in range(i): ax = axes[i][j] if plot_type == 'scatter': ax.scatter(x[j], x[i], **joint_options) elif plot_type == 'hist': ax.hist2d(x[j], x[i], bins=nbins, **joint_options) elif plot_type == 'contour': xbins = linspace(x[j].min(), x[j].max(), nbins + 1) ybins = linspace(x[i].min(), x[i].max(), nbins + 1) freq, _, _, im = ax.hist2d(x[j], x[i], bins=[xbins, ybins]) X, Y = meshgrid(xbins[:-1], ybins[:-1], indexing='xy') ax.contour(X, Y, freq.T, **joint_options) im.set_visible(False) return fig, axes def contour_matrix(pdf, x_grids, labels, ax_label_font=14, fig_options={}, marginal_options={}, joint_options={}): ndim = len(labels) if type(fig_options) is tuple: fig, axes = fig_options else: fig, axes = plt.subplots(ndim, ndim, sharex='col', sharey='row', gridspec_kw=dict(wspace=0, hspace=0), **fig_options) # Row & column formatting for i in range(ndim): axes[i][0].set_ylabel(labels[i], fontsize=ax_label_font) axes[i][0].set_ylim([x_grids[i][0], x_grids[i][-1]]) for j in range(ndim): axes[-1][j].set_xlabel(labels[j], fontsize=ax_label_font) axes[-1][j].set_xlim([x_grids[j][0], x_grids[j][-1]]) # Remove unwanted frames & ticks from the upper triangle for i in range(ndim-1): for j in range(i+1, ndim): axes[i][j].spines['top'].set_visible(False) axes[i][j].spines['bottom'].set_visible(False) axes[i][j].spines['left'].set_visible(False) axes[i][j].spines['right'].set_visible(False) axes[i][j].tick_params(axis='both', which='both', left=False, bottom=False) # Marginals for i in range(ndim): marginal = pdf.copy() for k in range(i): shape = (-1,) + (1,) * (ndim - (k + 1)) Δxk = (x_grids[k][1:] - x_grids[k][:-1]).reshape(shape) marginal = 0.5 * (Δxk * marginal[:-1] + Δxk * marginal[+1:]).sum(axis=0) for k in range(i + 1, ndim): shape = (1, -1) + (1,) * (ndim - (k + 1)) Δxk = (x_grids[k][1:] - x_grids[k][:-1]).reshape(shape) marginal = 0.5 * (Δxk * marginal[:, :-1] + Δxk * marginal[:, +1:]).sum(axis=1) ax = axes[i][i].twinx() ax.plot(x_grids[i], marginal, **marginal_options) ax.set_ylim([0, None]) ax.get_yaxis().set_ticks([]) axes[0][0].tick_params(axis='y', which='both', left=False, right=False, labelleft=False) # Pairwise plots: for i in range(ndim): for j in range(i): joint = pdf.copy() for k in range(j): shape = (-1,) + (1,) * (ndim - (k + 1)) Δxk = (x_grids[k][1:] - x_grids[k][:-1]).reshape(shape) joint = 0.5 * (Δxk * joint[:-1] + Δxk * joint[+1:]).sum(axis=0) for k in range(j + 1, i): shape = (1, -1) + (1,) * (ndim - (k + 1)) Δxk = (x_grids[k][1:] - x_grids[k][:-1]).reshape(shape) joint = 0.5 * (Δxk * joint[:, :-1] + Δxk * joint[:, +1:]).sum(axis=1) for k in range(i + 1, ndim): shape = (1, 1, -1) + (1,) * (ndim - (k + 1)) Δxk = (x_grids[k][1:] - x_grids[k][:-1]).reshape(shape) joint = 0.5 * (Δxk * joint[:, :, :-1] + Δxk * joint[:, :, +1:]).sum(axis=2) X1, X2 = meshgrid(x_grids[j], x_grids[i], indexing='xy') ax = axes[i][j] ax.contour(X1, X2, joint.T, **joint_options) return fig, axes