Source code for deepqmc.types

from __future__ import annotations

from collections.abc import Callable, MutableMapping
from typing import Any, NamedTuple, Optional, Protocol
from typing_extensions import TypeAlias

import jax
import jax_dataclasses as jdc


[docs] class Psi(NamedTuple): r"""Represent wave function values. The sign and log of the absolute value of the wave function are stored. """ sign: jax.Array log: jax.Array
@jdc.pytree_dataclass class PhysicalConfiguration: r"""Represent physical configurations of electrons and nuclei. It currently contains the nuclear and electronic coordinates, along with :data:`mol_idx`, which specifies which nuclear configuration a given sample was obtained from. """ R: jax.Array r: jax.Array mol_idx: jax.Array def __getitem__(self, idx): return self.__class__( self.R.__getitem__(idx), self.r.__getitem__(idx), self.mol_idx.__getitem__(idx), ) def __len__(self): return len(self.r) @property def batch_shape(self): assert self.r.shape[:-2] == self.R.shape[:-2] == self.mol_idx.shape return self.r.shape[:-2] Params: TypeAlias = MutableMapping Stats: TypeAlias = dict Weight: TypeAlias = jax.Array Energy: TypeAlias = jax.Array KeyArray: TypeAlias = jax.Array SamplerState: TypeAlias = dict OptState: TypeAlias = Any DataDict: TypeAlias = dict Batch: TypeAlias = tuple[PhysicalConfiguration, Weight, Optional[DataDict]] WaveFunction: TypeAlias = Callable[[PhysicalConfiguration], Psi] ParametrizedWaveFunction: TypeAlias = Callable[[Params, PhysicalConfiguration], Psi]
[docs] class TrainState(NamedTuple): r"""Represent the current state of the training.""" sampler: SamplerState params: Params opt: OptState
[docs] class Ansatz(Protocol): r"""Protocol for ansatz objects. :class:`~deepqmc.types.Ansatz` objects represent a parametrized wave function Ansatz. New types of Ansatzes should implement this protocol to be compatible with the DeepQMC software suite. It is assumed that Ansatzes take as input a :class:`~deepqmc.types.PhysicalConfiguration` for a single sample of electron and nuclei configuration. To handle batches of samples, e.g. during training, the Ansatz is ``vmap``-ed automatically by DeepQMC. """
[docs] def init(self, rng: KeyArray, phys_conf: PhysicalConfiguration) -> Params: r"""Initialize the parameters of the Ansatz. Args: rng (~deepqmc.types.KeyArray): the RNG key used to generate the initial parameters. phys_conf (~deepqmc.types.PhysicalConfiguration): a dummy input to the network of a single electron and nuclei configuration. The value of this can be anything, only its shape information is read. Returns: ~deepqmc.types.Params: the initial parameters of the Ansatz. """ ...
[docs] def apply( self, params: Params, phys_conf: PhysicalConfiguration, return_mos: bool = False ) -> Psi: r"""Evaluate the Ansatz. Args: params (~deepqmc.types.Params): the current parameters with which to evaluate the Ansatz. phys_conf (~deepqmc.types.PhysicalConfiguration): a single sample on which to evaluate the Ansatz. return_mos (bool): whether to return the many-body orbitals instead of the wave function. Returns: ~deepqmc.types.Psi: the value of the wave function. """ ...