from itertools import count

import jax.numpy as jnp
import jax_dataclasses as jdc
from jax import lax, random, vmap

from ..physics import (
from ..pp.ecp_potential import EcpTypePseudopotential
from ..types import PhysicalConfiguration
from ..utils import argmax_random_choice
from .base import Hamiltonian

__all__ = ['MolecularHamiltonian']

def get_shell(z):
    # returns the number of (at least partially) occupied shells for 'z' electrons
    # 'get_shell(z+1)-1' yields the number of fully occupied shells for 'z' electrons
    max_elec = 0
    for n in count():
        if z <= max_elec:
        max_elec += 2 * (1 + n) ** 2
    return n

[docs]class MolecularHamiltonian(Hamiltonian): r""" Hamiltonian of non-relativistic molecular systems. The system consists of nuclei with fixed positions and electrons moving around them. The total energy is defined as the sum of the nuclear-nuclear and electron-electron repulsion, the nuclear-electron attraction, and the kinetic energy of the electrons: :math:`E=V_\text{nuc-nuc} + V_\text{el-el} + V_\text{nuc-el} + E_\text{kin}`. Args: mol (~deepqmc.Molecule): the molecule to consider pp_type (str): If set, use the appropriate pseudopotential. The string is passed to :func:`pyscf.gto.M()` as :data:`'ecp'` argument. Supports pseudopotentials that are implemented in the pyscf package, e.g. :data:`'bfd'` [Burkatzki et al. 2007] or :data:`'ccECP'` [Bennett et al. 2017]. pp_mask (list, (:math:`N_\text{nuc}`)): list of True and False values specifying whether to use a pseudopotential for each nucleus elec_std (float): optional, a default value of the scaling factor of the spread of electrons around the nuclei. """ def __init__(self, *, mol, pp_type=None, pp_mask=None, elec_std=1.0): self.mol = mol self.elec_std = elec_std self.pp_type = pp_type if pp_type is None: pp_mask = [False] * len(mol.charges) elif pp_mask is None: # use PP only for atoms larger than He pp_mask = mol.charges > 2 assert len(pp_mask) == len(mol.charges), "Incompatible shape of 'pp_mask'!" self.pp_mask = jnp.array(pp_mask) # Derived properties if self.pp_type is None: self.potential = NuclearCoulombPotential(mol.charges) else: self.potential = EcpTypePseudopotential(mol.charges, pp_type, self.pp_mask) n_elec = int(sum(self.potential.ns_valence) - mol.charge) assert not (n_elec + mol.spin) % 2 assert n_elec > 1, 'The system must contain at least two active electrons.' self.n_nuc = len(mol.charges) self.n_up = (n_elec + mol.spin) // 2 self.n_down = (n_elec - mol.spin) // 2 self.ns_valence = self.potential.ns_valence self.mol_shells = [get_shell(z) for z in self.mol.charges] self.mol_pp_shells = [ get_shell(z + 1) - 1 for z in self.mol.charges - self.ns_valence ]
[docs] def init_sample(self, rng, R, n, elec_std=None): r""" Guess some initial electron positions. Tries to make an educated guess about plausible initial electron configurations. Places electrons according to normal distributions centered on the nuclei. If the molecule is not neutral, extra electrons are placed on or removed from random nuclei. The resulting configurations are usually very crude, a subsequent, thorough equilibration is needed. Args: rng (jax.random.PRNGKey): key used for PRNG. R (float, (:math:`N_\text{nuc}`, 3)): nuclear coordinates of a single molecular geometry n (int): the number of configurations to generate. electrons around the nuclei. """ assert R.ndim == 2 Rs = jnp.tile(R[None], (n, 1, 1)) return vmap(self.init_single_sample, (0, 0, None))( random.split(rng, n), Rs, elec_std )
def init_single_sample(self, rng, R, elec_std): rng_remainder, rng_normal, rng_spin = random.split(rng, 3) valence_electrons = self.potential.ns_valence - self.mol.charge / self.n_nuc electrons_of_atom = jnp.floor(valence_electrons).astype(jnp.int32) def cond_fn(value): _, electrons_of_atom = value return ( self.potential.ns_valence.sum() - self.mol.charge - electrons_of_atom.sum() > 0 ) def body_fn(value): rng, electrons_of_atom = value rng, rng_categorical = random.split(rng) atom_idx = random.categorical( rng_categorical, valence_electrons - electrons_of_atom, shape=() ) electrons_of_atom =[atom_idx].add(1) return rng, electrons_of_atom _, electrons_of_atom = lax.while_loop( cond_fn, body_fn, (rng_remainder, electrons_of_atom) ) rng_spin = random.split(rng_spin, 1) up, down = self.distribute_spins(rng_spin, R, electrons_of_atom) up = (jnp.cumsum(up)[:, None] <= jnp.arange(self.n_up)).sum(axis=0) down = (jnp.cumsum(down)[:, None] <= jnp.arange(self.n_down)).sum(axis=0) idxs = jnp.concatenate([up, down]) centers = R[idxs] std = (elec_std or self.elec_std) * jnp.sqrt(self.mol.charges)[idxs][..., None] r = centers + std * random.normal(rng_normal, centers.shape) return PhysicalConfiguration(R, r, jnp.array(0)) def distribute_spins(self, rng, R, elec_of_atom): up, down = jnp.zeros_like(elec_of_atom), jnp.zeros_like(elec_of_atom) # try to distribute electron pairs evenly across atoms def pair_cond_fn(value): i, *_ = value return i < jnp.max(elec_of_atom) def pair_body_fn(value): i, up, down = value mask = elec_of_atom >= 2 * (i + 1) increment = jnp.where(mask & (mask.sum() + down.sum() <= self.n_down), 1, 0) up = up + increment down = down + increment return i + 1, up, down _, up, down = lax.while_loop(pair_cond_fn, pair_body_fn, (0, up, down)) # distribute remaining electrons such that opposite spin electrons # end up close in an attempt to mimic covalent bonds dists = pairwise_distance(R, R).at[jnp.diag_indices(len(R))].set(jnp.inf) nearest_neighbor_indices = jnp.argsort(dists) def spin_cond_fn(value): _, _, up, down = value return (up + down < elec_of_atom).any() def spin_body_fn(value): i, center, up, down = value is_down = (i % 2) & (down.sum() < self.n_down) up =[center].add(1 - is_down) down =[center].add(is_down) ordering = nearest_neighbor_indices[center] ordered_has_remainder = (elec_of_atom - up - down)[ordering] > 0 first_ordered_has_remainder = jnp.argmax(ordered_has_remainder) center = ordering[first_ordered_has_remainder] return i + 1, center, up, down center = argmax_random_choice(rng, elec_of_atom - up - down) *_, up, down = lax.while_loop( spin_cond_fn, spin_body_fn, (jnp.array(0), center, up, down) ) return up, down def local_energy(self, wf, return_grad=False): def loc_ene(rng, phys_conf): def wave_function(r): pc = jdc.replace(phys_conf, r=r.reshape(-1, 3)) return wf(pc).log lap_log_psis, quantum_force = laplacian(wave_function)( phys_conf.r.flatten() ) Es_kin = -0.5 * (lap_log_psis + (quantum_force**2).sum(axis=-1)) Es_nuc = nuclear_energy(phys_conf, self.ns_valence) Vs_el = electronic_potential(phys_conf) Vs_loc = self.potential.local_potential(phys_conf) Vs_nl = self.potential.nonloc_potential(rng, phys_conf, wf) Es_loc = Es_kin + Vs_loc + Vs_nl + Vs_el + Es_nuc stats = { 'hamil/V_el': Vs_el, 'hamil/E_kin': Es_kin, 'hamil/V_loc': Vs_loc, 'hamil/V_nl': Vs_nl, 'hamil/lap': lap_log_psis, 'hamil/quantum_force': (quantum_force**2).sum(axis=-1), } result = (Es_loc, quantum_force) if return_grad else Es_loc return result, stats return loc_ene
[docs] def as_pyscf(self, coords): r"""Return nuclear charges and coordinates in a format pyscf can parse. Args: coords (jax.Array): nuclear coordinates, shape [n_nuc, 3]. """ return [(int(charge), coord) for coord, charge in zip(coords, self.mol.charges)]