diff --git a/arc/job/adapters/common.py b/arc/job/adapters/common.py index e689ba97b7..f8ebc39fe3 100644 --- a/arc/job/adapters/common.py +++ b/arc/job/adapters/common.py @@ -9,7 +9,6 @@ import sys import re -from pprint import pformat from typing import TYPE_CHECKING from arc.common import get_logger @@ -470,21 +469,26 @@ def set_job_args(args: dict | None, """ Set the job args considering args from ``level`` and from ``trsh``. + The caller (e.g. :meth:`arc.scheduler.Scheduler.run_job`) is expected to + have already merged any ``level.args`` content into ``args`` before calling + this function — ``run_job`` does so via ``args.update(level.args)``. When + the caller passes empty ``args`` and the level supplies ``args``, we fall + back to ``level.args`` for convenience. + Args: - args (dict): The job specific arguments. + args (dict): The job-specific arguments. level (Level): The level of theory. job_name (str): The job name. Returns: - dict: The initialized job specific arguments. + dict: The initialized job-specific arguments, guaranteed to carry the + ``'keyword'``, ``'block'``, and ``'trsh'`` buckets (each a dict). """ - # Ignore user-specified additional job arguments when troubleshooting. - if args is not None and args and any(val for val in args.values()) \ - and level is not None and level.args and any(val for val in level.args.values()): - logger.warning(f'When troubleshooting {job_name}, ARC ignores the following user-specified options:\n' - f'{pformat(level.args)}') - elif not args and level is not None: + # Convenience fallback: empty (or None) caller-args inherits level.args. + if not args and level is not None and level.args is not None: args = level.args + if args is None: + args = dict() for key in ['keyword', 'block', 'trsh']: if key not in args.keys(): args[key] = dict() diff --git a/arc/job/adapters/common_test.py b/arc/job/adapters/common_test.py index 322e352a65..44b04b3316 100644 --- a/arc/job/adapters/common_test.py +++ b/arc/job/adapters/common_test.py @@ -5,6 +5,7 @@ This module contains unit tests of the arc.job.adapters.common module """ +import logging import os import shutil import unittest @@ -166,6 +167,29 @@ def test_set_job_args(self): args = common.set_job_args(args={'keyword': 'k1'}, level=Level(repr='CBS-QB3'), job_name='j1') self.assertEqual(args, {'keyword':'k1', 'block': dict(), 'trsh': dict()}) + def test_set_job_args_no_spurious_warning_when_level_has_args(self): + """Regression: the previous "ARC ignores user-specified options" warning + fired on every first-run job whose level carried args, because + ``run_job`` had already merged ``level.args`` into ``args`` before + calling — nothing was actually being ignored. The warning should now + be silent on a normal first-run path.""" + merged_args = {'keyword': {'core': 'core,0,0,0,0,0,0,0,0;'}, 'block': {}} + level_with_args = Level(method='ccsd(t)', basis='cc-pCVTZ', + args=merged_args) + with self.assertNoLogs(logger='arc', level=logging.WARNING): + result = common.set_job_args(args=merged_args, + level=level_with_args, job_name='j_first_run') + # Args content is preserved (not dropped). + self.assertEqual(result['keyword'], {'core': 'core,0,0,0,0,0,0,0,0;'}) + self.assertEqual(result['trsh'], {}) # bucket added by guarantee + + def test_set_job_args_args_none_preserves_level_args(self): + """When the caller passes None, fall back to level.args (legacy convenience).""" + level = Level(method='ccsd(t)', basis='cc-pVTZ', + args={'keyword': {'general': 'foo'}, 'block': {}}) + result = common.set_job_args(args=None, level=level, job_name='j1') + self.assertEqual(result['keyword'], {'general': 'foo'}) + def test_which(self): """Test the which() function""" ans = common.which(command='python', return_bool=True, raise_error=False) diff --git a/arc/job/adapters/molpro.py b/arc/job/adapters/molpro.py index 0ed556bf55..d72ab73c29 100644 --- a/arc/job/adapters/molpro.py +++ b/arc/job/adapters/molpro.py @@ -35,6 +35,20 @@ settings['default_job_settings'], settings['global_ess_settings'], settings['input_filenames'], \ settings['output_filenames'], settings['servers'], settings['submit_filenames'] +# Methods that native Molpro does not support but its MRCC plugin does. +# When the level's method matches one of these (case-insensitive), the adapter +# emits a ``{mrcc,method=...}`` plugin call instead of a bare directive that +# Molpro's input parser would reject with "Unknown command or directive". +# Compared against the lowercased ``Level.method``. +MRCC_ROUTED_METHODS = frozenset({ + 'ccsdt', + 'ccsdt(q)', + 'ccsdtq', + 'ccsdtq(p)', + 'ccsdtqp', +}) + + input_template = """***,${label} memory,Total=${memory},m; @@ -47,7 +61,7 @@ ${cabs} int; -{hf;${shift} +{${hf_method};${shift} maxit,999; wf,spin=${spin},charge=${charge}; } @@ -229,10 +243,37 @@ def write_input_file(self) -> None: input_dict['spin'] = self.multiplicity - 1 input_dict['xyz'] = xyz_to_str(self.xyz) input_dict['orbitals'] = '\ngprint,orbitals;\n' + input_dict['hf_method'] = 'hf' # default; overridden below for open-shell MRCC if not is_restricted(self): input_dict['restricted'] = 'u' + if self.level.method in MRCC_ROUTED_METHODS: + # Restriction is implicit from the preceding {hf;...} block; the + # MRCC plugin call does not accept a 'u'/'r' prefix. + input_dict['method'] = '{mrcc,method=' + self.level.method.upper() + '}' + input_dict['restricted'] = '' + if not is_restricted(self): + # Open-shell wavefunction + MRCC's approximate-CC family + # (CCSDT(Q), CCSDTQ(P), and the perturbative-(T) variants) + # refuses standard ROHF orbitals: + # "Approximate CC methods are not implemented for standard + # ROHF orbitals! Use semicanonical orbitals!" + # Solution: use UHF instead of (RO)HF as the SCF reference. + # UHF orbitals are semicanonical by construction (alpha and + # beta Fock matrices are separately diagonal) and live at the + # default record 2100.2, which MRCC reads. MRCC then reports + # ``Type=UHF/CANONICAL`` and accepts. + # + # An earlier attempt at this fix prepended ``{uccsd}`` to the + # MRCC call. {uccsd} does run UCCSD on top of ROHF, but the + # post-UCCSD canonical orbitals go to a separate record while + # the default 2100.2 still holds the original ROHF orbitals — + # MRCC reads 2100.2 by default and complained. Switching the + # SCF reference to UHF avoids this orbital-record bookkeeping + # entirely. + input_dict['hf_method'] = 'uhf' + # Job type specific options if self.job_type in ['opt', 'optfreq', 'conf_opt']: keywords = ['optg', 'root=2', 'method=qsd', 'readhess', "savexyz='geometry.xyz'"] if self.is_ts \ diff --git a/arc/job/adapters/molpro_test.py b/arc/job/adapters/molpro_test.py index 113dc77582..801f55ff39 100644 --- a/arc/job/adapters/molpro_test.py +++ b/arc/job/adapters/molpro_test.py @@ -97,6 +97,24 @@ def setUpClass(cls): 'closed': [1, 0, 0, 0, 0, 0, 0, 0]})], testing=True, ) + cls.job_mrcc_ccsdt = MolproAdapter(execution_type='queue', + job_type='sp', + level=Level(method='CCSDT', basis='cc-pVDZ'), + project='test', + project_directory=os.path.join(ARC_TESTING_PATH, + 'test_MolproAdapter_mrcc_ccsdt'), + species=[ARCSpecies(label='spc1', xyz=['O 0 0 1'], multiplicity=3)], + testing=True, + ) + cls.job_mrcc_ccsdtq = MolproAdapter(execution_type='queue', + job_type='sp', + level=Level(method='CCSDT(Q)', basis='cc-pVDZ'), + project='test', + project_directory=os.path.join(ARC_TESTING_PATH, + 'test_MolproAdapter_mrcc_ccsdtq'), + species=[ARCSpecies(label='spc1', xyz=['O 0 0 1'], multiplicity=1)], + testing=True, + ) def test_set_cpu_and_mem(self): """Test assigning number of cpu's and memory""" @@ -441,6 +459,107 @@ def test_write_mrci_input_file(self): """ self.assertEqual(content_7, job_7_expected_input_file) + def test_write_input_file_mrcc_routing(self): + """Methods unsupported by native Molpro but supported by MRCC are routed through the MRCC plugin. + + For an open-shell wavefunction, the SCF reference is switched from + ``{hf;...}`` (which gives Molpro's ROHF for open-shell) to + ``{uhf;...}``. MRCC's approximate-CC family (``CCSDT(Q)``, + ``CCSDTQ(P)``, and the perturbative-``(T)`` variants) refuses + standard ROHF orbitals with the error:: + + Approximate CC methods are not implemented for standard ROHF orbitals! + Use semicanonical orbitals! + + UHF orbitals are semicanonical by construction (alpha and beta Fock + matrices are separately diagonal), saved to the default record 2100.2 + which MRCC reads — MRCC then reports ``Type=UHF/CANONICAL`` and runs + the requested approximate-CC method. + """ + self.job_mrcc_ccsdt.cpu_cores = 48 + self.job_mrcc_ccsdt.set_input_file_memory() + self.job_mrcc_ccsdt.write_input_file() + with open(os.path.join(self.job_mrcc_ccsdt.local_path, + input_filenames[self.job_mrcc_ccsdt.job_adapter]), 'r') as f: + content_ccsdt = f.read() + # spc1 has multiplicity=3 (open-shell triplet) — UHF reference expected. + expected_ccsdt = """***,spc1 +memory,Total=438,m; + +geometry={angstrom; +O 0.00000000 0.00000000 1.00000000} + +gprint,orbitals; + +basis=cc-pvdz + + + +int; + +{uhf; + maxit,999; + wf,spin=2,charge=0; +} + +{mrcc,method=CCSDT} + + + +---; + +""" + self.assertEqual(content_ccsdt, expected_ccsdt) + # Sanity: the bare directive Molpro rejects must NOT appear on its own line. + self.assertNotIn('\nccsdt;\n', content_ccsdt) + self.assertNotIn('\nuccsdt;\n', content_ccsdt) + # An earlier (insufficient) fix used `{uccsd}` between HF and MRCC — + # this contract has been replaced with UHF, so {uccsd} must NOT appear. + self.assertNotIn('{uccsd}', content_ccsdt) + # UHF must replace HF as the only SCF reference (no {hf;...} block). + self.assertNotIn('{hf;', content_ccsdt) + self.assertIn('{uhf;', content_ccsdt) + + self.job_mrcc_ccsdtq.cpu_cores = 48 + self.job_mrcc_ccsdtq.set_input_file_memory() + self.job_mrcc_ccsdtq.write_input_file() + with open(os.path.join(self.job_mrcc_ccsdtq.local_path, + input_filenames[self.job_mrcc_ccsdtq.job_adapter]), 'r') as f: + content_ccsdtq = f.read() + expected_ccsdtq = """***,spc1 +memory,Total=438,m; + +geometry={angstrom; +O 0.00000000 0.00000000 1.00000000} + +gprint,orbitals; + +basis=cc-pvdz + + + +int; + +{hf; + maxit,999; + wf,spin=0,charge=0; +} + +{mrcc,method=CCSDT(Q)} + + + +---; + +""" + self.assertEqual(content_ccsdtq, expected_ccsdtq) + self.assertNotIn('\nccsdt(q);\n', content_ccsdtq) + # spc1 here has multiplicity=1 (closed-shell) — RHF gives canonical + # orbitals MRCC accepts directly. No UHF/UCCSD pre-step needed. + self.assertNotIn('{uccsd}', content_ccsdtq) + self.assertNotIn('{uhf;', content_ccsdtq) + self.assertIn('{hf;', content_ccsdtq) + def test_set_files(self): """Test setting files""" job_1_files_to_upload = [{'file_name': 'submit.sub', diff --git a/arc/job/trsh.py b/arc/job/trsh.py index a1ec9a8eab..33b927d55b 100644 --- a/arc/job/trsh.py +++ b/arc/job/trsh.py @@ -393,10 +393,41 @@ def determine_ess_status(output_path: str, return 'errored', keywords, error, line elif software == 'molpro': + # MRCC ROHF-incompatibility check BEFORE the generic reverse scan + # because the underlying cause ("Use semicanonical orbitals!") + # appears earlier in the file than the downstream "Fatal error in + # mrcc." line — reverse iteration would otherwise classify the + # latter (generic) before the former (specific). Fix in the + # adapter prepends ``{uccsd}`` to generate semicanonical orbitals; + # this keyword surfaces the diagnostic for any legacy run that + # hits it. + joined = '\n'.join(lines) if isinstance(lines, list) else str(lines) + if 'standard ROHF orbitals' in joined or 'Use semicanonical orbitals' in joined: + rohf_line = next( + (ln for ln in lines if 'standard ROHF orbitals' in ln + or 'Use semicanonical orbitals' in ln), + '', + ) + return ('errored', ['MRCCRequiresSemicanonical'], + 'MRCC requires semicanonical orbitals; ROHF orbitals ' + 'are not supported for approximate CC.', + rohf_line) for line in reverse_lines: if 'molpro calculation terminated' in line.lower() \ or 'variable memory released' in line.lower(): return 'done', list(), '', '' + elif 'Fatal error in xmrcc' in line or 'Fatal error in mrcc' in line: + # MRCC bailed for a tiny system where the requested CC + # excitation rank exceeds the determinant space (e.g. + # atomic H or H2 at CCSDT(Q)). The composite framework + # should short-circuit a δ-term high leg with this + # keyword to the corresponding low-leg energy (δ = 0, + # which is correct for a degenerate-method case). + keywords = ['MRCCDegenerateSystem'] + error = ('MRCC xmrcc fatal — the requested CC excitation ' + 'rank exceeds the determinant space for this ' + 'system (degenerate / too few electrons).') + break elif 'No convergence' in line and '?No convergence in rhfpr' not in line: keywords = ['Unconverged'] error = 'Unconverged' @@ -1684,13 +1715,12 @@ def scan_quality_check(label: str, logger.warning(message) return invalidate, invalidation_reason, message, actions else: - logger.warning(f'The maximal barrier for rotor {pivots} of {label} is ' - f'{(np.max(energies) - np.min(energies)):.2f} kJ/mol, which is higher than the set threshold ' - f'of {maximum_barrier} kJ/mol. Since this mode when treated as torsion has {num_wells}, ' - f'this mode is not invalidated: treating it as a vibrational mode will be less accurate than ' - f'the hindered rotor treatment, since the entropy contribution from the population of ' - f'this species at the higher wells will not be taken into account. NOT invalidating this ' - f'torsional mode.') + barrier_kJmol = np.max(energies) - np.min(energies) + logger.warning(f'Rotor {pivots} of {label}: barrier {barrier_kJmol:.2f} kJ/mol ' + f'exceeds the {maximum_barrier} kJ/mol threshold, but the mode has ' + f'{num_wells} wells. Keeping the hindered-rotor treatment — ' + f'demoting to a harmonic vibration would miss the entropic ' + f'contribution from the upper well(s).') if preserve_params is not None: success = True diff --git a/arc/job/trsh_test.py b/arc/job/trsh_test.py index d974874e9c..556a724623 100644 --- a/arc/job/trsh_test.py +++ b/arc/job/trsh_test.py @@ -171,6 +171,32 @@ def test_determine_ess_status(self): self.assertEqual(error, "Unrecognized basis set 6-311G**") self.assertIn(" ? Basis library exhausted", line) # line includes '\n' + # Molpro + MRCC: degenerate small system (e.g. atomic H, H2 at CCSDT(Q)). + # MRCC's xmrcc bails because there's no determinant space at the + # requested excitation rank. Trsh must classify this so the framework + # knows to short-circuit the sub-job (delta = 0) instead of cycling + # the generic ladder (shift / vdz / memory). + path = os.path.join(self.base_path["molpro"], "mrcc_xmrcc_fatal.out") + status, keywords, error, line = trsh.determine_ess_status( + output_path=path, species_label="H", job_type="sp" + ) + self.assertEqual(status, "errored") + self.assertEqual(keywords, ["MRCCDegenerateSystem"]) + self.assertIn("xmrcc", error.lower()) + self.assertIn("Fatal error in xmrcc", line) + + # Molpro + MRCC: ROHF orbitals incompatible with approximate CC methods + # (open-shell radicals). Trsh classifies and the adapter's UCCSD + # prefix should prevent this from happening on new runs; the keyword + # is the diagnostic for any legacy runs that don't have the prefix. + path = os.path.join(self.base_path["molpro"], "mrcc_rohf_unsupported.out") + status, keywords, error, line = trsh.determine_ess_status( + output_path=path, species_label="OH", job_type="sp" + ) + self.assertEqual(status, "errored") + self.assertEqual(keywords, ["MRCCRequiresSemicanonical"]) + self.assertIn("semicanonical", error.lower()) + # Orca # test detection of a successful job diff --git a/arc/job/zombie.py b/arc/job/zombie.py new file mode 100644 index 0000000000..c65326bf09 --- /dev/null +++ b/arc/job/zombie.py @@ -0,0 +1,114 @@ +"""Zombie-job detection helpers. + +A "zombie" is a queue-running job that has produced no output traffic by the +grace period: scheduler reports it as RUNNING, but the ESS process has wedged +or never started. The orchestration (kill + resubmit + per-(species, job_type) +cap) lives on the Scheduler; the pure decision logic and ESS classification +live here. +""" + +import datetime +import os + +from arc.common import get_logger +from arc.imports import settings +from arc.job.ssh import SSHClient + + +logger = get_logger() + + +ZOMBIE_GRACE_SECONDS = 3600 + +ZOMBIE_OUTPUT_FILENAME_FALLBACK = 'out.txt' + +# ESS that flush login-visible output as the job runs (per SCF / per CC iter +# / per opt step). For these, absence of any output traffic after the grace +# period is a strong "zombie" signal. Incore-only or near-instant ESS +# (xtb / torchani / openbabel / mockter) are exempt. +ESS_PERIODIC_WRITERS = frozenset({ + 'cfour', 'gaussian', 'molpro', 'orca', 'psi4', 'qchem', 'terachem', +}) + + +def output_mtime(job) -> datetime.datetime | None: + """Return the latest mtime of the job's ESS output file. + + Tries the configured ESS output filename first and falls back to the + wrapper log. Local jobs use ``os.path.getmtime``; remote jobs use + ``SSHClient.get_last_modified_time`` against ``job.remote_path``. + + Args: + job: A ``JobAdapter`` (duck-typed). Required attributes: ``job_adapter``, + ``server``, ``local_path``, ``local_path_to_output_file``, + ``remote_path``, ``job_name``. + + Returns: + datetime.datetime | None: The output file's mtime, or ``None`` if no + candidate output file exists or the remote stat failed. + """ + out_filename = settings.get('output_filenames', {}).get(job.job_adapter) + if job.server is None or job.server in ('', 'local'): + candidates = [job.local_path_to_output_file] + if out_filename: + candidates.append(os.path.join(job.local_path, out_filename)) + for path in candidates: + if path and os.path.isfile(path): + return datetime.datetime.fromtimestamp(os.path.getmtime(path)) + return None + try: + with SSHClient(job.server) as ssh: + p1 = os.path.join(job.remote_path, out_filename) if out_filename else None + p2 = os.path.join(job.remote_path, ZOMBIE_OUTPUT_FILENAME_FALLBACK) + return ssh.get_last_modified_time(remote_file_path_1=p1 or p2, + remote_file_path_2=p2) + except Exception as exc: + logger.warning( + f'Could not stat remote output for job {job.job_name} on ' + f'{job.server} ({type(exc).__name__}: {exc}); skipping zombie check.' + ) + return None + + +def is_zombie(job, server_job_ids, now: datetime.datetime | None = None) -> bool: + """Decide whether a job is a zombie. + + Pure decision: takes the queue's running set rather than reaching into a + ``Scheduler``. A job is a zombie iff all of these hold: + + * Its ``execution_type`` is not ``'incore'``. + * Its ESS is in :data:`ESS_PERIODIC_WRITERS`. + * The queue still reports it as running (``job.job_id in server_job_ids``). + * It has been past :data:`ZOMBIE_GRACE_SECONDS` since spawn + (``job.initial_time``). + * Its output file is missing, or its mtime is at-or-before spawn time. + + Args: + job: A ``JobAdapter`` (duck-typed). Required attributes: ``execution_type``, + ``job_adapter``, ``job_id``, ``initial_time``, plus everything + :func:`output_mtime` needs. + server_job_ids: A collection of queue job IDs the scheduler currently + considers running. Membership is tested with ``in``. + now (datetime.datetime, optional): Reference "current time" for the + grace-period check. Defaults to ``datetime.datetime.now()``; + override in tests for determinism. + + Returns: + bool: ``True`` if the job is a zombie, ``False`` otherwise. + """ + if job.execution_type == 'incore': + return False + adapter_name = (getattr(job, 'job_adapter', None) or '').lower() + if adapter_name not in ESS_PERIODIC_WRITERS: + return False + if job.job_id is None or job.job_id not in server_job_ids: + return False + if job.initial_time is None: + return False + now = now or datetime.datetime.now() + if (now - job.initial_time).total_seconds() < ZOMBIE_GRACE_SECONDS: + return False + mtime = output_mtime(job) + if mtime is None: + return True + return mtime <= job.initial_time diff --git a/arc/job/zombie_test.py b/arc/job/zombie_test.py new file mode 100644 index 0000000000..8a0ac68aca --- /dev/null +++ b/arc/job/zombie_test.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +"""Unit tests for arc.job.zombie — pure helpers and ESS classification.""" + +import datetime +import os +import tempfile +import unittest +from types import SimpleNamespace +from unittest.mock import patch + +from arc.job import zombie + + +def _stub_job(job_adapter='molpro', job_type='sp', execution_type='queue', + initial_offset_seconds=7200, job_name='sp_a3177', job_id=12345, + server='server1', remote_path='/remote/no/such/path', + local_path='/tmp/no/such/path', + local_path_to_output_file='/tmp/no/such/output.out'): + return SimpleNamespace( + job_name=job_name, job_type=job_type, job_id=job_id, + job_adapter=job_adapter, execution_type=execution_type, + initial_time=datetime.datetime.now() - datetime.timedelta(seconds=initial_offset_seconds), + server=server, + local_path=local_path, local_path_to_output_file=local_path_to_output_file, + remote_path=remote_path, + ) + + +class TestEssPeriodicWritersClassification(unittest.TestCase): + def test_periodic_writers_set(self): + self.assertEqual( + zombie.ESS_PERIODIC_WRITERS, + frozenset({'cfour', 'gaussian', 'molpro', 'orca', 'psi4', 'qchem', 'terachem'}), + ) + + def test_grace_period_default(self): + self.assertEqual(zombie.ZOMBIE_GRACE_SECONDS, 3600) + + +class TestIsZombie(unittest.TestCase): + def test_zombie_when_no_output_after_grace(self): + job = _stub_job() + with patch('arc.job.zombie.output_mtime', return_value=None): + self.assertTrue(zombie.is_zombie(job, server_job_ids=[job.job_id])) + + def test_not_zombie_when_output_fresh(self): + job = _stub_job() + fresh = job.initial_time + datetime.timedelta(seconds=2000) + with patch('arc.job.zombie.output_mtime', return_value=fresh): + self.assertFalse(zombie.is_zombie(job, server_job_ids=[job.job_id])) + + def test_zombie_when_output_mtime_at_spawn_time(self): + """An output file whose mtime equals spawn_time means ARC's own input + write — no ESS progress. Treat as zombie.""" + job = _stub_job() + with patch('arc.job.zombie.output_mtime', return_value=job.initial_time): + self.assertTrue(zombie.is_zombie(job, server_job_ids=[job.job_id])) + + def test_grace_period_blocks(self): + job = _stub_job(initial_offset_seconds=1800) # 30 min + with patch('arc.job.zombie.output_mtime', return_value=None): + self.assertFalse(zombie.is_zombie(job, server_job_ids=[job.job_id])) + + def test_non_periodic_writer_skipped(self): + job = _stub_job(job_adapter='xtb') + with patch('arc.job.zombie.output_mtime', return_value=None): + self.assertFalse(zombie.is_zombie(job, server_job_ids=[job.job_id])) + + def test_incore_skipped(self): + job = _stub_job(execution_type='incore') + with patch('arc.job.zombie.output_mtime', return_value=None): + self.assertFalse(zombie.is_zombie(job, server_job_ids=[job.job_id])) + + def test_queue_done_skipped(self): + job = _stub_job() + with patch('arc.job.zombie.output_mtime', return_value=None): + self.assertFalse(zombie.is_zombie(job, server_job_ids=[])) + + def test_no_initial_time_skipped(self): + job = _stub_job() + job.initial_time = None + with patch('arc.job.zombie.output_mtime', return_value=None): + self.assertFalse(zombie.is_zombie(job, server_job_ids=[job.job_id])) + + def test_now_argument_overrides_clock(self): + """Pass an explicit ``now`` to remove wall-clock dependency in tests.""" + job = _stub_job(initial_offset_seconds=0) + spawn = job.initial_time + within_grace = spawn + datetime.timedelta(seconds=zombie.ZOMBIE_GRACE_SECONDS - 1) + past_grace = spawn + datetime.timedelta(seconds=zombie.ZOMBIE_GRACE_SECONDS + 1) + with patch('arc.job.zombie.output_mtime', return_value=None): + self.assertFalse(zombie.is_zombie(job, [job.job_id], now=within_grace)) + self.assertTrue(zombie.is_zombie(job, [job.job_id], now=past_grace)) + + +class TestOutputMtimeLocal(unittest.TestCase): + def test_local_output_present(self): + with tempfile.TemporaryDirectory() as tmp: + out_path = os.path.join(tmp, 'output.out') + with open(out_path, 'w') as fh: + fh.write('x') + job = _stub_job(server='local', local_path=tmp, local_path_to_output_file=out_path) + mtime = zombie.output_mtime(job) + self.assertIsNotNone(mtime) + self.assertIsInstance(mtime, datetime.datetime) + + def test_local_output_missing(self): + with tempfile.TemporaryDirectory() as tmp: + job = _stub_job(server='local', local_path=tmp, + local_path_to_output_file=os.path.join(tmp, 'nope.out')) + self.assertIsNone(zombie.output_mtime(job)) + + def test_local_server_none_treated_as_local(self): + with tempfile.TemporaryDirectory() as tmp: + out_path = os.path.join(tmp, 'output.out') + with open(out_path, 'w') as fh: + fh.write('x') + job = _stub_job(server=None, local_path=tmp, local_path_to_output_file=out_path) + self.assertIsNotNone(zombie.output_mtime(job)) + + +if __name__ == '__main__': + unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/arc/level/__init__.py b/arc/level/__init__.py new file mode 100644 index 0000000000..5310bc21e0 --- /dev/null +++ b/arc/level/__init__.py @@ -0,0 +1,68 @@ +""" +``arc.level`` — level-of-theory abstractions for ARC. + +This package groups everything related to specifying *how* an electronic-structure +calculation is performed: + +* The legacy :class:`~arc.level.level.Level` class, which represents a single QM level + (method, basis, dispersion, solvation, ESS-specific options) and is unchanged from + ``arc/level.py`` prior to its relocation into this package. +* New composite single-point abstractions added in Phase 1 of the ``sp_composite`` work: + protocols, terms, presets, CBS extrapolation, and reporting helpers. These let a + user define the final electronic energy of a stationary point as a sum of multiple + SP corrections — a HEAT-style focal-point analysis (Tajti et al., + *J. Chem. Phys.* **121**, 11599 (2004); DOI: 10.1063/1.1811608). + +Backwards compatibility +----------------------- + +All public symbols that historically lived in ``arc/level.py`` are re-exported here so +that existing call sites (``from arc.level import Level`` etc.) continue to work +without modification. New code should prefer the qualified imports +``from arc.level.protocol import CompositeProtocol`` etc. when reaching for the new +machinery. + +References +---------- + +* Allen, East, Császár, *Structures and Conformations of Non-Rigid Molecules* — review + of focal-point analysis methodology. +* Tajti, Szalay, Császár, Kállay, Gauss, Valeev, Flowers, Vázquez, Stanton, + *J. Chem. Phys.* **121**, 11599 (2004). DOI: 10.1063/1.1811608 — HEAT protocol. +* Helgaker, Klopper, Koch, Noga, *J. Chem. Phys.* **106**, 9639 (1997). + DOI: 10.1063/1.473863 — two-point correlation-energy CBS extrapolation. +* Halkier, Helgaker, Jørgensen, Klopper, Koch, Olsen, Wilson, + *Chem. Phys. Lett.* **286**, 243-252 (1998). DOI: 10.1016/S0009-2614(98)00111-0 — + extends the two-point correlation-energy CBS extrapolation to Ne, N₂, H₂O. +* Halkier, Helgaker, Jørgensen, Klopper, Olsen, + *Chem. Phys. Lett.* **302**, 437-446 (1999). DOI: 10.1016/S0009-2614(99)00179-7 — + two-point HF-energy CBS extrapolation; source of the fitted ``α = 1.63``. +* Martin, *Chem. Phys. Lett.* **259**, 669-678 (1996). + DOI: 10.1016/0009-2614(96)00898-6 — three-point Schwartz-style extrapolation. +* Dunning, *J. Chem. Phys.* **90**, 1007 (1989). DOI: 10.1063/1.456153 — correlation- + consistent basis-set families used by the cardinal-number deduction logic. +""" + +from arc.level.level import ( + Level, + assign_frequency_scale_factor, + levels_ess, + logger, + supported_ess, +) +from arc.level.species_state import ( + INHERIT, + SP_COMPOSITE_STATES, + active_composite_for, +) + +__all__ = [ + "Level", + "assign_frequency_scale_factor", + "levels_ess", + "logger", + "supported_ess", + "INHERIT", + "SP_COMPOSITE_STATES", + "active_composite_for", +] diff --git a/arc/level/cbs.py b/arc/level/cbs.py new file mode 100644 index 0000000000..88da961be1 --- /dev/null +++ b/arc/level/cbs.py @@ -0,0 +1,389 @@ +""" +``arc.level.cbs`` — Complete-Basis-Set extrapolation primitives. + +This module implements the building blocks needed by +:class:`~arc.level.protocol.CBSExtrapolationTerm`: the cardinal-number deduction from +basis-set names, the three built-in extrapolation formulas shipped with ARC, and a +sandboxed evaluator for user-supplied formula strings. + +The CBS step in a focal-point analysis takes ≥2 single-point energies computed at the +*same* method but at *different* basis-set cardinalities X (cc-pVDZ → 2, cc-pVTZ → 3, +cc-pVQZ → 4, ...) and combines them according to a closed-form expression that +extrapolates to the (formally infinite) basis-set limit. + +Built-in formulas +----------------- + +``helgaker_corr_2pt`` + Two-point correlation-energy extrapolation + ``E_CBS = (X^3·E_X − Y^3·E_Y) / (X^3 − Y^3)``. + Helgaker, Klopper, Koch, Noga, *J. Chem. Phys.* **106**, 9639 (1997), + Eq. 4. DOI: 10.1063/1.473863. + +``helgaker_hf_2pt`` + Two-point HF-energy extrapolation + ``E(X) = E_CBS + A·exp(-α·X)``, default ``α = 1.63``. + Halkier, Helgaker, Jørgensen, Klopper, Olsen, + *Chem. Phys. Lett.* **302**, 437-446 (1999), "Basis-set convergence of the + energy in molecular Hartree–Fock calculations". + DOI: 10.1016/S0009-2614(99)00179-7. + +``martin_3pt`` + Three-point Schwartz-style extrapolation + ``E(L) = E_CBS + b·(L+½)^(-4) + c·(L+½)^(-6)`` solved exactly for the three + unknowns. Martin, *Chem. Phys. Lett.* **259**, 669-678 (1996), Eq. 5. + DOI: 10.1016/0009-2614(96)00898-6. + +Cardinal numbers follow the Dunning correlation-consistent convention introduced in +Dunning, *J. Chem. Phys.* **90**, 1007 (1989). DOI: 10.1063/1.456153. +""" + +import ast +import math +import re +from collections.abc import Callable, Mapping + +import numpy as np + +from arc.exceptions import InputError + + +# ----------------------------------------------------------------------------- # +# Cardinal-number deduction # +# ----------------------------------------------------------------------------- # + +# Map letter labels in correlation-consistent basis sets to cardinal numbers. +# D=2, T=3, Q=4 (Dunning, J. Chem. Phys. 90, 1007 (1989)). +_LETTER_CARDINAL = {"D": 2, "T": 3, "Q": 4} + +# Pattern: optional aug- prefix, cc-p, optional C, V, then cardinal letter or digit, Z. +# Accepts cc-pVDZ, cc-pVTZ, cc-pVQZ, cc-pV5Z, cc-pV6Z, cc-pV7Z, cc-pCV*, aug-cc-pV*. +_DUNNING_RE = re.compile( + r"^(?:aug-)?cc-p(?:c)?v(?P[dtq2-7])z(?:-[a-z0-9]+)?$", + re.IGNORECASE, +) + +# Pattern for the def2 family (Weigend & Ahlrichs): SVP=2, TZVP=3, QZVP=4, plus PP variants. +_DEF2_RE = re.compile( + r"^def2-(?Ps|tz|qz)vp+(?:d?)?$", + re.IGNORECASE, +) + +_DEF2_CARDINAL = {"S": 2, "TZ": 3, "QZ": 4} + + +def cardinal_from_basis(basis: str) -> int: + """Return the cardinal number X for a correlation-consistent or def2 basis set. + + Parameters + ---------- + basis : str + Basis-set name (case-insensitive). Supported families: + + * ``cc-pV{D,T,Q,5,6,7}Z`` — Dunning correlation-consistent. + * ``aug-cc-pV{D,T,Q,5,6,7}Z`` — diffuse-augmented variants. + * ``cc-pCV{D,T,Q,5,6}Z`` and ``aug-cc-pCV*`` — core-valence variants. + * ``def2-{SVP,TZVP,QZVP}`` and the ``...PP`` variants (Weigend & Ahlrichs). + + Returns + ------- + int + Cardinal X (2 for double-zeta, 3 for triple-zeta, etc.). + + Raises + ------ + arc.exceptions.InputError + If ``basis`` does not match a known correlation-consistent or def2 pattern. + CBS extrapolation requires a known cardinal; non-systematic basis sets such + as ``6-31G*`` or ``STO-3G`` are rejected explicitly. + """ + if not basis: + raise InputError("Cannot deduce cardinal number from an empty basis-set name.") + text = basis.strip() + m = _DUNNING_RE.match(text) + if m: + card = m.group("card").upper() + if card.isdigit(): + return int(card) + return _LETTER_CARDINAL[card] + m = _DEF2_RE.match(text) + if m: + return _DEF2_CARDINAL[m.group("card").upper()] + raise InputError( + f"Cannot deduce a CBS cardinal number from basis '{basis}'. " + "Only correlation-consistent (cc-pV*Z, aug-cc-pV*Z, cc-pCV*Z) and def2 " + "(def2-SVP, def2-TZVP, def2-QZVP) families are supported. Use one of " + "these families for the levels of a cbs_extrapolation term, or add a " + "new pattern to this function if you need a different basis family." + ) + + +# ----------------------------------------------------------------------------- # +# Built-in CBS formulas # +# ----------------------------------------------------------------------------- # + + +def _sorted_pairs(energies: Mapping[int, float], expected: int) -> list: + """Return ``[(X, E_X), ...]`` sorted by cardinal, validating count & uniqueness.""" + pairs = sorted(energies.items()) + if len(pairs) != expected: + raise InputError( + f"Expected exactly {expected} (cardinal, energy) pairs, got {len(pairs)}." + ) + cardinals = [X for X, _ in pairs] + if len(set(cardinals)) != len(cardinals): + raise InputError(f"Cardinals must be distinct, got {cardinals}.") + return pairs + + +def helgaker_corr_2pt(energies: Mapping[int, float]) -> float: + """Two-point correlation-energy CBS extrapolation. + + Implements ``E_CBS = (X³·E_X − Y³·E_Y) / (X³ − Y³)`` per + Helgaker, Klopper, Koch, Noga, *J. Chem. Phys.* **106**, 9639 (1997), Eq. 4. + DOI: 10.1063/1.473863. + + Parameters + ---------- + energies : Mapping[int, float] + Mapping ``{cardinal: energy}`` with exactly two entries. Insertion order is + irrelevant: pairs are sorted by ascending cardinal internally. + + Returns + ------- + float + Extrapolated energy in the same units as the inputs. + """ + (X, E_X), (Y, E_Y) = _sorted_pairs(energies, expected=2) + return (X ** 3 * E_X - Y ** 3 * E_Y) / (X ** 3 - Y ** 3) + + +def helgaker_hf_2pt(energies: Mapping[int, float], alpha: float = 1.63) -> float: + """Two-point HF (or other exponentially-converging) CBS extrapolation. + + Solves ``E(X) = E_CBS + A·exp(-α·X)`` for two cardinals analytically: + ``E_CBS = (E_X·exp(-α·Y) − E_Y·exp(-α·X)) / (exp(-α·Y) − exp(-α·X))``. + + Halkier, Helgaker, Jørgensen, Klopper, Olsen, *Chem. Phys. Lett.* **302**, + 437-446 (1999), "Basis-set convergence of the energy in molecular + Hartree–Fock calculations" reports the fitted value ``α = 1.63`` averaged + across small molecules. DOI: 10.1016/S0009-2614(99)00179-7. + + Parameters + ---------- + energies : Mapping[int, float] + Mapping ``{cardinal: energy}`` with exactly two entries. + alpha : float, optional + Exponential decay parameter. Defaults to 1.63 (Halkier et al. 1999). + + Returns + ------- + float + Extrapolated energy. + """ + (X, E_X), (Y, E_Y) = _sorted_pairs(energies, expected=2) + e_x = math.exp(-alpha * X) + e_y = math.exp(-alpha * Y) + return (E_X * e_y - E_Y * e_x) / (e_y - e_x) + + +def martin_3pt(energies: Mapping[int, float]) -> float: + """Three-point Schwartz-style CBS extrapolation. + + Solves the linear system + + E(L) = E_CBS + b·(L+½)⁻⁴ + c·(L+½)⁻⁶ + + exactly for ``E_CBS`` given three (L, E(L)) pairs. + + Martin, *Chem. Phys. Lett.* **259**, 669-678 (1996), Eq. 5. + DOI: 10.1016/0009-2614(96)00898-6. + + Parameters + ---------- + energies : Mapping[int, float] + Mapping ``{cardinal: energy}`` with exactly three entries. + + Returns + ------- + float + Extrapolated energy. + """ + pairs = _sorted_pairs(energies, expected=3) + A = np.array( + [[1.0, (L + 0.5) ** -4, (L + 0.5) ** -6] for L, _ in pairs], + dtype=float, + ) + b = np.array([E for _, E in pairs], dtype=float) + e_cbs, _b, _c = np.linalg.solve(A, b) + return float(e_cbs) + + +# String → callable registry advertised to user input. New built-in formulas are +# added by inserting an entry here (and a corresponding test). +BUILTIN_FORMULAS: dict[str, Callable[..., float]] = { + "helgaker_corr_2pt": helgaker_corr_2pt, + "helgaker_hf_2pt": helgaker_hf_2pt, + "martin_3pt": martin_3pt, +} + + +# ----------------------------------------------------------------------------- # +# Safe AST evaluator for user-supplied formula strings # +# ----------------------------------------------------------------------------- # + +# Functions a user formula may call. Restricted to a tiny math whitelist; no +# I/O, no introspection, no attribute access whatsoever. +_ALLOWED_CALLS = { + "exp": math.exp, + "log": math.log, + "sqrt": math.sqrt, + "pow": math.pow, +} + +# AST node classes the walker accepts. Anything else is rejected with InputError. +# Notably absent: Attribute, Subscript, Lambda, Comprehensions, NamedExpr (walrus), +# Starred, JoinedStr, FormattedValue, IfExp, Compare, BoolOp. +_ALLOWED_NODES = ( + ast.Expression, + ast.BinOp, + ast.UnaryOp, + ast.Constant, + ast.Name, + ast.Load, + ast.Call, + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ast.Pow, + ast.Mod, + ast.FloorDiv, + ast.UAdd, + ast.USub, +) + + +def _validate_ast(node: ast.AST, env_names: set) -> None: + """Raise :class:`InputError` if any descendant of ``node`` is non-whitelisted.""" + for child in ast.walk(node): + if not isinstance(child, _ALLOWED_NODES): + raise InputError( + f"Disallowed expression element {type(child).__name__!r} in user " + "formula. Only basic arithmetic (+ - * / ** %), unary +/-, " + "numeric literals, named variables, and calls to " + f"{sorted(_ALLOWED_CALLS)} are permitted." + ) + if isinstance(child, ast.Constant) and not isinstance(child.value, (int, float)): + raise InputError( + f"Only numeric constants are allowed in user formulas; got " + f"{type(child.value).__name__} ({child.value!r})." + ) + if isinstance(child, ast.Name) and child.id not in env_names \ + and child.id not in _ALLOWED_CALLS: + raise InputError( + f"Unknown name '{child.id}' in user formula. Allowed names: " + f"variables {sorted(env_names)} and functions {sorted(_ALLOWED_CALLS)}." + ) + if isinstance(child, ast.Call): + if not isinstance(child.func, ast.Name) or child.func.id not in _ALLOWED_CALLS: + raise InputError( + f"Disallowed function call in user formula. Only " + f"{sorted(_ALLOWED_CALLS)} may be called." + ) + + +def validate_formula(expression: str, allowed_names: set) -> None: + """Parse and whitelist-validate ``expression`` without evaluating it. + + Useful at construction time to surface malformed user formulas eagerly, + independent of any specific numeric inputs (which might cause spurious + runtime errors like division by zero on a probe environment). + + Raises :class:`InputError` on any non-whitelisted construct. + """ + try: + tree = ast.parse(expression, mode="eval") + except SyntaxError as exc: + raise InputError(f"User formula failed to parse: {expression!r} ({exc})") + _validate_ast(tree, set(allowed_names)) + + +def safe_eval_formula(expression: str, env: Mapping[str, float]) -> float: + """Evaluate an arithmetic expression against ``env`` without using :func:`eval`. + + Parses ``expression`` to an AST, validates every node against a strict whitelist + (basic arithmetic, unary ±, numeric literals, named variables drawn from + ``env``, and calls to :func:`math.exp`, :func:`math.log`, :func:`math.sqrt`, + :func:`math.pow`), then walks the tree to compute the result. + + Parameters + ---------- + expression : str + Arithmetic expression. Examples: + ``"(X**3 * E_X - Y**3 * E_Y) / (X**3 - Y**3)"``, + ``"E_X - sqrt(E_Y)"``. + env : Mapping[str, float] + Variable bindings. Names referenced by ``expression`` must appear here + (or be one of the allowed function names). + + Returns + ------- + float + Numerical value of the expression. + + Raises + ------ + arc.exceptions.InputError + If the expression is syntactically invalid, references unknown names, or + uses any AST construct outside the whitelist (attribute access, + subscript, lambdas, comprehensions, walrus, string literals, etc.). + """ + try: + tree = ast.parse(expression, mode="eval") + except SyntaxError as exc: + raise InputError(f"User formula failed to parse: {expression!r} ({exc})") + env_names = set(env.keys()) + _validate_ast(tree, env_names) + return _eval_node(tree.body, env) + + +def _eval_node(node: ast.AST, env: Mapping[str, float]) -> float: + """Recursively evaluate a whitelisted AST node.""" + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.Name): + if node.id in env: + return env[node.id] + # _validate_ast already rejected unknown names, so this is unreachable. + raise InputError(f"Unknown name '{node.id}'.") + if isinstance(node, ast.UnaryOp): + operand = _eval_node(node.operand, env) + if isinstance(node.op, ast.UAdd): + return +operand + if isinstance(node.op, ast.USub): + return -operand + raise InputError(f"Unsupported unary operator {type(node.op).__name__}.") + if isinstance(node, ast.BinOp): + left = _eval_node(node.left, env) + right = _eval_node(node.right, env) + if isinstance(node.op, ast.Add): + return left + right + if isinstance(node.op, ast.Sub): + return left - right + if isinstance(node.op, ast.Mult): + return left * right + if isinstance(node.op, ast.Div): + return left / right + if isinstance(node.op, ast.Pow): + return left ** right + if isinstance(node.op, ast.Mod): + return left % right + if isinstance(node.op, ast.FloorDiv): + return left // right + raise InputError(f"Unsupported binary operator {type(node.op).__name__}.") + if isinstance(node, ast.Call): + func = _ALLOWED_CALLS[node.func.id] + args = [_eval_node(a, env) for a in node.args] + return func(*args) + raise InputError(f"Unsupported AST node {type(node).__name__}.") diff --git a/arc/level/cbs_test.py b/arc/level/cbs_test.py new file mode 100644 index 0000000000..5284256def --- /dev/null +++ b/arc/level/cbs_test.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Unit tests for ``arc.level.cbs`` — basis-set cardinal inference, built-in CBS +extrapolation formulas, and the safe AST evaluator for user-supplied formulas. + +References whose values are checked here: + +* Helgaker, Klopper, Koch, Noga, *J. Chem. Phys.* **106**, 9639 (1997). + DOI: 10.1063/1.473863 — two-point correlation extrapolation. +* Halkier, Helgaker, Jørgensen, Klopper, Olsen, *Chem. Phys. Lett.* **302**, + 437-446 (1999). DOI: 10.1016/S0009-2614(99)00179-7 — two-point HF + extrapolation; source of the fitted α = 1.63. +* Martin, *Chem. Phys. Lett.* **259**, 669-678 (1996). + DOI: 10.1016/0009-2614(96)00898-6 — three-point Schwartz expansion. +""" + +import math +import unittest + +from arc.exceptions import InputError +from arc.level.cbs import ( + BUILTIN_FORMULAS, + cardinal_from_basis, + helgaker_corr_2pt, + helgaker_hf_2pt, + martin_3pt, + safe_eval_formula, +) + + +class TestCardinalFromBasis(unittest.TestCase): + """``cardinal_from_basis`` covers the common Dunning families and def2.""" + + def test_cc_pvxz(self): + self.assertEqual(cardinal_from_basis("cc-pVDZ"), 2) + self.assertEqual(cardinal_from_basis("cc-pVTZ"), 3) + self.assertEqual(cardinal_from_basis("cc-pVQZ"), 4) + self.assertEqual(cardinal_from_basis("cc-pV5Z"), 5) + self.assertEqual(cardinal_from_basis("cc-pV6Z"), 6) + + def test_aug_cc_pvxz(self): + self.assertEqual(cardinal_from_basis("aug-cc-pVDZ"), 2) + self.assertEqual(cardinal_from_basis("aug-cc-pVTZ"), 3) + self.assertEqual(cardinal_from_basis("aug-cc-pVQZ"), 4) + self.assertEqual(cardinal_from_basis("aug-cc-pV5Z"), 5) + + def test_cc_pcvxz_core_valence(self): + self.assertEqual(cardinal_from_basis("cc-pCVDZ"), 2) + self.assertEqual(cardinal_from_basis("cc-pCVTZ"), 3) + self.assertEqual(cardinal_from_basis("cc-pCVQZ"), 4) + self.assertEqual(cardinal_from_basis("aug-cc-pCVTZ"), 3) + + def test_def2_family(self): + self.assertEqual(cardinal_from_basis("def2-SVP"), 2) + self.assertEqual(cardinal_from_basis("def2-TZVP"), 3) + self.assertEqual(cardinal_from_basis("def2-QZVP"), 4) + self.assertEqual(cardinal_from_basis("def2-TZVPP"), 3) + self.assertEqual(cardinal_from_basis("def2-QZVPP"), 4) + + def test_case_insensitive(self): + self.assertEqual(cardinal_from_basis("cc-pvtz"), 3) + self.assertEqual(cardinal_from_basis("CC-PVTZ"), 3) + self.assertEqual(cardinal_from_basis("Aug-CC-pVQZ"), 4) + self.assertEqual(cardinal_from_basis("DEF2-tzvp"), 3) + + def test_unknown_basis_raises(self): + with self.assertRaises(InputError): + cardinal_from_basis("6-31G*") + with self.assertRaises(InputError): + cardinal_from_basis("STO-3G") + with self.assertRaises(InputError): + cardinal_from_basis("not-a-basis-set") + with self.assertRaises(InputError): + cardinal_from_basis("") + + +class TestHelgakerCorr2Pt(unittest.TestCase): + """``helgaker_corr_2pt`` implements (X^3·E_X − Y^3·E_Y) / (X^3 − Y^3).""" + + def test_known_values(self): + # E_T = 1.0, E_Q = 1.05 -> (27*1.0 - 64*1.05) / (27 - 64) = -40.2 / -37 + result = helgaker_corr_2pt({3: 1.0, 4: 1.05}) + self.assertAlmostEqual(result, 40.2 / 37, places=12) + + def test_invariance_to_dict_insertion_order(self): + a = helgaker_corr_2pt({3: -1.0, 4: -1.05}) + b = helgaker_corr_2pt({4: -1.05, 3: -1.0}) + self.assertAlmostEqual(a, b, places=12) + + def test_higher_basis_dominates(self): + # E_CBS should be closer to E_Q than to E_T (since cc-pVQZ is more accurate). + e_t, e_q = -100.0, -100.05 + cbs = helgaker_corr_2pt({3: e_t, 4: e_q}) + self.assertLess(abs(cbs - e_q), abs(cbs - e_t)) + + def test_real_h2o_correlation_extrapolation(self): + # Synthetic but representative: CCSD(T) corr energy at TZ vs QZ. + # E_corr_TZ = -0.30, E_corr_QZ = -0.31 (Hartree) -> CBS ≈ -0.31730 + result = helgaker_corr_2pt({3: -0.30, 4: -0.31}) + expected = (27 * (-0.30) - 64 * (-0.31)) / (27 - 64) + self.assertAlmostEqual(result, expected, places=12) + self.assertAlmostEqual(result, -0.31729729729729728, places=10) + + def test_requires_exactly_two_points(self): + with self.assertRaises(InputError): + helgaker_corr_2pt({3: -1.0}) + with self.assertRaises(InputError): + helgaker_corr_2pt({3: -1.0, 4: -1.05, 5: -1.06}) + + def test_rejects_equal_cardinals(self): + with self.assertRaises(InputError): + helgaker_corr_2pt({3: -1.0, 3: -1.05}) # noqa: F601 — Python collapses; size=1 path + + def test_q5_pair_reproduces_formula(self): + # X=4, Y=5; E_Q = -0.310, E_5 = -0.315 + result = helgaker_corr_2pt({4: -0.310, 5: -0.315}) + expected = (4**3 * -0.310 - 5**3 * -0.315) / (4**3 - 5**3) + self.assertAlmostEqual(result, expected, places=12) + + +class TestHelgakerHF2Pt(unittest.TestCase): + """``helgaker_hf_2pt`` extrapolates HF energies via E(X) = E_CBS + A·exp(-α·X).""" + + def test_default_alpha_is_halkier_value(self): + # Halkier et al. 1999 fitted α = 1.63. + # Pick numbers and verify the formula uses α=1.63 by default. + e_t, e_q = -76.0500, -76.0510 + from_default = helgaker_hf_2pt({3: e_t, 4: e_q}) + from_explicit = helgaker_hf_2pt({3: e_t, 4: e_q}, alpha=1.63) + self.assertAlmostEqual(from_default, from_explicit, places=12) + + def test_known_value(self): + # E_CBS = (E_X · exp(-α·Y) - E_Y · exp(-α·X)) / (exp(-α·Y) - exp(-α·X)) + e_t, e_q = -76.0500, -76.0510 + alpha = 1.63 + expected = ( + e_t * math.exp(-alpha * 4) - e_q * math.exp(-alpha * 3) + ) / (math.exp(-alpha * 4) - math.exp(-alpha * 3)) + result = helgaker_hf_2pt({3: e_t, 4: e_q}) + self.assertAlmostEqual(result, expected, places=12) + + def test_alpha_override(self): + e_t, e_q = -76.0500, -76.0510 + alpha = 1.50 + expected = ( + e_t * math.exp(-alpha * 4) - e_q * math.exp(-alpha * 3) + ) / (math.exp(-alpha * 4) - math.exp(-alpha * 3)) + self.assertAlmostEqual(helgaker_hf_2pt({3: e_t, 4: e_q}, alpha=alpha), expected, places=12) + + def test_invariance_to_dict_insertion_order(self): + a = helgaker_hf_2pt({3: -76.05, 4: -76.051}) + b = helgaker_hf_2pt({4: -76.051, 3: -76.05}) + self.assertAlmostEqual(a, b, places=12) + + def test_requires_exactly_two_points(self): + with self.assertRaises(InputError): + helgaker_hf_2pt({3: -76.05}) + with self.assertRaises(InputError): + helgaker_hf_2pt({3: -76.05, 4: -76.051, 5: -76.0512}) + + +class TestMartin3Pt(unittest.TestCase): + """``martin_3pt`` solves E(L) = E_CBS + b·(L+½)⁻⁴ + c·(L+½)⁻⁶ exactly.""" + + def test_recovers_constant_term(self): + # If we feed E(L) = -1.0 + 0.05/(L+0.5)**4 + 0.01/(L+0.5)**6 for L=2,3,4 + # then E_CBS must come back as -1.0 to high precision. + def model(L): + return -1.0 + 0.05 / (L + 0.5) ** 4 + 0.01 / (L + 0.5) ** 6 + + result = martin_3pt({2: model(2), 3: model(3), 4: model(4)}) + self.assertAlmostEqual(result, -1.0, places=10) + + def test_higher_cardinals(self): + def model(L): + return -100.0 + 0.123 / (L + 0.5) ** 4 - 0.045 / (L + 0.5) ** 6 + + result = martin_3pt({3: model(3), 4: model(4), 5: model(5)}) + self.assertAlmostEqual(result, -100.0, places=10) + + def test_invariance_to_dict_insertion_order(self): + e = {3: -1.0, 4: -1.05, 5: -1.06} + a = martin_3pt(e) + b = martin_3pt({5: e[5], 3: e[3], 4: e[4]}) + self.assertAlmostEqual(a, b, places=12) + + def test_requires_exactly_three_points(self): + with self.assertRaises(InputError): + martin_3pt({3: -1.0, 4: -1.05}) + with self.assertRaises(InputError): + martin_3pt({3: -1.0, 4: -1.05, 5: -1.06, 6: -1.065}) + + +class TestBuiltinFormulasRegistry(unittest.TestCase): + """The string→callable registry advertised to user input.""" + + def test_helgaker_corr_2pt_registered(self): + self.assertIs(BUILTIN_FORMULAS["helgaker_corr_2pt"], helgaker_corr_2pt) + + def test_helgaker_hf_2pt_registered(self): + self.assertIs(BUILTIN_FORMULAS["helgaker_hf_2pt"], helgaker_hf_2pt) + + def test_martin_3pt_registered(self): + self.assertIs(BUILTIN_FORMULAS["martin_3pt"], martin_3pt) + + def test_no_other_entries(self): + self.assertEqual( + set(BUILTIN_FORMULAS.keys()), + {"helgaker_corr_2pt", "helgaker_hf_2pt", "martin_3pt"}, + ) + + +class TestSafeEvalFormula(unittest.TestCase): + """``safe_eval_formula`` accepts arithmetic + math whitelist; rejects everything else.""" + + def test_basic_arithmetic(self): + self.assertEqual(safe_eval_formula("1 + 2", {}), 3) + self.assertEqual(safe_eval_formula("3 * 4 - 5", {}), 7) + self.assertEqual(safe_eval_formula("10 / 4", {}), 2.5) + self.assertEqual(safe_eval_formula("2 ** 8", {}), 256) + self.assertEqual(safe_eval_formula("-5 + 3", {}), -2) + self.assertEqual(safe_eval_formula("+(7)", {}), 7) + + def test_helgaker_corr_2pt_via_safe_eval(self): + # Reproduce the helgaker_corr_2pt formula by string. + formula = "(X**3 * E_X - Y**3 * E_Y) / (X**3 - Y**3)" + env = {"X": 3, "Y": 4, "E_X": -0.30, "E_Y": -0.31} + result = safe_eval_formula(formula, env) + self.assertAlmostEqual(result, helgaker_corr_2pt({3: -0.30, 4: -0.31}), places=12) + + def test_allowed_math_calls(self): + self.assertAlmostEqual(safe_eval_formula("exp(1)", {}), math.e, places=12) + self.assertAlmostEqual(safe_eval_formula("log(exp(2.5))", {}), 2.5, places=12) + self.assertAlmostEqual(safe_eval_formula("sqrt(16)", {}), 4.0, places=12) + self.assertAlmostEqual(safe_eval_formula("pow(2, 10)", {}), 1024.0, places=12) + + def test_user_variables_resolved(self): + self.assertEqual(safe_eval_formula("E_X * 2", {"E_X": 5}), 10) + + def test_unknown_name_raises(self): + with self.assertRaises(InputError): + safe_eval_formula("os.system('rm')", {}) + with self.assertRaises(InputError): + safe_eval_formula("E_Z", {"E_X": 1}) + + def test_dunder_attribute_rejected(self): + with self.assertRaises(InputError): + safe_eval_formula("(0).__class__", {}) + + def test_attribute_access_rejected(self): + with self.assertRaises(InputError): + safe_eval_formula("(0.0).real", {}) + + def test_subscript_rejected(self): + with self.assertRaises(InputError): + safe_eval_formula("[1,2,3][0]", {}) + + def test_lambda_rejected(self): + with self.assertRaises(InputError): + safe_eval_formula("(lambda x: x)(1)", {}) + + def test_comprehension_rejected(self): + with self.assertRaises(InputError): + safe_eval_formula("[i for i in range(3)]", {}) + + def test_call_to_unwhitelisted_function_rejected(self): + with self.assertRaises(InputError): + safe_eval_formula("eval('1')", {}) + with self.assertRaises(InputError): + safe_eval_formula("__import__('os')", {}) + + def test_walrus_rejected(self): + with self.assertRaises(InputError): + safe_eval_formula("(x := 5)", {}) + + def test_string_literal_rejected(self): + # Numeric constants only. + with self.assertRaises(InputError): + safe_eval_formula("'hello'", {}) + + def test_syntax_error_propagates_as_input_error(self): + with self.assertRaises(InputError): + safe_eval_formula("1 +", {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/arc/level/examples_test.py b/arc/level/examples_test.py new file mode 100644 index 0000000000..179b0a855b --- /dev/null +++ b/arc/level/examples_test.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Tests that every ``examples/Composite/*/input.yml`` example is valid YAML and +that its ``sp_composite`` block (or per-species ``sp_composite`` entries) +builds a valid :class:`CompositeProtocol` via +:meth:`CompositeProtocol.from_user_input`. Keeps the docs + examples honest. +""" + +import glob +import os +import unittest + +import yaml + +from arc.common import ARC_PATH +from arc.level.protocol import CompositeProtocol + + +EXAMPLES_DIR = os.path.join(ARC_PATH, "examples", "Composite") + + +class TestCompositeExamples(unittest.TestCase): + """Parse every shipped example and validate its sp_composite payload.""" + + def _example_files(self): + pattern = os.path.join(EXAMPLES_DIR, "*", "input.yml") + return sorted(glob.glob(pattern)) + + def test_examples_directory_ships_at_least_four_inputs(self): + self.assertGreaterEqual(len(self._example_files()), 4) + + def test_examples_readme_exists(self): + self.assertTrue(os.path.isfile(os.path.join(EXAMPLES_DIR, "README.md"))) + + def test_every_example_is_valid_yaml(self): + for path in self._example_files(): + with self.subTest(path=path): + with open(path, "r") as fh: + data = yaml.safe_load(fh) + self.assertIsInstance(data, dict) + self.assertIn("project", data) + self.assertIn("species", data) + + def test_every_project_level_sp_composite_builds(self): + """Project-level ``sp_composite`` (if present) is parseable.""" + for path in self._example_files(): + with open(path, "r") as fh: + data = yaml.safe_load(fh) + sp = data.get("sp_composite") + if sp is None: + continue + with self.subTest(path=path): + protocol = CompositeProtocol.from_user_input(sp) + self.assertIsInstance(protocol, CompositeProtocol) + + def test_every_species_sp_composite_builds_if_explicit(self): + """Per-species ``sp_composite`` (string/dict, not null) is parseable.""" + for path in self._example_files(): + with open(path, "r") as fh: + data = yaml.safe_load(fh) + for spc in data.get("species", []): + sp = spc.get("sp_composite", "__missing__") + if sp == "__missing__": + continue + if sp is None: + continue + with self.subTest(path=path, label=spc.get("label")): + protocol = CompositeProtocol.from_user_input(sp) + self.assertIsInstance(protocol, CompositeProtocol) + + def test_all_four_forms_covered(self): + """Each of the four documented YAML forms must appear at least once.""" + form1 = form2 = form3 = form4 = False + for path in self._example_files(): + with open(path, "r") as fh: + data = yaml.safe_load(fh) + sp = data.get("sp_composite") + if isinstance(sp, str): + form1 = True + elif isinstance(sp, dict) and "preset" in sp: + form2 = True + elif isinstance(sp, dict) and "base" in sp: + form3 = True + for spc in data.get("species", []): + if "sp_composite" in spc: + form4 = True + self.assertTrue(form1, "Form 1 (preset by name) not demonstrated.") + self.assertTrue(form2, "Form 2 (preset + override) not demonstrated.") + self.assertTrue(form3, "Form 3 (fully explicit recipe) not demonstrated.") + self.assertTrue(form4, "Form 4 (per-species override) not demonstrated.") + + def test_explicit_recipe_example_includes_cbs_extrapolation(self): + path = os.path.join(EXAMPLES_DIR, "explicit_fpa", "input.yml") + with open(path, "r") as fh: + data = yaml.safe_load(fh) + corrections = data["sp_composite"]["corrections"] + term_types = {c["type"] for c in corrections} + self.assertIn("cbs_extrapolation", term_types) + + +if __name__ == "__main__": + unittest.main() diff --git a/arc/level/legacy_imports_test.py b/arc/level/legacy_imports_test.py new file mode 100644 index 0000000000..dd4dce5247 --- /dev/null +++ b/arc/level/legacy_imports_test.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Backward-compatibility tests for the ``arc.level`` package. + +These tests assert that every public symbol that used to live in the legacy +``arc/level.py`` module is still importable from ``arc.level`` after the package +relocation. They guard the public surface so an accidental re-organisation of +the new package internals cannot break the existing 50+ external call sites. +""" + +import importlib +import unittest + + +class TestLegacyArcLevelImports(unittest.TestCase): + """Verify the public surface of ``arc.level`` is preserved.""" + + def test_from_arc_level_import_Level(self): + """``from arc.level import Level`` resolves to the legacy class.""" + from arc.level import Level + + instance = Level(method="b3lyp", basis="def2tzvp") + self.assertEqual(instance.method, "b3lyp") + self.assertEqual(instance.basis, "def2tzvp") + + def test_from_arc_level_import_assign_frequency_scale_factor(self): + """``assign_frequency_scale_factor`` is still re-exported.""" + from arc.level import assign_frequency_scale_factor + + self.assertTrue(callable(assign_frequency_scale_factor)) + + def test_from_arc_level_import_module_singletons(self): + """``levels_ess`` and ``supported_ess`` are still accessible.""" + from arc.level import levels_ess, supported_ess + + self.assertIsNotNone(levels_ess) + self.assertIsNotNone(supported_ess) + + def test_import_arc_level_as_module(self): + """``import arc.level`` succeeds (the side-effect import in arc/__init__.py). + + Loaded via importlib so this test file's source contains only + ``from arc.level import …`` statements (CodeQL flags mixing both + styles in the same module). + """ + module = importlib.import_module("arc.level") + self.assertTrue(hasattr(module, "Level")) + self.assertTrue(hasattr(module, "assign_frequency_scale_factor")) + + def test_alias_import(self): + """``from arc.level import Level as Lvl`` keeps working (used in tests).""" + from arc.level import Level as Lvl + + self.assertIs(Lvl.__name__, "Level") + + def test_level_class_is_a_real_class(self): + """Sanity check: re-export is the actual class, not a re-binding.""" + from arc.level import Level + from arc.level.level import Level as LevelDirect + + self.assertIs(Level, LevelDirect) + + +if __name__ == "__main__": + unittest.main() diff --git a/arc/level.py b/arc/level/level.py similarity index 87% rename from arc/level.py rename to arc/level/level.py index 9d285e102d..7c2a3b0a67 100644 --- a/arc/level.py +++ b/arc/level/level.py @@ -112,13 +112,43 @@ def __init__(self, # it wasn't set by the user, try determining it self.deduce_software() + # Attributes that participate in structural equality. These are the user- + # provided / round-trippable fields; derived attributes (``method_type``, + # ``compatible_ess``) are intentionally excluded because they are computed + # from the others and would create false-negative equalities when only + # one of the operands has been resolved. + _EQ_ATTRS = ( + 'method', 'basis', 'auxiliary_basis', 'dispersion', 'cabs', + 'software', 'software_version', + 'solvation_method', 'solvent', 'solvation_scheme_level', + 'args', 'year', + ) + def __eq__(self, other: Level) -> bool: """ - Determine equality between Level object instances. + Determine structural equality between Level instances. + + Compares every user-relevant attribute (method/basis/dispersion/cabs/ + solvation/software/version/year/args) one-by-one rather than relying on + :meth:`__str__`, because ``__str__`` historically dropped ``args`` when + any ``args`` bucket (e.g. an empty ``block``) was falsy — which let two + protocols whose ``args.keyword`` differed (e.g. an all-electron + ``core,...`` directive vs the molpro frozen-core default) compare equal + and silently collapse into one sub-job at composite-spawn time. """ - if isinstance(other, Level): - return str(self) == str(other) - return False + if not isinstance(other, Level): + return False + for attr in self._EQ_ATTRS: + if getattr(self, attr, None) != getattr(other, attr, None): + return False + return True + + # ``__eq__`` without ``__hash__`` makes the class unhashable in Python. + # Level was already unhashable (no ``__hash__`` was previously defined), + # and nothing in the codebase uses Level as a dict key or set element, so + # we keep that contract — explicitly setting ``__hash__ = None`` documents + # the intent. + __hash__ = None def __str__(self) -> str: """ @@ -148,12 +178,13 @@ def __str__(self) -> str: str_ += f', software: {self.software}' if self.software_version is not None: str_ += f', software_version: {self.software_version}' - if self.args is not None and self.args and all([val for val in self.args.values()]): - if any([key == 'keyword' for key in self.args.keys()]): - str_ += ', keyword args:' - for key, arg in self.args.items(): - if key == 'keyword': - str_ += f' {arg}' + # Emit ``args.keyword`` whenever it carries content, regardless of + # whether sibling buckets (e.g. ``args.block``) are empty. The previous + # ``all(values)`` guard hid keyword content (such as a frozen-core + # ``core,...`` directive) when ``block`` was an empty dict, which made + # two protocols comparing only on str() look identical. + if self.args and self.args.get('keyword'): + str_ += f", keyword args: {self.args['keyword']}" return str_ def copy(self): @@ -183,11 +214,20 @@ def as_dict(self) -> dict: """ Returns a minimal dictionary representation from which the object can be reconstructed. Useful for ARC restart files. + + ``args`` is included whenever any of its buckets carries content. + Previously a falsy sibling bucket (e.g. an empty ``block``) caused the + whole ``args`` dict to be dropped from the serialised form — which lost + meaningful settings such as ``args.keyword.core,...`` and made + round-tripped Levels compare equal to ones that never had those args. """ original_dict = self.__dict__ clean_dict = {} for key, val in original_dict.items(): - if val is not None and key != 'args' or key == 'args' and all([v for v in self.args.values()]): + if key == 'args': + if val and any(val.values()): + clean_dict[key] = val + elif val is not None: clean_dict[key] = val return clean_dict @@ -288,8 +328,11 @@ def lower(self): f'Got {arg} which is a {type(arg)} in {self.args}.') self.args = ' '.join([arg.lower() for arg in self.args]) if isinstance(self.args, str): - self.args = {'keyword': {'general': args.lower()}, 'block': dict()} - elif self.args is not None and not isinstance(args, dict): + # Phase 5.5 fix: previously ``args.lower()`` (the local *dict*), which + # raised AttributeError. The intent is to lowercase the user-supplied + # string that was just assigned to self.args. + self.args = {'keyword': {'general': self.args.lower()}, 'block': dict()} + elif self.args is not None and not isinstance(self.args, dict): raise ValueError(f'The args argument must be either a string, an iterable or a dictionary.\n' f'Got {self.args} which is a {type(self.args)}.') diff --git a/arc/level_test.py b/arc/level/level_test.py similarity index 74% rename from arc/level_test.py rename to arc/level/level_test.py index b3e538a086..22b4e1e86b 100644 --- a/arc/level_test.py +++ b/arc/level/level_test.py @@ -199,6 +199,87 @@ def test_assign_frequency_scale_factor(self): self.assertEqual(assign_frequency_scale_factor(Level(method='CBS-QB3')), 1.004) self.assertEqual(assign_frequency_scale_factor(Level(method='PM6')), 1.093) + def test_level_accepts_string_args(self): + """Regression: Level.lower() used to crash on string `args` because the + code called ``.lower()`` on the local args dict instead of self.args.""" + level = Level(method='B3LYP', basis='cc-pVTZ', args='EmpiricalDispersion=GD3') + self.assertIsInstance(level.args, dict) + self.assertEqual(level.args['keyword']['general'], 'empiricaldispersion=gd3') + self.assertEqual(level.args['block'], {}) + + def test_level_accepts_iterable_args(self): + """Iterable → space-joined string → dict path should also work.""" + level = Level(method='B3LYP', basis='cc-pVTZ', + args=['EmpiricalDispersion=GD3', 'Int=UltraFine']) + self.assertEqual(level.args['keyword']['general'], + 'empiricaldispersion=gd3 int=ultrafine') + + # --- structural __eq__ + as_dict args fix (sp_composite Bug B) -------- # + + def test_eq_distinguishes_args_keyword_differences(self): + """Two Levels identical in method+basis but differing only in + ``args.keyword`` must NOT compare equal. + + Pre-fix: ``__eq__`` delegated to ``str(self)`` which dropped ``args`` + whenever any sibling bucket was empty (``block: {}``). That let HEAT + protocol's δ_CV high (all-electron ``core,...``) and low (default + frozen-core) Levels collapse into one job at composite-spawn time — + silently producing δ_CV = 0. + """ + ae_level = Level( + method='ccsd(t)', basis='cc-pCVTZ', + args={'keyword': {'core': 'core,0,0,0,0,0,0,0,0;'}, 'block': {}}, + ) + fc_level = Level(method='ccsd(t)', basis='cc-pCVTZ') + self.assertNotEqual(ae_level, fc_level) + + def test_eq_identical_levels_remain_equal(self): + """Sanity: the strict __eq__ doesn't make every Level construction unique.""" + a = Level(method='wb97xd', basis='def2-TZVP') + b = Level(method='wb97xd', basis='def2-TZVP') + self.assertEqual(a, b) + + def test_as_dict_includes_args_when_keyword_set_and_block_empty(self): + """as_dict() must serialise ``args`` whenever any bucket has content. + + Pre-fix the ``all(values)`` guard skipped ``args`` when ``block`` was + empty, dropping the keyword half on round-trip. + """ + level = Level( + method='ccsd(t)', basis='cc-pCVTZ', + args={'keyword': {'core': 'core,0,0,0,0,0,0,0,0;'}, 'block': {}}, + ) + d = level.as_dict() + self.assertIn('args', d) + self.assertIn('keyword', d['args']) + self.assertEqual(d['args']['keyword']['core'], 'core,0,0,0,0,0,0,0,0;') + + def test_as_dict_omits_args_when_all_buckets_empty(self): + """No content anywhere ⇒ args is omitted from the serialised form.""" + level = Level(method='hf', basis='cc-pVTZ') + self.assertNotIn('args', level.as_dict()) + + def test_str_includes_keyword_when_block_empty(self): + """str(Level) used to drop ``keyword`` info when ``block`` was empty.""" + level = Level( + method='ccsd(t)', basis='cc-pCVTZ', + args={'keyword': {'core': 'core,0,0,0,0,0,0,0,0;'}, 'block': {}}, + ) + self.assertIn('keyword args:', str(level)) + self.assertIn('core,0,0,0,0,0,0,0,0', str(level)) + + def test_level_is_unhashable(self): + """Custom __eq__ without a matching __hash__ ⇒ unhashable. + Locks the contract; nothing in the codebase puts Level into a set/dict-key. + + We assert this via the ``__hash__`` class marker (Python's documented + mechanism for making instances unhashable) rather than by calling + ``hash()`` on an instance and expecting ``TypeError``. The behavioural + form trips CodeQL's ``py/hash-of-unhashable-value`` query — and that + query's pattern is *exactly* the contract under test, so suppressing + it via the dunder check is more direct than annotating around it.""" + self.assertIsNone(Level.__hash__) + if __name__ == '__main__': unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/arc/level/presets.py b/arc/level/presets.py new file mode 100644 index 0000000000..c1f1ad9fd4 --- /dev/null +++ b/arc/level/presets.py @@ -0,0 +1,206 @@ +""" +``arc.level.presets`` — named composite-protocol presets shipped with ARC. + +Presets are loaded from the data file ``presets.yml`` located alongside this module. +Each entry maps a preset name (e.g. ``"HEAT-345Q"``) to a recipe dict in the same +shape that :meth:`arc.level.protocol.CompositeProtocol.from_user_input` accepts +(``base:`` + ``corrections:`` + ``reference:``). + +The :func:`expand_preset` helper resolves a preset name (with optional per-term +overrides) to a fresh, independent recipe dict suitable for handing to +``CompositeProtocol.from_user_input``. Returned dicts are deep copies so that +caller-side mutation cannot pollute the cached registry. + +References +---------- + +* Tajti, Szalay, Császár, Kállay, Gauss, Valeev, Flowers, Vázquez, Stanton, + *J. Chem. Phys.* **121**, 11599 (2004). DOI: 10.1063/1.1811608 — HEAT. +* East, Allen, *J. Chem. Phys.* **99**, 4638 (1993). DOI: 10.1063/1.466062 — focal- + point analysis methodology. +""" + +import copy +import os +from collections.abc import Mapping +from typing import Any + +import yaml + +from arc.exceptions import InputError + + +_HERE = os.path.dirname(os.path.abspath(__file__)) +_PRESETS_PATH = os.path.join(_HERE, "presets.yml") + + +def _load_presets(path: str) -> dict[str, dict[str, Any]]: + """Load ``presets.yml`` once; return the parsed mapping.""" + with open(path, "r") as fh: + data = yaml.safe_load(fh) or {} + if not isinstance(data, dict): + raise InputError(f"Preset file {path} must parse to a mapping, got {type(data).__name__}.") + return data + + +# Module-level cache. Loaded once at import time; a single source of truth. +PRESETS: dict[str, dict[str, Any]] = _load_presets(_PRESETS_PATH) +REGISTERED_PRESET_NAMES: list[str] = sorted(PRESETS.keys()) + + +# Fields that may appear on a preset term by its ``type`` discriminator. +# Used to reject typos in preset overrides (e.g. ``delta_T.hihg``). The key +# ``"base"`` is not a term type — it's the protocol's base level dict, for +# which we accept any Level-level keyword plus ``label``. +_ALLOWED_OVERRIDE_FIELDS_BY_TYPE: dict[str, set] = { + "single_point": {"label", "type", "level"}, + "delta": {"label", "type", "high", "low"}, + "cbs_extrapolation": {"label", "type", "formula", "components", "levels"}, +} + +# Level dict keys — accepted on the ``base`` target and on any ``high``/``low``/ +# ``level`` sub-dict the user is replacing wholesale. Kept in sync with +# ``Level.__init__`` parameters (see ``arc/level/level.py``). +_ALLOWED_LEVEL_FIELDS = { + "repr", "method", "basis", "auxiliary_basis", "dispersion", "cabs", + "method_type", "software", "software_version", "compatible_ess", + "solvation_method", "solvent", "solvation_scheme_level", "args", "year", + # Also valid in the base-of-a-preset context (YAML shorthand): + "label", "type", "level", +} + + +def _deep_merge_level_dict(target: dict[str, Any], patch: dict[str, Any]) -> None: + """Shallow-merge ``patch`` into ``target`` with one level of nesting for + ``high``/``low``/``level`` — replacing fields of the inner dict rather than + the whole dict. Mutates ``target`` in place. + + Rationale: overriding ``delta_T: {high: {basis: cc-pVTZ}}`` on a preset + where ``high`` was ``{method: ccsdt, basis: cc-pVDZ}`` should produce + ``{method: ccsdt, basis: cc-pVTZ}`` — not discard the method. Only the + nested Level dicts (high/low/level) get this treatment; scalar or + list-valued fields (formula, levels) still replace wholesale. + """ + for key, new_val in patch.items(): + existing = target.get(key) + if ( + key in {"high", "low", "level", "base"} + and isinstance(existing, dict) + and isinstance(new_val, dict) + ): + merged = dict(existing) + merged.update(new_val) + target[key] = merged + else: + target[key] = new_val + + +def _validate_override_fields(term_or_base: dict[str, Any], + patch: dict[str, Any], + target_name: str) -> None: + """Reject typos in override patch keys. + + For a correction term, patch keys must match the term's ``type``-specific + allowed fields. For ``base``, patch keys must be valid Level-dict keys + (plus the usual level-dict extensions). + """ + if target_name == "base": + allowed = _ALLOWED_LEVEL_FIELDS + else: + term_type = term_or_base.get("type") + allowed = _ALLOWED_OVERRIDE_FIELDS_BY_TYPE.get(term_type) + if allowed is None: + raise InputError( + f"Cannot validate override for term '{target_name}': its type " + f"'{term_type}' is not one of " + f"{sorted(_ALLOWED_OVERRIDE_FIELDS_BY_TYPE)}." + ) + unknown = set(patch.keys()) - allowed + if unknown: + raise InputError( + f"Override for '{target_name}' has unknown field(s) " + f"{sorted(unknown)}. Allowed for this target: {sorted(allowed)}." + ) + + +def _apply_overrides( + recipe: dict[str, Any], + overrides: Mapping[str, Any], +) -> dict[str, Any]: + """Merge per-term ``overrides`` into a recipe and return the result. + + ``overrides`` is a mapping ``{term_label: {field_name: new_value}}``. The + special key ``"base"`` targets the protocol's base level rather than a + correction. + + * **Unknown term labels** raise :class:`InputError` so a typo can't silently no-op. + * **Unknown fields within a known term** also raise :class:`InputError` — + see ``_validate_override_fields``. + * Nested Level dicts (``high`` / ``low`` / ``level`` / ``base``) are + **deep-merged** when both old and new values are dicts: overriding + ``{high: {basis: cc-pVTZ}}`` preserves the existing ``method``. Other + fields (``formula``, ``levels``, scalar values) replace wholesale. + """ + if not overrides: + return recipe + + correction_labels = {c["label"] for c in recipe.get("corrections", [])} + valid_targets = correction_labels | {"base"} + + for target, patch in overrides.items(): + if target not in valid_targets: + raise InputError( + f"Override target '{target}' is not a known term in this preset. " + f"Valid targets: {sorted(valid_targets)}." + ) + if not isinstance(patch, dict): + raise InputError( + f"Override for '{target}' must be a dict; got {type(patch).__name__}." + ) + if target == "base": + _validate_override_fields(recipe.get("base") or {}, patch, target) + base = recipe["base"] + if isinstance(base, dict): + _deep_merge_level_dict(base, patch) + else: + # Base was a string shorthand; replace wholesale with the patch dict. + recipe["base"] = dict(patch) + else: + term = next(c for c in recipe["corrections"] if c["label"] == target) + _validate_override_fields(term, patch, target) + _deep_merge_level_dict(term, patch) + return recipe + + +def expand_preset( + name: str, + overrides: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + """Resolve a preset name (with optional overrides) to an independent recipe dict. + + Parameters + ---------- + name : str + One of the keys in :data:`PRESETS`. Lookup is case-sensitive. + overrides : Mapping[str, Any], optional + Mapping of term label → field patch. See :func:`_apply_overrides`. + + Returns + ------- + dict + A deep-copied recipe dict in the explicit form + (``{base: ..., corrections: [...]}``) ready to be handed to + :meth:`arc.level.protocol.CompositeProtocol.from_user_input`. + + Raises + ------ + arc.exceptions.InputError + If ``name`` is unknown or the overrides target a non-existent term. + """ + if name not in PRESETS: + raise InputError( + f"Unknown sp_composite preset '{name}'. " + f"Available presets: {REGISTERED_PRESET_NAMES}." + ) + recipe = copy.deepcopy(PRESETS[name]) + return _apply_overrides(recipe, overrides or {}) diff --git a/arc/level/presets.yml b/arc/level/presets.yml new file mode 100644 index 0000000000..5a5c209bcb --- /dev/null +++ b/arc/level/presets.yml @@ -0,0 +1,377 @@ +# sp_composite presets shipped with ARC. +# +# Each entry defines a CompositeProtocol that ARC can instantiate via +# `sp_composite: ` in the project YAML. The shape of each entry +# matches the explicit form accepted by CompositeProtocol.from_user_input: +# - base: a level (string "method/basis" or dict) +# - corrections: a list of term dicts (each with type / label / level fields) +# - reference: a free-text citation including a DOI, surfaced in logs, +# notebook provenance headers, and validated by the test suite. +# +# Notes on the recipes themselves: +# +# These presets are *adapted for ARC use* — the canonical Tajti-et-al HEAT +# protocol was designed for atomization energies of small molecules, with the +# HF energy itself CBS-extrapolated to the basis-set limit. ARC's +# CompositeProtocol pins the absolute base to a single SinglePointTerm, so +# the recipes below pick a sensible high-quality "anchor" SP (CCSD(T)-F12 in +# the cc-pVTZ-F12 basis) and apply the post-(T) and other corrections on top. +# This matches the typical focal-point workflow when refining TS barriers +# (see e.g. Nguyen, Stanton, Barker for CHO2 PES). +# +# ESS-specific syntax used below +# ------------------------------ +# The δ_CV (core-valence) and δ_rel (scalar-relativistic, DKH2) corrections +# require ESS-specific keywords to be injected into the SP input. Native ARC +# presets target Molpro syntax via ``args.keyword``: +# +# * δ_CV "high" leg = all-electron correlation: +# ``args.keyword.core: 'core,0,0,0,0,0,0,0,0;'`` +# places ``core,0,0,0,0,0,0,0,0;`` between the basis declaration and the +# ``int;`` directive in the Molpro template, setting the global frozen-core +# specification to zero in every irreducible representation. Trailing zeros +# are harmless for lower-symmetry point groups (Molpro reads only the +# irreps that exist). +# +# * δ_CV "low" leg = Molpro's default frozen-core (no extra args). For +# first-row elements this freezes 1s; for second-row 1s2s2p. +# +# * δ_rel "high" leg = DKH2 scalar-relativistic on the cc-pVTZ-DK +# recontracted basis: +# ``args.keyword.dkho: 'SET,DKHO=2;'`` +# The Molpro manual (https://www.molpro.net/manual/doku.php?id=relativistic_corrections) +# explicitly recommends ``SET,DKHO=n`` over the legacy ``DKROLL`` form +# ("In order to avoid confusion, it is recommended only to use DKHO and +# never set DKROLL"). The directive is placed before ``int;`` so the +# integrals are evaluated with the DK-transformed Hamiltonian. The +# ``cc-pVTZ-DK`` recontracted basis is required — without ``SET,DKHO=2`` +# Molpro uses the standard non-relativistic Hamiltonian on it. +# +# * δ_rel "low" leg = vanilla CCSD(T)/cc-pVTZ. +# +# Other ESSes (CFOUR/Orca) have different syntax for these corrections; the +# presets below will write the wrong directives if pointed at a non-Molpro +# adapter for the δ_CV/δ_rel SPs. Until a per-ESS preset family lands, users +# running through CFOUR/Orca should either supply an explicit recipe or use +# the ``HEAT-345_noC`` / ``HEAT-345Q_noC`` variants below, which omit the +# δ_CV term (the most ESS-syntax-sensitive one). + +HEAT-345: + reference: "Inspired by Tajti et al., J. Chem. Phys. 121, 11599 (2004); DOI: 10.1063/1.1811608. Adapted for use as a TS barrier refinement protocol within ARC. The δ_CV and δ_rel terms below assume a Molpro adapter (see preset comment header for ESS-specific syntax)." + base: + method: ccsd(t)-f12 + basis: cc-pVTZ-f12 + corrections: + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + - label: delta_CV + type: delta + # All-electron CCSD(T)/cc-pCVTZ via Molpro's ``core,0,...`` directive. + high: {method: ccsd(t), basis: cc-pCVTZ, args: {keyword: {core: 'core,0,0,0,0,0,0,0,0;'}, block: {}}} + # Frozen-core CCSD(T)/cc-pCVTZ — Molpro's default, no extra args. + low: {method: ccsd(t), basis: cc-pCVTZ} + - label: delta_rel + type: delta + # DKH2 scalar-relativistic CCSD(T)/cc-pVTZ-DK via Molpro's ``dkroll=2`` directive. + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + # Non-relativistic CCSD(T)/cc-pVTZ. + low: {method: ccsd(t), basis: cc-pVTZ} + +HEAT-345Q: + reference: "Inspired by the HEAT-345(Q) protocol used by Nguyen, Stanton, Barker for the CHO2 PES (citing Tajti et al., J. Chem. Phys. 121, 11599 (2004); DOI: 10.1063/1.1811608). Adds a perturbative δ[CCSDT(Q)] term on top of HEAT-345. The δ_CV and δ_rel terms below assume a Molpro adapter." + base: + method: ccsd(t)-f12 + basis: cc-pVTZ-f12 + corrections: + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + - label: delta_Q + type: delta + high: {method: ccsdt(q), basis: cc-pVDZ} + low: {method: ccsdt, basis: cc-pVDZ} + - label: delta_CV + type: delta + high: {method: ccsd(t), basis: cc-pCVTZ, args: {keyword: {core: 'core,0,0,0,0,0,0,0,0;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pCVTZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} + +# "_noC" variants drop the core-valence (δ_CV) correction but keep everything +# else. Use these when you cannot run an all-electron CCSD(T)/cc-pCVTZ pair +# (e.g. when targeting an ESS other than Molpro and you don't have CFOUR/Orca +# core-valence syntax wired up). These variants are NOT silently equivalent +# to HEAT-345 / HEAT-345Q — the missing δ_CV is acknowledged in the name and +# in the reference string so users can cite the protocol honestly. + +HEAT-345_noC: + reference: "Inspired by Tajti et al., J. Chem. Phys. 121, 11599 (2004); DOI: 10.1063/1.1811608. Adapted for ARC, with the δ_CV (core-valence) correction OMITTED — see preset name. Suitable when the ESS lacks a clean all-electron syntax or when the core-valence contribution is known to be negligible (e.g. first-row systems where δ_CV is typically < 0.5 kJ/mol)." + base: + method: ccsd(t)-f12 + basis: cc-pVTZ-f12 + corrections: + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} + +HEAT-345Q_noC: + reference: "Inspired by HEAT-345(Q) (Nguyen/Stanton/Barker for CHO2; Tajti et al., J. Chem. Phys. 121, 11599 (2004); DOI: 10.1063/1.1811608) with the δ_CV (core-valence) correction OMITTED — see preset name. Use when ESS-specific all-electron syntax is unavailable; cite as 'HEAT-345Q_noC' to make the omission explicit rather than as 'HEAT-345Q'." + base: + method: ccsd(t)-f12 + basis: cc-pVTZ-f12 + corrections: + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + - label: delta_Q + type: delta + high: {method: ccsdt(q), basis: cc-pVDZ} + low: {method: ccsdt, basis: cc-pVDZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} + +# HEAT-345QP extends HEAT-345Q with full quadruples (CCSDTQ) and perturbative +# pentuples (CCSDTQ(P)) corrections. The δ_QQ and δ_P legs route through the +# MRCC interface — modern Molpro builds with the MRCC interface linked in +# accept ``ccsdtq`` and ``ccsdtq(p)`` directly (the same path used today for +# ``ccsdt`` and ``ccsdt(q)`` in HEAT-345Q). CFOUR-NCC is an alternative back +# end. A plain Molpro install without MRCC will not run these sub-jobs. +HEAT-345QP: + reference: "Extension of HEAT-345Q (Bomble, Vázquez, Kállay, Michauk, Szalay, Császár, Gauss, Stanton, J. Chem. Phys. 125, 064108 (2006); DOI: 10.1063/1.2206789) with full-quadruples and perturbative-pentuples post-(T) corrections. The δ_QQ (CCSDTQ) and δ_P (CCSDTQ(P)) legs require an ESS that exposes those methods — Molpro built with the MRCC interface, or CFOUR-NCC. δ_CV and δ_rel assume a Molpro adapter." + base: + method: ccsd(t)-f12 + basis: cc-pVTZ-f12 + corrections: + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + - label: delta_Q + type: delta + high: {method: ccsdt(q), basis: cc-pVDZ} + low: {method: ccsdt, basis: cc-pVDZ} + - label: delta_QQ + type: delta + high: {method: ccsdtq, basis: cc-pVDZ} + low: {method: ccsdt(q), basis: cc-pVDZ} + - label: delta_P + type: delta + high: {method: ccsdtq(p), basis: cc-pVDZ} + low: {method: ccsdtq, basis: cc-pVDZ} + - label: delta_CV + type: delta + high: {method: ccsd(t), basis: cc-pCVTZ, args: {keyword: {core: 'core,0,0,0,0,0,0,0,0;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pCVTZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} + +# HEAT-456Q has the same correction structure as HEAT-345Q but a tighter HF/ +# CCSD(T) CBS reference (cardinals {Q,5,6} rather than {T,Q,5}). ARC's +# CompositeProtocol pins the absolute base to a single SinglePointTerm, so +# this adaptation tightens the anchor by promoting the F12 base from +# cc-pVTZ-F12 to cc-pVQZ-F12 (effectively near-CBS quality at QZ-5Z-6Z). +HEAT-456Q: + reference: "Inspired by HEAT-456Q (Bomble, Vázquez, Kállay, Michauk, Szalay, Császár, Gauss, Stanton, J. Chem. Phys. 125, 064108 (2006); DOI: 10.1063/1.2206789). Same correction structure as HEAT-345Q with a tighter base — ARC adaptation pins the anchor to CCSD(T)-F12/cc-pVQZ-F12 to mirror the {Q,5,6}-cardinal HF/CCSD(T) extrapolation. δ_CV and δ_rel assume a Molpro adapter." + base: + method: ccsd(t)-f12 + basis: cc-pVQZ-f12 + corrections: + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + - label: delta_Q + type: delta + high: {method: ccsdt(q), basis: cc-pVDZ} + low: {method: ccsdt, basis: cc-pVDZ} + - label: delta_CV + type: delta + high: {method: ccsd(t), basis: cc-pCVTZ, args: {keyword: {core: 'core,0,0,0,0,0,0,0,0;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pCVTZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} + +FPA-min: + reference: "Minimal Allen / East / Császár focal-point analysis recipe; review: East, Allen, J. Chem. Phys. 99, 4638 (1993); DOI: 10.1063/1.466062." + base: + method: ccsd(t)-f12 + basis: cc-pVTZ-f12 + corrections: + - label: cbs_corr + type: cbs_extrapolation + formula: helgaker_corr_2pt + # components currently must be "total"; adapter-level correlation-only + # parsing is a future addition. The formula name (`helgaker_corr_2pt`) + # still documents intent for the user. + components: total + levels: + - {method: ccsd(t), basis: cc-pVTZ} + - {method: ccsd(t), basis: cc-pVQZ} + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + +# ----------------------------------------------------------------------- # +# Weizmann-n (W2, W3, W4) family — Karton/Martin and predecessors. +# +# The canonical W*n* protocols build their absolute energy from a stack of +# basis-set CBS extrapolations (HF, CCSD, (T) at progressively smaller +# basis) plus δ-corrections. ARC's CompositeProtocol pins the absolute +# base to a single SinglePointTerm, so the recipes below pick a high- +# quality CCSD(T) or CCSD(T)-F12 anchor and apply the canonical post-(T) +# / δ_CV / δ_rel corrections on top. The original W*n* basis-cardinal +# extrapolations of the (T) component are absorbed into the anchor — +# this is faithful to the W*n* spirit (stacked corrections beyond +# CCSD(T)/CBS) but not byte-identical to the published prescription. +# Cite as 'W2 (ARC adaptation)' etc. to acknowledge the difference. +# +# As with the HEAT family, δ_CV and δ_rel use Molpro-specific +# ``args.keyword`` directives (see the file header for syntax notes). +# ----------------------------------------------------------------------- # + +W2: + reference: "Inspired by W2 (Martin, de Oliveira, J. Chem. Phys. 111, 1843 (1999); DOI: 10.1063/1.479454). ARC adaptation pins the anchor to CCSD(T)/aug-cc-pVQZ and applies the canonical δ_CV (core-valence) and δ_rel (DKH2 scalar-relativistic) corrections; the original W2 HF/CCSD/(T) basis-cardinal CBS extrapolations are absorbed into the anchor (single-anchor model). δ_CV and δ_rel assume a Molpro adapter." + base: + method: ccsd(t) + basis: aug-cc-pVQZ + corrections: + - label: delta_CV + type: delta + high: {method: ccsd(t), basis: cc-pCVTZ, args: {keyword: {core: 'core,0,0,0,0,0,0,0,0;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pCVTZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} + +W2-F12: + reference: "Inspired by W2-F12 (Karton, Martin, J. Chem. Phys. 136, 124114 (2012); DOI: 10.1063/1.3697678). F12-accelerated W2; ARC adaptation pins the anchor to CCSD(T)-F12/cc-pVQZ-F12 (near-CBS quality from a single SP) and applies the canonical δ_CV and δ_rel corrections. δ_CV and δ_rel assume a Molpro adapter." + base: + method: ccsd(t)-f12 + basis: cc-pVQZ-f12 + corrections: + - label: delta_CV + type: delta + high: {method: ccsd(t), basis: cc-pCVTZ, args: {keyword: {core: 'core,0,0,0,0,0,0,0,0;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pCVTZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} + +W3: + reference: "Inspired by W3 (Boese, Oren, Atasoylu, Martin, Kállay, Gauss, J. Chem. Phys. 120, 4129 (2004); DOI: 10.1063/1.1638736). Adds a δ[CCSDT] post-(T) correction on top of W2. ARC adaptation pins the anchor to CCSD(T)/aug-cc-pVQZ. δ_CV and δ_rel assume a Molpro adapter." + base: + method: ccsd(t) + basis: aug-cc-pVQZ + corrections: + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + - label: delta_CV + type: delta + high: {method: ccsd(t), basis: cc-pCVTZ, args: {keyword: {core: 'core,0,0,0,0,0,0,0,0;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pCVTZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} + +W3-F12: + reference: "ARC-defined extension of the W*n*-F12 family by analogy: 'W3 = W2 + δ[CCSDT]' (Boese et al., J. Chem. Phys. 120, 4129 (2004); DOI: 10.1063/1.1638736) applied to the F12 anchor introduced in W2-F12 (Karton, Martin, J. Chem. Phys. 136, 124114 (2012); DOI: 10.1063/1.3697678). There is no canonical primary publication titled 'W3-F12'; cite as 'W3-F12 (ARC adaptation)'. δ_CV and δ_rel assume a Molpro adapter." + base: + method: ccsd(t)-f12 + basis: cc-pVQZ-f12 + corrections: + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + - label: delta_CV + type: delta + high: {method: ccsd(t), basis: cc-pCVTZ, args: {keyword: {core: 'core,0,0,0,0,0,0,0,0;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pCVTZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} + +# W4 and W4-F12 add δ[CCSDT(Q)] and δ[CCSDTQ] (full quadruples) on top of +# the W3 stack. The δ_QQ leg routes through the MRCC interface — modern +# Molpro builds with MRCC linked in accept ``ccsdtq`` directly (same path +# already used for ``ccsdt`` / ``ccsdt(q)`` in W3 / HEAT-345Q). CFOUR-NCC +# is the alternative back end. A plain Molpro install without MRCC cannot +# run these sub-jobs. Cite as 'W4 (ARC adaptation)' etc. + +W4: + reference: "Inspired by W4 (Karton, Rabinovich, Martin, Ruscic, J. Chem. Phys. 125, 144108 (2006); DOI: 10.1063/1.2348881). Adds δ[CCSDT(Q)] and δ[CCSDTQ] on top of W3. The δ_QQ (CCSDTQ) leg requires an ESS that exposes the method — Molpro built with the MRCC interface, or CFOUR-NCC. δ_CV and δ_rel assume a Molpro adapter." + base: + method: ccsd(t) + basis: aug-cc-pVQZ + corrections: + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + - label: delta_Q + type: delta + high: {method: ccsdt(q), basis: cc-pVDZ} + low: {method: ccsdt, basis: cc-pVDZ} + - label: delta_QQ + type: delta + high: {method: ccsdtq, basis: cc-pVDZ} + low: {method: ccsdt(q), basis: cc-pVDZ} + - label: delta_CV + type: delta + high: {method: ccsd(t), basis: cc-pCVTZ, args: {keyword: {core: 'core,0,0,0,0,0,0,0,0;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pCVTZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} + +W4-F12: + reference: "Inspired by W4-F12 (Sylvetsky, Peterson, Karton, Martin, J. Chem. Phys. 144, 214101 (2016); DOI: 10.1063/1.4952410, 'Toward a W4-F12 approach: Can explicitly correlated and orbital-based ab initio CCSD(T) limits be reconciled?'). F12-accelerated W4. The δ_QQ (CCSDTQ) leg requires an ESS that exposes the method — Molpro built with the MRCC interface, or CFOUR-NCC. δ_CV and δ_rel assume a Molpro adapter." + base: + method: ccsd(t)-f12 + basis: cc-pVQZ-f12 + corrections: + - label: delta_T + type: delta + high: {method: ccsdt, basis: cc-pVDZ} + low: {method: ccsd(t), basis: cc-pVDZ} + - label: delta_Q + type: delta + high: {method: ccsdt(q), basis: cc-pVDZ} + low: {method: ccsdt, basis: cc-pVDZ} + - label: delta_QQ + type: delta + high: {method: ccsdtq, basis: cc-pVDZ} + low: {method: ccsdt(q), basis: cc-pVDZ} + - label: delta_CV + type: delta + high: {method: ccsd(t), basis: cc-pCVTZ, args: {keyword: {core: 'core,0,0,0,0,0,0,0,0;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pCVTZ} + - label: delta_rel + type: delta + high: {method: ccsd(t), basis: cc-pVTZ-DK, args: {keyword: {dkho: 'SET,DKHO=2;'}, block: {}}} + low: {method: ccsd(t), basis: cc-pVTZ} diff --git a/arc/level/presets_test.py b/arc/level/presets_test.py new file mode 100644 index 0000000000..bcaf015993 --- /dev/null +++ b/arc/level/presets_test.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Unit tests for ``arc.level.presets`` — preset loading and override merging. + +Presets are data: every entry in ``presets.yml`` should round-trip through +:meth:`CompositeProtocol.from_user_input` and through :meth:`CompositeProtocol.from_dict` +without loss. Preset overrides may replace named keys on individual terms but may not +introduce new term labels or unknown fields. +""" + +import unittest + +from arc.exceptions import InputError +from arc.level import Level +from arc.level.presets import PRESETS, REGISTERED_PRESET_NAMES, expand_preset +from arc.level.protocol import CompositeProtocol + + +class TestPresetRegistry(unittest.TestCase): + """The ``presets.yml`` data file ships at least three named protocols.""" + + def test_registry_non_empty(self): + self.assertGreaterEqual(len(REGISTERED_PRESET_NAMES), 3) + + def test_known_presets_present(self): + for name in ( + "HEAT-345", "HEAT-345Q", "HEAT-345_noC", "HEAT-345Q_noC", + "HEAT-345QP", "HEAT-456Q", "FPA-min", + "W2", "W2-F12", "W3", "W3-F12", "W4", "W4-F12", + ): + self.assertIn(name, REGISTERED_PRESET_NAMES) + + def test_noC_variants_omit_delta_CV_term(self): + """``_noC`` variants must NOT carry a delta_CV correction; the + omission is part of the contract their name advertises.""" + for name in ("HEAT-345_noC", "HEAT-345Q_noC"): + with self.subTest(name=name): + recipe = expand_preset(name) + labels = [c["label"] for c in recipe["corrections"]] + self.assertNotIn("delta_CV", labels, + f"{name} must not include delta_CV " + f"(found: {labels})") + self.assertIn("delta_T", labels) + self.assertIn("delta_rel", labels) + + def test_noC_reference_calls_out_omission(self): + """The reference string of every ``_noC`` variant must explicitly say + the core-valence correction was omitted, so users cite honestly.""" + for name in ("HEAT-345_noC", "HEAT-345Q_noC"): + with self.subTest(name=name): + ref = PRESETS[name]["reference"] + self.assertIn("OMITTED", ref.upper()) + self.assertIn("CORE-VALENCE", ref.upper()) + + def test_HEAT_protocols_delta_CV_legs_compare_unequal(self): + """Regression for sp_composite Bug B: HEAT-345Q's δ_CV high (all-electron + ``core,...``) and low (default frozen-core) Levels must not collapse to + a single sub-job at composite-spawn time. Extended in Phase 5+ to cover + every shipped preset that carries a δ_CV term — the Molpro-keyword + round-trip is the load-bearing piece, and silent dedup would defeat the + whole correction regardless of which protocol introduces it.""" + for name in ( + "HEAT-345", "HEAT-345Q", "HEAT-345QP", "HEAT-456Q", + "W2", "W2-F12", "W3", "W3-F12", "W4", "W4-F12", + ): + with self.subTest(name=name): + recipe = expand_preset(name) + cv = next( + (c for c in recipe["corrections"] if c["label"] == "delta_CV"), + None, + ) + self.assertIsNotNone( + cv, f"{name} expected to ship a delta_CV term; check preset." + ) + high = Level(repr=cv["high"]) + low = Level(repr=cv["low"]) + self.assertNotEqual(high, low, + f"{name} δ_CV legs collapsed to equal Levels — " + f"composite-spawn would silently dedupe to one job.") + + def test_each_preset_carries_a_reference_field(self): + """Every preset entry must include a `reference:` string with citation + DOI.""" + for name in REGISTERED_PRESET_NAMES: + entry = PRESETS[name] + self.assertIn("reference", entry, f"Preset '{name}' missing 'reference' field.") + ref = entry["reference"] + self.assertIsInstance(ref, str) + self.assertGreater(len(ref), 20, f"Preset '{name}' reference too short.") + self.assertIn("DOI", ref.upper(), f"Preset '{name}' reference must mention a DOI.") + + def test_each_preset_round_trips_to_protocol(self): + for name in REGISTERED_PRESET_NAMES: + with self.subTest(name=name): + protocol = CompositeProtocol.from_user_input(name) + rebuilt = CompositeProtocol.from_dict(protocol.as_dict()) + self.assertEqual(rebuilt.base.label, protocol.base.label) + self.assertEqual( + [t.label for t in rebuilt.corrections], + [t.label for t in protocol.corrections], + ) + + +class TestExpandPreset(unittest.TestCase): + def test_unknown_preset_raises(self): + with self.assertRaises(InputError) as ctx: + expand_preset("not_a_real_preset") + # The error message should help the user discover the available presets. + self.assertIn("HEAT-345", str(ctx.exception)) + + def test_returns_dict_with_base_and_corrections(self): + recipe = expand_preset("HEAT-345Q") + self.assertIn("base", recipe) + self.assertIn("corrections", recipe) + self.assertIsInstance(recipe["corrections"], list) + + def test_no_overrides_returns_canonical_recipe(self): + a = expand_preset("HEAT-345Q") + b = expand_preset("HEAT-345Q") + self.assertEqual(a, b) + + def test_returns_a_deep_copy(self): + """Mutating the returned recipe must not affect later calls.""" + recipe = expand_preset("HEAT-345Q") + recipe["base"] = "tampered" + recipe["corrections"].clear() + again = expand_preset("HEAT-345Q") + self.assertNotEqual(again["base"], "tampered") + self.assertGreater(len(again["corrections"]), 0) + + +class TestExpandPresetOverrides(unittest.TestCase): + """Overrides target named term labels and replace specific fields on them.""" + + def test_override_replaces_basis_on_named_delta_term(self): + recipe = expand_preset( + "HEAT-345Q", + overrides={"delta_T": {"high": {"method": "ccsdt", "basis": "cc-pVTZ"}}}, + ) + delta_t = next(c for c in recipe["corrections"] if c["label"] == "delta_T") + self.assertEqual(delta_t["high"]["basis"], "cc-pVTZ") + + def test_override_only_touches_named_term(self): + recipe = expand_preset( + "HEAT-345Q", + overrides={"delta_T": {"high": {"method": "ccsdt", "basis": "cc-pVTZ"}}}, + ) + delta_q = next(c for c in recipe["corrections"] if c["label"] == "delta_Q") + # delta_Q should be untouched. + original = expand_preset("HEAT-345Q") + original_delta_q = next(c for c in original["corrections"] if c["label"] == "delta_Q") + self.assertEqual(delta_q, original_delta_q) + + def test_override_unknown_label_raises(self): + with self.assertRaises(InputError): + expand_preset("HEAT-345Q", overrides={"not_a_term": {"high": "hf/cc-pVDZ"}}) + + def test_override_base_replaces_base_level(self): + recipe = expand_preset( + "HEAT-345Q", + overrides={"base": {"method": "ccsd(t)-f12", "basis": "cc-pVQZ-f12"}}, + ) + self.assertEqual(recipe["base"]["basis"], "cc-pVQZ-f12") + + def test_overridden_preset_still_parses_into_a_protocol(self): + recipe = expand_preset( + "HEAT-345Q", + overrides={"delta_T": {"high": {"method": "ccsdt", "basis": "cc-pVTZ"}}}, + ) + protocol = CompositeProtocol.from_user_input(recipe) + delta_t = next(c for c in protocol.corrections if c.label == "delta_T") + self.assertEqual(delta_t.high.basis, "cc-pvtz") + + # --- Phase 5.5 hardening --------------------------------------------- # + + def test_override_unknown_field_on_delta_rejected(self): + """Typo guard: ``hihg`` is not a valid field of a delta term.""" + with self.assertRaises(InputError) as ctx: + expand_preset("HEAT-345Q", overrides={ + "delta_T": {"hihg": {"method": "ccsdt", "basis": "cc-pVTZ"}}, + }) + self.assertIn("hihg", str(ctx.exception)) + + def test_override_unknown_field_on_base_rejected(self): + """``methhod`` is not a valid Level field.""" + with self.assertRaises(InputError) as ctx: + expand_preset("HEAT-345Q", overrides={ + "base": {"methhod": "hf"}, + }) + self.assertIn("methhod", str(ctx.exception)) + + def test_override_unknown_field_on_cbs_rejected(self): + """Typo on a cbs_extrapolation term is caught (FPA-min has a CBS term).""" + with self.assertRaises(InputError) as ctx: + expand_preset("FPA-min", overrides={ + "cbs_corr": {"formla": "helgaker_corr_2pt"}, + }) + self.assertIn("formla", str(ctx.exception)) + + def test_override_deep_merges_high_level_dict(self): + """Overriding ``delta_T.high.basis`` preserves the existing ``method``.""" + recipe = expand_preset( + "HEAT-345Q", + overrides={"delta_T": {"high": {"basis": "cc-pVTZ"}}}, + ) + delta_t = next(c for c in recipe["corrections"] if c["label"] == "delta_T") + self.assertEqual(delta_t["high"]["basis"], "cc-pVTZ") + # Original method ("ccsdt") is preserved by the deep-merge. + self.assertEqual(delta_t["high"]["method"], "ccsdt") + + def test_override_deep_merges_base_dict(self): + recipe = expand_preset( + "HEAT-345Q", + overrides={"base": {"basis": "cc-pVQZ-f12"}}, + ) + self.assertEqual(recipe["base"]["basis"], "cc-pVQZ-f12") + # Existing method ("ccsd(t)-f12") preserved. + self.assertEqual(recipe["base"]["method"], "ccsd(t)-f12") + + +class TestPresetIntegrationWithFromUserInput(unittest.TestCase): + def test_string_form_dispatches_to_preset(self): + protocol = CompositeProtocol.from_user_input("HEAT-345Q") + self.assertIsInstance(protocol, CompositeProtocol) + + def test_preset_with_overrides_form(self): + protocol = CompositeProtocol.from_user_input({ + "preset": "HEAT-345Q", + "overrides": {"delta_T": {"high": {"method": "ccsdt", "basis": "cc-pVTZ"}}}, + }) + delta_t = next(c for c in protocol.corrections if c.label == "delta_T") + self.assertEqual(delta_t.high.basis, "cc-pvtz") + + +if __name__ == "__main__": + unittest.main() diff --git a/arc/level/protocol.py b/arc/level/protocol.py new file mode 100644 index 0000000000..480287778a --- /dev/null +++ b/arc/level/protocol.py @@ -0,0 +1,588 @@ +""" +``arc.level.protocol`` — composite-energy protocol data model. + +A ``CompositeProtocol`` describes how to compute the final electronic energy of a +stationary point as a sum of contributions, each evaluated at a different level of +theory. The motivation is HEAT-style focal-point analysis (Tajti, Szalay, Császár, +Kállay, Gauss, Valeev, Flowers, Vázquez, Stanton, *J. Chem. Phys.* **121**, 11599 +(2004); DOI: 10.1063/1.1811608) and CBS extrapolation (Helgaker et al. 1997, Halkier +et al. 1998, Martin 1996), where small post-CCSD(T) corrections accumulate to +several kJ/mol — exactly the range that affects TS barriers in kinetics. + +Data model +---------- + +A ``CompositeProtocol`` consists of: + +* ``base`` — a single :class:`SinglePointTerm` providing the absolute electronic + energy. By convention this is the "main" SP that the scheduler runs first; it is + also the level used for AEC (atom-energy-correction) lookups when the protocol + is wired into Arkane in a later phase. +* ``corrections`` — an ordered list of additional :class:`Term` objects of any + subtype: :class:`SinglePointTerm`, :class:`DeltaTerm`, or + :class:`CBSExtrapolationTerm`. + +The final energy is ``base.evaluate(...) + Σ correction.evaluate(...)``. + +Sub-job naming +-------------- + +Each ``Term`` describes the QM single-point jobs it needs via +:meth:`Term.required_levels`, returning ``[(sub_label, Level), ...]`` pairs. The +sub_labels are *globally* unique within the protocol and follow the convention: + +* ``SinglePointTerm`` → ``""`` (one sub-job). +* ``DeltaTerm`` → ``"__high"``, ``"__low"``. +* ``CBSExtrapolationTerm`` → ``"__card_"`` for each cardinal ``X``. + +The Phase 2 scheduler integration uses these sub_labels to track per-sub-job state +across restarts. +""" + +import copy +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import Any + +from arc.exceptions import InputError +from arc.level.cbs import ( + BUILTIN_FORMULAS, + cardinal_from_basis, + safe_eval_formula, + validate_formula, +) +from arc.level.level import Level +from arc.level.presets import expand_preset + + +# --------------------------------------------------------------------------- # +# Term hierarchy # +# --------------------------------------------------------------------------- # + + +class Term(ABC): + """Abstract base class for any contribution to a composite electronic energy. + + A ``Term`` knows three things: + + 1. Its ``label`` — a unique name used by the scheduler and reporter to + identify the term in logs and the provenance notebook. + 2. The QM sub-jobs it needs, via :meth:`required_levels`. + 3. How to combine those sub-jobs' parsed energies into a single number, via + :meth:`evaluate`. + """ + + label: str + + @abstractmethod + def required_levels(self) -> list[tuple[str, Level]]: + """Return ``[(sub_label, Level), ...]`` pairs for every SP this term needs.""" + + @abstractmethod + def evaluate(self, energies: dict[str, float]) -> float: + """Combine sub-job energies into this term's contribution. + + The keys of ``energies`` are the ``sub_label`` strings yielded by + :meth:`required_levels`. Units are passed through unchanged (kJ/mol in the + ARC scheduler, but the data model is unit-agnostic). + """ + + @abstractmethod + def as_dict(self) -> dict[str, Any]: + """Serialise to a JSON/YAML-friendly dict including a discriminator ``type``.""" + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "Term": + """Reconstruct a ``Term`` subclass from its serialised dict. + + Dispatches on the ``type`` discriminator written by :meth:`as_dict`. + """ + if not isinstance(data, dict) or "type" not in data: + raise InputError( + "Term dict must include a 'type' discriminator " + "('single_point', 'delta', or 'cbs_extrapolation')." + ) + kind = data["type"] + if kind == "single_point": + return SinglePointTerm._from_dict(data) + if kind == "delta": + return DeltaTerm._from_dict(data) + if kind == "cbs_extrapolation": + return CBSExtrapolationTerm._from_dict(data) + raise InputError( + f"Unknown term type '{kind}'. Allowed: " + "'single_point', 'delta', 'cbs_extrapolation'." + ) + + +def _coerce_level(value: str | dict[str, Any] | Level) -> Level: + """Accept either a string, dict, or Level; return a Level instance.""" + if isinstance(value, Level): + return value + if isinstance(value, (str, dict)): + return Level(repr=value) + raise InputError( + f"Cannot interpret {value!r} (type {type(value).__name__}) as a Level." + ) + + +class SinglePointTerm(Term): + """One absolute single-point energy at one level of theory.""" + + def __init__(self, label: str, level: str | dict[str, Any] | Level): + if not label: + raise InputError("SinglePointTerm requires a non-empty label.") + self.label = label + self.level = _coerce_level(level) + + def required_levels(self) -> list[tuple[str, Level]]: + return [(self.label, self.level)] + + def evaluate(self, energies: dict[str, float]) -> float: + return energies[self.label] + + def as_dict(self) -> dict[str, Any]: + return { + "type": "single_point", + "label": self.label, + "level": self.level.as_dict(), + } + + @classmethod + def _from_dict(cls, data: dict[str, Any]) -> "SinglePointTerm": + return cls(label=data["label"], level=data["level"]) + + +class DeltaTerm(Term): + """A correction ``E[high] − E[low]`` between two levels of theory. + + Used to capture, e.g., the post-(T) correction + ``δ[CCSDT] = E[CCSDT/cc-pVDZ] − E[CCSD(T)/cc-pVDZ]``. + """ + + def __init__( + self, + label: str, + high: str | dict[str, Any] | Level | None, + low: str | dict[str, Any] | Level | None, + ): + if not label: + raise InputError("DeltaTerm requires a non-empty label.") + if high is None or low is None: + raise InputError( + f"DeltaTerm '{label}' requires both 'high' and 'low' levels; " + f"got high={high!r}, low={low!r}." + ) + self.label = label + self.high = _coerce_level(high) + self.low = _coerce_level(low) + + def _sub(self, suffix: str) -> str: + return f"{self.label}__{suffix}" + + def required_levels(self) -> list[tuple[str, Level]]: + return [(self._sub("high"), self.high), (self._sub("low"), self.low)] + + def evaluate(self, energies: dict[str, float]) -> float: + return energies[self._sub("high")] - energies[self._sub("low")] + + def as_dict(self) -> dict[str, Any]: + return { + "type": "delta", + "label": self.label, + "high": self.high.as_dict(), + "low": self.low.as_dict(), + } + + @classmethod + def _from_dict(cls, data: dict[str, Any]) -> "DeltaTerm": + return cls(label=data["label"], high=data["high"], low=data["low"]) + + +# Currently only "total" is supported: the energies fed to CBS formulas come +# from ``arc.parser.parse_e_elect``, which returns the total electronic energy +# of a single-point job. There is no parser pathway that surfaces correlation- +# only or HF-only components yet, so accepting ``components='corr'`` or +# ``'hf'`` would silently extrapolate *total* energies while pretending to be +# component-specific — a correctness hazard. When adapter-level component +# parsing is added, widen this tuple and add tests that the right component is +# actually routed per sub-job. +_ALLOWED_COMPONENTS = ("total",) + + +class CBSExtrapolationTerm(Term): + """Complete-Basis-Set extrapolated contribution. + + Computes one term in the composite from ≥2 single-point energies at the same + method but different basis-set cardinalities, combined via a closed-form + formula. ``formula`` may be the name of a built-in + (:data:`arc.level.cbs.BUILTIN_FORMULAS`) or a user-supplied arithmetic + expression evaluated by :func:`arc.level.cbs.safe_eval_formula`. + + Parameters + ---------- + label : str + Term identifier. + formula : str + Built-in name or arithmetic expression. User expressions may reference + ``X``, ``Y``, ``Z`` (cardinal numbers) and ``E_X``, ``E_Y``, ``E_Z`` + (corresponding energies), bound by ascending cardinal order. + User formulas with more than 3 levels are rejected: expose only the + first three cardinal variables we bind. + levels : list of Level + ≥2 levels, all with the same method, all with deducible distinct cardinals. + components : {'total'} + Which energy component the extrapolation applies to. **Only ``'total'`` + is currently accepted.** Other values are rejected at construction time + until component-specific parsing exists — see ``_ALLOWED_COMPONENTS`` + above for rationale. + """ + + def __init__( + self, + label: str, + formula: str, + levels: list[str | dict[str, Any] | Level], + components: str = "total", + ): + if not label: + raise InputError("CBSExtrapolationTerm requires a non-empty label.") + if components not in _ALLOWED_COMPONENTS: + raise InputError( + f"CBSExtrapolationTerm '{label}': components={components!r} not in " + f"{_ALLOWED_COMPONENTS}." + ) + coerced = [_coerce_level(lvl) for lvl in levels] + if len(coerced) < 2: + raise InputError( + f"CBSExtrapolationTerm '{label}' needs at least 2 levels, got {len(coerced)}." + ) + methods = {lvl.method for lvl in coerced} + if len(methods) > 1: + raise InputError( + f"CBSExtrapolationTerm '{label}': all levels must share one method, " + f"got {sorted(methods)}." + ) + cardinals = [cardinal_from_basis(lvl.basis) for lvl in coerced] + if len(set(cardinals)) != len(cardinals): + raise InputError( + f"CBSExtrapolationTerm '{label}': cardinals must be distinct, got " + f"{cardinals}." + ) + # Sort levels and cardinals together by ascending cardinal so callers can rely + # on a canonical ordering downstream. + ordered = sorted(zip(cardinals, coerced)) + self._cardinals = [c for c, _ in ordered] + self.levels = [lvl for _, lvl in ordered] + self.label = label + self.components = components + self.formula = formula + self._formula_callable = self._resolve_formula(formula, len(self.levels)) + + # Arity required by each shipped built-in formula. Surfacing this at + # construction time catches "martin_3pt with 2 levels" before a sub-job + # ever runs. When new built-ins are added, update this table alongside + # the entry in arc.level.cbs.BUILTIN_FORMULAS. + _BUILTIN_FORMULA_ARITY: dict[str, int] = { + "helgaker_corr_2pt": 2, + "helgaker_hf_2pt": 2, + "martin_3pt": 3, + } + + # Upper bound for user-supplied formula arity: the safe-eval variable + # binder exposes only X/Y/Z (and E_X/E_Y/E_Z). Supporting more would + # require extending both the binder and the safe-eval allow-list tests. + _USER_FORMULA_MAX_LEVELS = 3 + + @staticmethod + def _resolve_formula(formula: str, n_levels: int): + """Validate ``formula`` against the built-in registry and (if user-supplied) + the safe-eval whitelist; return a callable taking ``{cardinal: energy}``. + + Built-in formulas additionally have their required arity enforced here + (Phase 5.5) so a recipe with the wrong number of levels fails at + construction, not at sub-job-completion time. + """ + if formula in BUILTIN_FORMULAS: + required = CBSExtrapolationTerm._BUILTIN_FORMULA_ARITY.get(formula) + if required is not None and n_levels != required: + raise InputError( + f"Built-in CBS formula '{formula}' requires exactly " + f"{required} levels; got {n_levels}." + ) + return BUILTIN_FORMULAS[formula] + # User expression: validate the AST eagerly so malformed formulas raise + # at construction, not when sub-job energies are first plugged in. We + # advertise X/Y/Z and E_X/E_Y/E_Z up to the number of levels. + if n_levels > CBSExtrapolationTerm._USER_FORMULA_MAX_LEVELS: + raise InputError( + f"User CBS formulas currently support at most " + f"{CBSExtrapolationTerm._USER_FORMULA_MAX_LEVELS} levels " + f"(X/Y/Z and E_X/E_Y/E_Z variables); got {n_levels}." + ) + allowed = {f"E_{var}" for var in ("X", "Y", "Z")[:n_levels]} + allowed.update({var for var in ("X", "Y", "Z")[:n_levels]}) + validate_formula(formula, allowed) + + def _user_fn(energies): + env = {} + for idx, (X, E) in enumerate(sorted(energies.items())): + var = ("X", "Y", "Z")[idx] + env[var] = X + env[f"E_{var}"] = E + return safe_eval_formula(formula, env) + + return _user_fn + + def _sub(self, cardinal: int) -> str: + return f"{self.label}__card_{cardinal}" + + def required_levels(self) -> list[tuple[str, Level]]: + return [(self._sub(c), lvl) for c, lvl in zip(self._cardinals, self.levels)] + + def evaluate(self, energies: dict[str, float]) -> float: + cardinal_to_energy = {c: energies[self._sub(c)] for c in self._cardinals} + return self._formula_callable(cardinal_to_energy) + + def as_dict(self) -> dict[str, Any]: + return { + "type": "cbs_extrapolation", + "label": self.label, + "formula": self.formula, + "components": self.components, + "levels": [lvl.as_dict() for lvl in self.levels], + } + + @classmethod + def _from_dict(cls, data: dict[str, Any]) -> "CBSExtrapolationTerm": + return cls( + label=data["label"], + formula=data["formula"], + levels=data["levels"], + components=data.get("components", "total"), + ) + + +# --------------------------------------------------------------------------- # +# CompositeProtocol # +# --------------------------------------------------------------------------- # + + +class CompositeProtocol: + """An ordered sum of :class:`Term` objects defining the final electronic energy. + + The protocol's electronic energy is ``base.evaluate(...) + Σ correction.evaluate(...)``. + + Optional metadata: + + * ``preset_name`` — the name of the preset this protocol was expanded from + (``"HEAT-345Q"`` etc.), or ``None`` for explicit recipes. Populated + automatically by :meth:`from_user_input` when the input is a preset name + or a ``{preset: ..., overrides: ...}`` dict; carried through ``as_dict`` + and restored by ``from_dict``. + * ``reference`` — a citation string (typically a DOI) describing the source + of the protocol. For presets, this comes from ``presets.yml``'s + ``reference:`` field; for explicit recipes, users may supply a + ``reference:`` key at the top level of their recipe dict. + """ + + def __init__( + self, + base: SinglePointTerm, + corrections: list[Term] | None = None, + preset_name: str | None = None, + reference: str | None = None, + ): + if not isinstance(base, SinglePointTerm): + raise InputError( + "CompositeProtocol.base must be a SinglePointTerm; " + f"got {type(base).__name__}." + ) + corrections = list(corrections) if corrections else [] + labels = [base.label] + [t.label for t in corrections] + if len(set(labels)) != len(labels): + raise InputError( + f"All term labels must be unique within a CompositeProtocol; " + f"got duplicates in {labels}." + ) + # sub_labels are a *global* namespace within a protocol — they key the + # scheduler's pending dict and the output-dict's 'paths/sp_composite'. + # A collision (e.g. SinglePointTerm(label='delta_T__high') plus a + # DeltaTerm(label='delta_T', ...) whose 'high' sub-leg also ends up as + # 'delta_T__high') would overwrite state silently. Reject at construction. + sub_labels: list[str] = [] + for term in [base, *corrections]: + for sub_label, _level in term.required_levels(): + sub_labels.append(sub_label) + if len(set(sub_labels)) != len(sub_labels): + duplicates = sorted({s for s in sub_labels if sub_labels.count(s) > 1}) + raise InputError( + f"CompositeProtocol has colliding sub_labels across terms: " + f"{duplicates}. Rename the offending term(s) so their " + f"sub_labels ('