Source code for deepqmc.sampling.sampling_utils

from collections.abc import Callable, Iterable
from statistics import mean, stdev
from typing import Optional

import jax
import jax.numpy as jnp

from ..hamil import MolecularHamiltonian
from ..molecule import Molecule
from ..parallel import pmap, rng_iterator, select_one_device
from ..physics import pairwise_diffs
from ..types import (
    Ansatz,
    KeyArray,
    ParametrizedWaveFunction,
    Params,
    PhysicalConfiguration,
    SamplerState,
)
from .base import ElectronSampler
from .combined_samplers import (
    MoleculeIdxSampler,
    MultiElectronicStateSampler,
    MultiNuclearGeometrySampler,
)
from .nuclei_samplers import IdleNucleiSampler, no_elec_warp

__all__ = ['combine_samplers']


[docs] def chain(*samplers) -> ElectronSampler: r""" Combine multiple sampler types, to create advanced sampling schemes. For example :data:`chain(DecorrSampler(10),MetropolisSampler(hamil, tau=1.))` will create a :class:`MetropolisSampler`, where the samples are taken from every 10th MCMC step. The last element of the sampler chain has to be either a :class:`MetropolisSampler` or a :class:`LangevinSampler`. Args: samplers (~deepqmc.sampling.base.ElectronSampler): one or more sampler instances to combine. Returns: :type:`~deepqmc.sampling.base.ElectronSampler`: the combined sampler. """ name = 'Sampler' bases = tuple(map(type, samplers)) for base in bases: name = name.replace('Sampler', base.__name__) chained = type(name, bases, {'__init__': lambda self: None})() for sampler in samplers: chained.__dict__.update(sampler.__dict__) return chained # type: ignore
[docs] def combine_samplers( samplers, hamil: MolecularHamiltonian, wf: ParametrizedWaveFunction ) -> ElectronSampler: r"""Combine samplers to create more advanced sampling schemes. Args: samplers (list[~deepqmc.sampling.base.ElectronSampler]): one or more sampler instances to combine. hamil (~deepqmc.hamil.MolecularHamiltonian): the molecular Hamiltonian. wf (~deepqmc.types.ParametrizedWaveFunction): the wave function to sample. """ sampler = chain(*samplers[:-1], samplers[-1](hamil, wf)) return sampler
def diffs_to_nearest_nuc(r, coords): z = pairwise_diffs(r, coords) idx = jnp.argmin(z[..., -1], axis=-1) return z[jnp.arange(len(r)), idx], idx def crossover_parameter(z, f, charge): z, z2 = z[..., :3], z[..., 3] eps = jnp.finfo(f.dtype).eps z_unit = z / jnp.linalg.norm(z, axis=-1, keepdims=True) f_unit = f / jnp.clip(jnp.linalg.norm(f, axis=-1, keepdims=True), eps, None) Z2z2 = charge**2 * z2 return (1 + jnp.sum(f_unit * z_unit, axis=-1)) / 2 + Z2z2 / (10 * (4 + Z2z2)) def clean_force(force, phys_conf, mol, *, tau): z, idx = jax.vmap(diffs_to_nearest_nuc)(phys_conf.r, phys_conf.R) a = crossover_parameter(z, force, mol.charges[idx]) av2tau = a * jnp.sum(force**2, axis=-1) * tau # av2tau can be small or zero, so the following expression must handle that factor = 2 / (jnp.sqrt(1 + 2 * av2tau) + 1) force = factor[..., None] * force eps = jnp.finfo(phys_conf.r.dtype).eps norm_factor = jnp.minimum( 1.0, jnp.sqrt(z[..., -1]) / (tau * jnp.clip(jnp.linalg.norm(force, axis=-1), eps, None)), ) force = force * norm_factor[..., None] return force def equilibrate( rng: KeyArray, params: Params, molecule_idx_sampler: MoleculeIdxSampler, sampler: MultiNuclearGeometrySampler, state: SamplerState, criterion: Callable[[PhysicalConfiguration], jax.Array], steps: Iterable[int], *, block_size: int, n_blocks: int = 5, allow_early_stopping: bool = True, ): sample_wf = pmap(sampler.sample) buffer_size = block_size * n_blocks buffer: list[float] = [] for step, rng in zip(steps, rng_iterator(rng)): mol_idxs = molecule_idx_sampler.sample() state, phys_conf, stats = sample_wf(rng, state, params, mol_idxs) yield step, state, select_one_device(mol_idxs), stats if allow_early_stopping: buffer = [*buffer[-buffer_size + 1 :], criterion(phys_conf).item()] if len(buffer) < buffer_size: continue b1, b2 = buffer[:block_size], buffer[-block_size:] if abs(mean(b1) - mean(b2)) < min(stdev(b1), stdev(b2)): break def initialize_sampling( rng: KeyArray, hamil: MolecularHamiltonian, ansatz: Ansatz, mols: list[Molecule], electronic_states: int, molecule_batch_size: int, *, elec_sampler, nuc_sampler=None, elec_warp_fn: Optional[Callable] = None, update_nuc_period: Optional[int] = None, elec_equilibration_steps: Optional[int] = None, ) -> tuple[MoleculeIdxSampler, MultiNuclearGeometrySampler]: molecule_idx_sampler = MoleculeIdxSampler( rng, len(mols), molecule_batch_size, 'once' ) elec_sampler = elec_sampler(hamil=hamil, wf=ansatz.apply) multi_state_elec_sampler = MultiElectronicStateSampler( elec_sampler, electronic_states ) nuc_sampler = (IdleNucleiSampler if nuc_sampler is None else nuc_sampler)( hamil.mol.charges, ) elec_warp_fn = no_elec_warp if elec_warp_fn is None else elec_warp_fn sampler = MultiNuclearGeometrySampler( multi_state_elec_sampler, nuc_sampler, elec_warp_fn, update_nuc_period, elec_equilibration_steps, ) return molecule_idx_sampler, sampler def initialize_sampler_state(rng, sampler, params, electron_batch_size, mols): @jax.pmap def sampler_state_initializer(rng, params): return sampler.init( rng, params, electron_batch_size // jax.device_count(), jnp.stack([mol.coords for mol in mols]), ) return sampler_state_initializer(rng, params)