Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
df440c5
fix: ty now doesn't complain but a bunch of tests fail.
CompRhys Mar 25, 2026
5e7af8e
remove privileged role of spin and charge
CompRhys Mar 25, 2026
22476c5
fix: down to 28 test failures
CompRhys Mar 26, 2026
465c74f
Fixes to Extensible Extras PR (#526)
falletta Mar 26, 2026
ce8e104
Merge remote-tracking branch 'origin/main' into extensible-extras
CompRhys Apr 2, 2026
1445ed0
lint: avoid SLF001 for extras in elastic code.
CompRhys Apr 2, 2026
2e94247
fea: add state modifier hook to interface test
CompRhys Apr 4, 2026
006a7ce
fix: units should be relative
CompRhys Apr 4, 2026
ebd4d22
fea: add some more blessed extras keys to enums
CompRhys Apr 4, 2026
e86ca52
fea: configure ase atoms to ts io better.
CompRhys Apr 4, 2026
0567eae
Merge remote-tracking branch 'origin/main' into extensible-extras
CompRhys Apr 5, 2026
db6f0f4
fix: all rather than any on retain graph
CompRhys Apr 5, 2026
1c251df
Merge remote-tracking branch 'origin/main' into extensible-extras
CompRhys Apr 5, 2026
1ba3ea4
fix units issue
CompRhys Apr 5, 2026
c2892cf
patch orb forward to handle extras. Bump version to 0.6.0 for a new r…
CompRhys Apr 6, 2026
174195e
remove metatomic from docs CI
CompRhys Apr 6, 2026
d6d0b4d
Merge branch 'main' into extensible-extras
CompRhys Apr 8, 2026
560b27f
fix
CompRhys Apr 8, 2026
88079a2
fix: address extrasmap hint comment
CompRhys Apr 9, 2026
26c49fd
Merge remote-tracking branch 'origin/main' into extensible-extras
CompRhys Apr 9, 2026
cbdb583
more tests
CompRhys Apr 9, 2026
b7694dd
fix fairchem test
CompRhys Apr 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
uv pip install ".[test,docs]" --system

- name: Install extras for tutorial generation
run: uv pip install ".[mace,metatomic]" --system
run: uv pip install ".[mace]" --system

- name: Copy tutorials
run: |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ name = "torch-sim-atomistic"
version = "0.5.2"
description = "A pytorch toolkit for calculating material properties using MLIPs"
authors = [
{ name = "TorchSim Maintainers", email = "torchsimatomistic@gmail.com" },
{ name = "Abhijeet Gangan", email = "abhijeetgangan@g.ucla.edu" },
{ name = "Janosh Riebesell", email = "janosh.riebesell@gmail.com" },
{ name = "Orion Cohen", email = "orioncohen@berkeley.edu" },
Expand Down
24 changes: 23 additions & 1 deletion tests/models/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Pytest fixtures and test factories for model testing."""

from __future__ import annotations

import typing

import pytest
Expand All @@ -10,7 +12,10 @@


if typing.TYPE_CHECKING:
from collections.abc import Callable, Sequence

from torch_sim.models.interface import ModelInterface
from torch_sim.state import SimState


def make_model_calculator_consistency_test(
Expand Down Expand Up @@ -81,22 +86,39 @@ def make_validate_model_outputs_test(
dtype: torch.dtype = DTYPE,
*,
check_detached: bool = True,
state_modifiers: Sequence[Callable[[SimState], SimState]] = (),
):
"""Factory function to create model output validation tests.

Runs ``validate_model_outputs`` once with no modifier (baseline), then
once more for each entry in *state_modifiers* so that every modifier
gets a full, independent validation pass.

Args:
model_fixture_name: Name of the model fixture to validate
device: Device to run validation on
dtype: Data type to use for validation
check_detached: Whether to assert output tensors are detached from the
autograd graph (skipped for models with ``retain_graph=True``).
state_modifiers: Each callable receives a ``SimState`` and returns a
(possibly new) ``SimState``. The full validation suite is run
once per modifier so that different input edge-cases are
exercised independently.
"""
from torch_sim.models.interface import validate_model_outputs

def test_model_output_validation(request: pytest.FixtureRequest) -> None:
"""Test that a model implementation follows the ModelInterface contract."""
model: ModelInterface = request.getfixturevalue(model_fixture_name)
validate_model_outputs(model, device, dtype, check_detached=check_detached)
modifiers = state_modifiers or [None]
for modifier in modifiers:
validate_model_outputs(
model,
device,
dtype,
check_detached=check_detached,
state_modifier=modifier,
)

test_model_output_validation.__name__ = f"test_{model_fixture_name}_output_validation"
return test_model_output_validation
16 changes: 8 additions & 8 deletions tests/models/test_fairchem.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import traceback
from collections.abc import Callable

import pytest
import torch
from ase.build import bulk, fcc100, molecule

import torch_sim as ts
from tests.conftest import DEVICE, DTYPE
from tests.models.conftest import make_validate_model_outputs_test


try:
from collections.abc import Callable

from ase.build import bulk, fcc100, molecule
from fairchem.core.calculate.pretrained_mlip import (
pretrained_checkpoint_path_from_name,
)
from huggingface_hub.utils._auth import get_token

import torch_sim as ts
from torch_sim.models.fairchem import FairChemModel

except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
Expand Down Expand Up @@ -263,10 +261,12 @@ def test_fairchem_charge_spin(charge: float, spin: float) -> None:
mol.info["charge"] = charge
mol.info["spin"] = spin

# Convert to SimState (should extract charge/spin)
state = ts.io.atoms_to_state([mol], device=DEVICE, dtype=DTYPE)

# Verify charge/spin were extracted correctly
state = ts.io.atoms_to_state(
[mol],
device=DEVICE,
dtype=DTYPE,
system_extras_map={"charge": "charge", "spin": "spin"},
)
assert state.charge is not None
assert state.spin is not None
assert state.charge[0].item() == charge
Expand Down
16 changes: 8 additions & 8 deletions tests/test_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_get_elementary_deformations_strain_consistency(
n_deform=n_deform,
max_strain_normal=max_strain_normal,
max_strain_shear=max_strain_shear,
bravais_type=BravaisType.triclinic, # Test all axes
bravais_type=BravaisType.TRICLINIC, # Test all axes
)

# Should generate deformations for all 6 axes (triclinic)
Expand Down Expand Up @@ -271,12 +271,12 @@ def mace_model() -> MaceModel:
@pytest.mark.parametrize(
("sim_state_name", "expected_bravais_type", "atol"),
[
("cu_sim_state", BravaisType.cubic, 2e-1),
("mg_sim_state", BravaisType.hexagonal, 5e-1),
("sb_sim_state", BravaisType.trigonal, 5e-1),
("tio2_sim_state", BravaisType.tetragonal, 5e-1),
("ga_sim_state", BravaisType.orthorhombic, 5e-1),
("niti_sim_state", BravaisType.monoclinic, 5e-1),
("cu_sim_state", BravaisType.CUBIC, 2e-1),
("mg_sim_state", BravaisType.HEXAGONAL, 5e-1),
("sb_sim_state", BravaisType.TRIGONAL, 5e-1),
("tio2_sim_state", BravaisType.TETRAGONAL, 5e-1),
("ga_sim_state", BravaisType.ORTHORHOMBIC, 5e-1),
("niti_sim_state", BravaisType.MONOCLINIC, 5e-1),
],
)
def test_elastic_tensor_symmetries(
Expand Down Expand Up @@ -340,7 +340,7 @@ def test_elastic_tensor_symmetries(
)
C_triclinic = (
calculate_elastic_tensor(
state=state, model=model, bravais_type=BravaisType.triclinic
state=state, model=model, bravais_type=BravaisType.TRICLINIC
)
* UnitConversion.eV_per_Ang3_to_GPa
)
Expand Down
Loading
Loading