Source code for deepqmc.molecule
import os
from copy import deepcopy
from dataclasses import dataclass
from importlib import resources
from typing import ClassVar
import jax.numpy as jnp
import yaml
angstrom = 1 / 0.52917721092
__all__ = ['Molecule']
def parse_molecules():
path = resources.files('deepqmc').joinpath('conf/hamil/mol')
data = {}
for f in os.listdir(path):
with open(path.joinpath(f), 'r') as stream:
data[f.strip('.yaml')] = yaml.safe_load(stream)
return data
_SYSTEMS = parse_molecules()
[docs]@dataclass(frozen=True, init=False)
class Molecule:
r"""Represents a molecule.
The array-like arguments accept anything that can be transformed to
:class:`jax.numpy.DeviceArray`.
Args:
coords (float, (:math:`N_\text{nuc}`, 3), a.u.):
nuclear coordinates as rows
charges (int, (:math:`N_\text{nuc}`)): atom charges
charge (int): total charge of a molecule
spin (int): total spin multiplicity
"""
all_names: ClassVar[set] = set(_SYSTEMS.keys())
coords: jnp.ndarray
charges: jnp.ndarray
charge: int
spin: int
data: dict
# DERIVED PROPERTIES:
n_atom_types: int
def __init__(
self,
*,
coords,
charges,
charge,
spin,
unit='bohr',
data=None,
):
def set_attr(**kwargs):
for k, v in kwargs.items():
object.__setattr__(self, k, v)
unit_multiplier = {'bohr': 1.0, 'angstrom': angstrom}[unit]
set_attr(
coords=unit_multiplier * jnp.asarray(coords),
charges=1.0 * jnp.asarray(charges),
charge=charge,
spin=spin,
data=data or {},
)
# Derived properties
set_attr(
n_atom_types=len(jnp.unique(jnp.asarray(charges))),
)
def __len__(self):
return len(self.charges)
def __iter__(self):
yield from zip(self.coords, self.charges)
def __repr__(self):
return (
'Molecule(\n'
f' coords=\n{self.coords},\n'
f' charges={self.charges},\n'
f' charge={self.charge},\n'
f' spin={self.spin}\n'
')'
)
[docs] @classmethod
def from_name(cls, name, **kwargs):
"""Create a molecule from a database of named molecules.
The available names are in :attr:`Molecule.all_names`.
"""
if name in cls.all_names:
system = deepcopy(_SYSTEMS[name])
system.update(kwargs)
else:
raise ValueError(f'Unknown molecule name: {name}')
coords = system.pop('coords')
return cls(coords=coords, **system)