Using mcmcsampler: a demo
import numpy as np
import matplotlib.pyplot as plt
import pickle
import m3c2.proposal as proposal
from m3c2.sampler import Sampler
from m3c2.tools import get_autocorr_len, get_mcmc_stats
import multiprocessing as mp
mp.set_start_method('fork') #needed for MacOS
import warnings
warnings.filterwarnings('ignore')
Define a log-likelihood function
Here we take a simple multi dimensional Gaussian.
We use a class
to easily store all the meta data needed to evaluate the log-likelihood
class TestLogLik:
""" A multi dimensional Gaussian
"""
def __init__(self, ndim):
""" Init number of dimensions.
"""
self.ndim = ndim
self.param_dic = [f"p{i}" for i in range(ndim)]
means = np.random.rand(ndim)
cov = 0.5 - np.random.rand(ndim ** 2).reshape((ndim, ndim))
cov = np.triu(cov)
cov += cov.T - np.diag(cov.diagonal())
cov = np.dot(cov, cov)
self.mu = means
self.cov = cov
def loglik(self, x, **kwargs):
""" Return log-likelihood for a given point x.
"""
diff = x - self.mu
return -0.5 * np.dot(diff, np.linalg.solve(self.cov, diff))
def logpi(self, p, **kwargs):
""" Return log-prior for a given point x.
"""
return -7.0
Define the sampler
# Multi chains parameters
Nchains = 5
# Define likelihood, priors and starting point
ndim = 3
T = TestLogLik(ndim)
priors = np.array([-3,3]*T.ndim).reshape(T.ndim,2)
x0 = [np.random.randn(T.ndim) for n in range(Nchains)]
S = Sampler(Nchains, T.loglik, T.logpi, T.param_dic, profiling=True, kde=False)#True)
S.set_starting_point(x0)
Define the proposals
# Define proposals
SL = proposal.Slice(T.param_dic).slice
SC = proposal.SCAM(T.param_dic).SCAM
#KDE = proposal.AdaptiveKDE(T.param_dic).kde_jump
p_dict = [{SC:50, SL:50}]*Nchains
S.set_proposals(p_dict)
Run the sampler
# Run mcmc
niter = 150_000
c = S.run_mcmc(niter, printN=10_000, multiproc=False)
INFO:root:iter 0
INFO:root:chain 0
INFO:root:current loglik: -0.3, best: -0.3, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -261.2, best: -261.2, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -134.6, best: -134.6, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -251.0, best: -251.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -11.7, best: -11.7, temp: 1.0, ratio: 1.0
INFO:root:iter 10000
INFO:root:chain 0
INFO:root:current loglik: -1.5, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -2.4, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -0.6, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -1.4, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -0.6, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 20000
INFO:root:chain 0
INFO:root:current loglik: -2.4, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -0.4, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -1.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -1.1, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -1.1, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 30000
INFO:root:chain 0
INFO:root:current loglik: -1.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -1.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -1.0, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -0.3, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -2.4, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 40000
INFO:root:chain 0
INFO:root:current loglik: -0.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -0.9, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -4.8, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -0.1, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -1.1, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 50000
INFO:root:chain 0
INFO:root:current loglik: -0.5, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -1.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -2.3, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -1.6, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -2.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 60000
INFO:root:chain 0
INFO:root:current loglik: -2.1, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -8.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -1.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -2.6, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -1.6, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 70000
INFO:root:chain 0
INFO:root:current loglik: -1.7, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -1.4, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -1.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -0.7, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -2.3, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 80000
INFO:root:chain 0
INFO:root:current loglik: -1.3, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -3.6, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -1.8, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -0.9, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -0.3, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 90000
INFO:root:chain 0
INFO:root:current loglik: -2.0, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -0.8, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -5.1, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -0.4, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -1.8, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 100000
INFO:root:chain 0
INFO:root:current loglik: -4.6, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -0.8, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -0.9, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -0.3, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -1.8, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 110000
INFO:root:chain 0
INFO:root:current loglik: -3.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -5.0, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -0.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -2.7, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -1.3, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 120000
INFO:root:chain 0
INFO:root:current loglik: -0.6, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -3.0, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -1.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -2.9, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -3.1, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 130000
INFO:root:chain 0
INFO:root:current loglik: -1.3, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -0.1, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -4.9, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -2.9, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -0.9, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:iter 140000
INFO:root:chain 0
INFO:root:current loglik: -1.3, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 1
INFO:root:current loglik: -0.5, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 2
INFO:root:current loglik: -1.7, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 3
INFO:root:current loglik: -1.2, best: -0.0, temp: 1.0, ratio: 1.0
INFO:root:chain 4
INFO:root:current loglik: -1.4, best: -0.0, temp: 1.0, ratio: 1.0
Show posterior for 1st parameter
# Plot distribution of first parameter
C0 = np.array(S.chains[0].chn)
plt.figure()
plt.hist(C0[:,0], bins=100, color="k", histtype="step")
plt.axvline(x=T.mu[0], label='True')
plt.legend()
<matplotlib.legend.Legend at 0x7f91e20b0f70>
Show likelihood evolution
L = np.load("chain_0.npy")['logL']
plt.figure()
plt.plot(L)
[<matplotlib.lines.Line2D at 0x7f91df9185b0>]
Show statistics
Acceptance rate for each proposals
jumps_t, ar_t, jk = get_mcmc_stats('./', p_dict)
full_names = jumps_t.dtype.names
plt.figure(figsize=(10,5))
plt.subplot(121)
y_offset = np.zeros((Nchains))
for n in full_names:
plt.bar(np.arange(0, Nchains), height=ar_t[n]/jumps_t[n], bottom=y_offset, label=n)
y_offset += ar_t[n]/jumps_t[n]
plt.legend()
plt.ylabel("# accepted / # jumps")
plt.subplot(122)
y_offset = np.zeros((Nchains))
for n in full_names:
plt.bar(np.arange(0, Nchains), height=ar_t[n], bottom=y_offset, label=n)
y_offset += ar_t[n]
plt.legend()
plt.ylabel("Number of MH steps")
Text(0, 0.5, 'Number of MH steps')
Auto-correlation
import numpy.lib.recfunctions as recf
cmap = plt.get_cmap("tab10")
C = np.load(f"chain_0.npy")
N = np.exp(np.linspace(np.log(100), np.log(C.shape[0]), 20)).astype(int)
acor = np.empty((len(N), Nchains))
emcee = np.empty((len(N), Nchains))
acor2 = np.empty((len(N), Nchains))
for j in range(Nchains):
C = np.load(f"chain_{j}.npy")
C = recf.structured_to_unstructured(C)
for i, n in enumerate(N):
acor[i,j] = get_autocorr_len(C[:n,:-2], burn=0.25, opt='acor-c')
emcee[i,j] = get_autocorr_len(C[:n,:-2], burn=0.25, opt='emcee')
acor2[i,j] = get_autocorr_len(C[:n,:-2], burn=0.25, opt='acor-p')
plt.loglog(N, np.mean(acor,axis=1), label='acor')
plt.loglog(N, np.mean(emcee, axis=1), label='emcee')
plt.loglog(N, np.mean(acor2, axis=1), label='my-acor')
plt.xlabel("number of samples, $N$")
plt.ylabel(r"$\tau$ estimates")
plt.legend(fontsize=14);
%timeit get_autocorr_len(C[:n,:-2], burn=0.25, opt='acor-c')
%timeit get_autocorr_len(C[:n,:-2], burn=0.25, opt='emcee')
%timeit get_autocorr_len(C[:n,:-2], burn=0.25, opt='acor-p')
4.24 ms ± 33.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
23.7 ms ± 868 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
3.86 ms ± 28.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)