Source code for deepqmc.sampling.base

from typing import Protocol

import jax

from deepqmc.types import (
    KeyArray,
    Params,
    PhysicalConfiguration,
    SamplerState,
    Stats,
)


[docs] class ElectronSampler(Protocol): r"""Protocol for :class:`~deepqmc.sampling.base.ElectronSampler` objects. :class:`~deepqmc.sampling.base.ElectronSampler` objects implement Markov chain samplers for the electron positions. The samplers are assumed to implement a batch of walkers for a single electronic state on a single molecule and may be vmapped to fit the respective context they are used in. Electron samplers can be combined with :func:`~deepqmc.sampling.chain`. """
[docs] def init(self, rng: KeyArray, params: Params, n: int, R: jax.Array) -> SamplerState: r""" Initializes the sampler state. Args: rng (~deepqmc.types.KeyArray): an rng key for the initialization of electron positions. params (~deepqmc.types.Params): the parameters of the wave function that is being sampled. n (int): the number of walkers to propagate in parallel. R (jax.Array): the nuclei positions of the molecular configuration. Returns: :type:`~deepqmc.types.SamplerState`: the sampler state holding electron positions and data about the sampler trajectory. """ ...
[docs] def sample( self, rng: KeyArray, state: SamplerState, params: Params, R: jax.Array ) -> tuple[SamplerState, PhysicalConfiguration, Stats]: r""" Propagates the sampler state. Args: rng (~deepqmc.types.KeyArray): an rng key for the proposal of electron positions. state (~deepqmc.types.SamplerState): the state of the sampler from the previous step. params (~deepqmc.types.Params): the parameters of the wave function that is being sampled. R (jax.Array): the nuclei positions of the molecular configuration. Returns: tuple[:type:`~deepqmc.types.SamplerState`, :type:`~deepqmc.types.PhysicalConfiguration`, :type:`~deepqmc.types.Stats`]: the new sampler state, a physical configuration and statistics about the sampling trajectory. """ ...
[docs] def update(self, state: SamplerState, params: Params, R: jax.Array) -> SamplerState: r""" Updates the sampler state. The sampler state is updated to account for changes in the wave function due to a parameter update. Args: state (~deepqmc.types.SamplerState): the state of the sampler before parameter update. params (~deepqmc.types.Params): the new parameters of the wave function. R (jax.Array): the nuclei positions of the molecular configuration. Returns: :type:`~deepqmc.types.SamplerState`: the updated sampler state holding electron positions and data about the sampler trajectory. """ ...
class NucleiSampler(Protocol): r"""Protocol for nuclei samplers.""" def init(self, nuc_coords: jax.Array) -> SamplerState: ... def sample( self, rng: KeyArray, state: SamplerState ) -> tuple[SamplerState, jax.Array, Stats]: ... class ElectronWarp(Protocol): r"""Protocol for electron warps.""" def __call__( self, rng: KeyArray, R: jax.Array, dR: jax.Array, smpl_state: SamplerState ) -> SamplerState: ...