Source code for deepqmc.sampling.electron_samplers

from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp
from jax import lax

from deepqmc.sampling.sampling_utils import clean_force

from ..hamil import MolecularHamiltonian
from ..physics import pairwise_self_distance
from ..types import (
    KeyArray,
    ParametrizedWaveFunction,
    Params,
    PhysicalConfiguration,
    SamplerState,
    Stats,
)
from ..utils import multinomial_resampling, split_dict

__all__ = [
    'MetropolisSampler',
    'LangevinSampler',
    'DecorrSampler',
    'ResampledSampler',
]


[docs] class MetropolisSampler: r""" Metropolis--Hastings Monte Carlo sampler. The :meth:`sample` method of this class returns electron coordinate samples from the distribution defined by the square of the sampled wave function. Args: hamil (~deepqmc.hamil.MolecularHamiltonian): the Hamiltonian of the physical system. wf: the :data:`apply` method of the :data:`haiku` transformed ansatz object. tau (float): optional, the proposal step size scaling factor. Adjusted during every step if :data:`target_acceptance` is specified. target_acceptance (float): optional, if specified the proposal step size will be scaled such that the ratio of accepted proposal steps approaches :data:`target_acceptance`. max_age (int): optional, if specified the next proposed step will always be accepted for a walker that hasn't moved in the last :data:`max_age` steps. """ WALKER_STATE = ['r', 'psi', 'age'] def __init__( self, hamil: MolecularHamiltonian, wf: ParametrizedWaveFunction, *, tau: float = 1.0, target_acceptance: float = 0.57, max_age: Optional[int] = None, ): self.hamil = hamil self.initial_tau = tau self.target_acceptance = target_acceptance self.max_age = max_age self.wf = wf def _update( self, state: SamplerState, params: Params, R: jax.Array ) -> SamplerState: psi = jax.vmap(self.wf, (None, 0))(params, self.phys_conf(R, state['r'])) state = {**state, 'psi': psi} return state def update(self, state: SamplerState, params: Params, R: jax.Array) -> SamplerState: return self._update(state, params, R) def init(self, rng: KeyArray, params: Params, n: int, R: jax.Array) -> SamplerState: state = { 'r': self.hamil.init_sample(rng, R, n).r, 'age': jnp.zeros(n, jnp.int32), 'tau': jnp.array(self.initial_tau), } return self._update(state, params, R) def _proposal(self, state: SamplerState, rng: KeyArray) -> jax.Array: r = state['r'] return r + state['tau'] * jax.random.normal(rng, r.shape) def _acc_log_prob(self, state: SamplerState, prop: SamplerState) -> jax.Array: return 2 * (prop['psi'].log - state['psi'].log) def sample( self, rng: KeyArray, state: SamplerState, params: Params, R: jax.Array ) -> tuple[SamplerState, PhysicalConfiguration, Stats]: rng_prop, rng_acc = jax.random.split(rng) prop = { 'r': self._proposal(state, rng_prop), 'age': jnp.zeros_like(state['age']), **{k: v for k, v in state.items() if k not in self.WALKER_STATE}, } prop = self._update(prop, params, R) log_prob = self._acc_log_prob(state, prop) accepted = log_prob > jnp.log(jax.random.uniform(rng_acc, log_prob.shape)) if self.max_age: accepted = accepted | (state['age'] >= self.max_age) acceptance = accepted.astype(int).sum() / accepted.shape[0] if self.target_acceptance: prop['tau'] /= self.target_acceptance / jnp.max( jnp.stack([acceptance, jnp.array(0.05)]) ) state = {**state, 'age': state['age'] + 1} (prop, other), (state, _) = ( split_dict(d, lambda k: k in self.WALKER_STATE) for d in (prop, state) ) state = { **jax.tree_util.tree_map( lambda xp, x: jax.vmap(jnp.where)(accepted, xp, x), prop, state ), **other, } stats = { 'sampling/acceptance': acceptance, 'sampling/tau': state['tau'], 'sampling/age/mean': jnp.mean(state['age']), 'sampling/age/max': jnp.max(state['age']), 'sampling/log_psi/mean': jnp.mean(state['psi'].log), 'sampling/log_psi/std': jnp.std(state['psi'].log), 'sampling/dists/mean': jnp.mean(pairwise_self_distance(state['r'])), } return state, self.phys_conf(R, state['r']), stats def phys_conf(self, R: jax.Array, r: jax.Array, **kwargs) -> PhysicalConfiguration: if r.ndim == 2: return PhysicalConfiguration(R, r, jnp.array(0)) # type: ignore n_smpl = len(r) return PhysicalConfiguration( jnp.tile(R[None], (n_smpl, 1, 1)), # type: ignore r, jnp.zeros(n_smpl, dtype=jnp.int32), )
[docs] class LangevinSampler(MetropolisSampler): r""" Metropolis adjusted Langevin Monte Carlo sampler. Derived from :class:`MetropolisSampler`. Args: hamil (~deepqmc.hamil.MolecularHamiltonian): the Hamiltonian of the physical system. wf: the :data:`apply` method of the :data:`haiku` transformed ansatz object. tau (float): optional, the proposal step size scaling factor. Adjusted during every step if :data:`target_acceptance` is specified. target_acceptance (float): optional, if specified the proposal step size will be scaled such that the ratio of accepted proposal steps approaches :data:`target_acceptance`. max_age (int): optional, if specified the next proposed step will always be accepted for a walker that hasn't moved in the last :data:`max_age` steps. """ WALKER_STATE = MetropolisSampler.WALKER_STATE + ['force'] def _update( self, state: SamplerState, params: Params, R: jax.Array ) -> SamplerState: @jax.vmap @partial(jax.value_and_grad, has_aux=True) def wf_and_force(r): psi = self.wf(params, self.phys_conf(R, r)) return psi.log, psi (_, psi), force = wf_and_force(state['r']) force = clean_force( force, self.phys_conf(R, state['r']), self.hamil.mol, tau=state['tau'] ) state = {**state, 'psi': psi, 'force': force} return state def _proposal(self, state: SamplerState, rng: KeyArray) -> jax.Array: r, tau = state['r'], state['tau'] r = r + tau * state['force'] + jnp.sqrt(tau) * jax.random.normal(rng, r.shape) return r def _acc_log_prob(self, state: SamplerState, prop: SamplerState) -> jax.Array: log_G_ratios = jnp.sum( (state['force'] + prop['force']) * ( (state['r'] - prop['r']) + state['tau'] / 2 * (state['force'] - prop['force']) ), axis=tuple(range(1, len(state['r'].shape))), ) return log_G_ratios + 2 * (prop['psi'].log - state['psi'].log)
[docs] class DecorrSampler: r""" Insert decorrelating steps into chained samplers. This sampler cannot be used as the last element of a sampler chain. Args: length (int): the samples will be taken in every :data:`length` MCMC step, that is, :data:`length` :math:`-1` decorrelating steps are inserted. """ def __init__(self, *, length): self.length = length def sample( self, rng: KeyArray, state: SamplerState, params: Params, R: jax.Array ) -> tuple[SamplerState, PhysicalConfiguration, Stats]: sample = super().sample # type: ignore state, stats = lax.scan( lambda state, rng: sample(rng, state, params, R)[::2], state, jax.random.split(rng, self.length), ) stats = {k: v[-1] for k, v in stats.items()} return state, self.phys_conf(R, state['r']), stats # type: ignore
[docs] class ResampledSampler: r""" Add resampling to chained samplers. This sampler cannot be used as the last element of a sampler chain. The resampling is performed by accumulating weights on each MCMC walker in each step. Based on a fixed resampling period :data:`period` and/or a threshold :data:`threshold` on the normalized effective sample size the walker positions are sampled according to the multinomial distribution defined by these weights, and the weights are reset to one. Either :data:`period` or :data:`threshold` have to be specified. Args: period (int): optional, if specified the walkers are resampled every :data:`period` MCMC steps. threshold (float): optional, if specified the walkers are resampled if the effective sample size normalized with the batch size is below :data:`threshold`. """ def __init__( self, *, period: Optional[int] = None, threshold: Optional[float] = None ): assert period is not None or threshold is not None self.period = period self.threshold = threshold def update(self, state: SamplerState, params: Params, R: jax.Array) -> SamplerState: state['log_weight'] -= 2 * state['psi'].log state = self._update(state, params, R) # type: ignore state['log_weight'] += 2 * state['psi'].log state['log_weight'] -= state['log_weight'].max() return state def init(self, *args, **kwargs): state = super().init(*args, **kwargs) # type: ignore state = { **state, 'step': jnp.array(0), 'log_weight': jnp.zeros_like(state['psi'].log), } return state def resample_walkers(self, rng_re: KeyArray, state: SamplerState) -> SamplerState: idx = multinomial_resampling(rng_re, jnp.exp(state['log_weight'])) state, other = split_dict(state, lambda k: k in self.WALKER_STATE) # type: ignore state = { **jax.tree_util.tree_map(lambda x: x[idx], state), **other, 'step': jnp.array(0), 'log_weight': jnp.zeros_like(other['log_weight']), } return state def sample( self, rng: KeyArray, state: SamplerState, params: Params, R: jax.Array ) -> tuple[SamplerState, PhysicalConfiguration, Stats]: rng_re, rng_smpl = jax.random.split(rng) state, _, stats = super().sample(rng_smpl, state, params, R) # type: ignore state['step'] += 1 weight = jnp.exp(state['log_weight']) ess = jnp.sum(weight) ** 2 / jnp.sum(weight**2) stats['sampling/effective sample size'] = ess state = jax.lax.cond( (self.period is not None and state['step'] >= self.period) | (self.threshold is not None and ess / len(weight) < self.threshold), self.resample_walkers, lambda rng, state: state, rng_re, state, ) return state, self.phys_conf(R, state['r']), stats # type: ignore