import logging
import os
from functools import partial
from itertools import count
import h5py
import jax
import jax.numpy as jnp
from tqdm.auto import tqdm, trange
from uncertainties import ufloat
from deepqmc.optimizer import construct_optimizer
from .ewm import init_ewm
from .fit import fit_wf
from .log import CheckpointStore, H5LogTable, TensorboardMetricLogger
from .parallel import (
gather_electrons_on_one_device,
select_one_device,
split_on_devices,
split_rng_key_to_devices,
)
from .physics import pairwise_self_distance
from .pretrain import pretrain
from .sampling import equilibrate, initialize_sampler_state, initialize_sampling
from .wf.base import init_wf_params
__all__ = ['train']
log = logging.getLogger(__name__)
class NanError(Exception):
def __init__(self):
super().__init__()
class TrainingCrash(Exception):
def __init__(self, train_state):
super().__init__()
self.train_state = train_state
[docs]def train( # noqa: C901
hamil,
ansatz,
opt,
sampler,
steps,
seed,
electron_batch_size,
molecule_batch_size=1,
mols=None,
workdir=None,
train_state=None,
init_step=0,
max_restarts=3,
max_eq_steps=1000,
pretrain_steps=None,
pretrain_kwargs=None,
pretrain_sampler=None,
opt_kwargs=None,
fit_kwargs=None,
chkptdir=None,
chkpts_kwargs=None,
metric_logger=None,
):
r"""Train or evaluate a JAX wave function model.
It initializes and equilibrates the MCMC sampling of the wave function ansatz,
then optimizes or samples it using the variational principle. It optionally
saves checkpoints and rewinds the training/evaluation if an error is encountered.
If an optimizer is supplied, the Ansatz is optimized, otherwise the Ansatz is
only sampled.
Args:
hamil (~deepqmc.hamil.Hamiltonian): the Hamiltonian of the physical system.
ansatz (~deepqmc.wf.WaveFunction): the wave function Ansatz.
opt (``kfac_jax`` or ``optax`` optimizers, :class:`str` or :data:`None`):
the optimizer. Possible values are:
- :class:`kfac_jax.Optimizer`: the partially initialized KFAC optimizer
is used
- an :data:`optax` optimizer instance: the supplied :data:`optax`
optimizer is used.
- :class:`str`: the name of the optimizer to use (:data:`'kfac'` or an
:data:`optax` optimizer name). Arguments to the optimizer can be
passed in :data:`opt_kwargs`.
- :data:`None`: no optimizer is used, e.g. the evaluation of the Ansatz
is performed.
sampler (~deepqmc.sampling.Sampler): a sampler instance
steps (int): number of optimization steps.
seed (int): the seed used for PRNG.
electron_batch_size (int): the number of electron samples considered in a batch
molecule_batch_size (int): optional, the number of molecules considered in a
batch. Only needed for transferable training.
mols (Sequence(~deepqmc.molecule.Molecule)): optional, a sequence of molecules
to consider for transferable training. If None the default molecule from
hamil is used.
workdir (str): optional, path, where results should be saved.
train_state (~deepqmc.types.TrainState): optional, training checkpoint to
restore training or run evaluation.
init_step (int): optional, initial step index, useful if
calculation is restarted from checkpoint saved on disk.
max_restarts (int): optional, the maximum number of times the training is
retried before a :class:`NaNError` is raised.
max_eq_steps (int): optional, maximum number of equilibration steps if not
detected earlier.
pretrain_steps (int): optional, the number of pretraining steps wrt. to the
Baseline wave function obtained with pyscf.
pretrain_kwargs (dict): optional, extra arguments for pretraining.
opt_kwargs (dict): optional, extra arguments passed to the optimizer.
fit_kwargs (dict): optional, extra arguments passed to the :func:`~.fit.fit_wf`
function.
chkptdir (str): optional, path, where checkpoints should be saved. Checkpoints
are only saved if :data:`workdir` is not :data:`None`. Default:
data:`workdir`.
chkpts_kwargs (dict): optional, extra arguments for checkpointing.
metric_logger: optional, an object that consumes metric logging information.
If not specified, the default `~.log.TensorboardMetricLogger` is used
to create tensorboard logs.
"""
mode = 'evaluation' if opt is None else 'training'
rng = jax.random.PRNGKey(seed)
rng, rng_smpl = jax.random.split(rng)
mols, molecule_idx_sampler, sampler, pretrain_sampler = initialize_sampling(
rng_smpl,
hamil,
mols,
sampler,
pretrain_sampler,
electron_batch_size,
molecule_batch_size,
)
opt = construct_optimizer(opt, opt_kwargs)
if workdir:
workdir = os.path.join(workdir, mode)
chkptdir = os.path.join(chkptdir, mode) if chkptdir else workdir
os.makedirs(workdir, exist_ok=True)
os.makedirs(chkptdir, exist_ok=True)
chkpts = CheckpointStore(chkptdir, **(chkpts_kwargs or {}))
metric_logger = (metric_logger or TensorboardMetricLogger)(
workdir, len(sampler)
)
log.debug('Setting up HDF5 file...')
h5file = h5py.File(os.path.join(workdir, 'result.h5'), 'a', libver='v110')
h5file.swmr_mode = True
table = H5LogTable(h5file)
table.resize(init_step)
h5file.flush()
pbar = None
try:
if train_state:
log.info(
{
'training': f'Restart training from step {init_step}',
'evaluation': 'Start evaluation',
}[mode]
)
params = train_state[1]
else:
rng, rng_init = jax.random.split(rng)
params = init_wf_params(rng_init, hamil, ansatz)
if pretrain_steps and mode == 'training':
log.info('Pretraining wrt. baseline wave function')
rng, rng_pretrain = jax.random.split(rng)
pretrain_kwargs = pretrain_kwargs or {}
opt_pretrain = construct_optimizer(
pretrain_kwargs.pop('opt', 'adamw'),
pretrain_kwargs.pop('opt_kwargs', None),
wrap=False,
)
ewm_state, update_ewm = init_ewm(decay_alpha=1.0)
ewm_states = len(pretrain_sampler) * [ewm_state]
pbar = tqdm(range(pretrain_steps), desc='pretrain', disable=None)
for step, params, per_sample_losses, mol_idxs in pretrain( # noqa: B007
rng_pretrain,
hamil,
mols,
ansatz,
params,
opt_pretrain,
molecule_idx_sampler,
pretrain_sampler,
steps=pbar,
electron_batch_size=electron_batch_size,
baseline_kwargs=pretrain_kwargs.pop('baseline_kwargs', {}),
):
per_mol_losses = per_sample_losses.mean(axis=1)
ewm_means = []
for loss, mol_idx in zip(per_mol_losses, mol_idxs):
ewm_states[mol_idx] = update_ewm(loss, ewm_states[mol_idx])
ewm_means.append(ewm_states[mol_idx].mean)
pretrain_stats = {
'MSE': per_mol_losses,
'MSE/ewm': jnp.array(ewm_means),
}
mse_rep = '|'.join(
f'{ewm.mean if ewm.mean is not None else jnp.nan:0.2e}'
for ewm in ewm_states
)
pbar.set_postfix(MSE=mse_rep)
if metric_logger:
metric_logger.update(
step, pretrain_stats, {}, mol_idxs, prefix='pretraining'
)
log.info(f'Pretraining completed with MSE = {mse_rep}')
rng = split_rng_key_to_devices(rng)
if not train_state or train_state[0] is None:
rng, rng_eq, rng_smpl_init = split_on_devices(rng, 3)
smpl_state = initialize_sampler_state(
rng_smpl_init, sampler, ansatz, params, electron_batch_size
)
log.info('Equilibrating sampler...')
pbar = tqdm(
count() if max_eq_steps is None else range(max_eq_steps),
desc='equilibrate sampler',
disable=None,
)
for step, smpl_state, mol_idxs, smpl_stats in equilibrate( # noqa: B007
rng_eq,
partial(ansatz.apply, select_one_device(params)),
molecule_idx_sampler,
sampler,
smpl_state,
lambda phys_conf: pairwise_self_distance(phys_conf.r).mean(),
pbar,
block_size=10,
):
tau_rep = '|'.join(
f'{tau:.3f}' for tau in smpl_state['tau'].mean(axis=0)
)
pbar.set_postfix(tau=tau_rep)
if metric_logger:
metric_logger.update(
step, {}, smpl_stats, mol_idxs, prefix='equilibration'
)
pbar.close()
train_state = smpl_state, params, None
if workdir and mode == 'training':
chkpts.update(init_step, train_state)
log.info(f'Start {mode}')
best_ene = None
ewm_state, update_ewm = init_ewm()
ewm_states = len(sampler) * [ewm_state]
for attempt in range(max_restarts):
try:
pbar = trange(
init_step,
steps,
initial=init_step,
total=steps,
desc=mode,
disable=None,
)
for step, train_state, E_loc, mol_idxs, stats in fit_wf( # noqa: B007
rng,
hamil,
ansatz,
opt,
molecule_idx_sampler,
sampler,
electron_batch_size,
pbar,
train_state,
**(fit_kwargs or {}),
):
per_mol_energy = E_loc.mean(axis=1)
ewm_energies = []
for energy, mol_idx in zip(per_mol_energy, mol_idxs):
ewm_states[mol_idx] = update_ewm(energy, ewm_states[mol_idx])
ewm_energies.append(ewm_states[mol_idx].mean)
ene = [
(
ufloat(ewm.mean, jnp.sqrt(ewm.sqerr))
if ewm.mean
else ufloat(jnp.nan, 0)
)
for ewm in ewm_states
]
if all(e.s for e in ene):
energies = '|'.join(f'{e:S}' for e in ene)
pbar.set_postfix(E=energies)
if best_ene is None or any(
map(lambda x, y: x.s < 0.5 * y.s, ene, best_ene)
):
best_ene = ene
log.info(
f'Progress: {step + 1}/{steps}, energy = {energies}'
)
if workdir:
if mode == 'training':
# the convention is that chkpt-i contains the step i-1 -> i
chkpts.update(
step + 1,
train_state,
stats['E_loc/std'].mean(),
)
table.row['mol_idxs'] = mol_idxs
table.row['E_loc'] = E_loc
table.row['E_ewm'] = jnp.array(ewm_energies)
psi = gather_electrons_on_one_device(train_state.sampler['psi'])
if jnp.isnan(psi.log).any():
raise NanError()
table.row['sign_psi'] = psi.sign[mol_idxs]
table.row['log_psi'] = psi.log[mol_idxs]
h5file.flush()
if metric_logger:
single_device_stats = {
'energy/ewm': jnp.array([e.n for e in ene]),
'energy/ewm_error': jnp.array([e.s for e in ene]),
}
metric_logger.update(
step, single_device_stats, stats, mol_idxs
)
log.info(f'The {mode} has been completed!')
return train_state
except NanError:
pbar.close()
log.warn('Restarting due to a NaN...')
if attempt < max_restarts:
init_step, train_state = chkpts.last
log.warn(
f'The {mode} has crashed before all steps were completed ({step}/{steps})!'
)
raise TrainingCrash(train_state)
finally:
if pbar:
pbar.close()
if workdir:
chkpts.close()
metric_logger.close()
h5file.close()