Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,34 @@ Inference note:
- Some environments crash during `torch.compile`.
- Disable compile: `TORCH_COMPILE=0 bash code/scripts/local_run.sh`.
- Or try a safer mode: `TORCH_COMPILE=1 TORCH_COMPILE_MODE=reduce-overhead bash code/scripts/local_run.sh`.
- **Blackwell GPUs (RTX 5080/5090, GB200/GB300)**:
- Stable PyTorch wheels (`cu124`) do not ship SM 12.0 kernels.
Install the nightly build with the `cu128` index:
```bash
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
```
- **Windows (Git Bash / WSL)**:
- Triton is not supported on native Windows, which causes `torch.compile` to
fail. Disable it before running:
```bash
export TORCH_COMPILE_DISABLE=1 # PyTorch-level flag
# or, equivalently for the repo scripts:
export PREDECODER_TORCH_COMPILE=0
```
- When running scripts directly (outside the notebook or `local_run.sh`),
set the Python path so that repo modules are importable:
```bash
export PYTHONPATH="code"
```
- **Pre-trained model not found during inference**:
- `find_best_model` searches inside `{output}/models/best_model/` first,
then falls back to `{output}/models/`. If you placed the downloaded
`.pt` file elsewhere, either move it into one of those directories or
point to it directly:
```bash
PREDECODER_MODEL_CHECKPOINT_FILE=path/to/Ising-Decoder-SurfaceCode-1-Accurate.pt \
WORKFLOW=inference bash code/scripts/local_run.sh
```

## Inference (pre-trained models)

Expand Down
8 changes: 7 additions & 1 deletion code/workflows/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,13 @@ def find_best_model(path, *, rank: int = 0):
print(f" [{marker}] {filename} (epoch {epoch_str})")

if best_file is None:
raise FileNotFoundError(f"No valid model checkpoint files found in {path}")
raise FileNotFoundError(
f"No valid model checkpoint files found in {path}\n"
f"Expected .pt files (e.g. Ising-Decoder-SurfaceCode-1-Fast.pt or "
f"PreDecoderModelMemory_*.pt).\n"
f"Hint: download the pretrained weights and place them in this directory, "
f"or set model_checkpoint_file in your config to an explicit path."
)

best_model_path = os.path.join(path, best_file)
if rank == 0:
Expand Down
Loading