diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d0a5005a21..27c5bfbdba 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -79,6 +79,44 @@ jobs: CONDA_ALWAYS_YES: "true" run: make install-ci + - name: Restore GoFlow cache (clone + goflow_env + checkpoint) + id: goflow-cache + uses: actions/cache@v4 + with: + path: | + ${{ github.workspace }}/goflow_lean + ~/micromamba/envs/goflow_env + key: goflow-cpu-${{ runner.os }}-${{ hashFiles('ARC/devtools/install_goflow.sh') }} + + - name: Install GoFlow (CPU) + shell: micromamba-shell {0} + working-directory: ${{ github.workspace }} + run: | + if [[ -d goflow_lean/.git ]]; then + bash ARC/devtools/install_goflow.sh --cpu --path "$PWD/goflow_lean" + else + bash ARC/devtools/install_goflow.sh --cpu + fi + + - name: Restore RitS cache (clone + rits_env + checkpoint) + id: rits-cache + uses: actions/cache@v4 + with: + path: | + ${{ github.workspace }}/RitS + ~/micromamba/envs/rits_env + key: rits-cpu-${{ runner.os }}-${{ hashFiles('ARC/devtools/install_rits.sh') }} + + - name: Install RitS (CPU) + shell: micromamba-shell {0} + working-directory: ${{ github.workspace }} + run: | + if [[ -d RitS/.git ]]; then + bash ARC/devtools/install_rits.sh --cpu --path "$PWD/RitS" + else + bash ARC/devtools/install_rits.sh --cpu + fi + - name: Set TS-GCN and AutoTST in PYTHONPATH shell: micromamba-shell {0} working-directory: ARC @@ -100,6 +138,8 @@ jobs: ARC_COVERAGE: 1 CYTHON_TRACE: 1 COVERAGE_CORE: ctrace + ARC_GOFLOW_CKPT: ${{ github.workspace }}/goflow_lean/data/RDB7/epoch_316.ckpt + ARC_RITS_CKPT: ${{ github.workspace }}/RitS/data/rits.ckpt run: | echo "Running Unit Tests..." export PYTHONPATH="${{ github.workspace }}/AutoTST:${{ github.workspace }}/KinBot:$PYTHONPATH" diff --git a/Makefile b/Makefile index eb3f8d00bd..27c94d10b2 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ DEVTOOLS_DIR := devtools .PHONY: all help clean test test-unittests test-functional test-all \ install-all install-ci install-rmg install-rmgdb install-autotst install-gcn \ install-gcn-cpu install-kinbot install-sella install-xtb install-torchani install-ob \ - lite check-env compile + install-goflow install-rits lite check-env compile # Default target @@ -37,6 +37,8 @@ help: @echo " install-xtb Install xTB" @echo " install-torchani Install TorchANI" @echo " install-ob Install OpenBabel" + @echo " install-goflow Install GoFlow (TS guesser, ~2-3 GB env; downloads pretrained ckpt from Zenodo, SHA-256-verified)" + @echo " install-rits Install RitS (TS guesser, ~3 GB env; downloads pretrained ckpt from Zenodo, SHA-256-verified)" @echo "" @echo "Maintenance:" @echo " lite Run lite installation (no tests)" @@ -65,8 +67,8 @@ install: bash $(DEVTOOLS_DIR)/install_all.sh --rmg-rms install-ci: - @echo "Installing all external ARC dependencies for CI (no clean)..." - bash $(DEVTOOLS_DIR)/install_all.sh --no-clean + @echo "Installing all external ARC dependencies for CI (no clean, no GoFlow, no RitS — each runs in its own CI lane)..." + bash $(DEVTOOLS_DIR)/install_all.sh --no-clean --no-goflow --no-rits install-lite: @echo "Installing ARC's lite version (no external dependencies)..." @@ -102,6 +104,12 @@ install-torchani: install-ob: bash $(DEVTOOLS_DIR)/install_ob.sh +install-goflow: + bash $(DEVTOOLS_DIR)/install_goflow.sh + +install-rits: + bash $(DEVTOOLS_DIR)/install_rits.sh + lite: bash $(DEVTOOLS_DIR)/lite.sh diff --git a/arc/common.py b/arc/common.py index 98f4916749..359a5a8d9b 100644 --- a/arc/common.py +++ b/arc/common.py @@ -140,7 +140,7 @@ def check_ess_settings(ess_settings: dict | None = None) -> dict: f'strings. Got: {server_list} which is a {type(server_list)}') # run checks: for ess, server_list in settings_dict.items(): - if ess.lower() not in supported_ess + ['gcn', 'heuristics', 'autotst', 'kinbot', 'xtb_gsm', 'orca_neb']: + if ess.lower() not in supported_ess + ['gcn', 'goflow', 'heuristics', 'autotst', 'kinbot', 'rits', 'xtb_gsm', 'orca_neb']: raise SettingsError(f'Recognized ESS software are {supported_ess}. Got: {ess}') for server in server_list: if not isinstance(server, bool) and server.lower() not in [s.lower() for s in servers.keys()]: diff --git a/arc/job/adapter.py b/arc/job/adapter.py index c9e00a010b..8a4382ed92 100644 --- a/arc/job/adapter.py +++ b/arc/job/adapter.py @@ -97,6 +97,8 @@ class JobEnum(str, Enum): heuristics = 'heuristics' # ARC's heuristics kinbot = 'kinbot' # KinBot, 10.1016/j.cpc.2019.106947 gcn = 'gcn' # Graph neural network for isomerization, https://doi.org/10.1021/acs.jpclett.0c00500 + goflow = 'goflow' # GoFlow, flow-matching E(3)-equivariant TS generator (Galustian et al., Digital Discovery 2025, 10.1039/D5DD00283D); https://github.com/heid-lab/goflow_lean + rits = 'rits' # Right into the Saddle, flow-matching TS generator, https://github.com/isayevlab/RitS, 10.26434/chemrxiv.15001681/v1 user = 'user' # user guesses xtb_gsm = 'xtb_gsm' # Double ended growing string method (DE-GSM), [10.1021/ct400319w, 10.1063/1.4804162] via xTB orca_neb = 'orca_neb' diff --git a/arc/job/adapters/common.py b/arc/job/adapters/common.py index e689ba97b7..e348e26f75 100644 --- a/arc/job/adapters/common.py +++ b/arc/job/adapters/common.py @@ -74,11 +74,12 @@ 'Singlet_Carbene_Intra_Disproportionation': ['gcn', 'xtb_gsm', 'orca_neb'], } -all_families_ts_adapters = [] +all_families_ts_adapters = ['goflow', 'rits'] adapters_that_do_not_require_a_level_arg = ['xtb', 'torchani'] # Default is "queue", "pipe" will be called whenever needed. So just list 'incore'. -default_incore_adapters = ['autotst', 'crest', 'gcn', 'heuristics', 'kinbot', 'openbabel', 'torchani', 'psi4', 'xtb', 'xtb_gsm'] +default_incore_adapters = ['autotst', 'crest', 'gcn', 'goflow', 'heuristics', 'kinbot', 'openbabel', 'psi4', + 'rits', 'torchani', 'xtb', 'xtb_gsm'] def _initialize_adapter(obj: JobAdapter, diff --git a/arc/job/adapters/scripts/goflow_script.py b/arc/job/adapters/scripts/goflow_script.py new file mode 100644 index 0000000000..29bd3c87d1 --- /dev/null +++ b/arc/job/adapters/scripts/goflow_script.py @@ -0,0 +1,453 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +A standalone script to run GoFlow inference on a single reaction and emit +TS guesses as a YAML file consumable by ARC's GoFlowAdapter. + +This script must be invoked from inside the ``goflow_env`` conda environment +(it imports ``goflow``, ``hydra``, ``torch``, ``torch_geometric`` and +``torchdiffeq``). The parent ARC process shells out to it via +``subprocess.run`` so that ARC's main env does not have to carry the heavy +ML dependency stack. + +Architecture note +----------------- +GoFlow Lean ships no single-reaction inference CLI (the closest is ``test_save_all_samples_rdb7.sh``, +which runs the entire RDB7 test split via Hydra). This script: + + 1. Loads ``feat_dict_organic.pkl`` and derives the model's atom-feature + dimension (``n_atom_rdkit_feats = sum(len(v) for v in feat_dict.values())``). + This is necessary because ``configs/model/flow.yaml`` defaults to 27 but + the lean repo's own training script overrides to 36 — using the wrong + value causes a silent ``state_dict`` shape mismatch on load. + 2. Composes Hydra config ``train.yaml`` programmatically with overrides + ``model=flow``, ``data=rdb7``, ``model.representation.n_atom_rdkit_feats=``, + ``model.num_samples=``, ``model.num_steps=``, ``model.sample_method=gaussian``. + 3. Instantiates the FlowModule via Hydra and loads the checkpoint's raw + ``state_dict`` with ``strict=True``. Validates the checkpoint is a real + Lightning ckpt (not the 45-byte LFS-pointer placeholder). + 4. Builds a single-reaction PyG ``Data`` via ``goflow.preprocessing.generate_graph_data``, + using the atom-mapped reactant + product SMILES that ARC produced. + Sets ``pos_gt`` to the reactant geometry only as a length-N placeholder + (goflow's ``CountNodesPerGraph`` reads ``len(data.pos)``); since we + never call ``test_step``, the GT-alignment branch is never triggered. + 5. Runs a custom in-script ODE sampling loop (mirroring only the sampling + part of ``FlowModule.test_step``, NOT the substruct-match/Kabsch align + of samples to GT). Yields one geometry per sample. + 6. Writes a multi-frame XYZ + a list-of-TSGuess-dicts YAML. + +If GoFlow fails to produce any usable output, the script writes a list with +a single failed-guess entry instead of raising — the parent adapter then +logs the failure but continues running other TS methods. + +Input file (``input.yml``) — required keys +------------------------------------------- + reactant_xyz_path : str absolute path to a plain XYZ file + product_xyz_path : str absolute path to the matching product XYZ + reactant_smiles : str atom-mapped SMILES (every H explicit) + product_smiles : str ditto, map numbers consistent with reactant + goflow_repo_path : str absolute path to the goflow_lean source checkout + ckpt_path : str absolute path to the pretrained ckpt + feat_dict_path : str absolute path to feat_dict_organic.pkl + output_xyz_path : str absolute path for the multi-frame XYZ output + yml_out_path : str absolute path for the parsed TSGuess list + +Optional keys (with defaults): + n_samples: int default 10 + num_steps: int default 25 + device : str default 'auto' + +Output (``yml_out_path``) +------------------------- +A YAML *list* of TSGuess dictionaries. Each entry has: + method : 'GoFlow' + method_direction : 'F' + method_index : int (0-based sample index) + initial_xyz : str (XYZ-format coordinate block, no header lines) + success : bool + execution_time : str (str(datetime.timedelta)) +""" + +import argparse +import datetime +import os +import pickle +import sys +import traceback +from typing import List, Optional + +import yaml + + +def read_xyz_positions(xyz_path: str): + """ + Parse a single-frame plain XYZ file into an (N, 3) coordinate array. + + The file is expected to start with an atom-count line, then a comment + line, then N coordinate lines of the form `` ``. + Leading blank lines are tolerated; trailing rows are not required to + match the count exactly (extra rows are ignored). + + Args: + xyz_path (str): Path to a plain XYZ file. + + Returns: + numpy.ndarray: An ``(N, 3)`` float32 array of Cartesian coordinates (atomic symbols are dropped). + + Raises: + ValueError: If the file is empty, the header declares more atoms than are present, or any row is malformed. + """ + import numpy as np + with open(xyz_path, 'r') as f: + lines = [ln.rstrip('\n') for ln in f] + # Skip leading blank lines. + i = 0 + while i < len(lines) and not lines[i].strip(): + i += 1 + if i >= len(lines): + raise ValueError(f'XYZ file is empty: {xyz_path}') + n_atoms = int(lines[i].strip()) + i += 2 # skip count + comment line + coords: List[List[float]] = [] + for _ in range(n_atoms): + if i >= len(lines): + raise ValueError(f'XYZ file {xyz_path} is truncated: header declares {n_atoms} ' + f'atoms but only {len(coords)} coordinate rows are present.') + parts = lines[i].split() + if len(parts) < 4: + raise ValueError(f'Malformed XYZ row in {xyz_path} at line {i + 1}: ' + f'expected ` `, got {lines[i]!r}') + coords.append([float(parts[1]), float(parts[2]), float(parts[3])]) + i += 1 + return np.asarray(coords, dtype=np.float32) + + +def format_xyz_block(symbols, pos) -> str: + """ + Return a body-only XYZ coordinate block, one atom per line. + + No leading count or comment header is emitted — just N rows of `` ``. + Coordinates are formatted to 6 decimal places. + + Args: + symbols (Iterable[str]): Atomic-symbol strings, length N. + pos (Iterable[Sequence[float]]): N triples of Cartesian coordinates (each a 3-element ``(x, y, z)`` sequence). + + Returns: + str: Multi-line XYZ block, no trailing newline. + """ + rows = [] + for sym, (x, y, z) in zip(symbols, pos): + rows.append(f'{sym} {float(x):.6f} {float(y):.6f} {float(z):.6f}') + return '\n'.join(rows) + + +def write_multi_frame_xyz(path: str, symbols, pos_S_N_3) -> None: + """ + Write a multi-frame XYZ file (one frame per GoFlow sample). + + Each frame is laid out as:: + + + GoFlow sample + + ... + + Args: + path (str): Output file path. Overwritten if it exists. + symbols (Sequence[str]): Atomic-symbol strings, length N. + pos_S_N_3 (Iterable[Sequence[Sequence[float]]]): S frames, each an N×3 coordinate iterable. + """ + n_atoms = len(symbols) + with open(path, 'w') as f: + for i, frame in enumerate(pos_S_N_3): + f.write(f'{n_atoms}\n') + f.write(f'GoFlow sample {i}\n') + f.write(format_xyz_block(symbols, frame)) + f.write('\n') + + +def _failed_guess(elapsed: datetime.timedelta, index: int = 0) -> dict: + """ + Build the standard failed-TSGuess sentinel dict. + + Returned to the parent ARC adapter when GoFlow inference raises so the + adapter can mark the attempt as unsuccessful without losing track of + the elapsed time. + + Args: + elapsed (datetime.timedelta): Wall-clock time spent before failure. + index (int): Sample index to record. Defaults to 0. + + Returns: + dict: A TSGuess-shaped dict with ``success=False`` and ``initial_xyz=None``. + """ + return {'method': 'GoFlow', + 'method_direction': 'F', + 'method_index': index, + 'initial_xyz': None, + 'success': False, + 'execution_time': str(elapsed)} + + +def _string_representer(dumper, data): + """ + Represent a Python ``str`` as a YAML scalar. + + Multi-line strings (e.g. an XYZ coordinate block stored under ``initial_xyz``) + get the literal-block ``|`` style so they round-trip cleanly; single-line strings get the default style. + + Args: + dumper (yaml.Dumper): The YAML dumper invoking the representer. + data (str): The string being serialized. + + Returns: + yaml.ScalarNode: The representer node. + """ + if len(data.splitlines()) > 1: + return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data, style='|') + return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data) + + +def save_yaml_file_local(path: str, content) -> None: + """ + Serialize ``content`` to ``path`` as YAML. + + Multi-line strings are written using the literal-block ``|`` style (see :func:`_string_representer`). + + Args: + path (str): Output file path. Overwritten if it exists. + content: Any YAML-serializable Python object (typically a list/dict). + + Returns: + None + """ + yaml.add_representer(str, _string_representer) + with open(path, 'w') as f: + f.write(yaml.dump(data=content)) + + +def read_yaml_file_local(path: str) -> dict: + """ + Read a YAML file using the safe loader. + + Args: + path (str): Path to a YAML file. + + Returns: + dict: The loaded mapping (or whatever top-level type ``yaml.safe_load`` returned, + typically a ``dict`` for our input.yml schema). + """ + with open(path, 'r') as f: + return yaml.safe_load(stream=f) + + +def _resolve_device(requested: str) -> str: + """ + Pick a concrete torch device string. + + ``'auto'`` defers to ``torch.cuda.is_available()`` (returns ``'cuda'`` if available, ``'cpu'`` otherwise). + Any explicit value (``'cpu'``, ``'cuda'``, ``'cuda:1'``, …) is honored as-is. + + Args: + requested (str): Either ``'auto'`` or a literal torch device string. + + Returns: + str: The resolved torch device string. + """ + if requested != 'auto': + return requested + try: + import torch + except ImportError: + return 'cpu' + return 'cuda' if torch.cuda.is_available() else 'cpu' + + +def _validate_ckpt(ckpt_path: str) -> None: + """ + Verify the checkpoint at ``ckpt_path`` is plausibly real. + + Three guards: file exists, size ≥ 1 MB (rejects the 45-byte LFS-pointer placeholder shipped in goflow_lean@main), + and ``torch.load`` returns a dict containing a ``'state_dict'`` key (rejects malformed pickles or + non-Lightning checkpoints). + + Args: + ckpt_path (str): Path to the checkpoint file. + + Raises: + FileNotFoundError: If ``ckpt_path`` does not exist. + ValueError: If the file is too small or not a Lightning-style ckpt. + """ + if not os.path.isfile(ckpt_path): + raise FileNotFoundError(f'GoFlow checkpoint not found: {ckpt_path}') + if os.path.getsize(ckpt_path) < 1_000_000: + raise ValueError( + f'GoFlow checkpoint is suspiciously small ({os.path.getsize(ckpt_path)} bytes) ' + f'at {ckpt_path}. The 45-byte file shipped in goflow_lean@main is an LFS ' + f'pointer; set ARC_GOFLOW_CKPT to a real Lightning ckpt.' + ) + import torch + # weights_only=False: Lightning ckpts embed an omegaconf.DictConfig in + # 'hyper_parameters' which PyTorch 2.6+'s safe-by-default unpickler refuses. + # We trust the source (a user-supplied or self-trained ckpt that already + # passed the size check above). + obj = torch.load(ckpt_path, map_location='cpu', weights_only=False) + if not isinstance(obj, dict) or 'state_dict' not in obj: + raise ValueError(f'GoFlow checkpoint at {ckpt_path} is not a Lightning ckpt ' + f'(missing "state_dict" key). Got type={type(obj).__name__}.') + + +def run_goflow_inference(input_dict: dict) -> List[dict]: + """ + Run flow-matching ODE sampling on a single reaction. + + Loads the pretrained GoFlow model (Hydra-composed FlowModule + ckpt + state_dict), builds a single-reaction PyG ``Data`` from the atom-mapped + SMILES, and runs ``num_steps`` Euler ODE steps to draw ``n_samples`` + TS-geometry samples. Never propagates exceptions to the caller — any + failure produces a single sentinel entry with ``success=False`` so the + parent adapter can log and continue. + + Args: + input_dict (dict): Parsed ``input.yml`` payload. + Required keys: + ``reactant_xyz_path``, ``product_xyz_path``, ``reactant_smiles``, ``product_smiles``, + ``goflow_repo_path``, ``ckpt_path``, ``feat_dict_path``, ``output_xyz_path``, ``yml_out_path``. + Optional keys: ``n_samples`` (default 10), ``num_steps`` (default 25), ``device`` (default ``'auto'``). + + Returns: + list[dict]: One TSGuess-shaped dict per sample (or a single failure sentinel if the pipeline raised). + Each entry has keys ``method``, ``method_direction``, ``method_index``, ``initial_xyz``, + ``success``, ``execution_time``. + """ + t0 = datetime.datetime.now() + try: + # Late imports: this function only runs inside goflow_env. + import torch + import numpy as np + from torch_geometric.data import Batch + from torchdiffeq import odeint + from hydra import initialize_config_dir, compose + from hydra.utils import instantiate + from ase.data import chemical_symbols + + from goflow.preprocessing import generate_graph_data + from goflow.gotennet.data.components.utils import CountNodesPerGraph + + _validate_ckpt(input_dict['ckpt_path']) + with open(input_dict['feat_dict_path'], 'rb') as f: + feat_dict = pickle.load(f) + feat_dim = sum(len(v) for v in feat_dict.values()) + + n_samples = int(input_dict.get('n_samples', 10)) + num_steps = int(input_dict.get('num_steps', 25)) + device = _resolve_device(input_dict.get('device', 'auto')) + + cfg_dir = os.path.join(input_dict['goflow_repo_path'], 'src', 'goflow', 'configs') + with initialize_config_dir(config_dir=cfg_dir, version_base='1.3'): + cfg = compose(config_name='train', + overrides=['model=flow', + 'data=rdb7', + f'model.representation.n_atom_rdkit_feats={feat_dim}', + f'model.num_samples={n_samples}', + f'model.num_steps={num_steps}', + 'model.sample_method=gaussian'], + ) + flow_module = instantiate(cfg.model) + + ckpt = torch.load(input_dict['ckpt_path'], map_location='cpu', weights_only=False) + flow_module.load_state_dict(ckpt['state_dict'], strict=True) + flow_module = flow_module.to(device).eval() + + pos_r = read_xyz_positions(input_dict['reactant_xyz_path']) + # goflow's CountNodesPerGraph transform reads len(data.pos) to set + # num_nodes, so pos_gt must be a tensor of shape (N, 3) — None breaks it. + # We pass pos_r purely as a placeholder of the right shape; FlowModule's + # test_step (which would Kabsch-align samples to data.pos) is never + # invoked because our _ode_sample helper drives the ODE directly. + data = generate_graph_data(r_smiles=input_dict['reactant_smiles'], + p_smiles=input_dict['product_smiles'], + pos_guess=pos_r, + pos_gt=pos_r, + feat_dict=feat_dict, + ) + data = CountNodesPerGraph()(data) + batch = Batch.from_data_list([data]).to(device) + + n_nodes = batch.num_nodes + seed_base = getattr(flow_module, 'seed', 1) or 1 + + t_T = torch.linspace(0, 1, steps=num_steps, device=device) + + def ode_func(t, x_t_N_3): + t_G = torch.tensor([t] * batch.num_graphs, device=device) + return flow_module.model_output(x_t_N_3, batch, t_G) + + out_S_N_3 = torch.zeros((n_samples, n_nodes, 3), device=device) + with torch.no_grad(): + for i in range(n_samples): + torch.manual_seed(seed_base + i) + x0 = torch.randn(n_nodes, 3, device=device) + out_S_N_3[i] = odeint(ode_func, x0, t_T, method='euler')[-1] + + pos_S_N_3 = out_S_N_3.cpu().numpy() + symbols = [chemical_symbols[int(z)] for z in batch.atom_type.cpu().numpy()] + write_multi_frame_xyz(input_dict['output_xyz_path'], symbols, pos_S_N_3) + + elapsed = datetime.datetime.now() - t0 + return [{'method': 'GoFlow', + 'method_direction': 'F', + 'method_index': i, + 'initial_xyz': format_xyz_block(symbols, pos_S_N_3[i]), + 'success': True, + 'execution_time': str(elapsed)} + for i in range(n_samples)] + except Exception: + traceback.print_exc() + elapsed = datetime.datetime.now() - t0 + return [_failed_guess(elapsed, index=0)] + + +def parse_command_line_arguments(command_line_args: Optional[list] = None): + """ + Parse the script's command-line arguments. + + Args: + command_line_args (list, optional): Override sys.argv (used by tests). + Defaults to ``None`` which reads from ``sys.argv``. + + Returns: + argparse.Namespace: Parsed flags. Currently exposes only ``yml_in_path`` (path to ``input.yml``). + """ + parser = argparse.ArgumentParser(description='Run GoFlow to generate TS guesses for an ARC reaction.') + parser.add_argument('--yml_in_path', metavar='input', type=str, default='input.yml', + help='Path to the input YAML file (default: ./input.yml).') + return parser.parse_args(command_line_args) + + +def main(): + """ + Script entry point. + + Reads ``input.yml`` (path from ``--yml_in_path``), runs :func:`run_goflow_inference`, and writes the TSGuess list + to the ``yml_out_path`` declared in the input. Prints a one-line summary to stdout. + Exits with code 1 if the input file is missing. + """ + args = parse_command_line_arguments() + yml_in_path = str(args.yml_in_path) + if not os.path.isfile(yml_in_path): + print(f'[goflow_script] input file not found: {yml_in_path}', file=sys.stderr) + sys.exit(1) + input_dict = read_yaml_file_local(yml_in_path) + + tsgs = run_goflow_inference(input_dict) + save_yaml_file_local(path=input_dict['yml_out_path'], content=tsgs) + n_ok = sum(1 for tsg in tsgs if tsg.get('success')) + print(f'[goflow_script] wrote {len(tsgs)} TSGuess entries ({n_ok} successful) to {input_dict["yml_out_path"]}', + flush=True) + + +if __name__ == '__main__': + main() diff --git a/arc/job/adapters/scripts/goflow_script_test.py b/arc/job/adapters/scripts/goflow_script_test.py new file mode 100644 index 0000000000..a730d8d54e --- /dev/null +++ b/arc/job/adapters/scripts/goflow_script_test.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Unit tests for the pure-Python helpers in +``arc.job.adapters.scripts.goflow_script``. + +The script as a whole is intended to run inside ``goflow_env`` (where torch ++ goflow are importable). The helpers tested here are stdlib-only and run in +ARC's main env — that's by design: I/O parsing, YAML serialization, and +sentinel construction must work even before the heavy ML stack is installed. + +Heavy paths (Hydra config compose, ckpt load, ODE sampling) are exercised by +the env-gated Tier-2 tests in ``arc/job/adapters/ts/goflow_test.py``. +""" + +import datetime +import os +import tempfile +import unittest + +from arc.job.adapters.scripts.goflow_script import ( + _failed_guess, + format_xyz_block, + read_xyz_positions, + save_yaml_file_local, + write_multi_frame_xyz, +) + + +class TestReadXyzPositions(unittest.TestCase): + """`read_xyz_positions(path)` parses a single-frame plain XYZ → (N, 3) array.""" + + def setUp(self): + self.tmp = tempfile.NamedTemporaryFile('w', suffix='.xyz', delete=False) + + def tearDown(self): + try: + os.unlink(self.tmp.name) + except OSError: + pass + + def test_parses_three_atom_xyz(self): + self.tmp.write('3\n# comment\nC 0.0 0.0 0.0\nH 1.0 0.0 0.0\nH 0.0 1.0 0.0\n') + self.tmp.close() + pos = read_xyz_positions(self.tmp.name) + self.assertEqual(pos.shape, (3, 3)) + self.assertAlmostEqual(pos[1, 0], 1.0) + self.assertAlmostEqual(pos[2, 1], 1.0) + + def test_handles_trailing_blank_lines(self): + self.tmp.write('2\n\nN 0.5 0.5 0.5\nO 1.5 0.5 0.5\n\n\n') + self.tmp.close() + pos = read_xyz_positions(self.tmp.name) + self.assertEqual(pos.shape, (2, 3)) + + +class TestFormatXyzBlock(unittest.TestCase): + """`format_xyz_block(symbols, pos)` returns body-only XYZ (no header).""" + + def test_emits_no_header_no_comment(self): + block = format_xyz_block(['H', 'O', 'H'], + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]]) + lines = block.strip().split('\n') + self.assertEqual(len(lines), 3) + # First token of every line is the symbol. + self.assertEqual(lines[0].split()[0], 'H') + self.assertEqual(lines[1].split()[0], 'O') + # Should NOT start with a numeric atom-count header. + self.assertFalse(lines[0].strip().isdigit()) + + +class TestWriteMultiFrameXyz(unittest.TestCase): + """`write_multi_frame_xyz(path, symbols, pos_S_N_3)` round-trips.""" + + def test_round_trips_two_frames(self): + symbols = ['C', 'H'] + pos_S_N_3 = [[[0.0, 0.0, 0.0], [1.1, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [1.2, 0.0, 0.0]]] + with tempfile.NamedTemporaryFile('w', suffix='.xyz', delete=False) as f: + path = f.name + try: + write_multi_frame_xyz(path, symbols, pos_S_N_3) + with open(path) as fh: + lines = [ln.rstrip('\n') for ln in fh] + # Each frame: 1 atom-count line + 1 comment line + 2 atom lines = 4 lines. + # 2 frames = 8 lines (allowing trailing newlines). + self.assertGreaterEqual(len(lines), 8) + atom_count_lines = [ln for ln in lines if ln.strip().isdigit()] + self.assertEqual(len(atom_count_lines), 2) + self.assertEqual(int(atom_count_lines[0]), 2) + finally: + os.unlink(path) + + +class TestFailedGuessSentinel(unittest.TestCase): + """`_failed_guess(elapsed)` returns the standard failure dict shape.""" + + def test_sentinel_has_required_keys_and_method_name(self): + sentinel = _failed_guess(datetime.timedelta(seconds=1), index=0) + for key in ('method', 'method_direction', 'method_index', + 'initial_xyz', 'success', 'execution_time'): + self.assertIn(key, sentinel) + self.assertEqual(sentinel['method'], 'GoFlow') + self.assertFalse(sentinel['success']) + self.assertIsNone(sentinel['initial_xyz']) + + +class TestSaveYamlFileBlockLiteral(unittest.TestCase): + """`save_yaml_file_local(path, content)` writes multi-line strings as block literals.""" + + def test_multi_line_xyz_string_uses_block_literal_style(self): + content = [{'method': 'GoFlow', 'method_index': 0, + 'initial_xyz': 'C 0 0 0\nH 1 0 0\nH 0 1 0', + 'success': True}] + with tempfile.NamedTemporaryFile('w', suffix='.yaml', delete=False) as f: + path = f.name + try: + save_yaml_file_local(path, content) + with open(path) as fh: + text = fh.read() + # Block literal style uses a `|` indicator on the value line. + self.assertIn('initial_xyz: |', text) + finally: + os.unlink(path) + + +if __name__ == '__main__': + unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/arc/job/adapters/scripts/rits_script.py b/arc/job/adapters/scripts/rits_script.py new file mode 100644 index 0000000000..9c0eae1611 --- /dev/null +++ b/arc/job/adapters/scripts/rits_script.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +A standalone script to run RitS (Right into the Saddle) and emit TS guesses +as a YAML file consumable by ARC's RitSAdapter. + +This script must be invoked from inside the ``rits_env`` conda environment +(it does NOT import ``megalodon`` directly — RitS's own +``scripts/sample_transition_state.py`` does that). The parent ARC process +shells out to this script via ``subprocess.run`` so that ARC's main env +does not have to carry the heavy ML dependency stack. + +Input file (``input.yml``) +-------------------------- +Required keys: + reactant_xyz_path : str absolute path to a plain XYZ file (atom-mapped) + product_xyz_path : str absolute path to the matching product XYZ + rits_repo_path : str absolute path to the RitS source checkout + ckpt_path : str absolute path to the pretrained ``rits.ckpt`` + output_xyz_path : str absolute path RitS should write its raw output to + yml_out_path : str absolute path this script writes the parsed TSGuess list to + +Optional keys (with defaults): + config_path : str defaults to ``/scripts/conf/rits.yaml`` + n_samples : int default 10 + batch_size : int default 32 + charge : int default 0 + device : str default 'auto' (RitS picks GPU if visible, else CPU) + add_stereo : bool default False + num_steps : int default None (use config value) + +Output (``yml_out_path``) +------------------------- +A YAML *list* of TSGuess dictionaries. Each entry has: + method : 'RitS' + method_direction : 'F' + method_index : int (0-based sample index) + initial_xyz : str (XYZ-format coordinate block, no header lines) + success : bool + execution_time : str (str(datetime.timedelta)) + +If RitS fails to produce any usable output, the script writes a list with a +single failed-guess entry instead of raising — the parent adapter then logs +the failure but continues running other TS methods. +""" + +import argparse +import datetime +import os +import subprocess +import sys +import traceback +from typing import List, Optional + +import yaml + + +def read_yaml_file(path: str) -> dict: + """Read a YAML file and return its contents as a dict.""" + with open(path, 'r') as f: + return yaml.safe_load(stream=f) + + +def string_representer(dumper, data): + """YAML representer that uses block literals for multi-line strings.""" + if len(data.splitlines()) > 1: + return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data, style='|') + return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data) + + +def save_yaml_file(path: str, content) -> None: + """Save ``content`` to a YAML file at ``path``.""" + yaml.add_representer(str, string_representer) + with open(path, 'w') as f: + f.write(yaml.dump(data=content)) + + +def parse_multi_frame_xyz(xyz_path: str) -> List[str]: + """ + Parse a (possibly multi-frame) XYZ file into a list of coordinate-block strings. + + RitS writes a single XYZ file when ``--n_samples == 1`` and a multi-frame + XYZ when ``n_samples > 1`` (frames concatenated, each prefixed by an atom + count line and a blank/comment line). This parser handles both. + + Args: + xyz_path (str): Path to the XYZ file emitted by RitS. + + Returns: + List[str]: One coordinate block per frame, suitable for passing to + ``arc.species.converter.str_to_xyz`` (atom symbols + xyz only — no + header / comment lines). + """ + if not os.path.isfile(xyz_path): + return list() + with open(xyz_path, 'r') as f: + raw_lines = [line.rstrip('\n') for line in f] + frames = list() + i, n = 0, len(raw_lines) + while i < n: + # Skip blank lines between frames + while i < n and not raw_lines[i].strip(): + i += 1 + if i >= n: + break + # First non-blank line of a frame should be the atom count + try: + n_atoms = int(raw_lines[i].strip()) + except ValueError: + # Not a frame header — bail on this row to avoid an infinite loop + i += 1 + continue + i += 1 + # Comment / energy line (may be blank) + if i < n: + i += 1 + # The next n_atoms lines are coordinates + coord_lines = list() + for _ in range(n_atoms): + if i >= n: + break + coord_lines.append(raw_lines[i]) + i += 1 + if len(coord_lines) == n_atoms: + frames.append('\n'.join(coord_lines)) + return frames + + +def run_rits(input_dict: dict) -> List[dict]: + """ + Invoke ``scripts/sample_transition_state.py`` from the RitS source tree + and parse the resulting XYZ frames into a list of TSGuess dictionaries. + + Args: + input_dict (dict): The parsed contents of ``input.yml``. + + Returns: + List[dict]: One TSGuess-shaped dict per generated sample. Always at + least one entry — a failed sentinel if RitS produced nothing. + """ + repo = input_dict['rits_repo_path'] + sample_script = os.path.join(repo, 'scripts', 'sample_transition_state.py') + config_path = input_dict.get('config_path') or os.path.join(repo, 'scripts', 'conf', 'rits.yaml') + output_xyz = input_dict['output_xyz_path'] + n_samples = int(input_dict.get('n_samples', 10)) + batch_size = int(input_dict.get('batch_size', 32)) + charge = int(input_dict.get('charge', 0)) + device = str(input_dict.get('device', 'auto')) + add_stereo = bool(input_dict.get('add_stereo', False)) + num_steps = input_dict.get('num_steps') + + cmd = [ + sys.executable, sample_script, + '--reactant_xyz', input_dict['reactant_xyz_path'], + '--product_xyz', input_dict['product_xyz_path'], + '--config', config_path, + '--ckpt', input_dict['ckpt_path'], + '--output', output_xyz, + '--n_samples', str(n_samples), + '--batch_size', str(batch_size), + '--charge', str(charge), + '--device', device, + ] + if add_stereo: + cmd.append('--add_stereo') + if num_steps is not None: + cmd.extend(['--num_steps', str(num_steps)]) + + t0 = datetime.datetime.now() + print(f'[rits_script] running: {" ".join(cmd)}', flush=True) + completed = subprocess.run(cmd, cwd=repo) + elapsed = datetime.datetime.now() - t0 + + if completed.returncode != 0: + print(f'[rits_script] sample_transition_state.py exited with code {completed.returncode}', flush=True) + return [_failed_guess(elapsed, index=0)] + + frames = parse_multi_frame_xyz(output_xyz) + if not frames: + print(f'[rits_script] no frames parsed from {output_xyz}', flush=True) + return [_failed_guess(elapsed, index=0)] + + tsgs = list() + for i, coord_block in enumerate(frames): + tsgs.append({ + 'method': 'RitS', + 'method_direction': 'F', + 'method_index': i, + 'initial_xyz': coord_block, + 'success': True, + 'execution_time': str(elapsed), + }) + return tsgs + + +def _failed_guess(elapsed: datetime.timedelta, index: int = 0) -> dict: + """Return a failed-TSGuess sentinel dict.""" + return { + 'method': 'RitS', + 'method_direction': 'F', + 'method_index': index, + 'initial_xyz': None, + 'success': False, + 'execution_time': str(elapsed), + } + + +def parse_command_line_arguments(command_line_args: Optional[list] = None) -> argparse.Namespace: + """Parse the script's command-line arguments.""" + parser = argparse.ArgumentParser(description='Run RitS to generate TS guesses for an ARC reaction.') + parser.add_argument('--yml_in_path', metavar='input', type=str, default='input.yml', + help='Path to the input YAML file (default: ./input.yml).') + return parser.parse_args(command_line_args) + + +def main(): + """Entry point: read input.yml, run RitS, write output YAML.""" + args = parse_command_line_arguments() + yml_in_path = str(args.yml_in_path) + if not os.path.isfile(yml_in_path): + print(f'[rits_script] input file not found: {yml_in_path}', file=sys.stderr) + sys.exit(1) + input_dict = read_yaml_file(yml_in_path) + + try: + tsgs = run_rits(input_dict) + except Exception: + traceback.print_exc() + tsgs = [_failed_guess(datetime.timedelta(0), index=0)] + + save_yaml_file(path=input_dict['yml_out_path'], content=tsgs) + n_ok = sum(1 for tsg in tsgs if tsg.get('success')) + print(f'[rits_script] wrote {len(tsgs)} TSGuess entries ({n_ok} successful) to {input_dict["yml_out_path"]}', + flush=True) + + +if __name__ == '__main__': + main() diff --git a/arc/job/adapters/scripts/rits_script_test.py b/arc/job/adapters/scripts/rits_script_test.py new file mode 100644 index 0000000000..524d575366 --- /dev/null +++ b/arc/job/adapters/scripts/rits_script_test.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Unit tests for the pure-Python helpers in +``arc.job.adapters.scripts.rits_script``. + +The script as a whole is intended to run inside ``rits_env`` (where torch ++ megalodon are importable). The helpers tested here are stdlib-only and +run in ARC's main env — that's by design: I/O parsing, multi-frame XYZ +splitting, and YAML serialization must work even before the heavy ML +stack is installed. + +Heavy paths (``run_rits`` orchestration with the real subprocess) are +exercised by the env-gated Tier-2 tests in +``arc/job/adapters/ts/rits_test.py``. +""" + +import os +import shutil +import unittest + +from arc.common import ARC_TESTING_PATH +from arc.job.adapters.scripts.rits_script import parse_multi_frame_xyz + + +class TestRitSScriptParser(unittest.TestCase): + """Direct unit tests for arc/job/adapters/scripts/rits_script.py:parse_multi_frame_xyz.""" + + @classmethod + def setUpClass(cls): + cls.tmp_dir = os.path.join(ARC_TESTING_PATH, 'rits_script_parser') + os.makedirs(cls.tmp_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmp_dir, ignore_errors=True) + + def _write(self, name: str, body: str) -> str: + path = os.path.join(self.tmp_dir, name) + with open(path, 'w') as f: + f.write(body) + return path + + def test_single_frame_xyz(self): + body = "3\n\nC 0.0 0.0 0.0\nH 1.0 0.0 0.0\nH -1.0 0.0 0.0\n" + frames = parse_multi_frame_xyz(self._write('one.xyz', body)) + self.assertEqual(len(frames), 1) + self.assertEqual(frames[0].splitlines()[0].split()[0], 'C') + + def test_multi_frame_xyz(self): + body = ("3\n\nC 0.0 0.0 0.0\nH 1.0 0.0 0.0\nH -1.0 0.0 0.0\n" + "3\n\nC 0.1 0.0 0.0\nH 1.1 0.0 0.0\nH -0.9 0.0 0.0\n") + frames = parse_multi_frame_xyz(self._write('two.xyz', body)) + self.assertEqual(len(frames), 2) + # Frame 0 starts at the origin; frame 1 is shifted by +0.1 in x + self.assertAlmostEqual(float(frames[0].splitlines()[0].split()[1]), 0.0) + self.assertAlmostEqual(float(frames[1].splitlines()[0].split()[1]), 0.1) + + def test_missing_file_returns_empty_list(self): + frames = parse_multi_frame_xyz(os.path.join(self.tmp_dir, 'nope.xyz')) + self.assertEqual(frames, list()) + + def test_garbage_does_not_loop_forever(self): + body = "this is not an xyz\nat all\n" + frames = parse_multi_frame_xyz(self._write('garbage.xyz', body)) + self.assertEqual(frames, list()) + + +if __name__ == '__main__': + unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/arc/job/adapters/ts/__init__.py b/arc/job/adapters/ts/__init__.py index 5d571e8e80..89f5de1d3d 100644 --- a/arc/job/adapters/ts/__init__.py +++ b/arc/job/adapters/ts/__init__.py @@ -1,6 +1,8 @@ import arc.job.adapters.ts.autotst_ts import arc.job.adapters.ts.gcn_ts +import arc.job.adapters.ts.goflow_ts import arc.job.adapters.ts.heuristics import arc.job.adapters.ts.kinbot_ts +import arc.job.adapters.ts.rits_ts import arc.job.adapters.ts.xtb_gsm import arc.job.adapters.ts.orca_neb diff --git a/arc/job/adapters/ts/goflow_test.py b/arc/job/adapters/ts/goflow_test.py new file mode 100644 index 0000000000..85eaeef054 --- /dev/null +++ b/arc/job/adapters/ts/goflow_test.py @@ -0,0 +1,1286 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Unit tests for the GoFlow TS-guess adapter (``arc.job.adapters.ts.goflow_ts``). + +Test tiers +---------- +**Tier-1** (always runs): wiring, settings resolution, helper functions, adapter +instantiation, graceful skip when ``goflow_env`` / ckpt are missing, mocked +subprocess. + +**Tier-2** (gated on ``_goflow_environment_ready()``): end-to-end +``execute_incore`` against the real ``goflow_env`` for a handful of family- +diverse reactions. Includes an explicit `strict=True` `load_state_dict` test +so a placeholder checkpoint file (e.g. the 45-byte LFS pointer shipped in +goflow_lean@main) is rejected immediately. + +The Tier-2 tests are skipped automatically on CI runners that did not run +``install_goflow.sh`` AND do not have ``ARC_GOFLOW_CKPT`` pointing at a real +checkpoint. +""" + +import json +import math +import os +import shutil +import subprocess +import sys +import tempfile +import unittest +from types import SimpleNamespace +from unittest import mock + +from rdkit import Chem + +from arc.common import read_yaml_file, save_yaml_file +from arc.job.adapter import JobEnum +from arc.job.adapters.common import all_families_ts_adapters, default_incore_adapters +from arc.settings import external_paths as goflow_paths +from arc.job.adapters.ts.goflow_ts import ( + GOFLOW_DEDUP_DMAT_RMSD, + GoFlowAdapter, + MAX_GOFLOW_ATOMS, + _goflow_environment_ready, + _within_goflow_supported_domain, + build_atom_mapped_smiles, + process_goflow_tsg, +) +from arc.reaction import ARCReaction +from arc.settings import settings as settings_mod +from arc.species.converter import str_to_xyz, xyz_to_str +from arc.species.species import ARCSpecies, TSGuess + + +class TestJobEnumIncludesGoFlow(unittest.TestCase): + """JobEnum must expose `goflow` for adapter selection.""" + + def test_goflow_is_a_member_of_job_enum(self): + self.assertTrue(hasattr(JobEnum, 'goflow'), 'JobEnum is missing the `goflow` member') + self.assertEqual(JobEnum.goflow.value, 'goflow') + + +class TestDefaultIncoreAdaptersIncludesGoFlow(unittest.TestCase): + """GoFlow runs incore (no queue submission needed).""" + + def test_goflow_in_default_incore_adapters(self): + self.assertIn('goflow', default_incore_adapters) + + +class TestAllFamiliesTSAdaptersIncludesGoFlow(unittest.TestCase): + """ + GoFlow ships in the default ``ts_adapters`` list, so it must also be in + ``all_families_ts_adapters`` — otherwise the scheduler's gating in + ``spawn_ts_jobs`` would silently never spawn it. Out-of-domain reactions + (non-H/C/N/O/F elements, >100 atoms) are filtered at runtime by the + adapter's own ``_within_goflow_supported_domain`` guard; hosts without + ``goflow_env`` skip cleanly via ``_goflow_environment_ready``. + """ + + def test_goflow_in_all_families_ts_adapters(self): + self.assertIn('goflow', all_families_ts_adapters) + + +class TestWithinGoFlowSupportedDomain(unittest.TestCase): + """ + GoFlow was trained on RDB7 (small organic, H/C/N/O/F). Reactions outside + the validated domain must be skipped cleanly with a clear warning, not + sent to the model where they would either crash or produce silently bad + geometries. + """ + + def test_accepts_h_abstraction_ch4_oh(self): + rxn = ARCReaction(r_species=[ARCSpecies(label='CH4', smiles='C'), + ARCSpecies(label='OH', smiles='[OH]')], + p_species=[ARCSpecies(label='CH3', smiles='[CH3]'), + ARCSpecies(label='H2O', smiles='O')]) + ok, reason = _within_goflow_supported_domain(rxn) + self.assertTrue(ok, msg=f'Should accept H-abstraction; got reason={reason!r}') + self.assertEqual(reason, '') + + def test_rejects_unsupported_element(self): + """Sulfur (or any element not in H/C/N/O/F) → reject.""" + rxn = ARCReaction(r_species=[ARCSpecies(label='H2S', smiles='S')], + p_species=[ARCSpecies(label='HS', smiles='[SH]'), + ARCSpecies(label='H', smiles='[H]')]) + ok, reason = _within_goflow_supported_domain(rxn) + self.assertFalse(ok) + self.assertIn('S', reason) + + def test_rejects_reaction_above_max_atom_threshold(self): + """ + Build a lightweight rxn-shaped stand-in. The function only iterates + rxn.r_species and reads either get_xyz() or mol.atoms. A SimpleNamespace + with the minimum surface area avoids slow RDKit perception over a + 100-carbon polyene SMILES that the production code never sees. + """ + n_extra = 5 + n_atoms = MAX_GOFLOW_ATOMS + n_extra + big_xyz = {'symbols': ('C',) * n_atoms, + 'isotopes': (12,) * n_atoms, + 'coords': tuple((float(i), 0.0, 0.0) for i in range(n_atoms))} + fake_spc = SimpleNamespace(mol=None, get_xyz=lambda: big_xyz) + rxn = SimpleNamespace(r_species=[fake_spc], p_species=[fake_spc]) + ok, reason = _within_goflow_supported_domain(rxn) + self.assertFalse(ok) + self.assertIn('atom', reason.lower()) + + +class TestGoFlowEnvironmentReady(unittest.TestCase): + """ + `_goflow_environment_ready()` reads four module-level globals + (GOFLOW_PYTHON, GOFLOW_REPO_PATH, GOFLOW_CKPT_PATH, GOFLOW_FEAT_DICT_PATH) + and returns True iff all four point at real, plausibly-valid files/dirs. + """ + + def test_returns_false_when_python_missing(self): + with unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_PYTHON', None), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_REPO_PATH', '/tmp/repo'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_CKPT_PATH', '/tmp/ckpt'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_FEAT_DICT_PATH', '/tmp/fd'): + self.assertFalse(_goflow_environment_ready()) + + def test_returns_false_when_ckpt_missing(self): + with unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_PYTHON', '/usr/bin/python'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_REPO_PATH', '/tmp/repo'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_CKPT_PATH', None), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_FEAT_DICT_PATH', '/tmp/fd'): + self.assertFalse(_goflow_environment_ready()) + + +class TestProcessGoFlowTSG(unittest.TestCase): + """ + `process_goflow_tsg(tsg_dict, local_path, ts_species)` converts a script- + output TSGuess dict into an ARC TSGuess, checks for collisions, dedups + against existing guesses, and saves the geometry. + """ + + def setUp(self): + self.tmpdir = tempfile.mkdtemp(prefix='goflow_test_') + self.ts_species = ARCSpecies(label='ts', is_ts=True, smiles='[CH2]C') + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_returns_false_for_failed_guess(self): + bad = {'method': 'GoFlow', 'method_direction': 'F', 'method_index': 0, + 'success': False, 'initial_xyz': None} + self.assertFalse(process_goflow_tsg(bad, self.tmpdir, self.ts_species)) + self.assertEqual(len(self.ts_species.ts_guesses), 0) + + def test_returns_false_when_atoms_collide(self): + # Two atoms at exactly the same position. + collide_xyz = "C 0.0 0.0 0.0\nH 0.0 0.0 0.0" + tsg = {'method': 'GoFlow', 'method_direction': 'F', 'method_index': 0, + 'success': True, 'initial_xyz': collide_xyz} + self.assertFalse(process_goflow_tsg(tsg, self.tmpdir, self.ts_species)) + + def test_appends_new_unique_guess_to_species(self): + tsg = {'method': 'GoFlow', 'method_direction': 'F', 'method_index': 0, + 'success': True, + 'initial_xyz': 'C 0.0 0.0 0.0\nH 0.0 0.0 1.1\nH 0.0 1.0 -0.4\n' + 'H 0.9 -0.5 -0.4\nH -0.9 -0.5 -0.4'} + ok = process_goflow_tsg(tsg, self.tmpdir, self.ts_species) + self.assertTrue(ok) + self.assertEqual(len(self.ts_species.ts_guesses), 1) + self.assertIn('goflow', self.ts_species.ts_guesses[0].method.lower()) + + def test_consolidates_near_duplicate_against_existing_guess(self): + """A second guess that's only a tiny perturbation (well under the + dmat-RMSD threshold) of an existing one should NOT be appended; + the existing guess's `method` should be annotated to credit GoFlow.""" + first = {'method': 'GoFlow', 'method_direction': 'F', 'method_index': 0, + 'success': True, + 'initial_xyz': 'C 0.000 0.000 0.000\nH 0.000 0.000 1.100\n' + 'H 0.000 1.000 -0.400\nH 0.900 -0.500 -0.400\n' + 'H -0.900 -0.500 -0.400'} + # Same skeleton, every atom shifted by ~0.01 Å — well below the + # 0.15 Å aggregate dmat-RMSD threshold. + second = {'method': 'GoFlow', 'method_direction': 'F', 'method_index': 1, + 'success': True, + 'initial_xyz': 'C 0.010 0.010 0.000\nH 0.010 0.010 1.100\n' + 'H 0.010 1.010 -0.400\nH 0.910 -0.490 -0.400\n' + 'H -0.890 -0.490 -0.400'} + self.assertTrue(process_goflow_tsg(first, self.tmpdir, self.ts_species)) + self.assertFalse(process_goflow_tsg(second, self.tmpdir, self.ts_species)) + self.assertEqual(len(self.ts_species.ts_guesses), 1, + 'second near-duplicate should have been consolidated, not appended') + + def test_appends_distinct_guess_with_dmat_rmsd_above_threshold(self): + """A geometrically distinct second guess (dmat-RMSD > threshold) + must be appended as a new unique TSGuess.""" + first = {'method': 'GoFlow', 'method_direction': 'F', 'method_index': 0, + 'success': True, + 'initial_xyz': 'C 0.000 0.000 0.000\nH 0.000 0.000 1.100\n' + 'H 0.000 1.000 -0.400\nH 0.900 -0.500 -0.400\n' + 'H -0.900 -0.500 -0.400'} + # Same connectivity but the migrating-H is in a clearly different + # position; aggregate dmat-RMSD will be well above 0.15 Å. + second = {'method': 'GoFlow', 'method_direction': 'F', 'method_index': 1, + 'success': True, + 'initial_xyz': 'C 0.000 0.000 0.000\nH 0.000 0.000 1.500\n' + 'H 0.500 1.300 -0.400\nH 1.200 -0.500 -0.400\n' + 'H -1.200 -0.500 -0.400'} + self.assertTrue(process_goflow_tsg(first, self.tmpdir, self.ts_species)) + self.assertTrue(process_goflow_tsg(second, self.tmpdir, self.ts_species)) + self.assertEqual(len(self.ts_species.ts_guesses), 2) + + def test_consolidates_rotor_twins_with_heavy_atoms_unchanged(self): + """Two TS guesses that agree on the heavy-atom skeleton but differ + only in the torsion of a terminal H pair must be consolidated — + they're the same TS, sampled at different rotor wells.""" + # CH3 with one bond stretched — heavy atom (C) at origin in both. + first = {'method': 'GoFlow', 'method_direction': 'F', 'method_index': 0, + 'success': True, + 'initial_xyz': 'C 0.000 0.000 0.000\nH 0.000 0.000 1.500\n' + 'H 0.000 1.000 -0.400\nH 0.900 -0.500 -0.400\n' + 'H -0.900 -0.500 -0.400'} + # Same heavy atom, but the three "spectator" H's are rotated ~60° + # around the C-H(1) axis. The non-reactive H positions move ~0.5 Å + # each — pushing the all-atom dmat-RMSD well above 0.15 Å, but the + # heavy-atom dmat-RMSD stays at zero. + second = {'method': 'GoFlow', 'method_direction': 'F', 'method_index': 1, + 'success': True, + 'initial_xyz': 'C 0.000 0.000 0.000\nH 0.000 0.000 1.500\n' + 'H 0.866 0.500 -0.400\nH -0.866 0.500 -0.400\n' + 'H 0.000 -1.000 -0.400'} + self.assertTrue(process_goflow_tsg(first, self.tmpdir, self.ts_species)) + self.assertFalse(process_goflow_tsg(second, self.tmpdir, self.ts_species)) + self.assertEqual(len(self.ts_species.ts_guesses), 1, + 'rotor twin with same heavy-atom skeleton must consolidate') + + def test_consolidation_annotates_existing_method_string(self): + """When a near-duplicate hits an existing guess from a *different* + adapter (e.g. heuristics), the existing guess's `method` should be + appended with ' and GoFlow' so downstream consumers see both.""" + # Pre-seed ts_species with a heuristics-style guess. + ts_xyz_str = ('C 0.000 0.000 0.000\nH 0.000 0.000 1.100\n' + 'H 0.000 1.000 -0.400\nH 0.900 -0.500 -0.400\n' + 'H -0.900 -0.500 -0.400') + seed = TSGuess(method='Heuristics', method_direction='F', + method_index=0, success=True, + xyz=str_to_xyz(ts_xyz_str)) + self.ts_species.ts_guesses.append(seed) + + # GoFlow produces a near-duplicate of the heuristics guess. + twin = {'method': 'GoFlow', 'method_direction': 'F', 'method_index': 0, + 'success': True, + 'initial_xyz': 'C 0.005 0.005 0.000\nH 0.005 0.005 1.100\n' + 'H 0.005 1.005 -0.400\nH 0.905 -0.495 -0.400\n' + 'H -0.895 -0.495 -0.400'} + self.assertFalse(process_goflow_tsg(twin, self.tmpdir, self.ts_species)) + self.assertEqual(len(self.ts_species.ts_guesses), 1) + # Method string should now mention BOTH adapters. + merged_method = self.ts_species.ts_guesses[0].method.lower() + self.assertIn('heuristics', merged_method) + self.assertIn('goflow', merged_method) + + +class TestBuildAtomMappedSmiles(unittest.TestCase): + """ + `build_atom_mapped_smiles(rxn, side)` produces SMILES with every atom + (including every H) carrying an atom-map number 1..N. This is the + highest-risk helper in the adapter: GoFlow's preprocessor parses the + SMILES and reorders atoms by map number, so any silent loss of + hydrogens or duplicate map numbers will silently corrupt inference. + """ + + @classmethod + def setUpClass(cls): + # nC3H7 (10 atoms: 3C + 7H) → iC3H7. Same atom count, easy to verify. + cls.rxn = ARCReaction(r_species=[ARCSpecies(label='nC3H7', smiles='[CH2]CC')], + p_species=[ARCSpecies(label='iC3H7', smiles='C[CH]C')]) + # ARC will compute the atom_map lazily; force it once and cache. + cls.atom_map = cls.rxn.atom_map # may be None if mapping fails + + def test_returns_none_for_unknown_side(self): + self.assertIsNone(build_atom_mapped_smiles(self.rxn, side='other')) + + def test_returns_smiles_with_every_h_explicit(self): + smi = build_atom_mapped_smiles(self.rxn, side='reactants') + self.assertIsNotNone(smi, 'reactant SMILES build returned None') + # Round-trip the SMILES with H preservation; every H must be a real atom. + params = Chem.SmilesParserParams() + params.removeHs = False + mol = Chem.MolFromSmiles(smi, params) + self.assertIsNotNone(mol) + n_h_atoms = sum(1 for a in mol.GetAtoms() if a.GetSymbol() == 'H') + n_heavy_atoms = sum(1 for a in mol.GetAtoms() if a.GetSymbol() != 'H') + self.assertEqual(n_h_atoms, 7) + self.assertEqual(n_heavy_atoms, 3) + self.assertEqual(mol.GetNumAtoms(), 10) + + def test_map_numbers_are_one_through_n_with_no_gaps(self): + smi = build_atom_mapped_smiles(self.rxn, side='reactants') + params = Chem.SmilesParserParams() + params.removeHs = False + mol = Chem.MolFromSmiles(smi, params) + maps = sorted(a.GetAtomMapNum() for a in mol.GetAtoms()) + self.assertEqual(maps, list(range(1, 11))) + + def test_reactant_and_product_smiles_share_the_same_map_set(self): + if self.atom_map is None: + self.skipTest('rxn.atom_map could not be computed in this env') + r_smi = build_atom_mapped_smiles(self.rxn, side='reactants') + p_smi = build_atom_mapped_smiles(self.rxn, side='products') + self.assertIsNotNone(r_smi) + self.assertIsNotNone(p_smi) + params = Chem.SmilesParserParams() + params.removeHs = False + r_maps = sorted(a.GetAtomMapNum() for a in Chem.MolFromSmiles(r_smi, params).GetAtoms()) + p_maps = sorted(a.GetAtomMapNum() for a in Chem.MolFromSmiles(p_smi, params).GetAtoms()) + self.assertEqual(r_maps, p_maps) + + def test_returns_none_when_atom_map_unavailable_for_products(self): + """If rxn has no mapping, we cannot build product-side mapped SMILES.""" + rxn = ARCReaction( + r_species=[ARCSpecies(label='nC3H7', smiles='[CH2]CC')], + p_species=[ARCSpecies(label='iC3H7', smiles='C[CH]C')], + ) + rxn._atom_map = None + # Force the lazy property to return our None override. + with unittest.mock.patch.object(type(rxn), 'atom_map', + new_callable=unittest.mock.PropertyMock, + return_value=None): + self.assertIsNone(build_atom_mapped_smiles(rxn, side='products')) + + +class TestBuildAtomMappedSmilesStress(unittest.TestCase): + """ + Stress tests for ``build_atom_mapped_smiles`` across a family-diverse + set of reactions. Each fixture exercises a different code path: + - cross-fragment H migration (H + CH4 → CH3 + H2) + - heavy-atom permutation across fragments (2 CH3 → C2H6) + - heteroatom on reactant side (H + NH3, H + HF) + - O-H abstraction (H + CH3OH → H2 + CH3O) + - large complex permutation (CH3 + C2H6 → CH4 + C2H5) + - addition (H + propene → nC3H7) — multi-fragment reactant → single product + """ + + @staticmethod + def _build(name, r_species, p_species): + rxn = ARCReaction(r_species=r_species, p_species=p_species) + # Force atom_map computation; some fixtures fail at this stage + # (e.g. when ARC's mapping heuristics choke). We test only those + # that produce a real atom_map — others would also be skipped at + # adapter runtime. + try: + am = rxn.atom_map + except Exception: + am = None + return name, rxn, am + + @classmethod + def setUpClass(cls): + cls.fixtures = [] + cls.fixtures.append(cls._build( + 'h_abstraction_ch4', + [ARCSpecies(label='CH4', smiles='C'), ARCSpecies(label='H', smiles='[H]')], + [ARCSpecies(label='CH3', smiles='[CH3]'), ARCSpecies(label='H2', smiles='[H][H]')], + )) + cls.fixtures.append(cls._build( + 'oh_abstraction', + [ARCSpecies(label='CH3OH', smiles='CO'), ARCSpecies(label='H', smiles='[H]')], + [ARCSpecies(label='CH3O', smiles='[CH2]O'), ARCSpecies(label='H2', smiles='[H][H]')], + )) + cls.fixtures.append(cls._build( + 'nh3_h_abstraction', + [ARCSpecies(label='NH3', smiles='N'), ARCSpecies(label='H', smiles='[H]')], + [ARCSpecies(label='NH2', smiles='[NH2]'), ARCSpecies(label='H2', smiles='[H][H]')], + )) + cls.fixtures.append(cls._build( + 'methyl_recombination', + [ARCSpecies(label='CH3', smiles='[CH3]'), ARCSpecies(label='CH3', smiles='[CH3]')], + [ARCSpecies(label='C2H6', smiles='CC')], + )) + cls.fixtures.append(cls._build( + 'cross_h_abstraction', + [ARCSpecies(label='CH3', smiles='[CH3]'), ARCSpecies(label='C2H6', smiles='CC')], + [ARCSpecies(label='CH4', smiles='C'), ARCSpecies(label='C2H5', smiles='C[CH2]')], + )) + cls.fixtures.append(cls._build( + 'h_plus_propene_addition', + [ARCSpecies(label='C3H6', smiles='C=CC'), ARCSpecies(label='H', smiles='[H]')], + [ARCSpecies(label='C3H7', smiles='[CH2]CC')], + )) + cls.fixtures.append(cls._build( + 'hf_h_abstraction', + [ARCSpecies(label='HF', smiles='F'), ARCSpecies(label='H', smiles='[H]')], + [ARCSpecies(label='F', smiles='[F]'), ARCSpecies(label='H2', smiles='[H][H]')], + )) + + def _smiles_atoms(self, smi): + """Parse SMILES with explicit Hs preserved; return list of (map_num, element).""" + params = Chem.SmilesParserParams() + params.removeHs = False + mol = Chem.MolFromSmiles(smi, params) + self.assertIsNotNone(mol, f'RDKit could not re-parse: {smi!r}') + return [(a.GetAtomMapNum(), a.GetSymbol()) for a in mol.GetAtoms()] + + def test_every_fixture_produces_n_atom_smiles_with_complete_map_set(self): + """Reactant SMILES has N atoms, every atom carries a unique map num in 1..N. + N is taken from len(atom_map), which is ARC's stoichiometry-aware count + (counts each occurrence of identical species, e.g. 2 CH3 = 8 atoms even + though r_species is deduped to a single CH3 entry).""" + for name, rxn, am in self.fixtures: + with self.subTest(fixture=name): + if am is None: + self.skipTest(f'{name}: ARC atom_map computation failed') + n_atoms = len(am) + smi = build_atom_mapped_smiles(rxn, side='reactants') + self.assertIsNotNone(smi, f'{name}: build returned None') + atoms = self._smiles_atoms(smi) + self.assertEqual(len(atoms), n_atoms, + f'{name}: SMILES has {len(atoms)} atoms, expected {n_atoms}') + self.assertEqual(sorted(m for m, _ in atoms), list(range(1, n_atoms + 1)), + f'{name}: map numbers are not exactly 1..{n_atoms}') + + def test_element_at_each_map_number_is_consistent_across_sides(self): + """For every map number i, the atomic symbol on reactant and product + side must be identical — otherwise GoFlow's preprocessor would assert + on `atomic_numbers_match` mid-inference.""" + for name, rxn, am in self.fixtures: + with self.subTest(fixture=name): + if am is None: + self.skipTest(f'{name}: ARC atom_map computation failed') + r_smi = build_atom_mapped_smiles(rxn, side='reactants') + p_smi = build_atom_mapped_smiles(rxn, side='products') + self.assertIsNotNone(r_smi, f'{name}: reactant build returned None') + self.assertIsNotNone(p_smi, f'{name}: product build returned None') + r_by_map = dict(self._smiles_atoms(r_smi)) + p_by_map = dict(self._smiles_atoms(p_smi)) + self.assertEqual(set(r_by_map), set(p_by_map), f'{name}: map number sets differ across sides') + for mn in r_by_map: + self.assertEqual(r_by_map[mn], p_by_map[mn], + f'{name}: map={mn} is {r_by_map[mn]} on reactant ' + f'but {p_by_map[mn]} on product — atom identity not preserved') + + def test_h_migration_in_ch4_plus_h_actually_swaps_h_position(self): + """Sanity-check the cross-fragment H-migration semantics. For + H + CH4 → CH3 + H2 with atom_map=[0, 5, 1, 2, 3, 4]: + - reactant atom 0 (C) → product atom 0 (C) + - reactant atom 1 (lone H) → product atom 5 (one H of H2) + - reactant atoms 2..5 (CH-H) → product atoms 1..4 (CH3 Hs + the other H of H2) + Map number 2 should therefore label an H on both sides, but in + DIFFERENT bond environments — bonded to C on reactant, bonded to H + on product (or vice versa for the swapped index). + """ + name, rxn, am = self.fixtures[0] # h_abstraction_ch4 + if am is None: + self.skipTest('atom_map unavailable') + r_smi = build_atom_mapped_smiles(rxn, side='reactants') + p_smi = build_atom_mapped_smiles(rxn, side='products') + params = Chem.SmilesParserParams(); params.removeHs = False + r_mol = Chem.MolFromSmiles(r_smi, params) + p_mol = Chem.MolFromSmiles(p_smi, params) + + def neighbors_by_map(mol, map_num): + for a in mol.GetAtoms(): + if a.GetAtomMapNum() == map_num: + return sorted(n.GetSymbol() for n in a.GetNeighbors()) + self.fail(f'no atom with map={map_num}') + + # Map 1 is the carbon: bonded to 4 H on reactant (CH4), 3 H on product (CH3). + self.assertEqual(neighbors_by_map(r_mol, 1), ['H', 'H', 'H', 'H']) + self.assertEqual(neighbors_by_map(p_mol, 1), ['H', 'H', 'H']) + + def test_radical_electrons_total_is_preserved_through_round_trip(self): + """`SetNumRadicalElectrons` is in the build path; this test trips + if RDKit silently re-perceives radicals during sanitization.""" + for name, rxn, am in self.fixtures: + with self.subTest(fixture=name): + if am is None: + self.skipTest(f'{name}: atom_map unavailable') + # Sum of radical electrons on the reactant side from ARC's + # ARCSpecies.mol — that's what the build sees. + r_smi = build_atom_mapped_smiles(rxn, side='reactants') + params = Chem.SmilesParserParams(); params.removeHs = False + r_mol = Chem.MolFromSmiles(r_smi, params) + # Total spin = sum(num_radical_electrons) should be at least 1 + # for every fixture above (every one has a radical somewhere). + total_rad = sum(a.GetNumRadicalElectrons() for a in r_mol.GetAtoms()) + self.assertGreater(total_rad, 0, + f'{name}: round-tripped reactant SMILES has zero radical ' + f'electrons; SetNumRadicalElectrons was lost') + + def test_smiles_round_trips_with_strict_sanitization(self): + """The SMILES we hand to GoFlow must parse cleanly with default + sanitization (which is what goflow.preprocessing uses internally). + A SMILES that only parses with sanitize=False would silently + generate a bad PyG graph downstream.""" + for name, rxn, am in self.fixtures: + with self.subTest(fixture=name): + if am is None: + self.skipTest(f'{name}: atom_map unavailable') + for side in ('reactants', 'products'): + smi = build_atom_mapped_smiles(rxn, side=side) + self.assertIsNotNone(smi, f'{name}/{side}: build returned None') + # Default sanitization, removeHs=False (matches goflow's parser). + params = Chem.SmilesParserParams(); params.removeHs = False + mol = Chem.MolFromSmiles(smi, params) + self.assertIsNotNone(mol, f'{name}/{side}: SMILES failed strict parse: {smi!r}') + + def test_returns_none_when_atom_map_is_incomplete(self): + """If atom_map has a hole (missing index), the inversion loop on the + product side raises ValueError → function returns None gracefully.""" + rxn = ARCReaction(r_species=[ARCSpecies(label='CH4', smiles='C'), + ARCSpecies(label='H', smiles='[H]')], + p_species=[ARCSpecies(label='CH3', smiles='[CH3]'), + ARCSpecies(label='H2', smiles='[H][H]')]) + with unittest.mock.patch.object(type(rxn), 'atom_map', + new_callable=unittest.mock.PropertyMock, + return_value=[0, 5, 1, 2, 3, 99]): # 4 missing, 99 stray + self.assertIsNone(build_atom_mapped_smiles(rxn, side='products')) + + +class TestGoFlowAdapterInstantiation(unittest.TestCase): + """A bare adapter instance with `testing=True` does no I/O and is happy.""" + + def setUp(self): + self.tmpdir = tempfile.mkdtemp(prefix='goflow_proj_') + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_adapter_constructs_in_testing_mode(self): + rxn = ARCReaction(r_species=[ARCSpecies(label='nC3H7', smiles='[CH2]CC')], + p_species=[ARCSpecies(label='iC3H7', smiles='C[CH]C')]) + adapter = GoFlowAdapter(project='goflow_test', + project_directory=self.tmpdir, + job_type='tsg', + reactions=[rxn], + testing=True) + self.assertEqual(adapter.job_adapter, 'goflow') + self.assertEqual(adapter.command, 'goflow_script.py') + + +class TestExecuteIncoreWithMockedSubprocess(unittest.TestCase): + """ + Verifies the adapter's execute_incore lifecycle without touching the real + goflow_env: monkeypatches `_goflow_environment_ready → True`, mocks + `subprocess.run` to write a stub `output.yml`, and asserts that: + - `input.yml` was written with all required keys + - the stub TSGuess made it into `rxn.ts_species.ts_guesses` + - reactant.xyz and product.xyz exist on disk + """ + + def setUp(self): + self.tmpdir = tempfile.mkdtemp(prefix='goflow_e2e_') + self.rxn = ARCReaction(r_species=[ARCSpecies(label='nC3H7', smiles='[CH2]CC')], + p_species=[ARCSpecies(label='iC3H7', smiles='C[CH]C')]) + + def tearDown(self): + shutil.rmtree(self.tmpdir, ignore_errors=True) + + def test_skips_cleanly_when_environment_not_ready(self): + adapter = GoFlowAdapter(project='goflow_test', + project_directory=self.tmpdir, + job_type='tsg', + reactions=[self.rxn], + testing=False) + with unittest.mock.patch('arc.job.adapters.ts.goflow_ts._goflow_environment_ready', return_value=False): + # Should not raise. + adapter.execute_incore() + # No TSGuesses should have been added. + if self.rxn.ts_species is not None: + self.assertEqual(len(self.rxn.ts_species.ts_guesses), 0) + + def test_writes_input_yml_and_ingests_tsg_when_subprocess_mocked(self): + + adapter = GoFlowAdapter(project='goflow_test', + project_directory=self.tmpdir, + job_type='tsg', + reactions=[self.rxn], + testing=False) + + # 5-atom toy TS guess (CH3 with one bond stretched). + stub_xyz = ('C 0.0 0.0 0.0\nH 0.0 0.0 1.6\nH 0.0 1.0 -0.4\n' + 'H 0.9 -0.5 -0.4\nH -0.9 -0.5 -0.4') + stub_tsgs = [{'method': 'GoFlow', 'method_direction': 'F', 'method_index': 0, + 'success': True, 'initial_xyz': stub_xyz, 'execution_time': '0:00:00.001'}] + + def fake_subprocess_run(cmd, **kwargs): + save_yaml_file(path=adapter.yml_out_path, content=stub_tsgs) + return unittest.mock.Mock(returncode=0) + + with unittest.mock.patch('arc.job.adapters.ts.goflow_ts._goflow_environment_ready', return_value=True), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts._within_goflow_supported_domain', return_value=(True, '')), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_PYTHON', '/usr/bin/python'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_REPO_PATH', self.tmpdir), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_CKPT_PATH', '/dev/null'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_FEAT_DICT_PATH', '/dev/null'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.subprocess.run', side_effect=fake_subprocess_run): + adapter.execute_incore() + + # input.yml exists and has the right keys. + self.assertTrue(os.path.isfile(adapter.yml_in_path)) + input_dict = read_yaml_file(adapter.yml_in_path) + for key in ('reactant_xyz_path', 'product_xyz_path', + 'reactant_smiles', 'product_smiles', + 'goflow_repo_path', 'ckpt_path', 'feat_dict_path', + 'output_xyz_path', 'yml_out_path', + 'n_samples', 'num_steps', 'device'): + self.assertIn(key, input_dict) + + # reactant.xyz and product.xyz exist. + self.assertTrue(os.path.isfile(adapter.reactant_xyz_path)) + self.assertTrue(os.path.isfile(adapter.product_xyz_path)) + + # The stub TSGuess made it into ts_species. + self.assertEqual(len(self.rxn.ts_species.ts_guesses), 1) + self.assertIn('goflow', self.rxn.ts_species.ts_guesses[0].method.lower()) + + def test_iterates_over_every_reaction_in_self_reactions(self): + """When the adapter is constructed with multiple reactions, each one + must produce a TSGuess — not just self.reactions[0]. This guards + against a regression where execute_goflow forgets to loop.""" + rxn1 = ARCReaction(r_species=[ARCSpecies(label='nC3H7', smiles='[CH2]CC')], + p_species=[ARCSpecies(label='iC3H7', smiles='C[CH]C')]) + rxn2 = ARCReaction(r_species=[ARCSpecies(label='nC4H9', smiles='[CH2]CCC')], + p_species=[ARCSpecies(label='sC4H9', smiles='C[CH]CC')]) + # The mocked subprocess writes a 5-atom stub xyz; the adapter + # creates rxn.ts_species lazily and appends it. Don't pre-seed + # ts_species — that would trigger ARC's atom-balance check against + # the placeholder smiles. + ts_xyz = ('C 0.0 0.0 0.0\nH 0.0 0.0 1.6\nH 0.0 1.0 -0.4\n' + 'H 0.9 -0.5 -0.4\nH -0.9 -0.5 -0.4') + + adapter = GoFlowAdapter(project='goflow_multirxn', + project_directory=self.tmpdir, + job_type='tsg', + reactions=[rxn1, rxn2], + testing=False) + + def fake_subprocess_run(cmd, **kwargs): + save_yaml_file(path=adapter.yml_out_path, content=[{ + 'method': 'GoFlow', 'method_direction': 'F', 'method_index': 0, + 'success': True, 'initial_xyz': ts_xyz, + 'execution_time': '0:00:00.001'}]) + return unittest.mock.Mock(returncode=0) + + with unittest.mock.patch('arc.job.adapters.ts.goflow_ts._goflow_environment_ready', return_value=True), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts._within_goflow_supported_domain', return_value=(True, '')), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_PYTHON', '/usr/bin/python'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_REPO_PATH', self.tmpdir), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_CKPT_PATH', '/dev/null'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_FEAT_DICT_PATH', '/dev/null'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.subprocess.run', side_effect=fake_subprocess_run): + adapter.execute_incore() + + self.assertIsNotNone(rxn1.ts_species, 'first rxn ts_species not created') + self.assertIsNotNone(rxn2.ts_species, 'second rxn ts_species not created') + self.assertEqual(len(rxn1.ts_species.ts_guesses), 1, 'first rxn missing TSGuess') + self.assertEqual(len(rxn2.ts_species.ts_guesses), 1, + 'second rxn missing TSGuess — execute_goflow likely only processes self.reactions[0]') + + def test_subprocess_timeout_logged_and_skipped_gracefully(self): + """When the subprocess hangs past the timeout, the adapter must log + a warning and continue, not propagate the TimeoutExpired.""" + adapter = GoFlowAdapter(project='goflow_timeout', + project_directory=self.tmpdir, + job_type='tsg', + reactions=[self.rxn], + testing=False) + adapter.goflow_subprocess_timeout = 1 # second + + def hanging_subprocess_run(cmd, **kwargs): + raise subprocess.TimeoutExpired(cmd=cmd, timeout=kwargs.get('timeout', 1)) + + with unittest.mock.patch('arc.job.adapters.ts.goflow_ts._goflow_environment_ready', return_value=True), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts._within_goflow_supported_domain', return_value=(True, '')), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_PYTHON', '/usr/bin/python'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_REPO_PATH', self.tmpdir), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_CKPT_PATH', '/dev/null'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.GOFLOW_FEAT_DICT_PATH', '/dev/null'), \ + unittest.mock.patch('arc.job.adapters.ts.goflow_ts.subprocess.run', side_effect=hanging_subprocess_run): + # Must not raise — adapter swallows TimeoutExpired and continues. + adapter.execute_incore() + # No TSGuess appended (subprocess never wrote output.yml). + self.assertEqual(len(self.rxn.ts_species.ts_guesses), 0) + + +############################################################################### +# Tier-2 — env-gated end-to-end tests +############################################################################### +# +# These tests exercise the full GoFlow inference pipeline by spawning the +# real `goflow_env` subprocess against a real reaction. They self-skip on +# any host where `_goflow_environment_ready()` returns False — i.e. when +# python/repo/ckpt/feat_dict aren't all present and valid (the 45-byte LFS +# placeholder is rejected by the size guard in the settings layer). +# +# To run locally: +# bash devtools/install_goflow.sh --no-ckpt-check +# export ARC_GOFLOW_CKPT=/path/to/your/epoch_.ckpt +# pytest arc/job/adapters/ts/goflow_test.py::TestGoFlowEndToEnd -v + + +def _refresh_goflow_paths_and_check_ready() -> bool: + """ + Tier-2 gating predicate. Called from each Tier-2 class's ``setUpClass`` + (NOT at module-import time) so that user-set ``ARC_GOFLOW_*`` env vars + take effect even when set after pytest collection. + + ``goflow_ts`` caches its module-level ``GOFLOW_*`` globals at import time + via ``settings.get(...)``. ``_goflow_environment_ready()`` reads those + cached globals — refreshing only ``settings_mod.GOFLOW_*`` is not enough. + Both the settings module and ``goflow_ts``'s own globals must be + re-bound from a fresh discovery for the readiness check to see new + values.""" + repo = goflow_paths.find_goflow_repo() + ckpt = goflow_paths.find_goflow_ckpt(repo) + feat = goflow_paths.find_goflow_feat_dict(repo) + settings_mod.GOFLOW_REPO_PATH = repo + settings_mod.GOFLOW_CKPT_PATH = ckpt + settings_mod.GOFLOW_FEAT_DICT_PATH = feat + goflow_ts_mod = sys.modules['arc.job.adapters.ts.goflow_ts'] + goflow_ts_mod.GOFLOW_REPO_PATH = repo + goflow_ts_mod.GOFLOW_CKPT_PATH = ckpt + goflow_ts_mod.GOFLOW_FEAT_DICT_PATH = feat + return _goflow_environment_ready() + + +_TIER2_SKIP_MSG = ('goflow_env or real ckpt not available — ' + 'set ARC_GOFLOW_CKPT or run devtools/install_goflow.sh') + + +class TestGoFlowRealCheckpointStrictLoad(unittest.TestCase): + """ + Acid test for "is this a real checkpoint, not a 45-byte placeholder?" + Instantiate FlowModule via the same Hydra recipe the adapter uses, then + require `load_state_dict(strict=True)` succeeds with zero missing/unexpected + keys. A placeholder ckpt fails this immediately. + """ + + @classmethod + def setUpClass(cls): + if not _refresh_goflow_paths_and_check_ready(): + raise unittest.SkipTest(_TIER2_SKIP_MSG) + + def test_strict_load_succeeds_against_paper_equivalent_ckpt(self): + + # We must run the strict-load probe inside goflow_env (where torch + + # goflow are importable), not in arc_env. Spawn a tiny subprocess. + probe = ('import sys, pickle, torch\n' + 'from hydra import initialize_config_dir, compose\n' + 'from hydra.utils import instantiate\n' + 'ckpt = sys.argv[1]; feat = sys.argv[2]; cfg_dir = sys.argv[3]\n' + 'with open(feat, "rb") as f: fd = pickle.load(f)\n' + 'feat_dim = sum(len(v) for v in fd.values())\n' + 'with initialize_config_dir(config_dir=cfg_dir, version_base="1.3"):\n' + ' cfg = compose(config_name="train", overrides=[' + '"model=flow", "data=rdb7", ' + 'f"model.representation.n_atom_rdkit_feats={feat_dim}", ' + '"model.num_samples=1", "model.num_steps=5", ' + '"model.sample_method=gaussian"])\n' + 'fm = instantiate(cfg.model)\n' + 'obj = torch.load(ckpt, map_location="cpu", weights_only=False)\n' + 'res = fm.load_state_dict(obj["state_dict"], strict=True)\n' + 'import json; print(json.dumps({"missing": list(res.missing_keys), ' + '"unexpected": list(res.unexpected_keys), "feat_dim": feat_dim}))\n') + cfg_dir = os.path.join(settings_mod.GOFLOW_REPO_PATH, 'src', 'goflow', 'configs') + result = subprocess.run([settings_mod.GOFLOW_PYTHON, '-c', probe, + settings_mod.GOFLOW_CKPT_PATH, settings_mod.GOFLOW_FEAT_DICT_PATH, cfg_dir], + capture_output=True, text=True, timeout=120) + if result.returncode != 0: + self.fail(f'strict-load probe failed:\nstdout={result.stdout}\nstderr={result.stderr}') + report = json.loads(result.stdout.strip().splitlines()[-1]) + self.assertEqual(report['missing'], [], f'state_dict has missing keys: {report["missing"]}') + self.assertEqual(report['unexpected'], [], f'state_dict has unexpected keys: {report["unexpected"]}') + self.assertEqual(report['feat_dim'], 36, f'feat_dim derived from feat_dict_organic.pkl should be 36 ' + f'for the published RDB7 dictionary, got {report["feat_dim"]}') + + +class TestGoFlowEndToEnd(unittest.TestCase): + """End-to-end: real subprocess, real ckpt, real ARCReaction → real TSGuesses.""" + + @classmethod + def setUpClass(cls): + if not _refresh_goflow_paths_and_check_ready(): + raise unittest.SkipTest(_TIER2_SKIP_MSG) + cls.tmpdir = tempfile.mkdtemp(prefix='goflow_tier2_') + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir, ignore_errors=True) + + def _build_h_abstraction_rxn(self): + """H + CH4 ↔ CH3 + H2 — the canonical small H-abstraction in RDB7's + chemistry domain. Six atoms, all H/C; in-domain for GoFlow.""" + ch4 = ARCSpecies(label='CH4', smiles='C', xyz=( + 'C 0.00000000 0.00000000 0.00000000\n' + 'H 0.62911800 0.62911800 0.62911800\n' + 'H -0.62911800 -0.62911800 0.62911800\n' + 'H -0.62911800 0.62911800 -0.62911800\n' + 'H 0.62911800 -0.62911800 -0.62911800')) + h_atom = ARCSpecies(label='H', smiles='[H]', xyz='H 0.0 0.0 0.0') + ch3 = ARCSpecies(label='CH3', smiles='[CH3]', xyz=( + 'C 0.00000 0.00000 0.00000\n' + 'H 1.07770 0.00000 0.00000\n' + 'H -0.53885 0.93333 0.00000\n' + 'H -0.53885 -0.93333 0.00000')) + h2 = ARCSpecies(label='H2', smiles='[H][H]', + xyz='H 0.0 0.0 0.0\nH 0.0 0.0 0.74') + rxn = ARCReaction(r_species=[ch4, h_atom], p_species=[ch3, h2]) + # TS species needs a 6-atom placeholder so ARC's atom-balance check + # accepts the reaction. The actual coordinates are irrelevant — the + # adapter overwrites ts_guesses entirely with its own samples. + ts_placeholder_xyz = ('C 0.0 0.0 0.0\n' + 'H 0.6 0.6 0.6\n' + 'H -0.6 -0.6 0.6\n' + 'H -0.6 0.6 -0.6\n' + 'H 0.6 -0.6 -0.6\n' + 'H 1.5 1.5 1.5') + ts = ARCSpecies(label='TS_h_abstr', is_ts=True, charge=0, multiplicity=2, xyz=ts_placeholder_xyz) + ts.ts_guesses = [] + rxn.ts_species = ts + return rxn + + def test_h_abstraction_produces_geometrically_valid_tsguesses(self): + + rxn = self._build_h_abstraction_rxn() + + adapter = GoFlowAdapter(project='goflow_tier2_e2e', + project_directory=self.tmpdir, + job_type='tsg', + reactions=[rxn], + testing=False) + adapter.execute_incore() + + guesses = rxn.ts_species.ts_guesses + self.assertGreater(len(guesses), 0, 'GoFlow produced no TSGuesses for an in-domain reaction') + + n_atoms_expected = sum(spc.number_of_atoms for spc in rxn.r_species) + + for i, g in enumerate(guesses): + with self.subTest(guess_idx=i): + self.assertTrue(g.success, f'guess {i} marked unsuccessful') + self.assertIsNotNone(g.initial_xyz) + xyz = (str_to_xyz(g.initial_xyz) + if isinstance(g.initial_xyz, str) else g.initial_xyz) + + self.assertEqual(len(xyz['symbols']), n_atoms_expected, + f'guess {i}: expected {n_atoms_expected} atoms, ' + f'got {len(xyz["symbols"])}') + + # Element ordering matches the mapped reactant ordering. For + # CH4 + H, ARC's canonicalization puts the carbon first. + self.assertEqual(xyz['symbols'][0], 'C', f'guess {i}: first atom should be C') + self.assertEqual(sorted(xyz['symbols']), sorted(('C', 'H', 'H', 'H', 'H', 'H'))) + + flat = [c for tup in xyz['coords'] for c in tup] + self.assertFalse(any(math.isnan(c) or math.isinf(c) for c in flat), + f'guess {i}: NaN/inf in coordinates') + + # Not collapsed to all-zero or near-zero (would indicate + # the model output a degenerate geometry). + norm_sq = sum(c * c for c in flat) + self.assertGreater(norm_sq, 1.0, f'guess {i}: geometry collapsed (norm² = {norm_sq})') + + # No two atoms occupy the same position (collision). + for a in range(len(xyz['symbols'])): + for b in range(a + 1, len(xyz['symbols'])): + d = math.sqrt(sum( + (xyz['coords'][a][k] - xyz['coords'][b][k]) ** 2 + for k in range(3))) + self.assertGreater(d, 0.3, f'guess {i}: atoms {a} and {b} colliding (d={d:.3f} Å)') + + +############################################################################### +# Tier-2 — isomerization fixtures (manual-inspection driver) +############################################################################### +# +# 10 isomerization reactions adapted from the linear-adapter test suite +# (arc/job/adapters/ts/linear_test.py — see ARC main). Each fixture stays +# within GoFlow's HCNOF training domain. The tests here drive +# GoFlowAdapter.execute_incore() against a real ckpt and PRINT every +# surviving TSGuess as a plain XYZ string to stdout — for the user to +# eyeball geometric sanity. Light structural assertions only (atom count, +# no NaN/inf, no atom collisions). +# +# Run with `pytest -s` to see the printed XYZs. + + +class TestGoFlowIsomerizationFixtures(unittest.TestCase): + """One method per isomerization reaction; prints TS XYZs for manual review.""" + + @classmethod + def setUpClass(cls): + if not _refresh_goflow_paths_and_check_ready(): + raise unittest.SkipTest(_TIER2_SKIP_MSG) + cls.tmpdir_root = tempfile.mkdtemp(prefix='goflow_isom_') + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmpdir_root, ignore_errors=True) + + @staticmethod + def _heavy_dmat(coords, symbols): + """Pairwise distance matrix over heavy atoms only (rotor-invariant).""" + heavy = [i for i, s in enumerate(symbols) if s != 'H'] + if len(heavy) < 2: + heavy = list(range(len(symbols))) + return [math.sqrt(sum((coords[a][k] - coords[b][k]) ** 2 for k in range(3))) + for ai, a in enumerate(heavy) for b in heavy[ai + 1:]] + + @classmethod + def _dmat_rmsd(cls, c1, c2, symbols): + """Heavy-atom distance-matrix RMSD between two geometries (Å).""" + d1 = cls._heavy_dmat(c1, symbols) + d2 = cls._heavy_dmat(c2, symbols) + return math.sqrt(sum((a - b) ** 2 for a, b in zip(d1, d2)) / len(d1)) + + def _run_isomerization(self, name, r_smiles, r_xyz, p_smiles, p_xyz): + """Build the reaction, run GoFlow, print all surviving TSGuesses.""" + + r = ARCSpecies(label='R', smiles=r_smiles, xyz=r_xyz) + p = ARCSpecies(label='P', smiles=p_smiles, xyz=p_xyz) + rxn = ARCReaction(r_species=[r], p_species=[p]) + + # Skip cleanly if ARC's mapping engine can't handle this reaction — + # adapter would then return zero guesses (correct behavior) but that + # isn't a model-quality failure, so don't fail the test. + try: + am = rxn.atom_map + except Exception: + am = None + if am is None: + self.skipTest(f'{name}: ARC atom-map computation failed; adapter would skip') + + # TS placeholder must have the right atom count for atom-balance. + # Reactant XYZ has the right atomic composition; reuse it. + ts = ARCSpecies(label=f'TS_{name}', is_ts=True, charge=r.charge, + multiplicity=r.multiplicity, xyz=r_xyz) + ts.ts_guesses = [] + rxn.ts_species = ts + + proj_dir = tempfile.mkdtemp(prefix=f'{name}_', dir=self.tmpdir_root) + adapter = GoFlowAdapter(project=f'goflow_isom_{name}', + project_directory=proj_dir, + job_type='tsg', + reactions=[rxn], + testing=False) + adapter.execute_incore() + + guesses = rxn.ts_species.ts_guesses + n_atoms_expected = sum(spc.number_of_atoms for spc in rxn.r_species) + + print() + print(f'==== {name} ====') + print(f' reaction : {r_smiles} -> {p_smiles}') + print(f' expected atoms : {n_atoms_expected}') + print(f' guesses (after dedup): {len(guesses)}') + + self.assertGreater(len(guesses), 0, f'{name}: GoFlow produced no surviving TSGuess') + + for i, g in enumerate(guesses): + self.assertTrue(g.success, f'{name} guess {i}: marked unsuccessful') + xyz = (str_to_xyz(g.initial_xyz) if isinstance(g.initial_xyz, str) else g.initial_xyz) + + self.assertEqual(len(xyz['symbols']), n_atoms_expected, + f'{name} guess {i}: expected {n_atoms_expected} atoms, ' + f'got {len(xyz["symbols"])}') + + flat = [c for tup in xyz['coords'] for c in tup] + self.assertFalse(any(math.isnan(c) or math.isinf(c) for c in flat), + f'{name} guess {i}: NaN/inf coordinates') + for a in range(len(xyz['symbols'])): + for b in range(a + 1, len(xyz['symbols'])): + d = math.sqrt(sum((xyz['coords'][a][k] - xyz['coords'][b][k]) ** 2 for k in range(3))) + self.assertGreater(d, 0.3, f'{name} guess {i}: atoms {a} and {b} colliding (d={d:.3f} Å)') + + print(f' --- TS guess {i} ---') + print(xyz_to_str(xyz)) + + # ----- uniqueness diagnostics + assertion ----- + # Parse all guesses once. + parsed = [(str_to_xyz(g.initial_xyz) if isinstance(g.initial_xyz, str) else g.initial_xyz) for g in guesses] + if len(parsed) >= 2: + # Pairwise HEAVY-atom dmat-RMSD (matches the metric the adapter + # uses for consolidation: rotation + torsion invariant). + print(f' --- pairwise heavy-atom dmat-RMSD (Å), upper triangle ---') + header = ' ' + ' '.join(f' g{j}' for j in range(len(parsed))) + print(header) + min_rmsd = float('inf') + for i in range(len(parsed)): + row = [f' g{i}: '] + for j in range(len(parsed)): + if j <= i: + row.append(' ') + else: + rmsd = self._dmat_rmsd(parsed[i]['coords'], + parsed[j]['coords'], + parsed[i]['symbols']) + row.append(f'{rmsd:5.3f}') + min_rmsd = min(min_rmsd, rmsd) + print(' '.join(row)) + print(f' --- min pairwise heavy-atom dmat-RMSD: {min_rmsd:.3f} Å ---') + + # Hard uniqueness assertion: every surviving pair must have + # heavy-atom dmat-RMSD >= GOFLOW_DEDUP_DMAT_RMSD; else dedup leaked. + for i in range(len(parsed)): + for j in range(i + 1, len(parsed)): + rmsd = self._dmat_rmsd(parsed[i]['coords'], + parsed[j]['coords'], + parsed[i]['symbols']) + self.assertGreaterEqual( + rmsd, GOFLOW_DEDUP_DMAT_RMSD, + f'{name} guesses {i} and {j} are too similar: dmat-RMSD ' + f'= {rmsd:.3f} Å < {GOFLOW_DEDUP_DMAT_RMSD} Å threshold ' + f'(dedup pass missed them)') + + # -------------------------- the 10 fixtures -------------------------- + + def test_intra_h_migration_cco(self): # V + """4-membered ring H migration: [CH2]CO <=> CC[O]""" + r_xyz = """C -3.35807020 0.39772754 -0.02139706 +H -2.80953191 0.44242278 -0.93900704 +H -4.34767471 -0.00900040 -0.00893508 +C -2.72326461 0.91878394 1.28133933 +H -1.66157493 0.79378755 1.23561273 +H -2.95540282 1.95641525 1.40106030 +O -3.24245519 0.18293235 2.39213346 +H -2.84673223 0.50774673 3.20422887""" + p_xyz = """C -0.34334771 -0.13590857 0.00000002 +H 0.01333400 0.36848377 -0.87365124 +H -1.41334771 -0.13588640 -0.00000560 +C 0.16999407 0.59004821 1.25740487 +H 1.23999407 0.59002603 1.25741049 +H -0.18665169 1.59886128 1.25739942 +O -0.30669270 -0.08404623 2.42499487 +H 0.01329805 -1.14472164 0.00000547""" + self._run_isomerization('intra_h_migration_cco', '[CH2]CO', r_xyz, 'CC[O]', p_xyz) + + def test_intra_h_migration_ccoo(self): + """5-membered ring H migration: CCO[O] <=> [CH2]COO""" + r_xyz = """C -1.05582103 -0.03329574 -0.10080257 +C 0.41792695 0.17831205 0.21035514 +O 1.19234020 -0.65389683 -0.61111443 +O 2.44749684 -0.41401220 -0.28381363 +H -1.33614002 -1.09151783 0.08714882 +H -1.25953618 0.21489046 -1.16411897 +H -1.67410396 0.62341419 0.54699514 +H 0.59566350 -0.06437686 1.28256640 +H 0.67254676 1.24676329 0.02676370""" + p_xyz = """C -1.40886397 0.22567351 -0.37379668 +C 0.06280787 0.04097694 -0.38515682 +O 0.44130326 -0.57668419 0.84260864 +O 1.89519755 -0.66754203 0.80966180 +H -1.87218376 0.90693511 -1.07582340 +H -2.03646287 -0.44342165 0.20255768 +H 0.35571681 -0.60165457 -1.22096147 +H 0.56095122 1.01161503 -0.47393734 +H 2.05354047 -0.10415729 1.58865243""" + self._run_isomerization('intra_h_migration_ccoo', 'CCO[O]', r_xyz, '[CH2]COO', p_xyz) + + def test_intra_h_migration_cccoo(self): + """6-membered ring H migration: CCCO[O] <=> [CH2]CCOO""" + r_xyz = """C -1.31455963 0.65305704 0.00229593 +C 0.17407454 0.87684185 0.32708610 +O 0.97540012 0.03343074 -0.50443961 +O 2.25137227 0.22524629 -0.22604804 +H -1.56888362 -0.37060266 0.18212958 +H -1.49495314 0.89014604 -1.02539419 +H 0.35446804 0.63975284 1.35477623 +H 0.42839853 1.90050154 0.14725245 +C -2.17752564 1.56134592 0.89778516 +H -3.21183640 1.40585907 0.67211926 +H -1.99713214 1.32425692 1.92547529 +H -1.92320166 2.58500562 0.71795151""" + p_xyz = """C 0.10191448 0.80917231 0.12324900 +C 1.63680299 0.68488584 0.13968460 +O 2.03194937 -0.20270773 1.18894005 +O 3.34756810 -0.30923899 1.20302771 +H -0.33221037 -0.15465524 -0.04249800 +H 1.97345768 0.29884684 -0.79975007 +H 2.07092784 1.64871339 0.30543160 +H 3.73706329 0.55550348 1.35173530 +H -0.23474021 1.19521131 1.06268367 +C -0.32362778 1.76504231 -1.00671841 +H -1.26726146 1.63176387 -1.49322877 +H 0.32433693 2.56246418 -1.30531527""" + self._run_isomerization('intra_h_migration_cccoo', 'CCCO[O]', r_xyz, '[CH2]CCOO', p_xyz) + + def test_intra_oh_migration(self): + """OH migration: [CH2]COO <=> [O]CCO""" + r_xyz = """C -1.40886397 0.22567351 -0.37379668 +C 0.06280787 0.04097694 -0.38515682 +O 0.44130326 -0.57668419 0.84260864 +O 1.89519755 -0.66754203 0.80966180 +H -1.87218376 0.90693511 -1.07582340 +H -2.03646287 -0.44342165 0.20255768 +H 0.35571681 -0.60165457 -1.22096147 +H 0.56095122 1.01161503 -0.47393734 +H 2.05354047 -0.10415729 1.58865243""" + p_xyz = """O 0.97298522 1.16961708 0.68631092 +C 0.83017736 0.23002128 -0.24518707 +C -0.46505265 -0.55857538 0.09146589 +O -1.54540067 0.36524471 0.24441655 +H 1.61381747 -0.53531530 -0.35348282 +H 0.69744639 0.56361493 -1.28695526 +H -0.71560487 -1.25802813 -0.71249310 +H -0.36288272 -1.12613201 1.02419042 +H -1.03086141 1.13813060 0.58426610""" + self._run_isomerization('intra_oh_migration', '[CH2]COO', r_xyz, '[O]CCO', p_xyz) + + def test_intra_halogen_migration(self): + """Fluorine 1,4-shift: FCCC[C](F)F <=> [CH2]CCC(F)(F)F""" + r_xyz = """F 1.93592759 -1.04813200 0.17239309 +C 1.41395997 -0.06443750 -0.60748935 +C 0.46854139 0.77821484 0.23269059 +C 1.16469946 1.45000317 1.41577429 +C 2.13600384 2.49526387 0.98914077 +F 1.69221606 3.70990602 0.60332208 +F 3.45162393 2.20655153 0.91224277 +H 2.23977740 0.51935595 -1.02311040 +H 0.87599990 -0.54232434 -1.43132912 +H -0.01588539 1.53022886 -0.40118094 +H -0.31963114 0.12637206 0.62794629 +H 0.40903520 1.92224463 2.05360591 +H 1.67965177 0.70327850 2.03007255""" + p_xyz = """C -2.10258623 0.28609914 -0.11161659 +C -0.80850454 -0.44729615 0.01949484 +C -0.27209648 -0.40163127 1.44584029 +C 1.03111915 -1.15786446 1.56235292 +F 1.97934384 -0.63177629 0.75822896 +F 0.87880578 -2.45776869 1.23195390 +F 1.49664262 -1.10927826 2.83007421 +H -2.25664107 1.23858311 0.38402441 +H -2.81716662 -0.01824459 -0.86926845 +H -0.96292814 -1.48784377 -0.28803477 +H -0.08395313 -0.00553132 -0.67357116 +H -1.00377942 -0.83558539 2.13782580 +H -0.11333646 0.63795904 1.75659256""" + self._run_isomerization('intra_halogen_migration', 'FCCC[C](F)F', r_xyz, '[CH2]CCC(F)(F)F', p_xyz) + + def test_intra_no2_ono_conversion(self): + """NO2 ↔ ONO rearrangement: [O-][N+](=O)CC <=> CCON=O""" + r_xyz = """O 1.77136558 -0.91790626 0.88650594 +N 1.34754589 -0.18857388 -0.01862669 +O 1.86645005 -0.03906737 -1.13182045 +C 0.08946605 0.57559465 0.25484606 +C 0.46072863 1.91146690 0.86342166 +H -0.52075344 -0.02737899 0.93392769 +H -0.43797095 0.69242674 -0.69660400 +H 1.09014915 2.48001164 0.17179384 +H -0.42932512 2.51112436 1.08295532 +H 1.01533324 1.78326517 1.79934783""" + p_xyz = """C -1.36894499 0.07118059 -0.24801399 +C -0.01369535 0.17184136 0.42591278 +O -0.03967083 -0.62462610 1.60609048 +N 1.23538512 -0.53558048 2.24863846 +O 1.25629155 -1.21389295 3.27993827 +H -2.16063255 0.41812452 0.42429392 +H -1.39509985 0.66980796 -1.16284741 +H -1.59800183 -0.96960842 -0.49986392 +H 0.19191326 1.21800574 0.68271847 +H 0.76371340 -0.19234475 -0.25650067""" + self._run_isomerization('intra_no2_ono_conversion', '[O-][N+](=O)CC', r_xyz, 'CCON=O', p_xyz) + + def test_1_5_h_shift_pentadiene(self): + """Degenerate sigmatropic 1,5-H shift in penta-1,3-diene: CC=CC=C <=> CC=CC=C""" + xyz = """C 2.6362 0.0000 0.0000 +C 1.3442 0.6930 0.0000 +C 0.0000 0.0000 0.0000 +C -1.3442 0.6930 0.0000 +C -2.6362 0.0000 0.0000 +H 2.5820 -0.6289 0.8928 +H 2.5820 -0.6289 -0.8928 +H 3.6014 0.5018 0.0000 +H 1.3970 1.7729 0.0000 +H 0.0000 -1.0847 0.0000 +H -1.3970 1.7729 0.0000 +H -3.6014 0.5018 0.0000 +H -2.6362 -1.0847 0.0000""" + self._run_isomerization('1_5_h_shift_pentadiene', 'CC=CC=C', xyz, 'CC=CC=C', xyz) + + def test_6_mem_central_cc_shift_alkyne_to_allene(self): + """6-membered central C-C shift: C#CCCC#C <=> C=C=CC=C=C""" + r_xyz = """C 3.03272979 -0.11060195 -0.24229461 +C 1.85599055 -0.34675713 -0.20247149 +C 0.41485966 -0.64142590 -0.15352412 +C -0.41485965 0.64142578 -0.17240633 +C -1.85599061 0.34675702 -0.12346178 +C -3.03272995 0.11060190 -0.08364096 +H 4.07762286 0.09693448 -0.27758589 +H 0.19106566 -1.21954180 0.75163518 +H 0.14301783 -1.27648597 -1.00582442 +H -0.19106412 1.21954271 -1.07756459 +H -0.14301928 1.27648492 0.67989514 +H -4.07762310 -0.09693448 -0.04835177""" + p_xyz = """C -3.03124363 0.21595810 -0.01068883 +C -1.77136356 -0.00875193 -0.22839960 +C -0.51035344 -0.23538255 -0.44913569 +C 0.51035356 0.23538291 0.44913621 +C 1.77136365 0.00875234 0.22839985 +C 3.03124358 -0.21595777 0.01068824 +H -3.50880107 1.10742857 -0.40051872 +H -3.62554573 -0.48341738 0.56587595 +H -0.21235801 -0.79338469 -1.33170668 +H 0.21235823 0.79338484 1.33170737 +H 3.50880076 -1.10742925 0.40051615 +H 3.62554580 0.48341866 -0.56587535""" + self._run_isomerization('6_mem_central_cc_shift', 'C#CCCC#C', r_xyz, 'C=C=CC=C=C', p_xyz) + + def test_1_3_sigmatropic_rearrangement_imidazole(self): + """1,3-sigmatropic rearrangement on imidazole: c1ncc[nH]1 <=> N=CN1C=C1""" + r_xyz = """C -0.96405208 -0.58870010 -0.35675666 +N 0.09948347 -1.35699528 -0.30406608 +C 1.08781769 -0.57088551 0.22943180 +C 0.61245126 0.68985747 0.50218591 +N -0.70083129 0.66320502 0.12207481 +H -1.93870511 -0.87854432 -0.72608823 +H 2.08729155 -0.95482079 0.38815067 +H 1.07812779 1.57128662 0.91862266 +H -1.36158329 1.42559689 0.18141711""" + p_xyz = """N 0.76582385 -0.14849540 -1.32485588 +C 0.78208226 0.49284271 -0.20399502 +N -0.04861443 0.34490826 0.88039960 +C -0.56227958 -0.84609375 1.31645778 +C -1.38522743 0.06039446 0.80970400 +H 1.52092135 0.20130809 -1.92536405 +H 1.53681129 1.27833147 -0.02452505 +H -0.33519514 -1.78256934 0.82247210 +H -1.89445111 -0.06503499 -0.13767862""" + self._run_isomerization('1_3_sigmatropic_imidazole', 'c1ncc[nH]1', r_xyz, 'N=CN1C=C1', p_xyz) + + def test_1_2_methyl_shift_on_cyclopentadienyl(self): + """1,2-methyl shift on cyclopentadienyl carbene: CC[C]1C=CC=C1 <=> [CH2]C1(C)C=CC=C1""" + r_xyz = """C -2.08011725 -0.87098529 -0.24102896 +C -1.38616808 0.31243567 0.41701874 +C 0.09289885 0.19281646 0.35695343 +C 0.92864438 0.68782411 -0.70819340 +C 2.18636908 0.35957487 -0.37255721 +C 2.18107732 -0.34427638 0.89885251 +C 0.92008966 -0.45002583 1.34718072 +H -1.81540032 -1.81207279 0.25290661 +H -1.80896484 -0.95285941 -1.29907735 +H -3.16674193 -0.75248342 -0.17993048 +H -1.70601347 1.23815887 -0.07595913 +H -1.71241303 0.38717620 1.46117876 +H 0.59841756 1.21177196 -1.58944757 +H 3.07727387 0.57336525 -0.94230522 +H 3.06757417 -0.71677188 1.38813917 +H 0.58240767 -0.91755259 2.25688920""" + p_xyz = """C -0.91419261 -0.92211886 1.28775915 +C -0.38593444 -0.06230282 0.18302891 +C 0.67135826 0.91653743 0.70043366 +C -1.50477869 0.67123329 -0.52120219 +C -1.56600546 0.29578439 -1.80779070 +C -0.54194075 -0.68042899 -2.06986329 +C 0.15393453 -0.90959193 -0.94583904 +H -1.87479029 -1.41555103 1.18169570 +H -0.24773685 -1.28139411 2.06376191 +H 1.52757979 0.38651768 1.13544157 +H 1.05855879 1.55868444 -0.10049234 +H 0.25983545 1.57373963 1.47630068 +H -2.15472843 1.39365373 -0.04968993 +H -2.26617185 0.65598910 -2.54596855 +H -0.37671673 -1.14563680 -3.02965012 +H 0.98207416 -1.59686236 -0.85449771""" + self._run_isomerization('1_2_methyl_shift_cpd', 'CC[C]1C=CC=C1', r_xyz, '[CH2]C1(C)C=CC=C1', p_xyz) + + +if __name__ == '__main__': + unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/arc/job/adapters/ts/goflow_ts.py b/arc/job/adapters/ts/goflow_ts.py new file mode 100644 index 0000000000..30a48b243a --- /dev/null +++ b/arc/job/adapters/ts/goflow_ts.py @@ -0,0 +1,691 @@ +""" +An adapter for executing GoFlow TS-guess jobs. + +GoFlow is a flow-matching, E(3)-equivariant ML model that predicts 3D +transition-state geometries directly from atom-mapped reactant + product +SMILES (with RDKit-derived atom features), without requiring an initial +guess. Like RitS it complements ARC's existing TS-search stack with a +fast ML method; unlike GCN it is not restricted to isomerizations and +unlike RitS it is conditioned on 2D reaction graphs rather than mapped +3D structures. + +References +---------- +- Galustian et al., *Digital Discovery* 2025: 10.1039/D5DD00283D +- Preprint: doi.org/10.26434/chemrxiv-2025-bk2rh +- Code (practical fork) : https://github.com/heid-lab/goflow_lean +- Code (research fork) : https://github.com/heid-lab/goflow + +Implementation notes +-------------------- +* The heavy ML stack (torch + torch-geometric + lightning + torchdiffeq) + lives in its own conda env (``goflow_env``). This adapter never imports + it directly — it shells out to ``arc/job/adapters/scripts/goflow_script.py`` + via ``subprocess.run``, which loads the pretrained checkpoint and runs + the flow-matching ODE sampler. +* GoFlow's input is **atom-mapped reactant + product SMILES** (every H must + be an explicit, mapped atom). ARC builds them in :func:`build_atom_mapped_smiles` + from each ARCSpecies' RDKit Mol, using ``rxn.atom_map`` to keep + reactant→product correspondences consistent. +* GoFlow was trained on RDB7 (small organic; H/C/N/O/F). Reactions outside + this domain are skipped cleanly with a one-line warning by + :func:`_within_goflow_supported_domain`. +* The shipped checkpoint in goflow_lean@main is a 45-byte LFS pointer + rather than a real Lightning ckpt. We validate by file size at adapter + init time, defer ``torch.load``-level validation to ``goflow_script.py``, + and skip cleanly if no real checkpoint is available — the rest of ARC's + TS-search pipeline (heuristics, GCN, AutoTST, …) keeps running. +* ``incore_capacity = 1`` so the scheduler serializes GoFlow jobs and a + single GPU is not asked to load multiple checkpoints in parallel. +""" + +import datetime +import os +import subprocess +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import numpy as np +from rdkit import Chem + +from arc.common import ARC_PATH, get_logger, save_yaml_file, read_yaml_file +from arc.imports import settings +from arc.job.adapter import JobAdapter +from arc.job.adapters.common import _initialize_adapter +from arc.job.factory import register_job_adapter +from arc.plotter import save_geo +from arc.species.converter import str_to_xyz, to_rdkit_mol, xyz_to_dmat, xyz_to_str +from arc.species.species import ARCSpecies, TSGuess, colliding_atoms + +if TYPE_CHECKING: + from arc.level import Level + from arc.reaction import ARCReaction + + +GOFLOW_PYTHON = settings.get('GOFLOW_PYTHON') +GOFLOW_REPO_PATH = settings.get('GOFLOW_REPO_PATH') +GOFLOW_CKPT_PATH = settings.get('GOFLOW_CKPT_PATH') +GOFLOW_FEAT_DICT_PATH = settings.get('GOFLOW_FEAT_DICT_PATH') + +GOFLOW_SCRIPT_PATH = os.path.join(ARC_PATH, 'arc', 'job', 'adapters', 'scripts', 'goflow_script.py') +DEFAULT_N_SAMPLES = 10 +DEFAULT_NUM_STEPS = 25 + +# Domain guard — GoFlow was trained on RDB7 (small organic). +SUPPORTED_GOFLOW_ELEMENTS = frozenset({'H', 'C', 'N', 'O', 'F'}) +MAX_GOFLOW_ATOMS = 100 +GOFLOW_DEDUP_DMAT_RMSD = 0.15 + +# File-size thresholds (mirror settings.py, used by _goflow_environment_ready). +_GOFLOW_CKPT_MIN_SIZE = 1_000_000 +_GOFLOW_FEAT_DICT_MIN_SIZE = 100 + +logger = get_logger() + + +class GoFlowAdapter(JobAdapter): + """ + A class for executing GoFlow TS-guess jobs. + + Args: + project (str): The project's name. + project_directory (str): The path to the local project directory. + job_type (list, str): The job's type, validated against ``JobTypeEnum``. + args (dict, optional): Methods/troubleshooting; honored keys are + ``args['keyword']['n_samples']`` and ``args['keyword']['num_steps']``. + bath_gas (str, optional): A bath gas. Currently only used in OneDMin. + checkfile (str, optional): The path to a previous Gaussian checkfile. + conformer (int, optional): Conformer number if optimizing conformers. + constraints (list, optional): A list of constraints. + cpu_cores (int, optional): The total number of cpu cores requested for a job. + dihedral_increment (float, optional): Unused for GoFlow. + dihedrals (List[float], optional): The dihedral angles corresponding to self.torsions. + directed_scan_type (str, optional): The type of the directed scan. + ess_settings (dict, optional): A dictionary of available ESS. + ess_trsh_methods (List[str], optional): A list of troubleshooting methods. + execution_type (str, optional): The execution type, 'incore', 'queue', or 'pipe'. + fine (bool, optional): Whether to use fine geometry optimization parameters. + initial_time (datetime.datetime or str, optional): The time at which this job was initiated. + irc_direction (str, optional): The direction of the IRC job. + job_id (int, optional): The job's ID determined by the server. + job_memory_gb (int, optional): The total job allocated memory in GB. + job_name (str, optional): The job's name. + job_num (int, optional): Used as the entry number in the database. + job_server_name (str, optional): Job's name on the server. + job_status (list, optional): The job's server and ESS statuses. + level (Level, optional): The level of theory to use. + max_job_time (float, optional): The maximal allowed job time on the server in hours. + run_multi_species (bool, optional): Whether to run a job for multiple species in the same input file. + reactions (List[ARCReaction], optional): Entries are ARCReaction instances. + rotor_index (int, optional): The 0-indexed rotor number. + server (str): The server to run on. + server_nodes (list, optional): The nodes this job was previously submitted to. + species (List[ARCSpecies], optional): Entries are ARCSpecies instances. + testing (bool, optional): Whether the object is generated for testing purposes. + times_rerun (int, optional): Number of times this job was re-run. + torsions (List[List[int]], optional): The 0-indexed atom indices of the torsion(s). + tsg (int, optional): TSGuess number if optimizing TS guesses. + xyz (dict, optional): The 3D coordinates to use. + """ + + def __init__(self, + project: str, + project_directory: str, + job_type: Union[List[str], str], + args: Optional[dict] = None, + bath_gas: Optional[str] = None, + checkfile: Optional[str] = None, + conformer: Optional[int] = None, + constraints: Optional[List[Tuple[List[int], float]]] = None, + cpu_cores: Optional[str] = None, + dihedral_increment: Optional[float] = None, + dihedrals: Optional[List[float]] = None, + directed_scan_type: Optional[str] = None, + ess_settings: Optional[dict] = None, + ess_trsh_methods: Optional[List[str]] = None, + execution_type: Optional[str] = None, + fine: bool = False, + initial_time: Optional[Union['datetime.datetime', str]] = None, + irc_direction: Optional[str] = None, + job_id: Optional[int] = None, + job_memory_gb: float = 14.0, + job_name: Optional[str] = None, + job_num: Optional[int] = None, + job_server_name: Optional[str] = None, + job_status: Optional[List[Union[dict, str]]] = None, + level: Optional['Level'] = None, + max_job_time: Optional[float] = None, + run_multi_species: bool = False, + reactions: Optional[List['ARCReaction']] = None, + rotor_index: Optional[int] = None, + server: Optional[str] = None, + server_nodes: Optional[list] = None, + queue: Optional[str] = None, + attempted_queues: Optional[List[str]] = None, + species: Optional[List['ARCSpecies']] = None, + testing: bool = False, + times_rerun: int = 0, + torsions: Optional[List[List[int]]] = None, + tsg: Optional[int] = None, + xyz: Optional[dict] = None, + ): + + self.incore_capacity = 1 + self.job_adapter = 'goflow' + self.execution_type = execution_type or 'incore' + self.command = 'goflow_script.py' + self.url = 'https://github.com/heid-lab/goflow_lean' + + if reactions is None: + raise ValueError('Cannot execute GoFlow without ARCReaction object(s).') + + self.n_samples = DEFAULT_N_SAMPLES + self.num_steps = DEFAULT_NUM_STEPS + if args and isinstance(args, dict): + kw = args.get('keyword') or dict() + if 'n_samples' in kw: + try: + self.n_samples = int(kw['n_samples']) + except (TypeError, ValueError): + logger.warning(f"GoFlow adapter: could not parse args['keyword']['n_samples']=" + f"{kw['n_samples']!r} as an int; falling back to " + f"DEFAULT_N_SAMPLES={DEFAULT_N_SAMPLES}.") + if 'num_steps' in kw: + try: + self.num_steps = int(kw['num_steps']) + except (TypeError, ValueError): + logger.warning(f"GoFlow adapter: could not parse args['keyword']['num_steps']=" + f"{kw['num_steps']!r} as an int; falling back to " + f"DEFAULT_NUM_STEPS={DEFAULT_NUM_STEPS}.") + + _initialize_adapter(obj=self, + is_ts=True, + project=project, + project_directory=project_directory, + job_type=job_type, + args=args, + bath_gas=bath_gas, + checkfile=checkfile, + conformer=conformer, + constraints=constraints, + cpu_cores=cpu_cores, + dihedral_increment=dihedral_increment, + dihedrals=dihedrals, + directed_scan_type=directed_scan_type, + ess_settings=ess_settings, + ess_trsh_methods=ess_trsh_methods, + fine=fine, + initial_time=initial_time, + irc_direction=irc_direction, + job_id=job_id, + job_memory_gb=job_memory_gb, + job_name=job_name, + job_num=job_num, + job_server_name=job_server_name, + job_status=job_status, + level=level, + max_job_time=max_job_time, + run_multi_species=run_multi_species, + reactions=reactions, + rotor_index=rotor_index, + server=server, + server_nodes=server_nodes, + queue=queue, + attempted_queues=attempted_queues, + species=species, + testing=testing, + times_rerun=times_rerun, + torsions=torsions, + tsg=tsg, + xyz=xyz, + ) + + def write_input_file(self) -> None: + """No standalone input file — see set_files() (writes input.yml).""" + pass + + def set_files(self) -> None: + """ + Set files to be uploaded and downloaded for queue execution. + + ``self.files_to_upload`` is a list of dictionaries, each with the keys + ``'name'``, ``'source'``, ``'make_x'``, ``'local'``, and ``'remote'``. + """ + # 1. Upload + if self.execution_type != 'incore': + self.write_submit_script() + from arc.imports import settings as _s + self.files_to_upload.append(self.get_file_property_dictionary( + file_name=_s['submit_filenames'][_s['servers'][self.server]['cluster_soft']])) + if os.path.isfile(self.yml_in_path): + self.files_to_upload.append(self.get_file_property_dictionary(file_name='input.yml')) + if os.path.isfile(self.reactant_xyz_path): + self.files_to_upload.append(self.get_file_property_dictionary(file_name='reactant.xyz')) + if os.path.isfile(self.product_xyz_path): + self.files_to_upload.append(self.get_file_property_dictionary(file_name='product.xyz')) + # 2. Download + self.files_to_download.append(self.get_file_property_dictionary(file_name='output.yml')) + self.files_to_download.append(self.get_file_property_dictionary(file_name='goflow_ts.xyz')) + + def set_additional_file_paths(self) -> None: + """Set the local file paths used by GoFlow at job time.""" + self.reactant_xyz_path = os.path.join(self.local_path, 'reactant.xyz') + self.product_xyz_path = os.path.join(self.local_path, 'product.xyz') + self.ts_out_xyz_path = os.path.join(self.local_path, 'goflow_ts.xyz') + self.yml_in_path = os.path.join(self.local_path, 'input.yml') + self.yml_out_path = os.path.join(self.local_path, 'output.yml') + + def set_input_file_memory(self) -> None: + """Set the input file memory attribute.""" + self.cpu_cores, self.job_memory_gb = 1, 1 + + def execute_incore(self): + """Execute the GoFlow job locally (in-process subprocess).""" + self._log_job_execution() + self.initial_time = self.initial_time if self.initial_time else datetime.datetime.now() + self.execute_goflow() + self.final_time = datetime.datetime.now() + + def execute_queue(self): + """Execute the GoFlow job to the server's queue.""" + self.execute_goflow(exe_type='queue') + + def execute_goflow(self, exe_type: str = 'incore'): + """ + Drive the GoFlow subprocess and stitch its output back into ARC. + + Iterates over every reaction in ``self.reactions`` (per the JobAdapter + contract; multiple reactions can share one adapter instance). The + per-rxn input.yml / reactant.xyz / product.xyz / output.yml files + share the adapter's ``self.local_path`` and are overwritten between + reactions; results are consumed into ``rxn.ts_species.ts_guesses`` + before the next reaction runs (same idiom as AutoTST). + + Args: + exe_type (str, optional): Either ``'incore'`` (run locally now) or + ``'queue'`` (just stage the input.yml + submit script). + """ + if not _goflow_environment_ready(): + return + self.reactions = [self.reactions] if not isinstance(self.reactions, list) else self.reactions + timeout_s = getattr(self, 'goflow_subprocess_timeout', 600) + + for rxn in self.reactions: + ok, reason = _within_goflow_supported_domain(rxn) + if not ok: + logger.warning(f'GoFlow: skipping {rxn.label} — outside validated domain ({reason}).') + continue + + if rxn.ts_species is None: + rxn.ts_species = ARCSpecies(label=self.species_label, + is_ts=True, + charge=rxn.charge, + multiplicity=rxn.multiplicity, + ) + + # Build atom-aligned reactant + product XYZ files. ARC's get_reactants_xyz / + # get_products_xyz already use rxn.atom_map to align orderings. + try: + r_xyz_dict = rxn.get_reactants_xyz(return_format='dict') + p_xyz_dict = rxn.get_products_xyz(return_format='dict') + except Exception as e: + logger.warning(f'GoFlow: could not build mapped XYZs for {rxn.label}: {e}') + continue + if r_xyz_dict is None or p_xyz_dict is None: + logger.warning(f'GoFlow: empty mapped XYZs for {rxn.label}') + continue + if len(r_xyz_dict['symbols']) != len(p_xyz_dict['symbols']): + logger.warning(f'GoFlow: atom count mismatch for {rxn.label} ' + f'(R has {len(r_xyz_dict["symbols"])}, P has {len(p_xyz_dict["symbols"])}). Skipping.') + continue + + r_smiles = build_atom_mapped_smiles(rxn, side='reactants') + p_smiles = build_atom_mapped_smiles(rxn, side='products') + if r_smiles is None or p_smiles is None: + logger.warning(f'GoFlow: could not build atom-mapped SMILES for {rxn.label} — skipping.') + continue + + write_xyz_file(r_xyz_dict, self.reactant_xyz_path, comment=f'{rxn.label} reactant') + write_xyz_file(p_xyz_dict, self.product_xyz_path, comment=f'{rxn.label} product') + + input_dict = {'reactant_xyz_path': self.reactant_xyz_path, + 'product_xyz_path': self.product_xyz_path, + 'reactant_smiles': r_smiles, + 'product_smiles': p_smiles, + 'goflow_repo_path': GOFLOW_REPO_PATH, + 'ckpt_path': GOFLOW_CKPT_PATH, + 'feat_dict_path': GOFLOW_FEAT_DICT_PATH, + 'output_xyz_path': self.ts_out_xyz_path, + 'yml_out_path': self.yml_out_path, + 'n_samples': self.n_samples, + 'num_steps': self.num_steps, + 'device': 'auto'} + save_yaml_file(path=self.yml_in_path, content=input_dict) + + if exe_type == 'queue': + self.legacy_queue_execution() + continue + + cmd = [GOFLOW_PYTHON, GOFLOW_SCRIPT_PATH, '--yml_in_path', self.yml_in_path] + try: + result = subprocess.run(cmd, check=False, timeout=timeout_s) + except subprocess.TimeoutExpired: + logger.warning(f'GoFlow subprocess timed out after {timeout_s}s for {rxn.label}; ' + f'skipping. Increase adapter.goflow_subprocess_timeout to extend.') + continue + if result.returncode != 0: + logger.warning(f'GoFlow subprocess returned non-zero exit code {result.returncode} ' + f'for {rxn.label}.') + continue + + if not os.path.isfile(self.yml_out_path): + logger.warning(f'GoFlow produced no output YAML at {self.yml_out_path} for {rxn.label}.') + continue + + tsg_dicts = read_yaml_file(self.yml_out_path) or list() + n_added = 0 + for tsg_dict in tsg_dicts: + if process_goflow_tsg(tsg_dict=tsg_dict, + local_path=self.local_path, + ts_species=rxn.ts_species): + n_added += 1 + + if len(self.reactions) < 5: + if n_added: + logger.info(f'GoFlow successfully found {n_added} TS guesses for {rxn.label}.') + else: + logger.info(f'GoFlow did not find any successful TS guesses for {rxn.label}.') + + +def write_xyz_file(xyz_dict: dict, path: str, comment: str = '') -> None: + """Write an ARC xyz dict to a plain XYZ file (``\\n\\n``).""" + body = xyz_to_str(xyz_dict) + n_atoms = len(xyz_dict['symbols']) + safe_comment = comment.replace('\n', ' ').strip() + with open(path, 'w') as f: + f.write(f'{n_atoms}\n{safe_comment}\n{body}\n') + + +def build_atom_mapped_smiles(rxn: 'ARCReaction', side: str) -> Optional[str]: + """ + Build an atom-mapped SMILES (every H an explicit, mapped atom) for the + reactant or product side of ``rxn``, with map numbers consistent across + the two sides via ``rxn.atom_map``. + + Args: + rxn: The ARCReaction. + side: ``'reactants'`` or ``'products'``. + + Returns: + A SMILES string with every atom carrying an atom-map number 1..N, + or ``None`` if any precondition fails (mol unavailable, atom_map + missing, post-roundtrip validation fails). Returning ``None`` lets + the caller skip cleanly without raising. + """ + if side not in ('reactants', 'products'): + return None + try: + expanded_r, expanded_p = rxn.get_reactants_and_products(return_copies=False) + except Exception: + return None + species = expanded_r if side == 'reactants' else expanded_p + if not species: + return None + try: + atom_map = rxn.atom_map + except Exception: + return None + if atom_map is None and side == 'products': + return None + + # Combine each species's RDKit mol (with explicit Hs) into one editable mol. + combined = Chem.RWMol() + running = 0 + for spc in species: + if spc.mol is None: + return None + try: + rd_mol = to_rdkit_mol(spc.mol, remove_h=False, sanitize=True) + except Exception: + return None + # Defensive: ensure every H is a real atom (not a bracket H count). + rd_mol = Chem.AddHs(rd_mol) + for atom in rd_mol.GetAtoms(): + new_atom = Chem.Atom(atom.GetAtomicNum()) + new_atom.SetFormalCharge(atom.GetFormalCharge()) + new_atom.SetNumRadicalElectrons(atom.GetNumRadicalElectrons()) + combined.AddAtom(new_atom) + for bond in rd_mol.GetBonds(): + combined.AddBond(running + bond.GetBeginAtomIdx(), + running + bond.GetEndAtomIdx(), + bond.GetBondType()) + running += rd_mol.GetNumAtoms() + + n_atoms_combined = combined.GetNumAtoms() + + # Apply atom-map numbers. + if side == 'reactants': + for i in range(n_atoms_combined): + combined.GetAtomWithIdx(i).SetAtomMapNum(i + 1) + else: + # Product atom at combined-position p was reactant atom r where + # atom_map[r] == p. Build the inverse map once (O(N)) instead of + # calling list.index per atom (O(N²)). + product_to_reactant = [-1] * n_atoms_combined + for r_idx, p_idx in enumerate(atom_map): + if 0 <= p_idx < n_atoms_combined: + product_to_reactant[p_idx] = r_idx + if any(r == -1 for r in product_to_reactant): + return None # mapping incomplete or out of range + for p_idx, r_idx in enumerate(product_to_reactant): + combined.GetAtomWithIdx(p_idx).SetAtomMapNum(r_idx + 1) + + try: + smiles = Chem.MolToSmiles(combined.GetMol(), canonical=False) + except Exception: + return None + + # Hard validation: round-trip the SMILES and confirm atom count + map set. + params = Chem.SmilesParserParams() + params.removeHs = False + check = Chem.MolFromSmiles(smiles, params) + if check is None: + return None + if check.GetNumAtoms() != n_atoms_combined: + return None + maps = sorted(a.GetAtomMapNum() for a in check.GetAtoms()) + if maps != list(range(1, n_atoms_combined + 1)): + return None + return smiles + + +def _heavy_atom_pair_distances(xyz: dict) -> Optional[np.ndarray]: + """ + Flat array of pairwise atom-atom distances over the HEAVY-atom subset + of ``xyz`` (skipping every H). Returns ``None`` if there are <2 heavy + atoms — caller should fall back to all-atom dmat. + """ + symbols = xyz['symbols'] + heavy = [i for i, s in enumerate(symbols) if s != 'H'] + if len(heavy) < 2: + return None + coords = np.asarray(xyz['coords'], dtype=float)[heavy] + iu = np.triu_indices(len(heavy), k=1) + return np.linalg.norm(coords[iu[0]] - coords[iu[1]], axis=1) + + +def _heavy_atom_dmat_rmsd(xyz1: dict, xyz2: dict, + dmat1: Optional[np.ndarray] = None, + dmat2: Optional[np.ndarray] = None) -> float: + """ + Aggregate root-mean-square deviation of pairwise distance matrices + computed over the HEAVY-atom subset only. Translation- and rotation- + invariant (uses internal distances) AND torsion-invariant (rotor H + motions don't enter the comparison). + + Why heavy-atom only: torsional sampling of non-reactive rotors (e.g. + a terminal CH2 swinging around its single bond) moves only the H atoms; + the heavy-atom skeleton stays put. ML samples drawn from a flat + torsional PES therefore look "different" by all-atom dmat (rotor H's + at different angles) but are chemically the same TS — they collapse + to the same optimized geometry under any QM saddle-point search. + + ``dmat1`` / ``dmat2`` accept pre-computed heavy-atom distance vectors + (from :func:`_heavy_atom_pair_distances`); pass them when comparing + one xyz against many to avoid recomputation in the dedup loop. + + Falls back to all-atom dmat for purely-hydrogen systems (defensive; + real TS guesses always have ≥2 heavy atoms). Returns ``inf`` if the + two xyzs have differing symbol sequences. + """ + if xyz1.get('symbols') != xyz2.get('symbols'): + return float('inf') + if dmat1 is None: + dmat1 = _heavy_atom_pair_distances(xyz1) + if dmat2 is None: + dmat2 = _heavy_atom_pair_distances(xyz2) + if dmat1 is None or dmat2 is None: + d1 = xyz_to_dmat(xyz1) + d2 = xyz_to_dmat(xyz2) + if d1 is None or d2 is None or d1.shape != d2.shape: + return float('inf') + diff = (d1 - d2).flatten() + return float(np.sqrt(np.mean(diff * diff))) + diff = dmat1 - dmat2 + return float(np.sqrt(np.mean(diff * diff))) + + +def process_goflow_tsg(tsg_dict: dict, + local_path: str, + ts_species: ARCSpecies, + dmat_rmsd_tol: float = GOFLOW_DEDUP_DMAT_RMSD, + ) -> bool: + """ + Convert a single TSGuess-shaped dict from ``goflow_script.py`` into an + ARC ``TSGuess`` object, consolidate against existing similar guesses, + and append if unique. + + Dedup uses an *aggregate* distance-matrix RMSD threshold + (``dmat_rmsd_tol``, default ``GOFLOW_DEDUP_DMAT_RMSD`` Å) computed over + HEAVY atoms only. This is translation/rotation-invariant (important for + ML samples in random orientations) AND torsion-invariant (so terminal + rotor wells of the same TS collapse to a single guess instead of + leaking 5+ near-identical structures into the downstream queue). + + If a near-duplicate is found, the existing guess's ``method`` is + annotated with ``" and GoFlow"`` (mirroring the heuristics adapter's + consolidation pattern) so the consumer knows GoFlow also produced it. + + Returns: + ``True`` if a new (unique, non-colliding) TS guess was appended, + ``False`` otherwise. + """ + if not tsg_dict.get('success') or not tsg_dict.get('initial_xyz'): + return False + try: + ts_xyz = str_to_xyz(tsg_dict['initial_xyz']) + except Exception as e: + logger.warning(f'GoFlow: could not parse TS xyz: {e}') + return False + if colliding_atoms(ts_xyz): + return False + + # Pre-compute the new candidate's heavy-atom dmat once; cache existing + # guesses' dmats on the TSGuess object on first comparison so successive + # calls amortize to O(1) per pair (the dedup loop itself is O(N) per + # call → O(N²) over a batch of N samples; without caching it would be O(N^3)). + ts_dmat = _heavy_atom_pair_distances(ts_xyz) + for other_tsg in ts_species.ts_guesses: + if not (other_tsg.success and other_tsg.initial_xyz is not None): + continue + other_dmat = getattr(other_tsg, '_goflow_heavy_dmat', None) + if other_dmat is None: + other_dmat = _heavy_atom_pair_distances(other_tsg.initial_xyz) + try: + other_tsg._goflow_heavy_dmat = other_dmat + except (AttributeError, TypeError): + pass # accept the recompute cost + if _heavy_atom_dmat_rmsd(ts_xyz, other_tsg.initial_xyz, + dmat1=ts_dmat, dmat2=other_dmat) < dmat_rmsd_tol: + if 'goflow' not in other_tsg.method.lower(): + other_tsg.method += ' and GoFlow' + return False + + method_index = int(tsg_dict.get('method_index', 0)) + tsg = TSGuess(method='GoFlow', + method_direction=tsg_dict.get('method_direction', 'F'), + method_index=method_index, + index=len(ts_species.ts_guesses), + success=True, + ) + tsg.process_xyz(ts_xyz) + ts_species.ts_guesses.append(tsg) + save_geo(xyz=ts_xyz, + path=local_path, + filename=f'GoFlow {method_index}', + format_='xyz', + comment=f'GoFlow sample {method_index}', + ) + return True + + +def _within_goflow_supported_domain(rxn: 'ARCReaction') -> Tuple[bool, str]: + """ + Check that ``rxn`` falls within GoFlow's validated training domain + (RDB7-like: H/C/N/O/F, modest atom count). + + Returns: + (True, '') if the reaction is supported; (False, reason) otherwise. + """ + elements_seen = set() + n_atoms = 0 + for spc in (rxn.r_species or []): + try: + symbols = spc.get_xyz()['symbols'] if spc.get_xyz() else \ + tuple(a.element.symbol for a in spc.mol.atoms) if spc.mol is not None else tuple() + except Exception: + symbols = tuple(a.element.symbol for a in spc.mol.atoms) if spc.mol is not None else tuple() + elements_seen.update(symbols) + n_atoms += len(symbols) + unsupported = elements_seen - SUPPORTED_GOFLOW_ELEMENTS + if unsupported: + return False, f'unsupported element(s): {sorted(unsupported)}' + if n_atoms > MAX_GOFLOW_ATOMS: + return False, f'reaction has {n_atoms} atoms, above MAX_GOFLOW_ATOMS={MAX_GOFLOW_ATOMS}' + if n_atoms == 0: + return False, 'no atoms found in reactants' + return True, '' + + +def _goflow_environment_ready() -> bool: + """ + Verify that everything GoFlow needs at runtime is in place. Logs a clear + one-line warning per missing piece and returns ``False`` so the adapter + can skip cleanly without raising. + + Note: ``torch.load``-level checkpoint validation is deferred to + ``goflow_script.py`` (which runs in goflow_env where torch is available). + """ + ok = True + if not GOFLOW_PYTHON or not os.path.isfile(GOFLOW_PYTHON): + logger.warning('GoFlow adapter: goflow_env python not found ' + '(set GOFLOW_PYTHON or run `make install-goflow`). Skipping GoFlow TS guesses.') + ok = False + if not GOFLOW_REPO_PATH or not os.path.isdir(GOFLOW_REPO_PATH): + logger.warning('GoFlow adapter: goflow_lean source checkout not found ' + '(set ARC_GOFLOW_REPO or run `make install-goflow`). Skipping GoFlow TS guesses.') + ok = False + if not GOFLOW_CKPT_PATH or not os.path.isfile(GOFLOW_CKPT_PATH) \ + or os.path.getsize(GOFLOW_CKPT_PATH) < _GOFLOW_CKPT_MIN_SIZE: + logger.warning('GoFlow adapter: pretrained checkpoint not found or too small ' + '(the in-repo file may be a 45-byte LFS pointer; set ARC_GOFLOW_CKPT to a real ckpt). ' + 'Skipping GoFlow TS guesses.') + ok = False + if not GOFLOW_FEAT_DICT_PATH or not os.path.isfile(GOFLOW_FEAT_DICT_PATH) \ + or os.path.getsize(GOFLOW_FEAT_DICT_PATH) < _GOFLOW_FEAT_DICT_MIN_SIZE: + logger.warning('GoFlow adapter: atom-feature dictionary not found or too small ' + '(set ARC_GOFLOW_FEAT_DICT to a real pickle). Skipping GoFlow TS guesses.') + ok = False + return ok + + +register_job_adapter('goflow', GoFlowAdapter) diff --git a/arc/job/adapters/ts/rits_test.py b/arc/job/adapters/ts/rits_test.py new file mode 100644 index 0000000000..60dcc40cdc --- /dev/null +++ b/arc/job/adapters/ts/rits_test.py @@ -0,0 +1,1024 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Unit tests for the RitS TS-guess adapter (``arc.job.adapters.ts.rits_ts``). + +Tier-1 (always runs): + * settings resolution and finder helpers + * pure-Python helpers: ``write_xyz_file``, ``parse_multi_frame_xyz``, + ``process_rits_tsg`` dedup + * adapter instantiation with ``testing=True``, file-path layout + * graceful skip when ``rits_env`` / checkpoint are missing + * input.yml writer (mocked subprocess) + +Tier-2 (gated on ``_rits_environment_ready()``): + * end-to-end ``execute_incore`` against the real ``rits_env`` for a + handful of family-diverse reactions sourced from + ``arc/job/adapters/ts/linear_test.py``. + +The Tier-2 tests are skipped automatically on CI runners that did not run +``install_rits.sh`` — the matching CI lane (``rits-install`` in +``.github/workflows/ci.yml``) installs the env and exercises them. +""" + +import inspect +import math +import os +import shutil +import unittest +from collections import Counter +from unittest import mock + +import arc.job.adapters.ts.rits_ts as rits_mod +from arc.common import ARC_TESTING_PATH, read_yaml_file +from arc.job.adapters.ts.rits_ts import (RitSAdapter, + _rits_environment_ready, + process_rits_tsg, + write_xyz_file, + ) +from arc.plotter import save_geo +from arc.reaction import ARCReaction +from arc.species.converter import str_to_xyz, compare_confs +from arc.species.species import ARCSpecies, TSGuess + + +def _save_debug_geometries(ts_xyzs, rxn) -> None: + """Hard-coded debug helper for visualizing RitS-adapter TS guesses. + + Disabled unless ``ARC_DEBUG_RITS=1`` is set in the environment, so CI + and other developers do not get surprise writes/deletes under + ``~/Desktop/xyz/rits/`` (which can also be slow on networked home dirs). + + When enabled, clears ``~/Desktop/xyz/rits/`` and dumps the current + reaction's reactant(s), product(s), and TS guesses there as Gaussian + ``.gjf`` files. An empty marker file named after the calling test + function (``.txt``) is also written so it is obvious + which test produced the contents. + + This is *for debugging only* — it is not part of any assertion or + correctness check. When the directory does not exist, it is + created. All errors are swallowed defensively so that calling + this helper from a test never makes the test crash. + """ + if os.environ.get('ARC_DEBUG_RITS') != '1': + return + out_dir = os.path.expanduser('~/Desktop/xyz/rits') + try: + if os.path.isdir(out_dir): + for entry in os.listdir(out_dir): + full = os.path.join(out_dir, entry) + if os.path.isfile(full): + try: + os.remove(full) + except OSError: + # Best-effort debug cleanup; ignore deletion failures + # so a stale file does not break a test run. + pass + else: + os.makedirs(out_dir, exist_ok=True) + # Marker file with the calling test function's name. The + # ``x_`` prefix makes it sort last in the directory listing, + # *after* ``R_*``, ``P_*``, and ``TS_*`` files. Walk up the call + # stack until we hit a frame whose function starts with ``test_``, + # so the marker reflects the test method even when the helper is + # invoked from an intermediate ``_run_e2e``-style helper. + try: + for frame in inspect.stack()[1:]: + if frame.function.startswith('test_'): + caller = frame.function + break + else: + # No test_* in the stack — fall back to the immediate caller. + caller = inspect.stack()[1].function + except Exception: + caller = 'rits_debug' + try: + open(os.path.join(out_dir, f'x_{caller}.txt'), 'w').close() + except OSError: + # Best-effort debug marker only; never fail a test if writing + # the marker file fails (read-only filesystem, full disk, etc.). + pass + # Reactants. + for i, sp in enumerate(getattr(rxn, 'r_species', []) or []): + try: + save_geo(xyz=sp.get_xyz(), path=out_dir, + filename=f'R_{i}', format_='gjf') + except Exception: + # Debug helper must remain best-effort and never fail tests. + pass + # Products. + for i, sp in enumerate(getattr(rxn, 'p_species', []) or []): + try: + save_geo(xyz=sp.get_xyz(), path=out_dir, + filename=f'P_{i}', format_='gjf') + except Exception: + # Debug helper must remain best-effort and never fail tests. + pass + # TS guesses. + for i, ts in enumerate(ts_xyzs or []): + try: + save_geo(xyz=ts, path=out_dir, + filename=f'TS_{i}', format_='gjf') + except Exception: + # Debug helper must remain best-effort and never fail tests. + pass + except Exception: + # The helper must NEVER make a test fail; swallow everything that + # bubbles up from the per-step blocks above (filesystem errors, + # missing attributes on partially-built reactions, …). + pass + + +HAS_RITS = _rits_environment_ready(log=False) + + +def _build_rxn_isomerization_propyl(): + """nC3H7 → iC3H7. The simplest isomerization in ARC's test suite.""" + return ARCReaction(r_species=[ARCSpecies(label='nC3H7', smiles='[CH2]CC')], + p_species=[ARCSpecies(label='iC3H7', smiles='C[CH]C')]) + + +def _build_rxn_diels_alder(): + """C=CC(=C)C + C=CC=O → CC1=CCC(C=O)CC1 — bimolecular Diels-Alder.""" + r1_xyz = """C 1.97753426 -0.34691463 -0.12195850 +C 0.96032171 0.45485914 -0.46215363 +C -0.43629664 0.27157147 -0.09968556 +C -1.35584640 1.15966116 -0.51269091 +C -0.83651671 -0.91436221 0.73635894 +H 2.98719352 -0.11575642 -0.44772907 +H 1.84910220 -1.24076974 0.47792776 +H 1.19368072 1.33006788 -1.06832846 +H -2.40510842 1.04750710 -0.25687679 +H -1.09525737 2.02366247 -1.11636739 +H -0.32888591 -0.89422114 1.70676182 +H -1.91408642 -0.93005704 0.93479551 +H -0.58767904 -1.85093188 0.22577726""" + r2_xyz = """C -1.22034116 -0.10890246 0.02353603 +C -0.04004107 0.51094374 -0.08149118 +C 1.22322531 -0.24393463 0.03286276 +O 2.30875132 0.31445302 -0.06186255 +H -1.30612429 -1.17741471 0.19480533 +H -2.14393224 0.45618508 -0.06217786 +H 0.04657041 1.57753840 -0.25245803 +H 1.13189173 -1.32886845 0.20678550""" + p_xyz = """C 2.60098776 -0.04177774 0.73723478 +C 1.20465630 0.10105432 0.20245819 +C 0.16278370 -0.55312927 0.74494799 +C -1.24024239 -0.46705077 0.21761600 +C -1.33954822 0.16452081 -1.17701034 +C -1.06935354 -0.87644399 -2.25040126 +O -0.50075393 -0.64415323 -3.31363975 +C -0.41124651 1.37364733 -1.29938488 +C 1.04721460 1.02438027 -0.98148987 +H 3.26841747 -0.42094194 -0.04336972 +H 2.64920967 -0.73532885 1.58328037 +H 2.97843218 0.92762356 1.07822844 +H 0.31418172 -1.19138627 1.61332708 +H -1.82762138 0.12846013 0.92764672 +H -1.67646259 -1.47309646 0.21384290 +H -2.37737283 0.48324136 -1.33650826 +H -1.50255476 -1.87505625 -2.06737417 +H -0.75069363 2.15000964 -0.60076538 +H -0.46865428 1.81280411 -2.30253884 +H 1.51473571 0.55339822 -1.85465668 +H 1.59082870 1.95894204 -0.79688170""" + r1 = ARCSpecies(label='R1', smiles='C=CC(=C)C', xyz=r1_xyz) + r2 = ARCSpecies(label='R2', smiles='C=CC=O', xyz=r2_xyz) + p = ARCSpecies(label='P', smiles='CC1=CCC(C=O)CC1', xyz=p_xyz) + return ARCReaction(r_species=[r1, r2], p_species=[p]) + + +def _build_rxn_one_plus_two_cycloaddition(): + """Singlet CH2 + C=C=C → C=C1CC1 — bimolecular addition with carbene.""" + ch2_xyz = """C 0.00000000 0.00000000 0.10513200 +H 0.00000000 0.98826300 -0.31539600 +H 0.00000000 -0.98826300 -0.31539600""" + c3h4_xyz = """C 1.29697653 0.02233190 0.00658756 +C 0.00000000 -0.00000034 0.00000210 +C -1.29697654 -0.02233198 -0.00658580 +H 1.86532844 -0.70256077 -0.56460908 +H 1.83420869 0.76626329 0.58339481 +H -1.85591941 0.54211003 -0.74397783 +H -1.84361771 -0.60581213 0.72518823""" + c4h6_xyz = """C 1.59999925 -0.11618654 -0.14166302 +C 0.29517860 -0.02143486 -0.02613492 +C -0.92013120 -0.71833111 0.10894610 +C -0.81238032 0.84414025 0.04444949 +H 2.21797993 0.77036923 -0.22897655 +H 2.09015362 -1.08321135 -0.15246324 +H -1.12327237 -1.17593811 1.06705013 +H -1.28992770 -1.23997489 -0.76270297 +H -0.94547237 1.40230195 0.96062403 +H -1.11212744 1.33826544 -0.86912905""" + r1 = ARCSpecies(label='CH2_singlet', adjlist="""multiplicity 1 +1 C u0 p1 c0 {2,S} {3,S} +2 H u0 p0 c0 {1,S} +3 H u0 p0 c0 {1,S} +""", xyz=ch2_xyz) + r2 = ARCSpecies(label='allene', smiles='C=C=C', xyz=c3h4_xyz) + p = ARCSpecies(label='methylene_cyclopropane', smiles='C=C1CC1', xyz=c4h6_xyz) + return ARCReaction(r_species=[r1, r2], p_species=[p]) + + +def _build_rxn_nh3_elimination(): + """NNN → H2NN(s) + NH3 — 1 reactant → 2 products elimination.""" + n3_xyz = """N -1.26709244 -0.00392551 -0.17821516 +N -0.00831159 0.62912211 -0.22607923 +N -0.03650217 1.66537185 0.72488290 +H -1.36396603 -0.52480010 0.69598616 +H -1.33497366 -0.72150540 -0.90528855 +H 0.20276134 1.00409437 -1.16407646 +H 0.01517757 1.28943240 1.67165685 +H -0.93213409 2.15501337 0.67312449""" + h2nn_xyz = """N 1.24087876 0.00949543 0.60790318 +N -0.09033762 -0.00069128 0.02459641 +H -0.47927195 -0.84665038 -0.39226764 +H -0.67126919 0.83784623 0.01648883""" + nh3_xyz = """N 0.00064924 -0.00099698 0.29559292 +H -0.41786606 0.84210396 -0.09477452 +H -0.52039228 -0.78225292 -0.10002797 +H 0.93760911 -0.05885406 -0.10079043""" + r = ARCSpecies(label='triazene', smiles='NNN', xyz=n3_xyz) + p1 = ARCSpecies(label='H2NNs', adjlist="""multiplicity 1 +1 N u0 p0 c+1 {2,S} {3,S} {4,D} +2 H u0 p0 c0 {1,S} +3 H u0 p0 c0 {1,S} +4 N u0 p2 c-1 {1,D} +""", xyz=h2nn_xyz) + p2 = ARCSpecies(label='NH3', smiles='N', xyz=nh3_xyz) + return ARCReaction(r_species=[r], p_species=[p1, p2]) + + +# === Group A: 1<->1 isomerizations ========================================= + +def _build_rxn_vinyl_alcohol_to_acetaldehyde(): + """Keto-enol tautomerization C2H4O: C=CO -> CC=O (6 atoms, 1,3-H shift).""" + r = ARCSpecies(label='vinyl_alcohol', smiles='C=CO') + p = ARCSpecies(label='acetaldehyde', smiles='CC=O') + return ARCReaction(r_species=[r], p_species=[p]) + + +def _build_rxn_propenol_to_acetone(): + """Keto-enol tautomerization C3H6O: OC(=C)C -> CC(=O)C (10 atoms).""" + r = ARCSpecies(label='propen_2_ol', smiles='OC(=C)C') + p = ARCSpecies(label='acetone', smiles='CC(=O)C') + return ARCReaction(r_species=[r], p_species=[p]) + + +def _build_rxn_cyclobutene_to_butadiene(): + """Electrocyclic ring opening C4H6: C1=CCC1 -> C=CC=C (10 atoms).""" + r = ARCSpecies(label='cyclobutene', smiles='C1=CCC1') + p = ARCSpecies(label='1_3_butadiene', smiles='C=CC=C') + return ARCReaction(r_species=[r], p_species=[p]) + + +def _build_rxn_methoxy_to_hydroxymethyl(): + """1,2-H migration in CH3O radical: [O]C -> O[CH2] (5 atoms).""" + r = ARCSpecies(label='methoxy', smiles='[O]C') + p = ARCSpecies(label='hydroxymethyl', smiles='O[CH2]') + return ARCReaction(r_species=[r], p_species=[p]) + + +def _build_rxn_ethoxy_to_alpha_hydroxyethyl(): + """1,2-H migration in CH3CH2O radical: CC[O] -> [CH2]CO (8 atoms).""" + r = ARCSpecies(label='ethoxy', smiles='CC[O]') + p = ARCSpecies(label='alpha_hydroxyethyl', smiles='[CH2]CO') + return ARCReaction(r_species=[r], p_species=[p]) + + +def _build_rxn_cyclopropane_to_propene(): + """Ring opening C3H6: C1CC1 -> C=CC (9 atoms).""" + r = ARCSpecies(label='cyclopropane', smiles='C1CC1') + p = ARCSpecies(label='propene', smiles='C=CC') + return ARCReaction(r_species=[r], p_species=[p]) + + +# === Group B: 1<->2 / 2<->1 (eliminations / cycloadditions) ================ + +def _build_rxn_cyclobutane_retro_22(): + """Retro [2+2] C4H8 -> 2 C2H4 (cyclobutane -> 2 ethene), 12 atoms.""" + r = ARCSpecies(label='cyclobutane', smiles='C1CCC1') + p1 = ARCSpecies(label='ethene_a', smiles='C=C') + p2 = ARCSpecies(label='ethene_b', smiles='C=C') + return ARCReaction(r_species=[r], p_species=[p1, p2]) + + +def _build_rxn_da_butadiene_ethene(): + """Small Diels-Alder C4H6 + C2H4 -> cyclohexene C6H10 (16 atoms).""" + r1 = ARCSpecies(label='1_3_butadiene', smiles='C=CC=C') + r2 = ARCSpecies(label='ethene', smiles='C=C') + p = ARCSpecies(label='cyclohexene', smiles='C1=CCCCC1') + return ARCReaction(r_species=[r1, r2], p_species=[p]) + + +def _build_rxn_ethanol_dehydration(): + """β-elimination CCO -> C=C + H2O (9 atoms).""" + r = ARCSpecies(label='ethanol', smiles='CCO') + p1 = ARCSpecies(label='ethene', smiles='C=C') + p2 = ARCSpecies(label='water', smiles='O') + return ARCReaction(r_species=[r], p_species=[p1, p2]) + + +def _build_rxn_methylamine_dehydrogenation(): + """1,2-dehydrogenation CN -> C=N + H2 (7 atoms total).""" + r = ARCSpecies(label='methylamine', smiles='CN') + p1 = ARCSpecies(label='methyleneamine', smiles='C=N') + p2 = ARCSpecies(label='dihydrogen', smiles='[H][H]') + return ARCReaction(r_species=[r], p_species=[p1, p2]) + + +def _build_rxn_ethyl_peroxy_ho2_elimination(): + """β-scission CCO[O] -> C=C + O[O] (9 atoms).""" + r = ARCSpecies(label='ethyl_peroxy', smiles='CCO[O]') + p1 = ARCSpecies(label='ethene', smiles='C=C') + p2 = ARCSpecies(label='hydroperoxyl', smiles='O[O]') + return ARCReaction(r_species=[r], p_species=[p1, p2]) + + +# === Group C: 2<->2 H-abstractions ========================================= + +def _build_rxn_hab_ch4_oh(): + """H-abstraction CH4 + OH -> CH3 + H2O (6 atoms each side).""" + r1 = ARCSpecies(label='methane', smiles='C') + r2 = ARCSpecies(label='hydroxyl', smiles='[OH]') + p1 = ARCSpecies(label='methyl', smiles='[CH3]') + p2 = ARCSpecies(label='water', smiles='O') + return ARCReaction(r_species=[r1, r2], p_species=[p1, p2]) + + +def _build_rxn_hab_c2h6_h(): + """H-abstraction C2H6 + H -> C2H5 + H2 (9 atoms).""" + r1 = ARCSpecies(label='ethane', smiles='CC') + r2 = ARCSpecies(label='H_atom', smiles='[H]') + p1 = ARCSpecies(label='ethyl', smiles='C[CH2]') + p2 = ARCSpecies(label='dihydrogen', smiles='[H][H]') + return ARCReaction(r_species=[r1, r2], p_species=[p1, p2]) + + +def _build_rxn_hab_nh3_oh(): + """H-abstraction NH3 + OH -> NH2 + H2O (6 atoms).""" + r1 = ARCSpecies(label='ammonia', smiles='N') + r2 = ARCSpecies(label='hydroxyl', smiles='[OH]') + p1 = ARCSpecies(label='amidogen', smiles='[NH2]') + p2 = ARCSpecies(label='water', smiles='O') + return ARCReaction(r_species=[r1, r2], p_species=[p1, p2]) + + +def _build_rxn_hab_ch3oh_h(): + """H-abstraction CH3OH + H -> CH2OH + H2 (7 atoms; abstracts α-CH).""" + r1 = ARCSpecies(label='methanol', smiles='CO') + r2 = ARCSpecies(label='H_atom', smiles='[H]') + p1 = ARCSpecies(label='hydroxymethyl', smiles='[CH2]O') + p2 = ARCSpecies(label='dihydrogen', smiles='[H][H]') + return ARCReaction(r_species=[r1, r2], p_species=[p1, p2]) + + +# --------------------------------------------------------------------------- +# Pure-python helpers + plumbing +# --------------------------------------------------------------------------- + +class TestRitSHelpers(unittest.TestCase): + """Helper-function unit tests that don't need rits_env.""" + + @classmethod + def setUpClass(cls): + cls.tmp_dir = os.path.join(ARC_TESTING_PATH, 'rits_helpers') + os.makedirs(cls.tmp_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmp_dir, ignore_errors=True) + + def test_write_xyz_file_round_trip(self): + """write_xyz_file should produce a parseable XYZ file with correct atom count.""" + xyz_dict = { + 'symbols': ('C', 'H', 'H', 'H', 'H'), + 'isotopes': (12, 1, 1, 1, 1), + 'coords': ( + (0.0, 0.0, 0.0), + (1.0, 0.0, 0.0), + (-1.0, 0.0, 0.0), + (0.0, 1.0, 0.0), + (0.0, -1.0, 0.0), + ), + } + path = os.path.join(self.tmp_dir, 'methane.xyz') + write_xyz_file(xyz_dict, path, comment='methane test') + self.assertTrue(os.path.isfile(path)) + with open(path) as f: + lines = f.read().splitlines() + # Header + self.assertEqual(int(lines[0]), 5) + self.assertEqual(lines[1], 'methane test') + # Body — 5 coordinate lines starting with the right symbols + body_symbols = [ln.split()[0] for ln in lines[2:7]] + self.assertEqual(body_symbols, ['C', 'H', 'H', 'H', 'H']) + # Round-trip via str_to_xyz + rt = str_to_xyz(path) + self.assertEqual(rt['symbols'], xyz_dict['symbols']) + + def test_write_xyz_file_strips_newlines_in_comment(self): + """A multi-line comment must not corrupt the XYZ format.""" + xyz_dict = { + 'symbols': ('H', 'H'), + 'isotopes': (1, 1), + 'coords': ((0.0, 0.0, 0.0), (0.74, 0.0, 0.0)), + } + path = os.path.join(self.tmp_dir, 'h2.xyz') + write_xyz_file(xyz_dict, path, comment='line1\nline2\nline3') + with open(path) as f: + lines = f.read().splitlines() + # Header is exactly 2 lines + 2 atoms = 4 lines minimum + self.assertEqual(int(lines[0]), 2) + self.assertNotIn('\n', lines[1]) + self.assertEqual(len(lines), 4) + + def test_process_rits_tsg_failed_entry(self): + """A failed-sentinel dict should not produce a TSGuess.""" + ts_species = ARCSpecies(label='TS', is_ts=True) + added = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 0, + 'initial_xyz': None, 'success': False, 'execution_time': '0:00:00.0'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertFalse(added) + self.assertEqual(len(ts_species.ts_guesses), 0) + + def test_process_rits_tsg_dedup_against_existing(self): + """A RitS guess that matches an existing GCN guess should not be appended; + the existing guess should be re-labeled to credit RitS as well.""" + ts_species = ARCSpecies(label='TS', is_ts=True) + # Plant a GCN guess first. + existing_xyz_str = """C 0.0 0.0 0.0 +H 1.0 0.0 0.0 +H -1.0 0.0 0.0 +H 0.0 1.0 0.0 +H 0.0 -1.0 0.0""" + existing = TSGuess(method='GCN', method_direction='F', method_index=0, + index=0, success=True) + existing.process_xyz(str_to_xyz(existing_xyz_str)) + ts_species.ts_guesses.append(existing) + # Submit a RitS guess with identical coordinates. + added = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 0, + 'initial_xyz': existing_xyz_str, 'success': True, + 'execution_time': '0:00:01.0'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertFalse(added) # not appended + self.assertEqual(len(ts_species.ts_guesses), 1) + # The existing guess should now credit both methods. Note: TSGuess + # lowercases the method string on construction. + merged = ts_species.ts_guesses[0].method.lower() + self.assertIn('rits', merged) + self.assertIn('gcn', merged) + + def test_process_rits_tsg_unique_guess_appended(self): + """A unique non-colliding guess should be appended.""" + ts_species = ARCSpecies(label='TS', is_ts=True) + unique_xyz = """C 0.0 0.0 0.0 +H 1.5 0.0 0.0 +H -1.5 0.0 0.0 +H 0.0 1.5 0.0 +H 0.0 -1.5 0.0""" + added = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 2, + 'initial_xyz': unique_xyz, 'success': True, + 'execution_time': '0:00:02.0'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertTrue(added) + self.assertEqual(len(ts_species.ts_guesses), 1) + # TSGuess lowercases method on construction. + self.assertEqual(ts_species.ts_guesses[0].method.lower(), 'rits') + self.assertEqual(ts_species.ts_guesses[0].method_index, 2) + self.assertTrue(ts_species.ts_guesses[0].success) + + def test_process_rits_tsg_collision_rejected(self): + """A guess where two atoms overlap must be rejected by colliding_atoms.""" + ts_species = ARCSpecies(label='TS', is_ts=True) + bad_xyz = """C 0.0 0.0 0.0 +H 0.0 0.0 0.0 +H -1.5 0.0 0.0 +H 0.0 1.5 0.0 +H 0.0 -1.5 0.0""" + added = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 0, + 'initial_xyz': bad_xyz, 'success': True, + 'execution_time': '0:00:00.5'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertFalse(added) + self.assertEqual(len(ts_species.ts_guesses), 0) + + def test_process_rits_tsg_dedup_catches_rigid_rotation(self): + """A rigidly rotated + translated copy of an existing TSGuess must be + deduped. This is the whole point of switching from byte-level + almost_equal_coords to distance-matrix compare_confs — RitS samples + each TS in its own random orientation, so rotated copies are common. + """ + ts_species = ARCSpecies(label='TS', is_ts=True) + # Plant the original (use atypical CH bond lengths so we can be sure + # the assertion isn't accidentally matching some default geometry). + original_xyz = """C 0.000 0.000 0.000 +H 0.700 0.700 0.700 +H -0.700 -0.700 0.700 +H -0.700 0.700 -0.700 +H 0.700 -0.700 -0.700""" + first = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 0, + 'initial_xyz': original_xyz, 'success': True, + 'execution_time': '0:00:00.0'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertTrue(first) + self.assertEqual(len(ts_species.ts_guesses), 1) + + # Build a 37° z-axis rotation + translation of the same molecule. + theta = math.radians(37.0) + cos_t, sin_t = math.cos(theta), math.sin(theta) + original_coords = [ + (0.000, 0.000, 0.000), + (0.700, 0.700, 0.700), + (-0.700, -0.700, 0.700), + (-0.700, 0.700, -0.700), + (0.700, -0.700, -0.700), + ] + rotated = [] + for x, y, z in original_coords: + rx = cos_t * x - sin_t * y + 10.0 # also translate by (+10, +5, -3) + ry = sin_t * x + cos_t * y + 5.0 + rz = z - 3.0 + rotated.append((rx, ry, rz)) + symbols = ('C', 'H', 'H', 'H', 'H') + rotated_xyz_str = '\n'.join( + f'{s} {x:.6f} {y:.6f} {z:.6f}' for s, (x, y, z) in zip(symbols, rotated) + ) + + added = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 1, + 'initial_xyz': rotated_xyz_str, 'success': True, + 'execution_time': '0:00:00.5'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertFalse(added, + 'rotated+translated duplicate of an existing RitS guess ' + 'must be deduped via compare_confs (distance-matrix RMSD)') + self.assertEqual(len(ts_species.ts_guesses), 1, + 'no new TSGuess should be appended for a rotated duplicate') + + +class TestRitSAdapterInstantiation(unittest.TestCase): + """Verify the adapter constructs and lays out files even without rits_env.""" + + @classmethod + def setUpClass(cls): + cls.maxDiff = None + cls.output_dir = os.path.join(ARC_TESTING_PATH, 'RitS', 'instantiation') + os.makedirs(cls.output_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(os.path.join(ARC_TESTING_PATH, 'RitS'), ignore_errors=True) + + def _build_adapter(self, project_dir: str, n_samples: int = 5): + rxn = _build_rxn_isomerization_propyl() + return RitSAdapter( + job_type='tsg', + reactions=[rxn], + testing=True, + project='test_rits', + project_directory=project_dir, + args={'keyword': {'n_samples': n_samples}}, + ) + + def test_instantiation_sets_paths_and_metadata(self): + proj = os.path.join(self.output_dir, 'paths') + adapter = self._build_adapter(proj, n_samples=7) + self.assertEqual(adapter.job_adapter, 'rits') + self.assertEqual(adapter.execution_type, 'incore') + self.assertEqual(adapter.url, 'https://github.com/isayevlab/RitS') + self.assertEqual(adapter.incore_capacity, 1) + self.assertEqual(adapter.n_samples, 7) + # File paths should all live under the local_path the adapter set up + self.assertTrue(adapter.reactant_xyz_path.endswith('reactant.xyz')) + self.assertTrue(adapter.product_xyz_path.endswith('product.xyz')) + self.assertTrue(adapter.ts_out_xyz_path.endswith('rits_ts.xyz')) + self.assertTrue(adapter.yml_in_path.endswith('input.yml')) + self.assertTrue(adapter.yml_out_path.endswith('output.yml')) + # All five paths should share a parent directory + parents = {os.path.dirname(p) for p in (adapter.reactant_xyz_path, + adapter.product_xyz_path, + adapter.ts_out_xyz_path, + adapter.yml_in_path, + adapter.yml_out_path)} + self.assertEqual(len(parents), 1) + + def test_default_n_samples(self): + proj = os.path.join(self.output_dir, 'default_samples') + adapter = RitSAdapter( + job_type='tsg', + reactions=[_build_rxn_isomerization_propyl()], + testing=True, + project='test_rits', + project_directory=proj, + ) + self.assertEqual(adapter.n_samples, rits_mod.DEFAULT_N_SAMPLES) + + def test_n_samples_invalid_args_falls_back_to_default(self): + proj = os.path.join(self.output_dir, 'bad_samples') + adapter = RitSAdapter( + job_type='tsg', + reactions=[_build_rxn_isomerization_propyl()], + testing=True, + project='test_rits', + project_directory=proj, + args={'keyword': {'n_samples': 'not-a-number'}}, + ) + self.assertEqual(adapter.n_samples, rits_mod.DEFAULT_N_SAMPLES) + + def test_missing_reactions_raises(self): + proj = os.path.join(self.output_dir, 'no_reactions') + with self.assertRaises(ValueError): + RitSAdapter(job_type='tsg', reactions=None, testing=True, + project='test_rits', project_directory=proj) + + +class TestRitSGracefulSkip(unittest.TestCase): + """When rits_env / checkpoint are missing, execute_incore must NOT raise.""" + + @classmethod + def setUpClass(cls): + cls.output_dir = os.path.join(ARC_TESTING_PATH, 'RitS', 'graceful_skip') + os.makedirs(cls.output_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(os.path.join(ARC_TESTING_PATH, 'RitS'), ignore_errors=True) + + def test_missing_python_logs_and_returns(self): + rxn = _build_rxn_isomerization_propyl() + adapter = RitSAdapter( + job_type='tsg', + reactions=[rxn], + testing=True, + project='test_rits', + project_directory=os.path.join(self.output_dir, 'no_python'), + ) + # Patch the module-level constants to simulate a host without rits_env. + with mock.patch.object(rits_mod, 'RITS_PYTHON', None), \ + mock.patch.object(rits_mod, 'RITS_REPO_PATH', '/nonexistent/RitS'), \ + mock.patch.object(rits_mod, 'RITS_CKPT_PATH', '/nonexistent/rits.ckpt'): + # Should not raise + adapter.execute_incore() + # No TS guesses should have been created + if rxn.ts_species is not None: + self.assertEqual(len(rxn.ts_species.ts_guesses), 0) + + def test_missing_checkpoint_logs_and_returns(self): + rxn = _build_rxn_isomerization_propyl() + adapter = RitSAdapter( + job_type='tsg', + reactions=[rxn], + testing=True, + project='test_rits', + project_directory=os.path.join(self.output_dir, 'no_ckpt'), + ) + with mock.patch.object(rits_mod, 'RITS_CKPT_PATH', '/nonexistent/ckpt'): + adapter.execute_incore() + if rxn.ts_species is not None: + self.assertEqual(len(rxn.ts_species.ts_guesses), 0) + + +class TestRitSInputYamlWritten(unittest.TestCase): + """Verify input.yml is written correctly without invoking the real subprocess.""" + + @classmethod + def setUpClass(cls): + cls.output_dir = os.path.join(ARC_TESTING_PATH, 'RitS', 'input_yml') + os.makedirs(cls.output_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(os.path.join(ARC_TESTING_PATH, 'RitS'), ignore_errors=True) + + def test_input_yml_contents(self): + """A successful execute_incore should write input.yml with all required keys. + + We mock subprocess.run so the test does not depend on rits_env actually + being installed.""" + rxn = _build_rxn_diels_alder() + adapter = RitSAdapter( + job_type='tsg', + reactions=[rxn], + testing=True, + project='test_rits', + project_directory=os.path.join(self.output_dir, 'da'), + args={'keyword': {'n_samples': 4}}, + ) + + # Pretend the env is fully ready, but make subprocess.run a no-op so we + # never actually invoke RitS — we only care about input.yml + the + # mapped reactant.xyz / product.xyz files we wrote. + fake_completed = mock.Mock(returncode=0) + with mock.patch.object(rits_mod, '_rits_environment_ready', return_value=True), \ + mock.patch.object(rits_mod, 'RITS_PYTHON', '/fake/python'), \ + mock.patch.object(rits_mod, 'RITS_REPO_PATH', '/fake/RitS'), \ + mock.patch.object(rits_mod, 'RITS_CKPT_PATH', '/fake/rits.ckpt'), \ + mock.patch('arc.job.adapters.ts.rits_ts.subprocess.run', + return_value=fake_completed) as run_mock: + adapter.execute_incore() + + self.assertTrue(run_mock.called) + # input.yml should exist with the keys our standalone script expects + self.assertTrue(os.path.isfile(adapter.yml_in_path)) + in_dict = read_yaml_file(adapter.yml_in_path) + for key in ('reactant_xyz_path', 'product_xyz_path', 'rits_repo_path', + 'ckpt_path', 'output_xyz_path', 'yml_out_path', + 'config_path', 'n_samples', 'batch_size', 'charge', 'device'): + self.assertIn(key, in_dict, f'missing key {key} in input.yml') + self.assertEqual(in_dict['n_samples'], 4) + self.assertEqual(in_dict['device'], 'auto') + self.assertEqual(in_dict['rits_repo_path'], '/fake/RitS') + self.assertEqual(in_dict['ckpt_path'], '/fake/rits.ckpt') + self.assertTrue(in_dict['config_path'].endswith('rits.yaml')) + # The reactant + product XYZ files should be on disk and have matching atom counts + self.assertTrue(os.path.isfile(adapter.reactant_xyz_path)) + self.assertTrue(os.path.isfile(adapter.product_xyz_path)) + with open(adapter.reactant_xyz_path) as f: + r_n = int(f.readline()) + with open(adapter.product_xyz_path) as f: + p_n = int(f.readline()) + self.assertEqual(r_n, p_n) + # Diels-Alder C=CC(=C)C + C=CC=O → CC1=CCC(C=O)CC1 has 21 atoms + self.assertEqual(r_n, 21) + + +# --------------------------------------------------------------------------- +# End-to-end runs against the real rits_env (skipped without it) +# --------------------------------------------------------------------------- + +@unittest.skipUnless(HAS_RITS, 'rits_env / checkpoint not installed; run `make install-rits` to enable.') +class TestRitSEndToEnd(unittest.TestCase): + """End-to-end runs through subprocess into the real rits_env. + + These tests are gated on `_rits_environment_ready()` so a CI runner that + skipped install_rits.sh still gets a green run. The matching CI lane + `rits-install` in .github/workflows/ci.yml installs the env and exercises + them on every PR. + + Each test asks for a small number of samples (n_samples=2) so the runtime + stays reasonable: even on CPU, two samples per reaction completes in well + under a minute on the model RitS ships. + """ + + @classmethod + def setUpClass(cls): + cls.output_dir = os.path.join(ARC_TESTING_PATH, 'RitS', 'e2e') + os.makedirs(cls.output_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(os.path.join(ARC_TESTING_PATH, 'RitS'), ignore_errors=True) + + def _run_e2e(self, rxn, label: str, expected_n_atoms: int, n_samples: int = 2, + expect_success: bool = True): + """Helper: build adapter, execute, return the (rxn, adapter) pair after assertions. + + Args: + rxn: The ARCReaction to feed to RitS. + label: Subdirectory name under the test output dir. + expected_n_atoms: Atom count both reactant and product XYZs should match. + n_samples: Number of TS samples to ask RitS for. + expect_success: When True, assert at least one usable TSGuess was produced. + When False, assert only that the adapter handled RitS's failure + gracefully (output.yml exists, failed-sentinel entry inside, no + crash). Used for reactions RitS cannot handle by design — e.g. + charged/zwitterionic species, where its OpenBabel bond inference + trips RDKit sanitization. + """ + proj = os.path.join(self.output_dir, label) + adapter = RitSAdapter( + job_type='tsg', + reactions=[rxn], + testing=True, + project='test_rits', + project_directory=proj, + args={'keyword': {'n_samples': n_samples}}, + ) + adapter.execute_incore() + + # Dump R/P/TS geometries to ~/Desktop/xyz/rits/ for visual debugging. + # This must run BEFORE any assertions so the files are still produced + # if a test later fails. The helper swallows all exceptions internally. + ts_xyzs_for_debug = list() + if rxn.ts_species is not None: + ts_xyzs_for_debug = [tsg.initial_xyz for tsg in rxn.ts_species.ts_guesses + if tsg.success and tsg.initial_xyz is not None] + _save_debug_geometries(ts_xyzs_for_debug, rxn) + + # The reactant + product XYZ that ARC fed to RitS must have matching atom counts + with open(adapter.reactant_xyz_path) as f: + r_n = int(f.readline()) + with open(adapter.product_xyz_path) as f: + p_n = int(f.readline()) + self.assertEqual(r_n, expected_n_atoms) + self.assertEqual(p_n, expected_n_atoms) + # The reactant and product elements must match as multisets — atoms are + # neither created nor destroyed across an elementary reaction. + r_xyz_dict = str_to_xyz(adapter.reactant_xyz_path) + p_xyz_dict = str_to_xyz(adapter.product_xyz_path) + expected_formula = Counter(r_xyz_dict['symbols']) + self.assertEqual(expected_formula, Counter(p_xyz_dict['symbols']), + f'reactant and product element multisets disagree for {label}') + + # The output YAML should exist and be readable in either case + self.assertTrue(os.path.isfile(adapter.yml_out_path), + f'rits_script.py did not write {adapter.yml_out_path}') + out = read_yaml_file(adapter.yml_out_path) or list() + self.assertGreater(len(out), 0, f'rits_script.py produced 0 entries for {label}') + successes = [tsg for tsg in out if tsg.get('success') and tsg.get('initial_xyz')] + + if expect_success: + self.assertGreater(len(successes), 0, + f'RitS produced 0 successful TSGuess entries for {label}') + # Strict check: EVERY successful TS must have the same atom count + # AND the same element multiset as the reactants. Catches both + # atom-count mismatches and element-shuffling bugs. + for i, tsg_dict in enumerate(successes): + ts_xyz = str_to_xyz(tsg_dict['initial_xyz']) + self.assertEqual( + len(ts_xyz['symbols']), expected_n_atoms, + f'{label} TS sample {i}: atom count {len(ts_xyz["symbols"])} ' + f'!= expected {expected_n_atoms}', + ) + actual_formula = Counter(ts_xyz['symbols']) + self.assertEqual( + actual_formula, expected_formula, + f'{label} TS sample {i}: molecular formula ' + f'{dict(actual_formula)} does not match reactant ' + f'{dict(expected_formula)}', + ) + else: + # Failure path: there should be exactly one failed sentinel entry, + # and the adapter must not have created any successful TS guesses + # on the reaction object. + self.assertEqual(len(successes), 0, + f'Expected RitS to fail on {label}, but got ' + f'{len(successes)} successful guess(es)') + self.assertTrue(all(not tsg.get('success') for tsg in out)) + return adapter + + def test_e2e_isomerization_propyl(self): + """nC3H7 → iC3H7 (10 atoms, isomerization). + + With ``n_samples=2`` RitS produces two TS guesses. We assert that + BOTH survive the distance-matrix dedup — verified empirically: + the two samples differ along the reaction coordinate (C-C of the + donor side: 1.52 Å vs 1.76 Å; migrating-H acceptor distance: + 1.28 Å vs 1.16 Å), with a distance-matrix RMSD of ~0.56 Å — well + above the 0.1 Å dedup threshold. They represent two diverse + starting points for downstream Gaussian/ORCA TS optimization, + which is exactly the value of asking for ``n_samples > 1``. + Rotated/translated *exact* copies would be merged — see + TestRitSHelpers.test_process_rits_tsg_dedup_catches_rigid_rotation. + """ + adapter = self._run_e2e(_build_rxn_isomerization_propyl(), + label='isom_propyl', expected_n_atoms=10) + rxn = adapter.reactions[0] + successful = [tsg for tsg in rxn.ts_species.ts_guesses if tsg.success] + # Both samples should survive — they are structurally distinct. + self.assertEqual( + len(successful), 2, + f'Expected 2 unique TS guesses for nC3H7→iC3H7 (each from a ' + f'separate point on the reaction coordinate), got {len(successful)}.', + ) + # Sanity-check they ARE distinct under compare_confs (else dedup is broken). + self.assertFalse( + compare_confs(successful[0].initial_xyz, successful[1].initial_xyz), + 'The two propyl TS guesses unexpectedly compare equal — RitS may have ' + 'collapsed onto a single saddle, or the dedup is mis-tuned.', + ) + + def test_e2e_diels_alder(self): # fails!! + """Diels-Alder bimolecular addition (21 atoms).""" + self._run_e2e(_build_rxn_diels_alder(), + label='diels_alder', expected_n_atoms=21) + + def test_e2e_one_plus_two_cycloaddition(self): # fails + """1+2 cycloaddition with singlet carbene (10 atoms, bimolecular).""" + self._run_e2e(_build_rxn_one_plus_two_cycloaddition(), + label='one_plus_two', expected_n_atoms=10) + + def test_e2e_nh3_elimination_graceful_failure(self): # fails (as planned) + """1,2-NH3 elimination NNN → H2NN(s) + NH3 — RitS cannot handle this + because its OpenBabel bond inference rejects the zwitterionic + aminonitrene product (4-valent N+). The adapter must: + + * still write input.yml + reactant.xyz + product.xyz + * still get a non-empty output.yml back + * write a failed-sentinel TSGuess entry + * NOT raise + + This test pins the graceful-failure code path so it doesn't regress. + """ + adapter = self._run_e2e(_build_rxn_nh3_elimination(), + label='nh3_elim_graceful', expected_n_atoms=8, + expect_success=False) + # The reaction's ts_species should still exist but have no successful TSGuesses. + rxn = adapter.reactions[0] + self.assertIsNotNone(rxn.ts_species) + successful = [tsg for tsg in rxn.ts_species.ts_guesses if tsg.success] + self.assertEqual(len(successful), 0) + + # ----- Group A: 1<->1 isomerizations ------------------------------------- + + def test_e2e_vinyl_alcohol_to_acetaldehyde(self): # not amazing, can contrast in the paper + """Keto-enol tautomerization C2H4O (7 atoms: 2C + 4H + 1O).""" + self._run_e2e(_build_rxn_vinyl_alcohol_to_acetaldehyde(), + label='vinyl_alcohol_to_acetaldehyde', expected_n_atoms=7) + + def test_e2e_propenol_to_acetone(self): + """Keto-enol tautomerization C3H6O (10 atoms).""" + self._run_e2e(_build_rxn_propenol_to_acetone(), + label='propenol_to_acetone', expected_n_atoms=10) + + def test_e2e_cyclobutene_to_butadiene(self): # not amazing, but will probably converge + """Electrocyclic ring opening C4H6 (10 atoms).""" + self._run_e2e(_build_rxn_cyclobutene_to_butadiene(), + label='cyclobutene_to_butadiene', expected_n_atoms=10) + + def test_e2e_methoxy_to_hydroxymethyl(self): # not good + """1,2-H migration in CH3O radical (5 atoms).""" + self._run_e2e(_build_rxn_methoxy_to_hydroxymethyl(), + label='methoxy_to_hydroxymethyl', expected_n_atoms=5) + + def test_e2e_ethoxy_to_alpha_hydroxyethyl(self): # good + """1,2-H migration in CH3CH2O radical (8 atoms).""" + self._run_e2e(_build_rxn_ethoxy_to_alpha_hydroxyethyl(), + label='ethoxy_to_alpha_hydroxyethyl', expected_n_atoms=8) + + def test_e2e_cyclopropane_to_propene(self): # not good + """Cyclopropane ring opening C3H6 (9 atoms).""" + self._run_e2e(_build_rxn_cyclopropane_to_propene(), + label='cyclopropane_to_propene', expected_n_atoms=9) + + # ----- Group B: 1<->2 / 2<->1 (eliminations / cycloadditions) ----------- + + def test_e2e_cyclobutane_retro_22(self): + """Retro [2+2] cyclobutane -> 2 ethene (12 atoms).""" + self._run_e2e(_build_rxn_cyclobutane_retro_22(), + label='cyclobutane_retro_22', expected_n_atoms=12) + + def test_e2e_da_butadiene_ethene(self): + """Small Diels-Alder butadiene + ethene -> cyclohexene (16 atoms).""" + self._run_e2e(_build_rxn_da_butadiene_ethene(), + label='da_butadiene_ethene', expected_n_atoms=16) + + def test_e2e_ethanol_dehydration(self): + """β-elimination ethanol -> ethene + water (9 atoms).""" + self._run_e2e(_build_rxn_ethanol_dehydration(), + label='ethanol_dehydration', expected_n_atoms=9) + + def test_e2e_methylamine_dehydrogenation(self): + """1,2-dehydrogenation methylamine -> methyleneamine + H2 (7 atoms).""" + self._run_e2e(_build_rxn_methylamine_dehydrogenation(), + label='methylamine_dehydrogenation', expected_n_atoms=7) + + def test_e2e_ethyl_peroxy_ho2_elimination(self): + """β-scission ethyl peroxy -> ethene + HO2 (9 atoms).""" + self._run_e2e(_build_rxn_ethyl_peroxy_ho2_elimination(), + label='ethyl_peroxy_ho2_elimination', expected_n_atoms=9) + + # ----- Group C: 2<->2 H-abstractions ------- ----------------------------- + + def test_e2e_hab_ch4_oh(self): + """H-abstraction CH4 + OH -> CH3 + H2O (7 atoms total: 1C + 5H + 1O).""" + self._run_e2e(_build_rxn_hab_ch4_oh(), + label='hab_ch4_oh', expected_n_atoms=7) + + def test_e2e_hab_c2h6_h(self): + """H-abstraction C2H6 + H -> C2H5 + H2 (9 atoms).""" + self._run_e2e(_build_rxn_hab_c2h6_h(), + label='hab_c2h6_h', expected_n_atoms=9) + + def test_e2e_hab_nh3_oh(self): + """H-abstraction NH3 + OH -> NH2 + H2O (6 atoms).""" + self._run_e2e(_build_rxn_hab_nh3_oh(), + label='hab_nh3_oh', expected_n_atoms=6) + + def test_e2e_hab_ch3oh_h(self): + """H-abstraction CH3OH + H -> CH2OH + H2 (7 atoms; abstracts α-CH).""" + self._run_e2e(_build_rxn_hab_ch3oh_h(), + label='hab_ch3oh_h', expected_n_atoms=7) + + +if __name__ == '__main__': + unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/arc/job/adapters/ts/rits_ts.py b/arc/job/adapters/ts/rits_ts.py new file mode 100644 index 0000000000..82332c1436 --- /dev/null +++ b/arc/job/adapters/ts/rits_ts.py @@ -0,0 +1,475 @@ +""" +An adapter for executing RitS (Right into the Saddle) TS-guess jobs. + +RitS is a flow-matching ML model that generates 3D transition-state geometries +directly from atom-mapped reactant + product structures, without requiring an +initial guess. Unlike GCN (which is restricted to isomerizations), RitS can +handle bimolecular reactions and supports charged species, so it covers a +strictly larger reaction space. + +Code source : https://github.com/isayevlab/RitS +Paper : 10.26434/chemrxiv.15001681/v1 +Pretrained ckpt : https://doi.org/10.5281/zenodo.19474153 + +Implementation notes +-------------------- +* The heavy ML stack (torch + torch-geometric + megalodon) lives in its own + conda env (``rits_env``), so this adapter never imports it directly. It + shells out to ``arc/job/adapters/scripts/rits_script.py`` via subprocess, + which in turn invokes RitS's own ``scripts/sample_transition_state.py``. +* RitS requires the reactant and product XYZ files to have the *same atom + count and the same atom ordering* (it aligns them by index). ARC's + ``rxn.get_reactants_xyz`` / ``get_products_xyz`` already produce mapped + outputs via ``rxn.atom_map``, so we can use them as-is. +* Multiple samples per reaction are produced in a single subprocess call + (RitS's ``--n_samples`` flag), avoiding the per-sample model-load overhead + that GCN incurs. +* If ``rits_env`` or the pretrained checkpoint is missing on the host, the + adapter logs a warning and exits cleanly without raising — the rest of + ARC's TS-search pipeline (heuristics, GCN, AutoTST, …) keeps running. +* ``incore_capacity = 1`` so the scheduler serializes RitS jobs and a single + GPU is not asked to load multiple checkpoints in parallel. +""" + +import datetime +import os +import subprocess +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +from arc.common import ARC_PATH, get_logger, save_yaml_file, read_yaml_file +from arc.imports import settings +from arc.job.adapter import JobAdapter +from arc.job.adapters.common import _initialize_adapter +from arc.job.factory import register_job_adapter +from arc.plotter import save_geo +from arc.species.converter import compare_confs, str_to_xyz, xyz_to_str +from arc.species.species import ARCSpecies, TSGuess, colliding_atoms + +if TYPE_CHECKING: + from arc.level import Level + from arc.reaction import ARCReaction + + +RITS_PYTHON = settings.get('RITS_PYTHON') +RITS_REPO_PATH = settings.get('RITS_REPO_PATH') +RITS_CKPT_PATH = settings.get('RITS_CKPT_PATH') + +RITS_SCRIPT_PATH = os.path.join(ARC_PATH, 'arc', 'job', 'adapters', 'scripts', 'rits_script.py') +DEFAULT_N_SAMPLES = 10 +DEFAULT_BATCH_SIZE = 32 + +logger = get_logger() + + +class RitSAdapter(JobAdapter): + """ + A class for executing RitS (Right into the Saddle) TS-guess jobs. + + Args: + project (str): The project's name. Used for setting the remote path. + project_directory (str): The path to the local project directory. + job_type (list, str): The job's type, validated against ``JobTypeEnum``. + args (dict, optional): Methods (including troubleshooting) to be used in + input files. For RitS the only currently-honored entry is + ``args['keyword']['n_samples']`` (int, default 10). + bath_gas (str, optional): A bath gas. Currently only used in OneDMin. + checkfile (str, optional): The path to a previous Gaussian checkfile. + conformer (int, optional): Conformer number if optimizing conformers. + constraints (list, optional): A list of constraints. + cpu_cores (int, optional): The total number of cpu cores requested for a job. + dihedral_increment (float, optional): Unused for RitS. + dihedrals (List[float], optional): The dihedral angles corresponding to + self.torsions. + directed_scan_type (str, optional): The type of the directed scan. + ess_settings (dict, optional): A dictionary of available ESS. + ess_trsh_methods (List[str], optional): A list of troubleshooting methods. + execution_type (str, optional): The execution type, 'incore', 'queue', or 'pipe'. + fine (bool, optional): Whether to use fine geometry optimization parameters. + initial_time (datetime.datetime or str, optional): The time at which this job was initiated. + irc_direction (str, optional): The direction of the IRC job. + job_id (int, optional): The job's ID determined by the server. + job_memory_gb (int, optional): The total job allocated memory in GB. + job_name (str, optional): The job's name. + job_num (int, optional): Used as the entry number in the database. + job_server_name (str, optional): Job's name on the server. + job_status (list, optional): The job's server and ESS statuses. + level (Level, optional): The level of theory to use. + max_job_time (float, optional): The maximal allowed job time on the server in hours. + run_multi_species (bool, optional): Whether to run a job for multiple species in the same input file. + reactions (List[ARCReaction], optional): Entries are ARCReaction instances. + rotor_index (int, optional): The 0-indexed rotor number. + server (str): The server to run on. + server_nodes (list, optional): The nodes this job was previously submitted to. + species (List[ARCSpecies], optional): Entries are ARCSpecies instances. + testing (bool, optional): Whether the object is generated for testing purposes. + times_rerun (int, optional): Number of times this job was re-run. + torsions (List[List[int]], optional): The 0-indexed atom indices of the torsion(s). + tsg (int, optional): TSGuess number if optimizing TS guesses. + xyz (dict, optional): The 3D coordinates to use. + """ + + def __init__(self, + project: str, + project_directory: str, + job_type: Union[List[str], str], + args: Optional[dict] = None, + bath_gas: Optional[str] = None, + checkfile: Optional[str] = None, + conformer: Optional[int] = None, + constraints: Optional[List[Tuple[List[int], float]]] = None, + cpu_cores: Optional[str] = None, + dihedral_increment: Optional[float] = None, + dihedrals: Optional[List[float]] = None, + directed_scan_type: Optional[str] = None, + ess_settings: Optional[dict] = None, + ess_trsh_methods: Optional[List[str]] = None, + execution_type: Optional[str] = None, + fine: bool = False, + initial_time: Optional[Union['datetime.datetime', str]] = None, + irc_direction: Optional[str] = None, + job_id: Optional[int] = None, + job_memory_gb: float = 14.0, + job_name: Optional[str] = None, + job_num: Optional[int] = None, + job_server_name: Optional[str] = None, + job_status: Optional[List[Union[dict, str]]] = None, + level: Optional['Level'] = None, + max_job_time: Optional[float] = None, + run_multi_species: bool = False, + reactions: Optional[List['ARCReaction']] = None, + rotor_index: Optional[int] = None, + server: Optional[str] = None, + server_nodes: Optional[list] = None, + queue: Optional[str] = None, + attempted_queues: Optional[List[str]] = None, + species: Optional[List['ARCSpecies']] = None, + testing: bool = False, + times_rerun: int = 0, + torsions: Optional[List[List[int]]] = None, + tsg: Optional[int] = None, + xyz: Optional[dict] = None, + ): + + # Single in-flight job per scheduler tick — RitS holds an ML model in + # GPU memory, parallelizing it across reactions would risk OOM. + self.incore_capacity = 1 + self.job_adapter = 'rits' + self.execution_type = execution_type or 'incore' + self.command = 'sample_transition_state.py' + self.url = 'https://github.com/isayevlab/RitS' + + if reactions is None: + raise ValueError('Cannot execute RitS without ARCReaction object(s).') + + # Number of TS samples to draw per reaction. Honored from args['keyword']['n_samples'] + # so users can bump it via the standard ARC adapter-args path. + self.n_samples = DEFAULT_N_SAMPLES + if args and isinstance(args, dict): + kw = args.get('keyword') or dict() + if 'n_samples' in kw: + try: + self.n_samples = int(kw['n_samples']) + except (TypeError, ValueError): + logger.warning( + f"RitS adapter: could not parse args['keyword']['n_samples']=" + f"{kw['n_samples']!r} as an int; falling back to " + f"DEFAULT_N_SAMPLES={DEFAULT_N_SAMPLES}." + ) + + _initialize_adapter(obj=self, + is_ts=True, + project=project, + project_directory=project_directory, + job_type=job_type, + args=args, + bath_gas=bath_gas, + checkfile=checkfile, + conformer=conformer, + constraints=constraints, + cpu_cores=cpu_cores, + dihedral_increment=dihedral_increment, + dihedrals=dihedrals, + directed_scan_type=directed_scan_type, + ess_settings=ess_settings, + ess_trsh_methods=ess_trsh_methods, + fine=fine, + initial_time=initial_time, + irc_direction=irc_direction, + job_id=job_id, + job_memory_gb=job_memory_gb, + job_name=job_name, + job_num=job_num, + job_server_name=job_server_name, + job_status=job_status, + level=level, + max_job_time=max_job_time, + run_multi_species=run_multi_species, + reactions=reactions, + rotor_index=rotor_index, + server=server, + server_nodes=server_nodes, + queue=queue, + attempted_queues=attempted_queues, + species=species, + testing=testing, + times_rerun=times_rerun, + torsions=torsions, + tsg=tsg, + xyz=xyz, + ) + + def write_input_file(self) -> None: + """No standalone input file — see set_files() (writes input.yml).""" + pass + + def set_files(self) -> None: + """ + Set files to be uploaded and downloaded for queue execution. + + ``self.files_to_upload`` is a list of dictionaries, each with the keys + ``'name'``, ``'source'``, ``'make_x'``, ``'local'``, and ``'remote'``. + """ + # 1. Upload + if self.execution_type != 'incore': + self.write_submit_script() + from arc.imports import settings as _s + self.files_to_upload.append(self.get_file_property_dictionary( + file_name=_s['submit_filenames'][_s['servers'][self.server]['cluster_soft']])) + if os.path.isfile(self.yml_in_path): + self.files_to_upload.append(self.get_file_property_dictionary(file_name='input.yml')) + if os.path.isfile(self.reactant_xyz_path): + self.files_to_upload.append(self.get_file_property_dictionary(file_name='reactant.xyz')) + if os.path.isfile(self.product_xyz_path): + self.files_to_upload.append(self.get_file_property_dictionary(file_name='product.xyz')) + # 2. Download + self.files_to_download.append(self.get_file_property_dictionary(file_name='output.yml')) + self.files_to_download.append(self.get_file_property_dictionary(file_name='rits_ts.xyz')) + + def set_additional_file_paths(self) -> None: + """Set the local file paths used by RitS at job time.""" + self.reactant_xyz_path = os.path.join(self.local_path, 'reactant.xyz') + self.product_xyz_path = os.path.join(self.local_path, 'product.xyz') + self.ts_out_xyz_path = os.path.join(self.local_path, 'rits_ts.xyz') + self.yml_in_path = os.path.join(self.local_path, 'input.yml') + self.yml_out_path = os.path.join(self.local_path, 'output.yml') + + def set_input_file_memory(self) -> None: + """Set the input file memory attribute.""" + self.cpu_cores, self.job_memory_gb = 1, 1 + + def execute_incore(self): + """Execute the RitS job locally (in-process subprocess).""" + self._log_job_execution() + self.initial_time = self.initial_time if self.initial_time else datetime.datetime.now() + self.execute_rits() + self.final_time = datetime.datetime.now() + + def execute_queue(self): + """Execute the RitS job to the server's queue.""" + self.execute_rits(exe_type='queue') + + def execute_rits(self, exe_type: str = 'incore'): + """ + Drive the RitS subprocess and stitch its output back into ARC. + + Args: + exe_type (str, optional): Either ``'incore'`` (run locally now) or + ``'queue'`` (just stage the input.yml + submit script). + """ + if not _rits_environment_ready(): + return + rxn = self.reactions[0] + if rxn.ts_species is None: + rxn.ts_species = ARCSpecies(label=self.species_label, + is_ts=True, + charge=rxn.charge, + multiplicity=rxn.multiplicity, + ) + + # Build atom-aligned reactant + product XYZ files. ARC's get_reactants_xyz / + # get_products_xyz already use rxn.atom_map to align orderings. + try: + r_xyz_dict = rxn.get_reactants_xyz(return_format='dict') + p_xyz_dict = rxn.get_products_xyz(return_format='dict') + except Exception as e: + logger.warning(f'RitS: could not build mapped XYZs for {rxn.label}: {e}') + return + if r_xyz_dict is None or p_xyz_dict is None: + logger.warning(f'RitS: empty mapped XYZs for {rxn.label}') + return + if len(r_xyz_dict['symbols']) != len(p_xyz_dict['symbols']): + logger.warning(f'RitS: atom count mismatch for {rxn.label} ' + f'(R has {len(r_xyz_dict["symbols"])}, P has {len(p_xyz_dict["symbols"])}). ' + f'Skipping.') + return + + write_xyz_file(r_xyz_dict, self.reactant_xyz_path, comment=f'{rxn.label} reactant') + write_xyz_file(p_xyz_dict, self.product_xyz_path, comment=f'{rxn.label} product') + + input_dict = { + 'reactant_xyz_path': self.reactant_xyz_path, + 'product_xyz_path': self.product_xyz_path, + 'rits_repo_path': RITS_REPO_PATH, + 'ckpt_path': RITS_CKPT_PATH, + 'output_xyz_path': self.ts_out_xyz_path, + 'yml_out_path': self.yml_out_path, + 'config_path': os.path.join(RITS_REPO_PATH, 'scripts', 'conf', 'rits.yaml'), + 'n_samples': self.n_samples, + 'batch_size': DEFAULT_BATCH_SIZE, + 'charge': int(rxn.charge or 0), + 'device': 'auto', + } + save_yaml_file(path=self.yml_in_path, content=input_dict) + + if exe_type == 'queue': + self.legacy_queue_execution() + return + + # Incore: subprocess into rits_script.py inside rits_env. + # Pass argv as a list (not shell=True) so paths containing spaces or + # shell-special characters are handled safely without quoting. + cmd = [RITS_PYTHON, RITS_SCRIPT_PATH, '--yml_in_path', self.yml_in_path] + timeout_s = getattr(self, 'rits_subprocess_timeout', 600) + try: + result = subprocess.run(cmd, check=False, timeout=timeout_s) + except subprocess.TimeoutExpired: + logger.warning(f'RitS subprocess timed out after {timeout_s}s for {rxn.label}; ' + f'skipping. Increase adapter.rits_subprocess_timeout to extend.') + return + if result.returncode != 0: + logger.warning(f'RitS subprocess returned non-zero exit code {result.returncode} for {rxn.label}.') + return + + if not os.path.isfile(self.yml_out_path): + logger.warning(f'RitS produced no output YAML at {self.yml_out_path} for {rxn.label}.') + return + + tsg_dicts = read_yaml_file(self.yml_out_path) or list() + n_added = 0 + for tsg_dict in tsg_dicts: + if process_rits_tsg(tsg_dict=tsg_dict, + local_path=self.local_path, + ts_species=rxn.ts_species): + n_added += 1 + + if len(self.reactions) < 5: + if n_added: + logger.info(f'RitS successfully found {n_added} TS guesses for {rxn.label}.') + else: + logger.info(f'RitS did not find any successful TS guesses for {rxn.label}.') + + +def write_xyz_file(xyz_dict: dict, path: str, comment: str = '') -> None: + """ + Write an ARC xyz dict to a plain XYZ file with the standard + ``\\n\\n...`` header. + + Args: + xyz_dict (dict): An ARC xyz dictionary. + path (str): Output file path. + comment (str): Optional comment line (kept on a single line). + """ + body = xyz_to_str(xyz_dict) + n_atoms = len(xyz_dict['symbols']) + safe_comment = comment.replace('\n', ' ').strip() + with open(path, 'w') as f: + f.write(f'{n_atoms}\n{safe_comment}\n{body}\n') + + +def process_rits_tsg(tsg_dict: dict, + local_path: str, + ts_species: ARCSpecies) -> bool: + """ + Convert a single TSGuess-shaped dict from ``rits_script.py`` into an ARC + ``TSGuess`` object, dedup against existing guesses, and append it. + + Dedup uses :func:`arc.species.converter.compare_confs`, which compares + *internal distance matrices* — so it correctly merges two RitS samples + that are the same TS structure in different rigid orientations. This is + a stricter test than the byte-level ``almost_equal_coords`` ARC's older + adapters use; RitS specifically benefits from it because every flow- + matching sample lands the molecule in its own random orientation, so + rotated duplicates are the common case. + + Args: + tsg_dict (dict): One entry from the YAML written by rits_script.py. + local_path (str): The job's local working directory (used by save_geo). + ts_species (ARCSpecies): The reaction's TS species accumulator. + + Returns: + bool: ``True`` if a new (unique, non-colliding) TS guess was appended, + ``False`` otherwise. + """ + if not tsg_dict.get('success') or not tsg_dict.get('initial_xyz'): + return False + try: + ts_xyz = str_to_xyz(tsg_dict['initial_xyz']) + except Exception as e: + logger.warning(f'RitS: could not parse TS xyz: {e}') + return False + if colliding_atoms(ts_xyz): + return False + + # Dedup against every existing TSGuess (regardless of method) using a + # rotation/translation-invariant distance-matrix comparator. If a match + # is found, augment the existing guess's method label instead of + # appending a duplicate. + for other_tsg in ts_species.ts_guesses: + if other_tsg.success and other_tsg.initial_xyz is not None \ + and other_tsg.initial_xyz.get('symbols') == ts_xyz['symbols'] \ + and compare_confs(ts_xyz, other_tsg.initial_xyz): + if 'rits' not in other_tsg.method.lower(): + other_tsg.method += ' and RitS' + return False + + method_index = int(tsg_dict.get('method_index', 0)) + tsg = TSGuess(method='RitS', + method_direction=tsg_dict.get('method_direction', 'F'), + method_index=method_index, + index=len(ts_species.ts_guesses), + success=True, + ) + tsg.process_xyz(ts_xyz) + ts_species.ts_guesses.append(tsg) + save_geo(xyz=ts_xyz, + path=local_path, + filename=f'RitS {method_index}', + format_='xyz', + comment=f'RitS sample {method_index}', + ) + return True + + +def _rits_environment_ready(log: bool = True) -> bool: + """ + Check that everything RitS needs at runtime is in place. Returns ``False`` + if anything is missing so the adapter (or a test guard) can skip cleanly + without raising. + + Args: + log (bool): When ``True`` (default), emit one warning per missing + piece. Pass ``False`` to silence — useful at pytest collection + time, where the same check is invoked just to set a skip flag + and warnings would otherwise spam every run on hosts without + the RitS stack installed. + """ + ok = True + if not RITS_PYTHON or not os.path.isfile(RITS_PYTHON): + if log: + logger.warning('RitS adapter: rits_env python not found ' + '(set RITS_PYTHON or run `make install-rits`). Skipping RitS TS guesses.') + ok = False + if not RITS_REPO_PATH or not os.path.isdir(RITS_REPO_PATH): + if log: + logger.warning('RitS adapter: RitS source checkout not found ' + '(set ARC_RITS_REPO or run `make install-rits`). Skipping RitS TS guesses.') + ok = False + if not RITS_CKPT_PATH or not os.path.isfile(RITS_CKPT_PATH): + if log: + logger.warning('RitS adapter: pretrained checkpoint not found ' + '(set ARC_RITS_CKPT or run `make install-rits`). Skipping RitS TS guesses.') + ok = False + return ok + + +register_job_adapter('rits', RitSAdapter) diff --git a/arc/main_test.py b/arc/main_test.py index a5a5b81bd1..f88df55e9a 100644 --- a/arc/main_test.py +++ b/arc/main_test.py @@ -59,7 +59,7 @@ def test_as_dict(self): job_types=self.job_types1, species=[spc1], level_of_theory='ccsd(t)-f12/cc-pvdz-f12//b3lyp/6-311+g(3df,2p)', - ts_adapters=['heuristics', 'AutoTST', 'GCN', 'xtb_gsm'], + ts_adapters=['heuristics', 'AutoTST', 'GCN', 'xtb_gsm', 'goflow'], ) arc0.freq_level.args['keyword']['general'] = 'scf=(NDamp=30)' restart_dict = arc0.as_dict() @@ -87,6 +87,7 @@ def test_as_dict(self): 'orca': ['local'], 'orca_neb': ['local'], 'qchem': ['server1'], + 'rits': ['local'], 'terachem': ['server1'], 'torchani': ['local'], 'xtb': ['local'], @@ -133,7 +134,7 @@ def test_as_dict(self): 'props': {}}, 'multiplicity': 1, 'number_of_rotors': 0}], - 'ts_adapters': ['heuristics', 'AutoTST', 'GCN', 'xtb_gsm']} + 'ts_adapters': ['heuristics', 'AutoTST', 'GCN', 'xtb_gsm', 'goflow']} # import pprint # left intentionally for debugging # print(pprint.pprint(restart_dict)) self.assertEqual(restart_dict, expected_dict) diff --git a/arc/settings/external_paths.py b/arc/settings/external_paths.py new file mode 100644 index 0000000000..cf346bbd22 --- /dev/null +++ b/arc/settings/external_paths.py @@ -0,0 +1,248 @@ +""" +Discovery helpers for ARC's ML-based TS adapters' source checkouts and +pretrained artifacts. + +This module hosts filesystem-discovery logic for sister-repo installations +of TS-search adapters (currently GoFlow and RitS). It is deliberately kept +out of ``arc/settings/settings.py``: ``settings`` holds static +configuration (dicts/constants), this module holds adapter-specific +filesystem-discovery logic. + +GoFlow (Galustian et al., Digital Discovery 2025, 10.1039/D5DD00283D; +preprint doi.org/10.26434/chemrxiv-2025-bk2rh) is a flow-matching, E(3)- +equivariant TS-geometry generator. The "lean" practical fork lives at +https://github.com/heid-lab/goflow_lean . The shipped epoch_337.ckpt is a +45-byte placeholder (not a real Lightning checkpoint), so the ckpt finder +rejects anything below 1 MB. The shipped feat_dict_organic.pkl is a real +(small) pickle of about 387 bytes — feat dicts are inherently tiny — so +the feat-dict finder only guards against trivially-small (<100 B) files +and otherwise accepts whatever is in place. Real artifacts can be supplied +via env-var overrides (``ARC_GOFLOW_REPO``, ``ARC_GOFLOW_CKPT``, +``ARC_GOFLOW_FEAT_DICT``). + +RitS (Right into the Saddle, Isayev lab, 10.26434/chemrxiv.15001681/v1) +is a flow-matching TS generator that handles bimolecular reactions and +charged species. The upstream repository lives at +https://github.com/isayevlab/RitS . The shipped pretrained checkpoint is +~364 MB (downloaded from Zenodo by ``devtools/install_rits.sh``); it is +located by env-var override (``ARC_RITS_REPO``, ``ARC_RITS_CKPT``) or by +filesystem convention. +""" + +import os + + +_GOFLOW_CKPT_MIN_SIZE = 1_000_000 # any real Lightning ckpt is >> 1 MB +_GOFLOW_FEAT_DICT_MIN_SIZE = 100 # rejects only trivially-empty stubs + + +def _arc_root() -> str: + """ + Return the absolute path of the ARC repo root. + + Returns: + str: Absolute path of the ARC repo (this file lives at + ``/arc/settings/external_paths.py``). + """ + # this file: arc/settings/external_paths.py — three dirnames → repo root. + return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def _goflow_sibling_of_arc() -> str: + """ + Return the conventional sibling-of-ARC location ``/goflow_lean``. + + Returns: + str: Absolute path candidate (no existence check). + """ + return os.path.join(os.path.dirname(_arc_root()), 'goflow_lean') + + +def _rits_sibling_of_arc() -> str: + """ + Return the conventional sibling-of-ARC location ``/RitS``. + + Returns: + str: Absolute path candidate (no existence check). + """ + return os.path.join(os.path.dirname(_arc_root()), 'RitS') + + +def find_goflow_repo() -> str | None: + """ + Locate a goflow_lean source checkout. + + Used by the GoFlow TS adapter to find ``src/goflow/configs/`` (Hydra + configs) and the (validated) shipped ``data/RDB7/`` artifacts. + + Search order: + 1. ``ARC_GOFLOW_REPO`` environment variable (explicit override). + 2. ``~/Code/goflow_lean`` (default for ARC dev machines). + 3. Sibling-of-ARC location ``/goflow_lean`` — + matches what ``devtools/install_goflow.sh`` produces. + + A directory is considered "found" only if it contains + ``src/goflow/__init__.py`` (the package entry point). + + Returns: + str | None: Absolute path to the checkout, or ``None`` if no + candidate was located. + """ + home = os.getenv('HOME') or os.path.expanduser('~') + candidates = [] + env_override = os.getenv('ARC_GOFLOW_REPO') + if env_override: + candidates.append(env_override) + candidates.append(os.path.join(home, 'Code', 'goflow_lean')) + candidates.append(_goflow_sibling_of_arc()) + for path in candidates: + if path and os.path.isfile(os.path.join(path, 'src', 'goflow', '__init__.py')): + return os.path.abspath(path) + return None + + +def find_goflow_ckpt(repo_path: str | None) -> str | None: + """ + Locate a pretrained GoFlow Lightning checkpoint. + + Validates by file size to reject the LFS-pointer placeholder (45 bytes) + shipped in goflow_lean@main. + + Search order: + 1. ``ARC_GOFLOW_CKPT`` env-var override. + 2. ``/data/RDB7/epoch_*.ckpt`` — any ``epoch_NNN.ckpt`` + past the size guard. Matches both the Zenodo-installed + ``epoch_316.ckpt`` and the upstream-canonical ``epoch_337.ckpt``; + multiple matches are sorted by epoch number descending so the + newest wins. + + A file is considered valid iff size >= 1 MB. ``torch.load``-level + validation is deferred to the adapter's runtime (this module stays + torch-free). + + Args: + repo_path (str | None): The goflow_lean checkout to search inside + (``find_goflow_repo()`` output is the typical input). + + Returns: + str | None: Absolute path to a real ckpt, or ``None``. + """ + env_override = os.getenv('ARC_GOFLOW_CKPT') + if env_override and os.path.isfile(env_override) \ + and os.path.getsize(env_override) >= _GOFLOW_CKPT_MIN_SIZE: + return os.path.abspath(env_override) + if repo_path: + ckpt_dir = os.path.join(repo_path, 'data', 'RDB7') + if os.path.isdir(ckpt_dir): + candidates = [] + for name in os.listdir(ckpt_dir): + if not (name.startswith('epoch_') and name.endswith('.ckpt')): + continue + full = os.path.join(ckpt_dir, name) + if os.path.isfile(full) and os.path.getsize(full) >= _GOFLOW_CKPT_MIN_SIZE: + try: + epoch_num = int(name[len('epoch_'):-len('.ckpt')]) + except ValueError: + epoch_num = -1 + candidates.append((epoch_num, full)) + if candidates: + candidates.sort(reverse=True) + return os.path.abspath(candidates[0][1]) + return None + + +def find_goflow_feat_dict(repo_path: str | None) -> str | None: + """ + Locate the GoFlow atom-feature codebook pickle. + + Validates by file size to reject only trivially-empty (<100 B) files; + the in-repo 387-byte file is a real (tiny) pickle and is accepted + as-is. + + Search order: + 1. ``ARC_GOFLOW_FEAT_DICT`` env-var override. + 2. ``/data/RDB7/feat_dict_organic.pkl``. + + Pickle-level validation is deferred to the adapter's runtime. + + Args: + repo_path (str | None): The goflow_lean checkout to search inside. + + Returns: + str | None: Absolute path to a feat-dict pickle, or ``None``. + """ + candidates = [] + env_override = os.getenv('ARC_GOFLOW_FEAT_DICT') + if env_override: + candidates.append(env_override) + if repo_path: + candidates.append(os.path.join(repo_path, 'data', 'RDB7', 'feat_dict_organic.pkl')) + for path in candidates: + if path and os.path.isfile(path) and os.path.getsize(path) >= _GOFLOW_FEAT_DICT_MIN_SIZE: + return os.path.abspath(path) + return None + + +def find_rits_repo() -> str | None: + """ + Locate a RitS source checkout. + + Used by the RitS TS adapter to find ``scripts/sample_transition_state.py`` + and ``scripts/conf/rits.yaml``, which are not part of the importable + ``megalodon`` package. + + Search order: + 1. ``ARC_RITS_REPO`` environment variable (explicit override). + 2. ``~/Code/RitS`` (default for ARC dev machines). + 3. Sibling-of-ARC location ``/RitS`` — + matches what ``devtools/install_rits.sh`` produces. + + A directory is considered "found" only if it contains + ``scripts/sample_transition_state.py`` (the inference entry point). + + Returns: + str | None: Absolute path to the checkout, or ``None`` if no + candidate was located. + """ + home = os.getenv('HOME') or os.path.expanduser('~') + candidates = [] + env_override = os.getenv('ARC_RITS_REPO') + if env_override: + candidates.append(env_override) + candidates.append(os.path.join(home, 'Code', 'RitS')) + candidates.append(_rits_sibling_of_arc()) + for path in candidates: + if path and os.path.isfile(os.path.join(path, 'scripts', 'sample_transition_state.py')): + return os.path.abspath(path) + return None + + +def find_rits_ckpt(repo_path: str | None) -> str | None: + """ + Locate the pretrained RitS checkpoint file (``rits.ckpt``). + + Search order: + 1. ``ARC_RITS_CKPT`` environment variable (explicit override). + 2. ``/data/rits.ckpt`` — what ``install_rits.sh`` writes. + + No size guard is applied: the upstream Zenodo-distributed checkpoint is + a single canonical ~364 MB file and the install script verifies it via + SHA-256 at install time. Lightning-level validation is deferred to the + adapter's runtime (this module stays torch-free). + + Args: + repo_path (str | None): The RitS repo path returned by + ``find_rits_repo()``. If ``None``, only the env-var override + is consulted. + + Returns: + str | None: Absolute path to the checkpoint, or ``None``. + """ + env_override = os.getenv('ARC_RITS_CKPT') + if env_override and os.path.isfile(env_override): + return os.path.abspath(env_override) + if repo_path: + candidate = os.path.join(repo_path, 'data', 'rits.ckpt') + if os.path.isfile(candidate): + return os.path.abspath(candidate) + return None diff --git a/arc/settings/external_paths_test.py b/arc/settings/external_paths_test.py new file mode 100644 index 0000000000..c9980a6eb2 --- /dev/null +++ b/arc/settings/external_paths_test.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Unit tests for the filesystem-discovery helpers in +``arc/settings/external_paths.py``. + +Each test fully isolates filesystem + env-var state so it doesn't accidentally +match the developer's real ~/Code/goflow_lean or ~/Code/RitS checkout if one +exists. +""" + +import os +import pickle +import tempfile +import unittest +from unittest import mock + +from arc.settings import external_paths + + +class TestFindGoFlowRepo(unittest.TestCase): + """find_goflow_repo() — locates a goflow_lean source checkout.""" + + def test_returns_none_when_no_candidates_exist(self): + """No env var, no shipped path on disk → None.""" + with tempfile.TemporaryDirectory() as tmp: + with mock.patch.dict(os.environ, {'HOME': tmp}, clear=False): + os.environ.pop('ARC_GOFLOW_REPO', None) + with mock.patch.object(external_paths, '_goflow_sibling_of_arc', + return_value=os.path.join(tmp, 'definitely_no_goflow_here')): + self.assertIsNone(external_paths.find_goflow_repo()) + + def test_uses_env_var_override_when_repo_is_real(self): + """ARC_GOFLOW_REPO points at a dir with src/goflow/__init__.py → returns it.""" + with tempfile.TemporaryDirectory() as tmp: + init_dir = os.path.join(tmp, 'src', 'goflow') + os.makedirs(init_dir) + with open(os.path.join(init_dir, '__init__.py'), 'w') as f: + f.write('') + with mock.patch.dict(os.environ, {'ARC_GOFLOW_REPO': tmp}): + self.assertEqual(os.path.abspath(tmp), external_paths.find_goflow_repo()) + + def test_env_var_pointing_at_dir_without_src_goflow_returns_none(self): + """ARC_GOFLOW_REPO points at the wrong directory → not "found" → None.""" + with tempfile.TemporaryDirectory() as tmp: + with mock.patch.dict(os.environ, {'ARC_GOFLOW_REPO': tmp, 'HOME': tmp}): + with mock.patch.object(external_paths, '_goflow_sibling_of_arc', + return_value=os.path.join(tmp, 'no_goflow')): + self.assertIsNone(external_paths.find_goflow_repo()) + + +class TestFindGoFlowCkpt(unittest.TestCase): + """find_goflow_ckpt() — locates the pretrained checkpoint file.""" + + def test_returns_none_when_no_repo_and_no_env_var(self): + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop('ARC_GOFLOW_CKPT', None) + self.assertIsNone(external_paths.find_goflow_ckpt(repo_path=None)) + + def test_uses_env_var_when_set_and_file_is_large_enough(self): + with tempfile.NamedTemporaryFile(suffix='.ckpt', delete=False) as f: + f.write(b'\0' * (1_000_001)) # >= 1 MB + ckpt_path = f.name + try: + with mock.patch.dict(os.environ, {'ARC_GOFLOW_CKPT': ckpt_path}): + self.assertEqual(os.path.abspath(ckpt_path), + external_paths.find_goflow_ckpt(repo_path=None)) + finally: + os.unlink(ckpt_path) + + def test_rejects_undersized_ckpt_file_45_bytes_placeholder(self): + """The 45-byte LFS-pointer file shipped in goflow_lean must be rejected.""" + with tempfile.TemporaryDirectory() as tmp: + ckpt_path = os.path.join(tmp, 'data', 'RDB7', 'epoch_337.ckpt') + os.makedirs(os.path.dirname(ckpt_path)) + with open(ckpt_path, 'wb') as f: + f.write(b'\0' * 45) + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop('ARC_GOFLOW_CKPT', None) + self.assertIsNone(external_paths.find_goflow_ckpt(repo_path=tmp)) + + def test_accepts_ckpt_in_repo_when_size_is_realistic(self): + with tempfile.TemporaryDirectory() as tmp: + ckpt_path = os.path.join(tmp, 'data', 'RDB7', 'epoch_337.ckpt') + os.makedirs(os.path.dirname(ckpt_path)) + with open(ckpt_path, 'wb') as f: + f.write(b'\0' * (1_000_001)) + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop('ARC_GOFLOW_CKPT', None) + self.assertEqual(os.path.abspath(ckpt_path), + external_paths.find_goflow_ckpt(repo_path=tmp)) + + +class TestFindGoFlowFeatDict(unittest.TestCase): + """find_goflow_feat_dict() — locates the atom-feature codebook pickle.""" + + def test_returns_none_when_no_repo_and_no_env_var(self): + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop('ARC_GOFLOW_FEAT_DICT', None) + self.assertIsNone(external_paths.find_goflow_feat_dict(repo_path=None)) + + def test_rejects_feat_dict_file_below_size_threshold(self): + """Trivially-small feat_dict files (<100 B) must be rejected by the size guard. + + Note: the 387-byte ``feat_dict_organic.pkl`` shipped in goflow_lean@main + is a real (small) pickle and is *accepted*; the size guard only catches + clearly-empty stubs.""" + with tempfile.TemporaryDirectory() as tmp: + fd_path = os.path.join(tmp, 'data', 'RDB7', 'feat_dict_organic.pkl') + os.makedirs(os.path.dirname(fd_path)) + with open(fd_path, 'wb') as f: + f.write(b'\0' * 50) + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop('ARC_GOFLOW_FEAT_DICT', None) + self.assertIsNone(external_paths.find_goflow_feat_dict(repo_path=tmp)) + + def test_accepts_real_pickle_when_above_size_threshold(self): + with tempfile.TemporaryDirectory() as tmp: + fd_path = os.path.join(tmp, 'data', 'RDB7', 'feat_dict_organic.pkl') + os.makedirs(os.path.dirname(fd_path)) + real_dict = {f'feat_{i}': {j: j for j in range(20)} for i in range(20)} + with open(fd_path, 'wb') as f: + pickle.dump(real_dict, f) + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop('ARC_GOFLOW_FEAT_DICT', None) + self.assertEqual(os.path.abspath(fd_path), + external_paths.find_goflow_feat_dict(repo_path=tmp)) + + +class TestFindRitsRepo(unittest.TestCase): + """find_rits_repo() — locates a RitS source checkout.""" + + def test_returns_none_when_no_candidates_exist(self): + """No env var, no ~/Code/RitS, no sibling-of-ARC → None.""" + with tempfile.TemporaryDirectory() as tmp: + with mock.patch.dict(os.environ, {'HOME': tmp}, clear=False): + os.environ.pop('ARC_RITS_REPO', None) + with mock.patch.object(external_paths, '_rits_sibling_of_arc', + return_value=os.path.join(tmp, 'definitely_no_rits_here')): + self.assertIsNone(external_paths.find_rits_repo()) + + def test_uses_env_var_override_when_repo_is_real(self): + """ARC_RITS_REPO points at a dir with scripts/sample_transition_state.py → returns it.""" + with tempfile.TemporaryDirectory() as tmp: + scripts_dir = os.path.join(tmp, 'scripts') + os.makedirs(scripts_dir) + with open(os.path.join(scripts_dir, 'sample_transition_state.py'), 'w') as f: + f.write('') + with mock.patch.dict(os.environ, {'ARC_RITS_REPO': tmp}): + self.assertEqual(os.path.abspath(tmp), external_paths.find_rits_repo()) + + def test_env_var_pointing_at_dir_without_sampler_returns_none(self): + """ARC_RITS_REPO points at the wrong directory → not "found" → None.""" + with tempfile.TemporaryDirectory() as tmp: + with mock.patch.dict(os.environ, {'ARC_RITS_REPO': tmp, 'HOME': tmp}): + with mock.patch.object(external_paths, '_rits_sibling_of_arc', + return_value=os.path.join(tmp, 'no_rits')): + self.assertIsNone(external_paths.find_rits_repo()) + + def test_finds_repo_via_sibling_of_arc_fallback(self): + """No env var; sibling-of-ARC contains a valid checkout → returns it.""" + with tempfile.TemporaryDirectory() as tmp: + home = os.path.join(tmp, 'home') + os.makedirs(home) + sibling = os.path.join(tmp, 'RitS') + scripts_dir = os.path.join(sibling, 'scripts') + os.makedirs(scripts_dir) + with open(os.path.join(scripts_dir, 'sample_transition_state.py'), 'w') as f: + f.write('') + with mock.patch.dict(os.environ, {'HOME': home}, clear=False): + os.environ.pop('ARC_RITS_REPO', None) + with mock.patch.object(external_paths, '_rits_sibling_of_arc', + return_value=sibling): + self.assertEqual(os.path.abspath(sibling), external_paths.find_rits_repo()) + + +class TestFindRitsCkpt(unittest.TestCase): + """find_rits_ckpt() — locates the pretrained RitS checkpoint.""" + + def test_returns_none_when_no_repo_and_no_env_var(self): + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop('ARC_RITS_CKPT', None) + self.assertIsNone(external_paths.find_rits_ckpt(repo_path=None)) + + def test_uses_env_var_override_when_file_exists(self): + with tempfile.NamedTemporaryFile(suffix='.ckpt', delete=False) as f: + f.write(b'\0' * 1024) + ckpt_path = f.name + try: + with mock.patch.dict(os.environ, {'ARC_RITS_CKPT': ckpt_path}): + self.assertEqual(os.path.abspath(ckpt_path), + external_paths.find_rits_ckpt(repo_path=None)) + finally: + os.unlink(ckpt_path) + + def test_env_var_pointing_at_missing_file_returns_none(self): + with mock.patch.dict(os.environ, {'ARC_RITS_CKPT': '/nonexistent/path/to/rits.ckpt'}): + self.assertIsNone(external_paths.find_rits_ckpt(repo_path=None)) + + def test_finds_ckpt_at_repo_data_rits_ckpt(self): + with tempfile.TemporaryDirectory() as tmp: + ckpt_path = os.path.join(tmp, 'data', 'rits.ckpt') + os.makedirs(os.path.dirname(ckpt_path)) + with open(ckpt_path, 'wb') as f: + f.write(b'\0' * 1024) + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop('ARC_RITS_CKPT', None) + self.assertEqual(os.path.abspath(ckpt_path), + external_paths.find_rits_ckpt(repo_path=tmp)) + + def test_returns_none_when_repo_lacks_data_rits_ckpt(self): + with tempfile.TemporaryDirectory() as tmp: + with mock.patch.dict(os.environ, {}, clear=False): + os.environ.pop('ARC_RITS_CKPT', None) + self.assertIsNone(external_paths.find_rits_ckpt(repo_path=tmp)) + + +if __name__ == '__main__': + unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/arc/settings/settings.py b/arc/settings/settings.py index 9e17b62d91..48be911f27 100644 --- a/arc/settings/settings.py +++ b/arc/settings/settings.py @@ -10,6 +10,14 @@ import string import sys +from arc.settings.external_paths import ( + find_goflow_ckpt, + find_goflow_feat_dict, + find_goflow_repo, + find_rits_ckpt, + find_rits_repo, +) + # Users should update the following server dictionary. # Instructions for RSA key generation can be found here: # https://www.digitalocean.com/community/tutorials/how-to-set-up-ssh-keys--2 @@ -77,6 +85,7 @@ 'onedmin': 'server1', 'orca': 'local', 'qchem': 'server1', + 'rits': 'local', 'terachem': 'server1', 'xtb': 'local', 'xtb_gsm': 'local', @@ -89,7 +98,10 @@ supported_ess = ['cfour', 'gaussian', 'mockter', 'molpro', 'orca', 'qchem', 'terachem', 'onedmin', 'xtb', 'torchani', 'openbabel'] # TS methods to try when appropriate for a reaction (other than user guesses which are always allowed): -ts_adapters = ['heuristics', 'AutoTST', 'GCN', 'xtb_gsm', 'orca_neb'] +# Note: 'RitS' is intentionally NOT in the default — its env (rits_env + +# pretrained ckpt) is heavyweight, so users opt in explicitly via +# ``ts_adapters: ['rits', ...]`` in their input.yml. +ts_adapters = ['heuristics', 'AutoTST', 'goflow', 'orca_neb'] # List here job types to execute by default default_job_types = {'conf_opt': True, # defaults to True if not specified @@ -172,12 +184,14 @@ output_filenames = {'cfour': 'output.out', 'gaussian': 'input.log', 'gcn': 'output.yml', + 'goflow': 'output.yml', 'mockter': 'output.yml', 'molpro': 'input.out', 'onedmin': 'output.out', 'orca': 'input.log', 'orca_neb': 'input.log', 'qchem': 'output.out', + 'rits': 'output.yml', 'terachem': 'output.out', 'torchani': 'output.yml', 'xtb': 'output.out', @@ -325,8 +339,10 @@ ARC_FAMILIES_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'data', 'families') # Default environment names for sister repos -TS_GCN_PYTHON, TANI_PYTHON, AUTOTST_PYTHON, ARC_PYTHON, XTB, OB_PYTHON, RMG_PYTHON, RMG_PATH, RMG_DB_PATH = \ - None, None, None, None, None, None, None, None, None +TS_GCN_PYTHON, TANI_PYTHON, AUTOTST_PYTHON, GOFLOW_PYTHON, GOFLOW_REPO_PATH, GOFLOW_CKPT_PATH, \ + GOFLOW_FEAT_DICT_PATH, RITS_PYTHON, RITS_REPO_PATH, RITS_CKPT_PATH, \ + ARC_PYTHON, XTB, OB_PYTHON, RMG_PYTHON, RMG_PATH, RMG_DB_PATH = \ + None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None home = os.getenv("HOME") or os.path.expanduser("~") @@ -366,11 +382,25 @@ def find_executable(env_name, executable_name='python'): OB_PYTHON = find_executable('ob_env') TS_GCN_PYTHON = find_executable('ts_gcn') AUTOTST_PYTHON = find_executable('tst_env') +GOFLOW_PYTHON = find_executable('goflow_env') +RITS_PYTHON = find_executable('rits_env') ARC_PYTHON = find_executable('arc_env') RMG_ENV_NAME = 'rmg_env' RMG_PYTHON = find_executable('rmg_env') XTB = find_executable('xtb_env', 'xtb') + +# Filesystem-discovery helpers for ML-based TS adapters (find_goflow_*, +# find_rits_*) are defined in arc/settings/external_paths.py — kept out of +# this module so settings.py stays a pure data/config layer. We invoke +# them here so the resulting paths sit alongside the other adapter paths +# (RMG_PATH, etc.) for downstream ``settings.get(...)`` consumers. +GOFLOW_REPO_PATH = find_goflow_repo() +GOFLOW_CKPT_PATH = find_goflow_ckpt(GOFLOW_REPO_PATH) +GOFLOW_FEAT_DICT_PATH = find_goflow_feat_dict(GOFLOW_REPO_PATH) +RITS_REPO_PATH = find_rits_repo() +RITS_CKPT_PATH = find_rits_ckpt(RITS_REPO_PATH) + # Ensure BABEL_LIBDIR and BABEL_DATADIR are set before any openbabel import. # The danagroup conda build doesn't ship activate scripts that configure these. # Remove once the danagroup package is fixed upstream. diff --git a/arc/settings/settings_test.py b/arc/settings/settings_test.py new file mode 100644 index 0000000000..067dd19de2 --- /dev/null +++ b/arc/settings/settings_test.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Unit tests for arc/settings/settings.py. + +Path finder-helper tests live in ``arc/settings/external_paths_test.py`` +(GoFlow + RitS sister-repo discovery) — they cover the discovery logic +itself; this module only checks that the resulting paths are exposed as +settings globals and that the default ``ts_adapters`` list is correct. +""" + +import unittest + +from arc.settings import settings as settings_mod + + +class TestExternalAdapterSettingsExposed(unittest.TestCase): + """All sister-repo discovery globals must exist on the settings module after import.""" + + def test_goflow_globals_are_defined(self): + for name in ('GOFLOW_PYTHON', 'GOFLOW_REPO_PATH', 'GOFLOW_CKPT_PATH', 'GOFLOW_FEAT_DICT_PATH'): + self.assertTrue(hasattr(settings_mod, name), f"settings module missing attribute {name}") + + def test_rits_globals_are_defined(self): + for name in ('RITS_PYTHON', 'RITS_REPO_PATH', 'RITS_CKPT_PATH'): + self.assertTrue(hasattr(settings_mod, name), f"settings module missing attribute {name}") + + def test_ts_adapters_includes_goflow_by_default(self): + """GoFlow is enabled by default — case-insensitive match (the default + list uses 'goflow' lowercase).""" + self.assertIn('goflow', [a.lower() for a in settings_mod.ts_adapters]) + + def test_ts_adapters_does_not_include_rits_by_default(self): + """RitS's env (rits_env + pretrained ckpt) is heavyweight, so it must + stay opt-in. Users enable it via ``ts_adapters: ['rits', ...]`` in + their input.yml.""" + self.assertNotIn('RitS', settings_mod.ts_adapters) + self.assertNotIn('rits', settings_mod.ts_adapters) + + +if __name__ == '__main__': + unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/devtools/install_all.sh b/devtools/install_all.sh index 10c7696223..b3797b6dbb 100644 --- a/devtools/install_all.sh +++ b/devtools/install_all.sh @@ -26,6 +26,8 @@ run_devtool () { bash "$DEVTOOLS_DIR/$1" "${@:2}"; } SKIP_CLEAN=false SKIP_EXT=false SKIP_ARC=false +SKIP_GOFLOW=false +SKIP_RITS=false RMG_ARGS=() ARC_ARGS=() EXT_ARGS=() @@ -33,9 +35,11 @@ GENERIC_ARGS=() while [[ $# -gt 0 ]]; do case "$1" in - --no-clean) SKIP_CLEAN=true ;; - --no-ext) SKIP_EXT=true ;; - --no-arc) SKIP_ARC=true ;; + --no-clean) SKIP_CLEAN=true ;; + --no-ext) SKIP_EXT=true ;; + --no-arc) SKIP_ARC=true ;; + --no-goflow) SKIP_GOFLOW=true ;; + --no-rits) SKIP_RITS=true ;; --rmg-*) RMG_ARGS+=("--${1#--rmg-}") ;; --arc-*) ARC_ARGS+=("--${1#--arc-}") ;; --ext-*) EXT_ARGS+=("--${1#--ext-}") ;; @@ -44,6 +48,8 @@ while [[ $# -gt 0 ]]; do Usage: $0 [global-flags] [--rmg-xxx] [--arc-yyy] [--ext-zzz] --no-clean Skip micromamba/conda cache cleanup --no-ext Skip external tools (AutoTST, KinBot, …) + --no-goflow Skip the GoFlow installer (heavy ML stack — usually run in its own CI lane) + --no-rits Skip the RitS installer (heavy ML stack — usually run in its own CI lane) --rmg-path Forward '--path' to RMG installer --rmg-pip Forward '--pip' to RMG installer ... @@ -97,12 +103,28 @@ if [[ $SKIP_EXT == false ]]; then [xtb]=install_xtb.sh [Sella]=install_sella.sh [TorchANI]=install_torchani.sh + [GoFlow]=install_goflow.sh + [RitS]=install_rits.sh ) + # Optionally drop GoFlow — used by `make install-ci` since CI runs GoFlow in its own lane + if [[ $SKIP_GOFLOW == true ]]; then + unset 'EXT_INSTALLERS[GoFlow]' + echo "ℹ️ --no-goflow: skipping GoFlow installer (run \`make install-goflow\` or the goflow CI lane separately)" + fi + + # Optionally drop RitS — used by `make install-ci` since CI runs RitS in its own lane + if [[ $SKIP_RITS == true ]]; then + unset 'EXT_INSTALLERS[RitS]' + echo "ℹ️ --no-rits: skipping RitS installer (run \`make install-rits\` or the rits CI lane separately)" + fi + # installer-specific flag whitelists declare -A EXT_FLAG_WHITELIST=( [install_gcn.sh]="--conda" [install_autotst.sh]="--conda" + [install_goflow.sh]="--cpu --no-ckpt-check" + [install_rits.sh]="--cpu --no-ckpt-check" # add more later, e.g. [install_xtb.sh]="--cuda --prefix" ) diff --git a/devtools/install_goflow.sh b/devtools/install_goflow.sh new file mode 100755 index 0000000000..e5b33d0401 --- /dev/null +++ b/devtools/install_goflow.sh @@ -0,0 +1,478 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ── defaults ─────────────────────────────────────────────────────────────── +GOFLOW_REPO_URL="https://github.com/heid-lab/goflow_lean.git" +GOFLOW_ENV_NAME="goflow_env" +FORCE_CPU=false +GOFLOW_PATH="" +SKIP_CKPT_CHECK=false +USER_CKPT_PATH="" +USER_FEAT_DICT_PATH="" +CUDA_VARIANT="" # one of: cpu, cu118, cu121, cu124, cu126 (empty → autodetect) +TORCH_VERSION="2.6.0" # must match GoFlow's pinned torch version (see goflow_lean README) +TORCHVISION_VERSION="0.21.0" +SKIP_CKPT=false + +# Pretrained checkpoint mirror (Dana Research Group, Zenodo) +# Self-trained reproduction of the paper's epoch_337.ckpt with seed=1, same +# hyperparameters; see goflow/README.md in DanaResearchGroup/training_checkpoints +# for full provenance. +# Source paper : Galustian et al., Digital Discovery 2025, 10.1039/D5DD00283D +# Source repo : https://github.com/heid-lab/goflow_lean +GOFLOW_CKPT_URL="https://zenodo.org/records/20073635/files/epoch_316.ckpt?download=1" +GOFLOW_CKPT_SHA256="f0db9762687e4c9e5ce8af54c62e77b087bbacdb48374fa5fb6c6ecda16f13b8" + +# Pretrained checkpoint policy: +# +# GoFlow's published `goflow_lean` repo stores `data/RDB7/epoch_337.ckpt` and +# `data/RDB7/feat_dict_organic.pkl` as Git LFS pointers / placeholders (45 B +# and 387 B respectively at the time of writing). The 387-byte +# feat_dict_organic.pkl IS a real (small) pickle — feat_dicts are inherently +# tiny — but the 45-byte epoch_337.ckpt is a placeholder, NOT usable. +# +# This installer therefore: +# 1. accepts a user-supplied ckpt via --ckpt or ARC_GOFLOW_CKPT +# 2. otherwise downloads from Zenodo + verifies SHA-256 +# 3. validates the in-repo feat_dict_organic.pkl by size + pickle.load +# 4. accepts --no-ckpt-check to skip download/validation entirely (CI +# smoke installs only — adapter will skip cleanly until real ckpt +# is available) + +# ── parse flags ──────────────────────────────────────────────────────────── +TEMP=$(getopt -o h --long cpu,cuda:,path:,no-ckpt-check,no-ckpt,ckpt:,feat-dict:,help -- "$@") +eval set -- "$TEMP" +while true; do + case "$1" in + --cpu) + FORCE_CPU=true + shift + ;; + --cuda) + CUDA_VARIANT="$2" + shift 2 + ;; + --path) + GOFLOW_PATH="$2" + shift 2 + ;; + --no-ckpt-check) + SKIP_CKPT_CHECK=true + shift + ;; + --no-ckpt) + SKIP_CKPT=true + shift + ;; + --ckpt) + USER_CKPT_PATH="$2" + shift 2 + ;; + --feat-dict) + USER_FEAT_DICT_PATH="$2" + shift 2 + ;; + -h|--help) + cat <] [--path ] + [--ckpt ] [--feat-dict ] + [--no-ckpt] [--no-ckpt-check] [--help] + + --cpu force a CPU-only PyTorch install (shortcut for --cuda cpu) + --cuda pick a specific PyG wheel variant: cpu, cu118, cu121, cu124, cu126 + (default: autodetect via nvcc / nvidia-smi) + --path use an existing goflow_lean checkout instead of cloning + --ckpt copy this file into /data/RDB7/epoch_316.ckpt + (overrides ARC_GOFLOW_CKPT and the Zenodo download) + --feat-dict copy this file into /data/RDB7/feat_dict_organic.pkl + (overrides ARC_GOFLOW_FEAT_DICT) + --no-ckpt skip the Zenodo checkpoint download (offline installs) + --no-ckpt-check build the env without validating ckpt+feat_dict — useful for + CI smoke installs. The ARC adapter will skip GoFlow at runtime + until real artifacts are placed at the expected paths. + -h this help + +By default the script clones (or updates) goflow_lean as a sibling of the ARC +repo, creates the '${GOFLOW_ENV_NAME}' conda env with python=3.11, autodetects +the host CUDA version, installs torch=${TORCH_VERSION} + matching PyTorch +Geometric companion wheels (torch-scatter / torch-sparse / torch-cluster / +torch-spline-conv / pyg-lib / torch-geometric) from PyG's wheel index, runs +'pip install -e .' so that 'import goflow' works inside the env, and validates +the checkpoint + feature-dictionary artifacts in place. + +By default the GoFlow checkpoint is downloaded from Zenodo +(${GOFLOW_CKPT_URL%%\?*}) and SHA-256-verified. Override with --ckpt or +ARC_GOFLOW_CKPT to use a local file, --no-ckpt to skip the download +(offline installs), or --no-ckpt-check to skip download AND validation +entirely (CI smoke). + +Citation: + Galustian, L. et al. GoFlow: efficient transition state geometry prediction + with flow matching and E(3)-equivariant neural networks. + Digital Discovery 2025. https://doi.org/10.1039/D5DD00283D +EOF + exit 0 + ;; + --) shift; break ;; + *) echo "Invalid flag: $1" >&2; exit 1 ;; + esac +done + +# ── pick a CUDA variant for the PyG wheels ─────────────────────────────── +# PyG publishes wheels for torch ${TORCH_VERSION} against these variants only: +SUPPORTED_VARIANTS=(cpu cu118 cu121 cu124 cu126) + +map_cuda_to_variant() { # X.Y → cu118|cu121|cu124|cu126|cpu + local ver="$1" + local major minor + major=${ver%%.*} + minor=${ver#*.} + minor=${minor%%.*} + if [[ -z "$major" || -z "$minor" ]]; then echo cpu; return; fi + if (( major > 12 )) || { (( major == 12 )) && (( minor >= 6 )); }; then echo cu126 + elif (( major == 12 )) && (( minor >= 4 )); then echo cu124 + elif (( major == 12 )) && (( minor >= 1 )); then echo cu121 + elif { (( major == 12 )) && (( minor == 0 )); } || \ + { (( major == 11 )) && (( minor >= 8 )); }; then echo cu118 + else echo cpu + fi +} + +if [[ -n "$CUDA_VARIANT" ]]; then + if $FORCE_CPU && [[ "$CUDA_VARIANT" != cpu ]]; then + echo "❌ --cpu and --cuda $CUDA_VARIANT are contradictory" >&2 + exit 1 + fi + if ! printf '%s\n' "${SUPPORTED_VARIANTS[@]}" | grep -qx "$CUDA_VARIANT"; then + echo "❌ Unsupported --cuda variant: $CUDA_VARIANT" >&2 + echo " Supported: ${SUPPORTED_VARIANTS[*]}" >&2 + exit 1 + fi +elif $FORCE_CPU; then + CUDA_VARIANT="cpu" +elif command -v nvcc &>/dev/null; then + # POSIX-friendly: works with both GNU and BSD grep (grep -oP / \K is GNU-only). + VER=$(nvcc --version | sed -nE 's/.*release ([0-9]+\.[0-9]+).*/\1/p' | head -n1) + CUDA_VARIANT=$(map_cuda_to_variant "$VER") + echo "🔍 nvcc reports CUDA $VER → using PyG variant '$CUDA_VARIANT'" +elif command -v nvidia-smi &>/dev/null; then + VER=$(nvidia-smi 2>/dev/null | sed -nE 's/.*CUDA Version: ([0-9]+\.[0-9]+).*/\1/p' | head -n1 || true) + if [[ -n "$VER" ]]; then + CUDA_VARIANT=$(map_cuda_to_variant "$VER") + echo "🔍 nvidia-smi reports max CUDA $VER → using PyG variant '$CUDA_VARIANT'" + else + CUDA_VARIANT="cpu" + echo "🔍 Could not parse CUDA version from nvidia-smi → falling back to CPU" + fi +else + CUDA_VARIANT="cpu" + echo "🔍 No nvcc / nvidia-smi found → falling back to CPU" +fi +echo "→ PyG wheel variant: $CUDA_VARIANT" + +# ── locate ARC repo and the sibling clone root ──────────────────────────── +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +if ARC_ROOT=$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel 2>/dev/null); then + : +else + ARC_ROOT=$(cd "$SCRIPT_DIR/.." && pwd) +fi +CLONE_ROOT="$(dirname "$ARC_ROOT")" +echo "📂 ARC root : $ARC_ROOT" +echo "📂 Clone root: $CLONE_ROOT" + +# ── pick a conda frontend ───────────────────────────────────────────────── +if command -v micromamba &>/dev/null; then + COMMAND_PKG=micromamba +elif command -v mamba &>/dev/null; then + COMMAND_PKG=mamba +elif command -v conda &>/dev/null; then + COMMAND_PKG=conda +else + echo "❌ No micromamba/mamba/conda found in PATH" >&2 + exit 1 +fi +echo "✔️ Using $COMMAND_PKG" + +if [[ $COMMAND_PKG == micromamba ]]; then + eval "$(micromamba shell hook --shell=bash)" +else + BASE=$(conda info --base) + source "$BASE/etc/profile.d/conda.sh" +fi + +# ── clone or update goflow_lean ─────────────────────────────────────────── +if [[ -n "$GOFLOW_PATH" ]]; then + if [[ ! -d "$GOFLOW_PATH" ]]; then + echo "❌ --path was given but directory does not exist: $GOFLOW_PATH" >&2 + exit 1 + fi + GOFLOW_DIR="$(cd "$GOFLOW_PATH" && pwd)" + echo "📂 Using existing goflow_lean checkout at: $GOFLOW_DIR" +else + GOFLOW_DIR="$CLONE_ROOT/goflow_lean" + if [[ -d "$GOFLOW_DIR/.git" ]]; then + echo "🔄 Updating existing goflow_lean clone at $GOFLOW_DIR" + git -C "$GOFLOW_DIR" fetch origin + git -C "$GOFLOW_DIR" pull --ff-only || echo "⚠️ Could not fast-forward; leaving working tree as-is." + else + echo "⬇️ Cloning goflow_lean into $GOFLOW_DIR" + git clone "$GOFLOW_REPO_URL" "$GOFLOW_DIR" + fi +fi + +GOFLOW_COMMIT="$(git -C "$GOFLOW_DIR" rev-parse --short HEAD 2>/dev/null || echo unknown)" +echo "🔖 goflow_lean commit: $GOFLOW_COMMIT" + +# ── create / update the goflow_env conda environment ───────────────────── +# We deliberately do NOT use 'conda env create -f environment.yml' because the +# upstream env name is 'goflow' and we want our ARC-managed env at 'goflow_env'. +if $COMMAND_PKG env list | awk '{print $1}' | grep -qx "$GOFLOW_ENV_NAME"; then + echo "♻️ '$GOFLOW_ENV_NAME' already exists — updating in place." +else + echo "🆕 Creating '$GOFLOW_ENV_NAME' (python=3.11)" + $COMMAND_PKG create -n "$GOFLOW_ENV_NAME" -c conda-forge python=3.11 -y +fi + +set +u; $COMMAND_PKG activate "$GOFLOW_ENV_NAME"; set -u + +# RDKit / OpenBabel / ASE are far smoother via conda-forge than pip. +echo "📦 Installing rdkit + ase + numpy/pandas/tqdm/ipykernel from conda-forge" +$COMMAND_PKG install -n "$GOFLOW_ENV_NAME" -c conda-forge -y \ + rdkit ase numpy pandas tqdm ipykernel + +python -m pip install --upgrade pip + +if [[ "$CUDA_VARIANT" == "cpu" ]]; then + TORCH_INDEX="https://download.pytorch.org/whl/cpu" +else + TORCH_INDEX="https://download.pytorch.org/whl/${CUDA_VARIANT}" +fi +PYG_WHEELS="https://data.pyg.org/whl/torch-${TORCH_VERSION}+${CUDA_VARIANT}.html" + +echo "🚀 Installing torch==${TORCH_VERSION} (${CUDA_VARIANT}) from $TORCH_INDEX" +python -m pip install \ + "torch==${TORCH_VERSION}" "torchvision==${TORCHVISION_VERSION}" \ + --index-url "$TORCH_INDEX" + +echo "🧮 Installing PyG companion wheels from $PYG_WHEELS" +# --only-binary :all: forces wheels, never source builds (those would need a CUDA toolkit) +python -m pip install \ + pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric \ + --only-binary :all: -f "$PYG_WHEELS" + +echo "📦 Installing pure-Python goflow dependencies from PyPI" +python -m pip install \ + hydra-core lightning torchdiffeq wandb pymatgen einops rich + +# Editable install of the goflow package (this is what puts 'import goflow' on path) +echo "🧷 pip install -e . (goflow, src layout)" +python -m pip install -e "$GOFLOW_DIR" + +# Sanity check — import goflow AND the PyG companions, since a successful pip +# install does not guarantee the .so files actually load against the host's CUDA. +echo "🔍 Verifying inference stack inside $GOFLOW_ENV_NAME" +python - <<'PYEOF' +import importlib, sys +mods = ["torch", "torch_geometric", "torch_scatter", "torch_sparse", + "torch_cluster", "torch_spline_conv", "lightning", "torchdiffeq", + "ase", "goflow"] +for m in mods: + try: + mod = importlib.import_module(m) + ver = getattr(mod, "__version__", "?") + print(f" ✔️ {m:<22} {ver}") + except Exception as e: + print(f" ❌ {m:<22} FAILED: {e}", file=sys.stderr) + sys.exit(1) +import torch +print(f" ℹ️ torch.cuda.is_available() = {torch.cuda.is_available()}") +PYEOF + +# Capture absolute path to the env's python so the artifact-validation +# block (after deactivation) can still reach `import torch`. +ENV_PY="$(which python)" + +set +u; $COMMAND_PKG deactivate; set -u + +# ── checkpoint + feat_dict acquisition + validation ────────────────────── +CKPT_DIR="$GOFLOW_DIR/data/RDB7" +CKPT_PATH="$CKPT_DIR/epoch_316.ckpt" +FEAT_DICT_PATH="$CKPT_DIR/feat_dict_organic.pkl" +mkdir -p "$CKPT_DIR" + +verify_sha256() { # path expected_sha256 + local path="$1" expected="$2" + local actual + if command -v sha256sum &>/dev/null; then + actual=$(sha256sum "$path" | awk '{print $1}') + elif command -v shasum &>/dev/null; then + actual=$(shasum -a 256 "$path" | awk '{print $1}') + else + echo "❌ Neither sha256sum nor shasum found in PATH; cannot verify checkpoint." >&2 + return 2 + fi + if [[ "$actual" != "$expected" ]]; then + echo "❌ Checksum mismatch for $path" >&2 + echo " expected: $expected" >&2 + echo " actual : $actual" >&2 + return 1 + fi + return 0 +} + +# Highest-priority overrides: CLI flag, then env var, then Zenodo download. +# If a path is supplied, copy it into the canonical in-repo location so +# settings discovery picks it up. +if [[ -n "$USER_CKPT_PATH" ]]; then + if [[ ! -f "$USER_CKPT_PATH" ]]; then + echo "❌ --ckpt: file not found: $USER_CKPT_PATH" >&2 + exit 1 + fi + echo "📥 Copying user-supplied checkpoint into $CKPT_PATH" + cp "$USER_CKPT_PATH" "$CKPT_PATH" +elif [[ -n "${ARC_GOFLOW_CKPT:-}" ]]; then + if [[ ! -f "$ARC_GOFLOW_CKPT" ]]; then + echo "❌ ARC_GOFLOW_CKPT: file not found: $ARC_GOFLOW_CKPT" >&2 + exit 1 + fi + echo "📥 Copying \$ARC_GOFLOW_CKPT into $CKPT_PATH" + cp "$ARC_GOFLOW_CKPT" "$CKPT_PATH" +elif $SKIP_CKPT || $SKIP_CKPT_CHECK; then + echo "ℹ️ --no-ckpt / --no-ckpt-check: skipping Zenodo checkpoint download." +elif [[ -f "$CKPT_PATH" ]] && \ + [[ $(stat -c%s "$CKPT_PATH" 2>/dev/null || stat -f%z "$CKPT_PATH" 2>/dev/null) -ge 1000000 ]] && \ + verify_sha256 "$CKPT_PATH" "$GOFLOW_CKPT_SHA256" 2>/dev/null; then + echo "✔️ Existing checkpoint at $CKPT_PATH matches expected SHA-256; skipping download." +else + if ! command -v curl &>/dev/null; then + echo "❌ curl is required to download the GoFlow checkpoint." >&2 + exit 1 + fi + echo "⬇️ Downloading epoch_316.ckpt (~57 MB) from Zenodo:" + echo " $GOFLOW_CKPT_URL" + TMP_CKPT="$(mktemp "${CKPT_DIR}/epoch_316.ckpt.XXXXXX")" + if ! curl -fL --retry 3 --retry-delay 5 -o "$TMP_CKPT" "$GOFLOW_CKPT_URL"; then + rm -f "$TMP_CKPT" + echo "❌ Download failed. Re-run the install, or pass --no-ckpt to skip." >&2 + exit 1 + fi + if verify_sha256 "$TMP_CKPT" "$GOFLOW_CKPT_SHA256"; then + mv "$TMP_CKPT" "$CKPT_PATH" + echo "✔️ Checkpoint verified and saved to $CKPT_PATH" + else + rm -f "$TMP_CKPT" + echo "❌ Downloaded checkpoint failed SHA-256 verification — aborting." >&2 + exit 1 + fi +fi + +# feat_dict_organic.pkl: unlike the ckpt, the in-repo file in goflow_lean +# IS a real (small) pickle — no Zenodo download needed by default. Only +# acquire from --feat-dict / env var if the user wants to override. +if [[ -n "$USER_FEAT_DICT_PATH" ]]; then + if [[ ! -f "$USER_FEAT_DICT_PATH" ]]; then + echo "❌ --feat-dict: file not found: $USER_FEAT_DICT_PATH" >&2 + exit 1 + fi + echo "📥 Copying user-supplied feat_dict into $FEAT_DICT_PATH" + cp "$USER_FEAT_DICT_PATH" "$FEAT_DICT_PATH" +elif [[ -n "${ARC_GOFLOW_FEAT_DICT:-}" ]]; then + if [[ ! -f "$ARC_GOFLOW_FEAT_DICT" ]]; then + echo "❌ ARC_GOFLOW_FEAT_DICT: file not found: $ARC_GOFLOW_FEAT_DICT" >&2 + exit 1 + fi + echo "📥 Copying \$ARC_GOFLOW_FEAT_DICT into $FEAT_DICT_PATH" + cp "$ARC_GOFLOW_FEAT_DICT" "$FEAT_DICT_PATH" +fi + +if $SKIP_CKPT_CHECK; then + echo "ℹ️ --no-ckpt-check set; skipping artifact validation." +else + echo "🔬 Validating $CKPT_PATH and $FEAT_DICT_PATH" + if ! "$ENV_PY" - "$CKPT_PATH" "$FEAT_DICT_PATH" <<'PYEOF' +import os, pickle, sys + +ckpt_path, feat_path = sys.argv[1], sys.argv[2] +errors = [] + +# Checkpoint +if not os.path.isfile(ckpt_path): + errors.append(f"missing: {ckpt_path}") +else: + sz = os.path.getsize(ckpt_path) + if sz < 1_000_000: + errors.append( + f"too small ({sz} B; expected >=1 MB — likely an LFS placeholder): {ckpt_path}" + ) + else: + try: + import torch + # weights_only=False because Lightning ckpts embed an + # omegaconf.DictConfig in 'hyper_parameters' that PyTorch 2.6+'s + # safe-by-default unpickler refuses. + obj = torch.load(ckpt_path, map_location="cpu", weights_only=False) + except Exception as e: + errors.append(f"torch.load failed for {ckpt_path}: {e}") + else: + if not isinstance(obj, dict) or "state_dict" not in obj: + errors.append( + f"not a Lightning-style checkpoint (no 'state_dict' key): {ckpt_path}" + ) + +# Feat dict +if not os.path.isfile(feat_path): + errors.append(f"missing: {feat_path}") +else: + sz = os.path.getsize(feat_path) + if sz < 100: + errors.append( + f"too small ({sz} B; expected >=100 B — likely an LFS placeholder): {feat_path}" + ) + else: + try: + with open(feat_path, "rb") as f: + fd = pickle.load(f) + except Exception as e: + errors.append(f"pickle.load failed for {feat_path}: {e}") + else: + if not isinstance(fd, dict): + errors.append(f"feat_dict is not a dict: {feat_path}") + +if errors: + print("\n".join(" ❌ " + e for e in errors), file=sys.stderr) + sys.exit(1) +print(" ✔️ GoFlow artifacts validated.") +PYEOF + then + echo "" >&2 + echo "❌ GoFlow artifact validation failed." >&2 + echo "" >&2 + echo " Provide real files via one of:" >&2 + echo " --ckpt and --feat-dict " >&2 + echo " ARC_GOFLOW_CKPT=... and ARC_GOFLOW_FEAT_DICT=..." >&2 + echo " OR re-run with --no-ckpt-check to install the env without" >&2 + echo " validation (the adapter will then skip GoFlow at runtime)." >&2 + exit 1 + fi +fi + +# ── final notes ─────────────────────────────────────────────────────────── +echo "" +echo "✅ GoFlow installation complete." +echo " Repo : $GOFLOW_DIR (commit $GOFLOW_COMMIT)" +echo " Env : $GOFLOW_ENV_NAME" +if [[ -f "$CKPT_PATH" && $(stat -c%s "$CKPT_PATH" 2>/dev/null || stat -f%z "$CKPT_PATH" 2>/dev/null) -ge 1000000 ]]; then + echo " Ckpt : $CKPT_PATH" +else + echo " Ckpt : (not yet installed — set ARC_GOFLOW_CKPT or use --ckpt)" +fi +if [[ -f "$FEAT_DICT_PATH" && $(stat -c%s "$FEAT_DICT_PATH" 2>/dev/null || stat -f%z "$FEAT_DICT_PATH" 2>/dev/null) -ge 100 ]]; then + echo " Feat dict : $FEAT_DICT_PATH" +else + echo " Feat dict : (not yet installed — set ARC_GOFLOW_FEAT_DICT or use --feat-dict)" +fi +echo "" +echo " Source : https://github.com/heid-lab/goflow_lean" +echo " Paper DOI : https://doi.org/10.1039/D5DD00283D" +echo " Ckpt mirror: https://zenodo.org/records/20073635" diff --git a/devtools/install_rits.sh b/devtools/install_rits.sh new file mode 100755 index 0000000000..2c488ec367 --- /dev/null +++ b/devtools/install_rits.sh @@ -0,0 +1,328 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ── defaults ─────────────────────────────────────────────────────────────── +RITS_REPO_URL="https://github.com/isayevlab/RitS.git" +RITS_ENV_NAME="rits_env" +FORCE_CPU=false +RITS_PATH="" +SKIP_CKPT=false +SKIP_CKPT_CHECK=false +CUDA_VARIANT="" # one of: cpu, cu118, cu121, cu124, cu126 (empty → autodetect) +TORCH_VERSION="2.7.0" # must match RitS's pinned torch version + +# Pretrained checkpoint mirror (Dana Research Group, Zenodo) +# Google Drive checkpoint file source: https://drive.google.com/drive/folders/1DD2hmWx3E1klM3Ljon5r4gdquGoN_4v6 +# Source paper: https://github.com/isayevlab/RitS, 10.26434/chemrxiv.15001681/v1 +# Mirror DOI : https://doi.org/10.5281/zenodo.19474153 +RITS_CKPT_URL="https://zenodo.org/records/19474153/files/rits.ckpt?download=1" +RITS_CKPT_SHA256="2c6fdacc13a304cb8bb4030881ff4198de647d8f8f7fdaa878414e2514be0668" + +# ── parse flags ──────────────────────────────────────────────────────────── +TEMP=$(getopt -o h --long cpu,cuda:,path:,no-ckpt,no-ckpt-check,help -- "$@") +eval set -- "$TEMP" +while true; do + case "$1" in + --cpu) + FORCE_CPU=true + shift + ;; + --cuda) + CUDA_VARIANT="$2" + shift 2 + ;; + --path) + RITS_PATH="$2" + shift 2 + ;; + --no-ckpt) + SKIP_CKPT=true + shift + ;; + --no-ckpt-check) + SKIP_CKPT_CHECK=true + SKIP_CKPT=true + shift + ;; + -h|--help) + cat <] [--path ] + [--no-ckpt] [--no-ckpt-check] [--help] + + --cpu force a CPU-only PyTorch install (shortcut for --cuda cpu) + --cuda pick a specific PyG wheel variant: cpu, cu118, cu121, cu124, cu126 + (default: autodetect via nvcc / nvidia-smi) + --path use an existing RitS checkout instead of cloning + --no-ckpt skip the pretrained checkpoint download (offline installs) + --no-ckpt-check build the env without validating an existing checkpoint — + useful for CI smoke installs. The ARC adapter will skip RitS + at runtime until a real ckpt is placed at the expected path. + -h this help + +By default the script clones (or updates) RitS as a sibling of the ARC repo, +creates the '${RITS_ENV_NAME}' conda env with python=3.10, autodetects the +host CUDA version, installs torch=${TORCH_VERSION} + matching PyTorch Geometric +companion wheels (torch-scatter / torch-sparse / torch-cluster / +torch-spline-conv / pyg-lib) from PyG's wheel index, runs 'pip install -e .' +so that 'import megalodon' works inside the env, and downloads + verifies the +pretrained 'rits.ckpt' from Zenodo +(${RITS_CKPT_URL%%\?*}). + +No training is required — RitS ships pretrained weights. +EOF + exit 0 + ;; + --) shift; break ;; + *) echo "Invalid flag: $1" >&2; exit 1 ;; + esac +done + +# ── pick a CUDA variant for the PyG wheels ─────────────────────────────── +# PyG publishes wheels for torch ${TORCH_VERSION} against these variants only: +SUPPORTED_VARIANTS=(cpu cu118 cu121 cu124 cu126) + +map_cuda_to_variant() { # X.Y → cu118|cu121|cu124|cu126|cpu + local ver="$1" + local major minor + major=${ver%%.*} + minor=${ver#*.} + minor=${minor%%.*} + if [[ -z "$major" || -z "$minor" ]]; then echo cpu; return; fi + if (( major > 12 )) || { (( major == 12 )) && (( minor >= 6 )); }; then echo cu126 + elif (( major == 12 )) && (( minor >= 4 )); then echo cu124 + elif (( major == 12 )) && (( minor >= 1 )); then echo cu121 + elif { (( major == 12 )) && (( minor == 0 )); } || \ + { (( major == 11 )) && (( minor >= 8 )); }; then echo cu118 + else echo cpu + fi +} + +if [[ -n "$CUDA_VARIANT" ]]; then + if $FORCE_CPU && [[ "$CUDA_VARIANT" != cpu ]]; then + echo "❌ --cpu and --cuda $CUDA_VARIANT are contradictory" >&2 + exit 1 + fi + # validate against the supported set + if ! printf '%s\n' "${SUPPORTED_VARIANTS[@]}" | grep -qx "$CUDA_VARIANT"; then + echo "❌ Unsupported --cuda variant: $CUDA_VARIANT" >&2 + echo " Supported: ${SUPPORTED_VARIANTS[*]}" >&2 + exit 1 + fi +elif $FORCE_CPU; then + CUDA_VARIANT="cpu" +elif command -v nvcc &>/dev/null; then + # POSIX-friendly: works with both GNU and BSD grep (grep -oP / \K is GNU-only). + VER=$(nvcc --version | sed -nE 's/.*release ([0-9]+\.[0-9]+).*/\1/p' | head -n1) + CUDA_VARIANT=$(map_cuda_to_variant "$VER") + echo "🔍 nvcc reports CUDA $VER → using PyG variant '$CUDA_VARIANT'" +elif command -v nvidia-smi &>/dev/null; then + # The 'CUDA Version' field in nvidia-smi is the *driver's max supported* CUDA, which is the + # right ceiling for binary wheel compatibility (not driver_version, which is a different number). + VER=$(nvidia-smi 2>/dev/null | sed -nE 's/.*CUDA Version: ([0-9]+\.[0-9]+).*/\1/p' | head -n1 || true) + if [[ -n "$VER" ]]; then + CUDA_VARIANT=$(map_cuda_to_variant "$VER") + echo "🔍 nvidia-smi reports max CUDA $VER → using PyG variant '$CUDA_VARIANT'" + else + CUDA_VARIANT="cpu" + echo "🔍 Could not parse CUDA version from nvidia-smi → falling back to CPU" + fi +else + CUDA_VARIANT="cpu" + echo "🔍 No nvcc / nvidia-smi found → falling back to CPU" +fi +echo "→ PyG wheel variant: $CUDA_VARIANT" + +# ── locate ARC repo and the sibling clone root ──────────────────────────── +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +if ARC_ROOT=$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel 2>/dev/null); then + : +else + ARC_ROOT=$(cd "$SCRIPT_DIR/.." && pwd) +fi +CLONE_ROOT="$(dirname "$ARC_ROOT")" +echo "📂 ARC root : $ARC_ROOT" +echo "📂 Clone root: $CLONE_ROOT" + +# ── pick a conda frontend ───────────────────────────────────────────────── +if command -v micromamba &>/dev/null; then + COMMAND_PKG=micromamba +elif command -v mamba &>/dev/null; then + COMMAND_PKG=mamba +elif command -v conda &>/dev/null; then + COMMAND_PKG=conda +else + echo "❌ No micromamba/mamba/conda found in PATH" >&2 + exit 1 +fi +echo "✔️ Using $COMMAND_PKG" + +# Initialize shell integration so 'activate' works in this script +if [[ $COMMAND_PKG == micromamba ]]; then + eval "$(micromamba shell hook --shell=bash)" +else + BASE=$(conda info --base) + source "$BASE/etc/profile.d/conda.sh" +fi + +# ── clone or update RitS ────────────────────────────────────────────────── +if [[ -n "$RITS_PATH" ]]; then + if [[ ! -d "$RITS_PATH" ]]; then + echo "❌ --path was given but directory does not exist: $RITS_PATH" >&2 + exit 1 + fi + RITS_DIR="$(cd "$RITS_PATH" && pwd)" + echo "📂 Using existing RitS checkout at: $RITS_DIR" +else + RITS_DIR="$CLONE_ROOT/RitS" + if [[ -d "$RITS_DIR/.git" ]]; then + echo "🔄 Updating existing RitS clone at $RITS_DIR" + git -C "$RITS_DIR" fetch origin + git -C "$RITS_DIR" pull --ff-only || echo "⚠️ Could not fast-forward; leaving working tree as-is." + else + echo "⬇️ Cloning RitS into $RITS_DIR" + git clone "$RITS_REPO_URL" "$RITS_DIR" + fi +fi + +# ── create / update the rits_env conda environment ─────────────────────── +if $COMMAND_PKG env list | awk '{print $1}' | grep -qx "$RITS_ENV_NAME"; then + echo "♻️ '$RITS_ENV_NAME' already exists — updating in place." +else + echo "🆕 Creating '$RITS_ENV_NAME' (python=3.10)" + $COMMAND_PKG create -n "$RITS_ENV_NAME" -c conda-forge python=3.10 -y +fi + +set +u; $COMMAND_PKG activate "$RITS_ENV_NAME"; set -u + +# RDKit & OpenBabel are far smoother via conda-forge than pip +echo "📦 Installing rdkit + openbabel from conda-forge" +$COMMAND_PKG install -n "$RITS_ENV_NAME" -c conda-forge -y \ + "rdkit=2025.3.2" openbabel + +# Install PyTorch + PyTorch Geometric companion wheels for the chosen variant. +# We deliberately do NOT use RitS's requirements.txt because it pins +pt27cu126 +# specifically — we install the variant-matched companion wheels instead so the +# install works on CPU runners and on GPUs with CUDA != 12.6. +python -m pip install --upgrade pip + +if [[ "$CUDA_VARIANT" == "cpu" ]]; then + TORCH_INDEX="https://download.pytorch.org/whl/cpu" +else + TORCH_INDEX="https://download.pytorch.org/whl/${CUDA_VARIANT}" +fi +PYG_WHEELS="https://data.pyg.org/whl/torch-${TORCH_VERSION}+${CUDA_VARIANT}.html" + +echo "🚀 Installing torch==${TORCH_VERSION} (${CUDA_VARIANT}) from $TORCH_INDEX" +python -m pip install "torch==${TORCH_VERSION}" --index-url "$TORCH_INDEX" + +echo "🧮 Installing PyG companion wheels from $PYG_WHEELS" +# --only-binary :all: forces wheels, never source builds (those would need a CUDA toolkit) +python -m pip install \ + pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv \ + --only-binary :all: -f "$PYG_WHEELS" + +echo "📦 Installing pure-Python megalodon dependencies from PyPI" +python -m pip install \ + "torch_geometric==2.6.1" \ + "hydra-core==1.3.2" \ + "lightning==2.5.1.post0" \ + "einops==0.8.1" \ + "wandb==0.19.11" \ + "pandas==2.2.3" \ + "tqdm==4.67.1" + +# Editable install of the megalodon package (this is what puts 'import megalodon' on path) +echo "🧷 pip install -e . (megalodon, src layout)" +python -m pip install -e "$RITS_DIR" + +# Sanity check — import megalodon AND the PyG companions, since a successful pip +# install does not guarantee the .so files actually load against the host's CUDA. +echo "🔍 Verifying inference stack inside $RITS_ENV_NAME" +python - <<'PYEOF' +import importlib, sys +mods = ["torch", "torch_geometric", "torch_scatter", "torch_sparse", + "torch_cluster", "torch_spline_conv", "megalodon"] +for m in mods: + try: + mod = importlib.import_module(m) + ver = getattr(mod, "__version__", "?") + print(f" ✔️ {m:<22} {ver}") + except Exception as e: + print(f" ❌ {m:<22} FAILED: {e}", file=sys.stderr) + sys.exit(1) +import torch +print(f" ℹ️ torch.cuda.is_available() = {torch.cuda.is_available()}") +PYEOF + +set +u; $COMMAND_PKG deactivate; set -u + +# ── download + verify pretrained checkpoint ────────────────────────────── +CKPT_DIR="$RITS_DIR/data" +CKPT_PATH="$CKPT_DIR/rits.ckpt" + +verify_sha256() { # path expected_sha256 + local path="$1" expected="$2" + local actual + if command -v sha256sum &>/dev/null; then + actual=$(sha256sum "$path" | awk '{print $1}') + elif command -v shasum &>/dev/null; then + actual=$(shasum -a 256 "$path" | awk '{print $1}') + else + echo "❌ Neither sha256sum nor shasum found in PATH; cannot verify checkpoint." >&2 + return 2 + fi + if [[ "$actual" != "$expected" ]]; then + echo "❌ Checksum mismatch for $path" >&2 + echo " expected: $expected" >&2 + echo " actual : $actual" >&2 + return 1 + fi + return 0 +} + +if $SKIP_CKPT_CHECK; then + echo "ℹ️ --no-ckpt-check: skipping Zenodo checkpoint download AND validation." +elif $SKIP_CKPT; then + echo "ℹ️ --no-ckpt set, skipping checkpoint download." +elif [[ -f "$CKPT_PATH" ]]; then + echo "📦 Existing checkpoint found at $CKPT_PATH — verifying SHA-256..." + if verify_sha256 "$CKPT_PATH" "$RITS_CKPT_SHA256"; then + echo "✔️ Checkpoint SHA-256 OK ($RITS_CKPT_SHA256)" + else + echo "❌ Existing checkpoint does not match the expected SHA-256." >&2 + echo " Refusing to overwrite — move it aside or delete it and re-run." >&2 + exit 1 + fi +else + mkdir -p "$CKPT_DIR" + if ! command -v curl &>/dev/null; then + echo "❌ curl is required to download the RitS checkpoint." >&2 + exit 1 + fi + echo "⬇️ Downloading rits.ckpt (~364 MB) from Zenodo:" + echo " $RITS_CKPT_URL" + TMP_CKPT="$(mktemp "${CKPT_DIR}/rits.ckpt.XXXXXX")" + if ! curl -fL --retry 3 --retry-delay 5 -o "$TMP_CKPT" "$RITS_CKPT_URL"; then + rm -f "$TMP_CKPT" + echo "❌ Download failed. Re-run the install, or pass --no-ckpt to skip." >&2 + exit 1 + fi + if verify_sha256 "$TMP_CKPT" "$RITS_CKPT_SHA256"; then + mv "$TMP_CKPT" "$CKPT_PATH" + echo "✔️ Checkpoint verified and saved to $CKPT_PATH" + else + rm -f "$TMP_CKPT" + echo "❌ Downloaded checkpoint failed SHA-256 verification — aborting." >&2 + exit 1 + fi +fi + +# ── final notes ─────────────────────────────────────────────────────────── +echo "" +echo "✅ RitS installation complete." +echo " Repo : $RITS_DIR" +echo " Env : $RITS_ENV_NAME" +echo " Ckpt : $([[ -f $CKPT_PATH ]] && echo $CKPT_PATH || echo '(not installed)')" +echo "" +echo " Mirror DOI : https://doi.org/10.5281/zenodo.19474153" +echo " Source : https://github.com/isayevlab/RitS" diff --git a/docs/source/TS_search.rst b/docs/source/TS_search.rst index 87199e4b8c..cc149a2334 100644 --- a/docs/source/TS_search.rst +++ b/docs/source/TS_search.rst @@ -57,4 +57,169 @@ A detailed description of the methodology, design choices, and validation benchm L. Fahoum, A. Grinberg Dana, *“Automated Reaction Transition State Search for Neutral Hydrolysis Reactions”*, Digital Discovery, 2026. +GoFlow (flow-matching ML TS generator) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +ARC supports automated TS generation via **GoFlow**, a flow-matching, E(3)-equivariant +neural network that predicts transition-state Cartesian geometries from atom-mapped +reactant + product 2D graphs (SMILES + RDKit features). The model was trained on the +`RDB7 `_ database of single-step organic reactions. + +Supported domain +"""""""""""""""" +GoFlow is **enabled by default** in ARC's ``ts_adapters`` list. The adapter is safe to +ship as a default because it skips cleanly at runtime if ``goflow_env`` or the +checkpoint is not installed, and because it enforces a runtime domain guard: + +- Elements: H, C, N, O, F +- Reaction size: up to 100 atoms + +Reactions outside this domain (or hosts without the GoFlow stack installed) are +skipped with a warning instead of being attempted with out-of-distribution inputs. + +How it is used +"""""""""""""" +GoFlow runs out of the box once its environment + checkpoint are in place: + +1. Install the dedicated conda env and download the pretrained checkpoint: + + .. code-block:: bash + + make install-goflow + + This creates ``goflow_env`` (PyTorch 2.6 + PyTorch Geometric + GoFlow), clones + ``goflow_lean``, downloads the published checkpoint from Zenodo + (`10.5281/zenodo.20073635 `_), and verifies + its SHA-256. + +2. Once installed, GoFlow is invoked automatically for every reaction within its + supported domain. To disable it for a given run, override ``ts_adapters`` in + the input file without ``goflow``: + + .. code-block:: yaml + + ts_adapters: + - heuristics + - AutoTST + +3. (Optional) Override the default checkpoint or feature-dictionary location via env vars: + + .. code-block:: bash + + export ARC_GOFLOW_CKPT=/path/to/your/epoch_.ckpt + export ARC_GOFLOW_FEAT_DICT=/path/to/your/feat_dict_organic.pkl + + These take precedence over both the in-repo paths and the Zenodo download. + +What ARC does +""""""""""""" +For each reaction with GoFlow selected, ARC: + +1. Validates the reaction is within GoFlow's supported domain (elements + atom count); skips with a warning otherwise. +2. Builds atom-mapped reactant and product SMILES (every hydrogen explicit; map numbers consistent across sides via ``rxn.atom_map``). +3. Spawns the GoFlow inference subprocess (in ``goflow_env``), which performs flow-matching ODE sampling and returns multiple candidate TS geometries. +4. Filters out colliding-atom geometries and consolidates near-duplicate samples that share a heavy-atom skeleton (torsion-invariant deduplication; controlled by ``GOFLOW_DEDUP_DMAT_RMSD = 0.15`` Å). +5. Appends the surviving guesses to the reaction's TS species for downstream optimization, frequency, and IRC validation by ARC's standard pipeline. + +GoFlow is best used **alongside** other adapters (e.g. ``heuristics``) — its samples +provide additional starting points but do not replace the optimization/validation steps. + +Outputs and validation +"""""""""""""""""""""" +Each surviving TS guess is written as a numbered ``GoFlow N.xyz`` file under the +TS-guess directory of the reaction, alongside the staged ``input.yml`` and the raw +multi-frame ``output.yml`` returned by the subprocess. Optimized + validated TSs follow +the same reporting flow as any other ARC TS guess. + +Reference +""""""""" +The GoFlow model is described in: +L. Galustian, K. Mark, J. Karwounopoulos, M. P.-P. Kovar, E. Heid, +*"GoFlow: efficient transition state geometry prediction with flow matching and +E(3)-equivariant neural networks"*, Digital Discovery 2025, DOI +`10.1039/D5DD00283D `_. + +The upstream implementation lives at `heid-lab/goflow_lean +`_. + +RitS (flow-matching ML TS generator) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +ARC supports automated TS generation via **RitS** (*Right into the Saddle*), +a flow-matching neural network from the Isayev lab that predicts transition- +state Cartesian geometries directly from atom-mapped reactant + product 3D +structures. Unlike GCN — which is restricted to single-bond isomerizations — +RitS handles bimolecular reactions and charged species and is therefore +applied to **all** reaction families (it is the only entry currently in +``all_families_ts_adapters``). + +How it is used +"""""""""""""" +RitS is **opt-in only** — its inference stack is heavyweight, so it is +intentionally absent from the default ``ts_adapters`` list. To enable it, +install its environment once and request it in the ARC input file: + +1. Install the dedicated conda env and download the pretrained checkpoint: + + .. code-block:: bash + + make install-rits + + This creates ``rits_env`` (PyTorch 2.7 + PyTorch Geometric + RitS / megalodon), + clones ``RitS``, downloads the published checkpoint from Zenodo + (`10.5281/zenodo.19474153 `_), and + verifies its SHA-256 (~364 MB). + +2. Opt in to the adapter for a given run by adding it to ``ts_adapters``: + + .. code-block:: yaml + + ts_adapters: + - heuristics + - rits + +3. (Optional) Override the default repository or checkpoint location via env vars: + + .. code-block:: bash + + export ARC_RITS_REPO=/path/to/your/RitS + export ARC_RITS_CKPT=/path/to/your/rits.ckpt + + These take precedence over both the default ``~/Code/RitS`` / + sibling-of-ARC discovery and the install-time Zenodo download. + +What ARC does +""""""""""""" +For each reaction with RitS selected, ARC: + +1. Builds atom-mapped reactant and product XYZ files using the reaction's + ``rxn.atom_map``-aligned coordinates. +2. Spawns the RitS inference subprocess (in ``rits_env``), which performs + flow-matching ODE sampling and returns ``n_samples`` candidate TS + geometries in a single multi-frame XYZ. +3. Deduplicates near-duplicate samples that share a heavy-atom skeleton via + ``compare_confs`` (translation- and rotation-invariant distance-matrix + comparison). +4. Appends the surviving guesses to the reaction's TS species for downstream + optimization, frequency, and IRC validation by ARC's standard pipeline. + +If ``rits_env`` or the checkpoint is missing, the adapter logs a warning +and skips cleanly — the rest of the TS pipeline continues unaffected. + +Outputs and validation +"""""""""""""""""""""" +Each surviving TS guess is written as a numbered ``RitS N.xyz`` file under +the TS-guess directory of the reaction, alongside the staged ``input.yml`` +and the raw multi-frame ``output.yml`` returned by the subprocess. Optimized ++ validated TSs follow the same reporting flow as any other ARC TS guess. + +Reference +""""""""" +The RitS model is described in: +*"Right into the Saddle"* (Isayev lab), +DOI `10.26434/chemrxiv.15001681/v1 `_. + +The upstream implementation lives at `isayevlab/RitS +`_. + .. include:: links.txt