Source code for deepqmc.molecule

import os
import re
from collections import OrderedDict
from copy import deepcopy
from dataclasses import dataclass
from glob import glob
from importlib import resources
from pathlib import Path
from typing import ClassVar, Optional, cast
from typing_extensions import Self

import jax
import jax.numpy as jnp
import yaml
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import get_original_cwd, to_absolute_path

from .units import angstrom_to_bohr, null

__all__ = ['Molecule']


def mol_conf_dir() -> Path:
    return cast(Path, resources.files('deepqmc').joinpath('conf/hamil/mol'))


def get_all_names() -> set[str]:
    return {filename.replace('.yaml', '') for filename in os.listdir(mol_conf_dir())}


[docs] @dataclass(frozen=True, init=False) class Molecule: r"""Represents a molecule. The array-like arguments accept anything that can be transformed to :class:`jax.Array`. Args: coords (jax.Array | list[float]): nuclear coordinates ((:math:`N_\text{nuc}`, 3), a.u.) as rows charges (jax.Array | list[int | float]): atom charges (:math:`N_\text{nuc}`) charge (int): total charge of a molecule spin (int): total spin multiplicity unit (str): units of the coordinates, either 'bohr' or 'angstrom' data (dict): additional data stored with the molecule """ all_names: ClassVar[set] = get_all_names() coords: jax.Array charges: jax.Array charge: int spin: int data: dict # DERIVED PROPERTIES: n_atom_types: int def __init__( self, *, coords, charges, charge, spin, unit='bohr', data=None, ): def set_attr(**kwargs): for k, v in kwargs.items(): object.__setattr__(self, k, v) unit_multiplier = {'bohr': null, 'angstrom': angstrom_to_bohr}[unit] set_attr( coords=unit_multiplier(jnp.array(coords)), charges=jnp.array(charges, dtype=float), charge=charge, spin=spin, data=data or {}, ) # Derived properties set_attr( n_atom_types=len(jnp.unique(jnp.array(charges))), ) def __len__(self): return len(self.charges) def __iter__(self): yield from zip(self.coords, self.charges) def __repr__(self): return ( 'Molecule(\n' f' coords=\n{self.coords},\n' f' charges={self.charges},\n' f' charge={self.charge},\n' f' spin={self.spin}\n' ')' )
[docs] @classmethod def from_name(cls, name: str) -> Self: """Create a molecule from a database of named molecules. The available names are in :attr:`Molecule.all_names`. Args: name (str): name of the molecule (one of :attr:`Molecule.all_names`) """ if name in cls.all_names: mol = deepcopy(read_molecule_dataset(mol_conf_dir(), whitelist=name)[name]) else: raise ValueError(f'Unknown molecule name: {name}') return mol
[docs] @classmethod def from_file(cls, file: str) -> Self: """Create a molecule from a YAML file. Args: file (str): path to the YAML file """ if not Path(file).is_absolute(): if GlobalHydra().instance().is_initialized(): file = os.path.join(to_absolute_path(get_original_cwd()), file) else: file = to_absolute_path(file) with open(file, 'r') as stream: return cls(**yaml.safe_load(stream))
class MoleculeDict(OrderedDict): r"""Store molecules in the order they were added to the dictionary.""" def __setitem__(self, key: str, value: Molecule): super().__setitem__(key, value) self.move_to_end(key) def read_molecule_dataset( dataset: Path, whitelist: Optional[str] = None ) -> MoleculeDict: molecules = MoleculeDict() for f in sorted(glob(str(dataset / '*.yaml'))): filename = f.split('/')[-1].replace('.yaml', '') if whitelist is not None and not re.search(whitelist, filename): continue with open(f, 'r') as stream: molecules[filename] = Molecule(**yaml.safe_load(stream)) return molecules