Source code for deepqmc.hamil

from collections.abc import Callable
from functools import partial
from itertools import count
from typing import Any, Optional, Protocol

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc

from .ecp.gaussian_type_ecp import GaussianTypeECP
from .molecule import Molecule
from .physics import (
    NuclearCoulombPotential,
    electronic_potential,
    laplacian,
    nuclear_energy,
    pairwise_distance,
)
from .types import (
    Energy,
    KeyArray,
    ParametrizedWaveFunction,
    Params,
    PhysicalConfiguration,
    Stats,
)
from .utils import argmax_random_choice

__all__ = ['MolecularHamiltonian', 'LaplacianFactory']


[docs] class LaplacianFactory(Protocol): r"""Protocol class for Laplacian factories. A Laplacian factory takes as input a function and returns a function that computes the laplacian and gradient of the input function """ def __call__( self, f: Callable[[jax.Array], jax.Array] ) -> Callable[[jax.Array], tuple[jax.Array, jax.Array]]: ...
def get_shell(z): # returns the number of (at least partially) occupied shells for 'z' electrons # 'get_shell(z+1)-1' yields the number of fully occupied shells for 'z' electrons max_elec = 0 n = 0 for n in count(): if z <= max_elec: break max_elec += 2 * (1 + n) ** 2 return n class Hamiltonian(Protocol): r"""Protocol for :class:`~deepqmc.types.Hamiltonian` objects. :class:`~deepqmc.types.Hamiltonian` objects represent the Hamiltonian of the system under investigation. New Hamiltonians should implement this protocol to be compatible with the DeepQMC software suite. The :class:`~deepqmc.types.Hamiltonian` object holds information about the system and implements the local energy factory. """ def local_energy(self, ansatz: ParametrizedWaveFunction) -> Callable[ [Optional[KeyArray], Params, PhysicalConfiguration], tuple[Energy, Stats], ]: r""" Return a function that calculates the local energy of the wave function. Args: wf (~deepqmc.types.ParametrizedWaveFunction): the wave function ansatz. return_grad (bool): whether to return a tuple with the quantum force. Returns: :class:`Callable[r, ...]`: a function that evaluates the local energy of :data:`wf` at :data:`r`. """ ...
[docs] class MolecularHamiltonian(Hamiltonian): r""" Hamiltonian of non-relativistic molecular systems. The system consists of nuclei with fixed positions and electrons moving around them. The total energy is defined as the sum of the nuclear-nuclear and electron-electron repulsion, the nuclear-electron attraction, and the kinetic energy of the electrons: :math:`E=V_\text{nuc-nuc} + V_\text{el-el} + V_\text{nuc-el} + E_\text{kin}`. Args: mol (~deepqmc.molecule.Molecule): the molecule to consider ecp_type (str): If set, use the appropriate effective core potential (ECP). The string is passed to :func:`pyscf.gto.M()` as :data:`'ecp'` argument. Supports ECPs that are implemented in the pyscf package, e.g. :data:`'bfd'` [Burkatzki et al. 2007] or :data:`'ccECP'` [Bennett et al. 2017]. ecp_mask (list[bool]): list of True and False values (:math:`N_\text{nuc}`) specifying whether to use an ECP for each nucleus. elec_std (float): optional, a default value of the scaling factor of the spread of electrons around the nuclei. laplacian_factory (~deepqmc.hamil.LaplacianFactory): creates a function that returns a tuple containing the laplacian and gradient of the wave function. """ def __init__( self, *, mol: Molecule, ecp_type: Optional[str] = None, ecp_mask: Optional[list[bool]] = None, elec_std: float = 1.0, laplacian_factory: LaplacianFactory = laplacian, ): self.mol = mol self.elec_std = elec_std self.ecp_type = ecp_type if ecp_type is None: ecp_mask = [False] * len(mol.charges) elif ecp_mask is None: # use ECP only for atoms larger than He ecp_mask = list(mol.charges > 2) assert len(ecp_mask) == len(mol.charges), "Incompatible shape of 'ecp_mask'!" self.ecp_mask = jnp.array(ecp_mask) self.laplacian = laplacian_factory self.potential: ( NuclearCoulombPotential | GaussianTypeECP ) # mypy otherwise complains about the following assignment if self.ecp_mask.any(): self.potential = GaussianTypeECP(mol.charges, ecp_type, self.ecp_mask) else: self.potential = NuclearCoulombPotential(mol.charges) n_elec = int(sum(self.potential.ns_valence) - mol.charge) assert not (n_elec + mol.spin) % 2 assert n_elec > 1, 'The system must contain at least two active electrons.' self.n_nuc = len(mol.charges) self.n_up = (n_elec + mol.spin) // 2 self.n_down = (n_elec - mol.spin) // 2 self.ns_valence = self.potential.ns_valence self.mol_shells = [get_shell(z) for z in self.mol.charges] self.mol_ecp_shells = [ get_shell(z + 1) - 1 for z in self.mol.charges - self.ns_valence ]
[docs] def init_sample( self, rng: KeyArray, R: jax.Array, n: int, elec_std: Optional[float] = None ) -> PhysicalConfiguration: r""" Guess some initial electron positions. Tries to make an educated guess about plausible initial electron configurations. Places electrons according to normal distributions centered on the nuclei. If the molecule is not neutral, extra electrons are placed on or removed from random nuclei. The resulting configurations are usually very crude, a subsequent, thorough equilibration is needed. Args: rng (~deepqmc.types.KeyArray): key used for PRNG. R (jax.Array): nuclear coordinates of a single molecular geometry (:math:`N_\text{nuc}`, 3) n (int): the number of initial electron configurations to generate. Returns: :class:`~deepqmc.types.PhysicalConfiguration`: initial electron and nuclei configurations """ assert R.ndim == 2 Rs = jnp.tile(R[None], (n, 1, 1)) return jax.vmap(self.init_single_sample, (0, 0, None))( jax.random.split(rng, n), Rs, elec_std )
def init_single_sample( self, rng: KeyArray, R: jax.Array, elec_std: Optional[float] ) -> PhysicalConfiguration: rng_remainder, rng_normal, rng_spin = jax.random.split(rng, 3) valence_electrons = self.potential.ns_valence - self.mol.charge / self.n_nuc electrons_of_atom = jnp.floor(valence_electrons).astype(jnp.int32) def cond_fn(value): _, electrons_of_atom = value return ( self.potential.ns_valence.sum() - self.mol.charge - electrons_of_atom.sum() > 0 ) def body_fn(value): rng, electrons_of_atom = value rng, rng_categorical = jax.random.split(rng) atom_idx = jax.random.categorical( rng_categorical, valence_electrons - electrons_of_atom, shape=() ) electrons_of_atom = electrons_of_atom.at[atom_idx].add(1) return rng, electrons_of_atom _, electrons_of_atom = jax.lax.while_loop( cond_fn, body_fn, (rng_remainder, electrons_of_atom) ) up, down = self.distribute_spins(rng_spin, R, electrons_of_atom) up = (jnp.cumsum(up)[:, None] <= jnp.arange(self.n_up)).sum(axis=0) down = (jnp.cumsum(down)[:, None] <= jnp.arange(self.n_down)).sum(axis=0) idxs = jnp.concatenate([up, down]) centers = R[idxs] std = (elec_std or self.elec_std) * jnp.sqrt(self.mol.charges)[idxs][..., None] r = centers + std * jax.random.normal(rng_normal, centers.shape) return PhysicalConfiguration(R, r, jnp.array(0)) # type: ignore def distribute_spins( self, rng: KeyArray, R: jax.Array, elec_of_atom: jax.Array ) -> tuple[jax.Array, jax.Array]: up, down = jnp.zeros_like(elec_of_atom), jnp.zeros_like(elec_of_atom) # try to distribute electron pairs evenly across atoms def pair_cond_fn(value): i, *_ = value return i < jnp.max(elec_of_atom) def pair_body_fn(value): i, up, down = value mask = elec_of_atom >= 2 * (i + 1) increment = jnp.where(mask & (mask.sum() + down.sum() <= self.n_down), 1, 0) up = up + increment down = down + increment return i + 1, up, down _, up, down = jax.lax.while_loop(pair_cond_fn, pair_body_fn, (0, up, down)) # distribute remaining electrons such that opposite spin electrons # end up close in an attempt to mimic covalent bonds dists = pairwise_distance(R, R).at[jnp.diag_indices(len(R))].set(jnp.inf) nearest_neighbor_indices = jnp.argsort(dists) def spin_cond_fn(value): _, _, up, down = value return (up + down < elec_of_atom).any() def spin_body_fn(value): i, center, up, down = value is_down = (i % 2) & (down.sum() < self.n_down) up = up.at[center].add(1 - is_down) down = down.at[center].add(is_down) ordering = nearest_neighbor_indices[center] ordered_has_remainder = (elec_of_atom - up - down)[ordering] > 0 first_ordered_has_remainder = jnp.argmax(ordered_has_remainder) center = ordering[first_ordered_has_remainder] return i + 1, center, up, down center = argmax_random_choice(rng, elec_of_atom - up - down) *_, up, down = jax.lax.while_loop( spin_cond_fn, spin_body_fn, (jnp.array(0), center, up, down) ) return up, down def local_energy(self, ansatz: ParametrizedWaveFunction) -> Callable[ [Optional[KeyArray], Params, PhysicalConfiguration], tuple[Energy, Stats], ]: def loc_ene( rng: Optional[KeyArray], params: Params, phys_conf: PhysicalConfiguration ) -> tuple[Energy, Stats]: wf = partial(ansatz, params) def wave_function(r: jax.Array) -> jax.Array: pc = jdc.replace(phys_conf, r=r.reshape(-1, 3)) return wf(pc).log lap_log_psis, quantum_force = self.laplacian(wave_function)( phys_conf.r.flatten() ) Es_kin = -0.5 * (lap_log_psis + (quantum_force**2).sum(axis=-1)) Es_nuc = nuclear_energy(phys_conf, self.ns_valence) Vs_el = electronic_potential(phys_conf) Vs_loc = self.potential.local_potential(phys_conf) Vs_nl = self.potential.nonloc_potential(rng, phys_conf, wf) Es_loc = Es_kin + Vs_loc + Vs_nl + Vs_el + Es_nuc stats = { 'hamil/V_el': Vs_el, 'hamil/E_kin': Es_kin, 'hamil/V_loc': Vs_loc, 'hamil/V_nl': Vs_nl, 'hamil/lap': lap_log_psis, 'hamil/quantum_force': (quantum_force**2).sum(axis=-1), } return Es_loc, stats return loc_ene
[docs] def as_pyscf(self, *, coords: Optional[jax.Array] = None) -> dict[str, Any]: r"""Return the hamiltonian parameters in format pyscf can parse. Args: coords (jax.Array): optional, nuclear coordinates (:math:`N_\text{nuc}`, 3). """ coords = coords if coords is not None else self.mol.coords pyscf_kwargs = { 'atom': [(int(c), r.tolist()) for c, r in zip(self.mol.charges, coords)], 'charge': self.mol.charge, 'spin': self.mol.spin, 'ecp': {int(c): self.ecp_type for c in self.mol.charges[self.ecp_mask]}, 'unit': 'bohr', } return pyscf_kwargs