from collections.abc import Callable, Iterable
from statistics import mean, stdev
from typing import Optional
import jax
import jax.numpy as jnp
from ..hamil import MolecularHamiltonian
from ..molecule import Molecule
from ..parallel import pmap, rng_iterator, select_one_device
from ..physics import pairwise_diffs
from ..types import (
    Ansatz,
    KeyArray,
    ParametrizedWaveFunction,
    Params,
    PhysicalConfiguration,
    SamplerState,
)
from .base import ElectronSampler
from .combined_samplers import (
    MoleculeIdxSampler,
    MultiElectronicStateSampler,
    MultiNuclearGeometrySampler,
)
from .nuclei_samplers import IdleNucleiSampler, no_elec_warp
__all__ = ['combine_samplers']
[docs]
def chain(*samplers) -> ElectronSampler:
    r"""
    Combine multiple sampler types, to create advanced sampling schemes.
    For example :data:`chain(DecorrSampler(10),MetropolisSampler(hamil, tau=1.))`
    will create a :class:`MetropolisSampler`, where the samples
    are taken from every 10th MCMC step. The last element of the sampler chain has
    to be either a :class:`MetropolisSampler` or a :class:`LangevinSampler`.
    Args:
        samplers (~deepqmc.sampling.base.ElectronSampler): one or more sampler instances
            to combine.
    Returns:
        :type:`~deepqmc.sampling.base.ElectronSampler`: the combined sampler.
    """
    name = 'Sampler'
    bases = tuple(map(type, samplers))
    for base in bases:
        name = name.replace('Sampler', base.__name__)
    chained = type(name, bases, {'__init__': lambda self: None})()
    for sampler in samplers:
        chained.__dict__.update(sampler.__dict__)
    return chained  # type: ignore 
[docs]
def combine_samplers(
    samplers, hamil: MolecularHamiltonian, wf: ParametrizedWaveFunction
) -> ElectronSampler:
    r"""Combine samplers to create more advanced sampling schemes.
    Args:
        samplers (list[~deepqmc.sampling.base.ElectronSampler]): one or more sampler
            instances to combine.
        hamil (~deepqmc.hamil.MolecularHamiltonian): the molecular Hamiltonian.
        wf (~deepqmc.types.ParametrizedWaveFunction): the wave function to sample.
    """
    sampler = chain(*samplers[:-1], samplers[-1](hamil, wf))
    return sampler 
def diffs_to_nearest_nuc(r, coords):
    z = pairwise_diffs(r, coords)
    idx = jnp.argmin(z[..., -1], axis=-1)
    return z[jnp.arange(len(r)), idx], idx
def crossover_parameter(z, f, charge):
    z, z2 = z[..., :3], z[..., 3]
    eps = jnp.finfo(f.dtype).eps
    z_unit = z / jnp.linalg.norm(z, axis=-1, keepdims=True)
    f_unit = f / jnp.clip(jnp.linalg.norm(f, axis=-1, keepdims=True), eps, None)
    Z2z2 = charge**2 * z2
    return (1 + jnp.sum(f_unit * z_unit, axis=-1)) / 2 + Z2z2 / (10 * (4 + Z2z2))
def clean_force(force, phys_conf, mol, *, tau):
    z, idx = jax.vmap(diffs_to_nearest_nuc)(phys_conf.r, phys_conf.R)
    a = crossover_parameter(z, force, mol.charges[idx])
    av2tau = a * jnp.sum(force**2, axis=-1) * tau
    # av2tau can be small or zero, so the following expression must handle that
    factor = 2 / (jnp.sqrt(1 + 2 * av2tau) + 1)
    force = factor[..., None] * force
    eps = jnp.finfo(phys_conf.r.dtype).eps
    norm_factor = jnp.minimum(
        1.0,
        jnp.sqrt(z[..., -1])
        / (tau * jnp.clip(jnp.linalg.norm(force, axis=-1), eps, None)),
    )
    force = force * norm_factor[..., None]
    return force
def equilibrate(
    rng: KeyArray,
    params: Params,
    molecule_idx_sampler: MoleculeIdxSampler,
    sampler: MultiNuclearGeometrySampler,
    state: SamplerState,
    criterion: Callable[[PhysicalConfiguration], jax.Array],
    steps: Iterable[int],
    *,
    block_size: int,
    n_blocks: int = 5,
    allow_early_stopping: bool = True,
):
    sample_wf = pmap(sampler.sample)
    buffer_size = block_size * n_blocks
    buffer: list[float] = []
    for step, rng in zip(steps, rng_iterator(rng)):
        mol_idxs = molecule_idx_sampler.sample()
        state, phys_conf, stats = sample_wf(rng, state, params, mol_idxs)
        yield step, state, select_one_device(mol_idxs), stats
        if allow_early_stopping:
            buffer = [*buffer[-buffer_size + 1 :], criterion(phys_conf).item()]
            if len(buffer) < buffer_size:
                continue
            b1, b2 = buffer[:block_size], buffer[-block_size:]
            if abs(mean(b1) - mean(b2)) < min(stdev(b1), stdev(b2)):
                break
def initialize_sampling(
    rng: KeyArray,
    hamil: MolecularHamiltonian,
    ansatz: Ansatz,
    mols: list[Molecule],
    electronic_states: int,
    molecule_batch_size: int,
    *,
    elec_sampler,
    nuc_sampler=None,
    elec_warp_fn: Optional[Callable] = None,
    update_nuc_period: Optional[int] = None,
    elec_equilibration_steps: Optional[int] = None,
) -> tuple[MoleculeIdxSampler, MultiNuclearGeometrySampler]:
    molecule_idx_sampler = MoleculeIdxSampler(
        rng, len(mols), molecule_batch_size, 'once'
    )
    elec_sampler = elec_sampler(hamil=hamil, wf=ansatz.apply)
    multi_state_elec_sampler = MultiElectronicStateSampler(
        elec_sampler, electronic_states
    )
    nuc_sampler = (IdleNucleiSampler if nuc_sampler is None else nuc_sampler)(
        hamil.mol.charges,
    )
    elec_warp_fn = no_elec_warp if elec_warp_fn is None else elec_warp_fn
    sampler = MultiNuclearGeometrySampler(
        multi_state_elec_sampler,
        nuc_sampler,
        elec_warp_fn,
        update_nuc_period,
        elec_equilibration_steps,
    )
    return molecule_idx_sampler, sampler
def initialize_sampler_state(rng, sampler, params, electron_batch_size, mols):
    @jax.pmap
    def sampler_state_initializer(rng, params):
        return sampler.init(
            rng,
            params,
            electron_batch_size // jax.device_count(),
            jnp.stack([mol.coords for mol in mols]),
        )
    return sampler_state_initializer(rng, params)