Geometry-R1 is a two-stage training pipeline for teaching vision-language models to solve geometry problems. Using Qwen2.5-VL-3B-Instruct as the base model and Hugging Face TRL framework, it combines Supervised Fine-Tuning (SFT) with Group Relative Policy Optimization (GRPO) for mathematical reasoning.
The repository now includes:
- End-to-end SFT training entrypoint:
python -m src.train_sft - End-to-end GRPO training entrypoint:
python -m src.train_grpo - Offline evaluation entrypoint:
python -m src.eval - Unit tests for reward logic:
python -m unittest tests/test_reward.py
The first phase uses distillation from DashScope teacher models to create training data:
- Dataset: First 500 samples from
hiyouga/geometry3k - Base recommendation:
qwen3-vl-plus - Stronger but more expensive option:
qwen3-vl-235b-a22b-instruct - Method: Call DashScope multimodal API to generate pseudo chain-of-thought
- Format:
½...½ → <answer>...</answer>structured reasoning
The second phase uses rule-based rewards for self-improvement:
- Dataset: Remaining samples from
hiyouga/geometry3k - Reward: SymPy-based mathematical equivalence verification
- Method: Compare model predictions with ground truth symbolically
Geometry-R1/
├── configs/
│ ├── sft_config.yaml # Phase 1 SFT hyperparameters
│ └── grpo_config.yaml # Phase 2 GRPO hyperparameters
├── data/
│ ├── images/
│ │ ├── sft/ # SFT images
│ │ └── rl/ # RL images
│ ├── sft_geometry3k.jsonl # SFT training data
│ └── rl_geometry3k.jsonl # RL training data
├── output/ # Model checkpoints (git-ignored)
├── scripts/
│ ├── prepare_geometry3k.py # Data preparation script
│ ├── setup_autodl.sh # AutoDL environment setup
│ ├── run_sft.sh # SFT training launcher
│ └── run_grpo.sh # GRPO training launcher
├── src/
│ ├── __init__.py
│ ├── config.py # Path configuration module
│ ├── eval.py # Offline evaluation script
│ ├── train_sft.py # SFT training entrypoint
│ ├── train_grpo.py # GRPO training entrypoint
│ ├── train_utils.py # Shared training/evaluation utilities
│ └── reward/
│ ├── __init__.py
│ └── reward.py # SymPy-based reward functions
├── tests/
│ └── test_reward.py # Unit tests for reward parsing/equivalence
├── pyproject.toml # Project configuration (uv)
└── README.md
# Clone the repository
git clone <your-repo-url>
cd Geometry-R1
# Install dependencies with uv
uv sync# Clone the repository
git clone <your-repo-url>
cd Geometry-R1
# Run setup script
chmod +x scripts/setup_autodl.sh
./scripts/setup_autodl.shAll paths are managed through src/config.py and support environment variable overrides:
from src.config import get_data_dir, get_output_dir, setup_directories
# Get configured paths
print(get_data_dir()) # -> data/ or $DATA_DIR
print(get_output_dir()) # -> output/ or $OUTPUT_DIR| Variable | Default | Description |
|---|---|---|
PROJECT_ROOT |
(auto-detect) | Project root directory |
DATA_DIR |
data/ |
Data directory for datasets and images |
OUTPUT_DIR |
output/ |
Output directory for checkpoints |
HF_ENDPOINT |
- | HuggingFace mirror URL (for China) |
HF_HOME |
.cache/ |
HuggingFace cache directory |
DASHSCOPE_API_KEY |
- | DashScope API key for distillation |
For AutoDL, data and models should be stored on the data disk:
# Set environment variables
export HF_ENDPOINT=https://hf-mirror.com
export DATA_DIR=/root/autodl-tmp/data
export OUTPUT_DIR=/root/autodl-tmp/output
export HF_HOME=/root/autodl-tmp/hf_cacheOr use the setup script:
./scripts/setup_autodl.sh # Automatically sets these variables# Option 1: Environment variable
export DASHSCOPE_API_KEY="your_api_key_here"
# Option 2: Command line argument
# --api-key "your_api_key_here"# Local (mock mode for testing)
uv run python scripts/prepare_geometry3k.py --sft-samples 500
# With real API distillation
uv run python scripts/prepare_geometry3k.py --sft-samples 500 --use-api
# Explicitly choose the teacher model
uv run python scripts/prepare_geometry3k.py \
--sft-samples 500 \
--use-api \
--teacher-model qwen3-vl-plus
# AutoDL (custom data directory)
export DATA_DIR=/root/autodl-tmp/data
python scripts/prepare_geometry3k.py --sft-samples 500 --use-api| Option | Default | Description |
|---|---|---|
--sft-samples |
500 | Number of samples for SFT split |
--data-dir |
data/ |
Data directory (overrides DATA_DIR env var) |
--use-api |
False | Use DashScope multimodal API for pseudo-CoT generation |
--api-key |
None | DashScope API key (or set env var) |
--teacher-model |
qwen3-vl-plus |
DashScope teacher model used for distillation |
# Run directly
python -m src.train_sft \
--config configs/sft_config.yaml \
--train_file data/sft_geometry3k.jsonl \
--output_dir output/sft_phase1
# Or use the helper script
./scripts/run_sft.sh# Run directly
python -m src.train_grpo \
--config configs/grpo_config.yaml \
--model_name_or_path output/sft_phase1 \
--train_file data/rl_geometry3k.jsonl \
--output_dir output/grpo_phase2
# Or use the helper script
./scripts/run_grpo.shpython -m src.eval \
--model_name_or_path output/grpo_phase2 \
--data_file data/rl_geometry3k.jsonl \
--output_file output/eval_grpo_phase2.json \
--max_samples 128The evaluation script reports:
format_rate: fraction of generations that contain a valid<answer>...</answer>blockexact_match: exact string match on extracted answerssymbolic_accuracy: symbolic equivalence accuracy using SymPy-backed verification
Phase 2 uses a SymPy-based reward function that:
- Extracts answer from
<answer>...</answer>tags - Normalizes LaTeX strings
- Converts common LaTeX constructs to SymPy-friendly expressions when the LaTeX parser is unavailable
- Compares mathematical equivalence:
- Direct string matching
- Symbolic simplification (
simplify(pred - gt) == 0) - Numerical evaluation fallback
from src.reward import accuracy_reward
rewards = accuracy_reward(
prompts=["Q1", "Q2"],
completions=[
"½ Reasoning... ½ <answer>\\frac{12}{5}</answer>",
"<answer>5</answer>"
],
answer=["\\frac{24}{10}", "6"] # Ground truth
)
# Returns: [1.0, 0.0] (12/5 == 24/10, but 5 != 6)| Predicted | Ground Truth | Equivalent |
|---|---|---|
12 |
12 |
✓ |
\frac{12}{5} |
\frac{12}{5} |
✓ |
\frac{24}{10} |
\frac{12}{5} |
✓ (symbolic) |
\sqrt{2} |
$\sqrt{2}$ |
✓ |
5 |
6 |
✗ |
# Format with Ruff
uv run ruff format .
# Check linting
uv run ruff check .python -m unittest tests/test_reward.pypython -c "from src.config import print_config; print_config()"- Core: torch, transformers, accelerate, peft, trl
- Vision: qwen-vl-utils, pillow
- Data: datasets
- Math: sympy, antlr4-python3-runtime (4.11.1)
- API: dashscope (for Qwen-VL-Max distillation)
- Utils: pyyaml, tqdm
MIT License
- Base model: Qwen2.5-VL-3B-Instruct
- Dataset: hiyouga/geometry3k
- Training framework: TRL
- Distillation API: DashScope Vision Models