A pure PyTorch implementation for FP8 quantization with stochastic rounding. No CUDA extensions or custom kernels required.
import torch
from main import to_fp8, from_fp8
# Convert any tensor to FP8
tensor = torch.randn(1000, dtype=torch.float32)
fp8_data, scale = to_fp8(tensor, fmt="e4m3fn")
# Convert back
recovered = from_fp8(fp8_data, fmt="e4m3fn", scaling_factor=scale)pip install torch rich| Format | Best For | Range | Precision |
|---|---|---|---|
| E4M3FN | Weights & activations | ±448 | Higher (3-bit mantissa) |
| E5M2 | Gradients | ±57,344 | Lower (2-bit mantissa) |
Rule of thumb: Use E4M3FN for forward pass, E5M2 for backward pass.
from main import to_fp8, from_fp8, compute_scaling_factor
# Auto-scaling (recommended for best precision)
fp8_tensor, scale = to_fp8(tensor, fmt="e4m3fn")
# Manual scaling
scale = compute_scaling_factor(tensor, fmt="e4m3fn")
fp8_tensor, _ = to_fp8(tensor, fmt="e4m3fn", scaling_factor=scale)
# Deterministic rounding (faster, but biased)
fp8_tensor, scale = to_fp8(tensor, fmt="e4m3fn", stochastic=False)
# Convert back
recovered = from_fp8(fp8_tensor, fmt="e4m3fn", scaling_factor=scale)# Standard OCP formats
to_fp8(tensor, fmt="e4m3fn") # 4-bit exp, 3-bit mantissa, no inf
to_fp8(tensor, fmt="e5m2") # 5-bit exp, 2-bit mantissa, has inf
# FNUZ variants (no negative zero, larger bias)
to_fp8(tensor, fmt="e4m3fnuz") # bias=8 instead of 7
to_fp8(tensor, fmt="e5m2fnuz") # bias=16 instead of 15For fine-grained control:
from main import round_to_fp8_represented_as_int8, undo_int8_fp8
# Convert to FP8 (stored as uint8)
fp8 = round_to_fp8_represented_as_int8(tensor, 3, scaling_factor=1.0)
# Convert back
recovered = undo_int8_fp8(fp8, 3, torch.float32, scaling_factor=1.0)Stochastic rounding (default) randomly rounds up or down based on proximity to representable values. This preserves the mean across many samples, making it ideal for training.
# Stochastic (default) - unbiased, better for training
fp8, scale = to_fp8(tensor, stochastic=True)
# Deterministic - faster, better for inference
fp8, scale = to_fp8(tensor, stochastic=False)Benchmarked on 100,000 elements (CPU):
| Operation | Time | Throughput |
|---|---|---|
| Float → FP8 | ~1.5ms | ~70M elem/s |
| FP8 → Float | ~0.8ms | ~130M elem/s |
Compared to PyTorch native FP8 (which uses C++/CUDA):
| Metric | Our Implementation | PyTorch Native |
|---|---|---|
| E4M3FN MSE | 0.000763 | 0.000763 |
| E5M2 MSE | 0.002706 | 0.002706 |
Our implementation produces identical results to PyTorch's native FP8 types.
uv run python main.pyThe test suite validates:
- Sign preservation
- Special values (zero, inf, NaN) per OCP spec
- Round-trip accuracy for all formats
- Stochastic rounding unbiasedness
- Comparison with PyTorch native FP8
| Property | E4M3FN | E5M2 |
|---|---|---|
| Exponent bits | 4 | 5 |
| Mantissa bits | 3 | 2 |
| Bias | 7 | 15 |
| Max value | ±448 | ±57,344 |
| Has infinity | No | Yes |
| NaN encoding | 0x7F, 0xFF | exp=31, mantissa≠0 |
Conversion Flow Diagram
flowchart LR
subgraph Input
BF16[bfloat16 Tensor]
end
subgraph "Float to FP8"
A[Scale & Extract Sign] --> B[Calculate Exponent & Mantissa]
B --> C[Stochastic Rounding]
C --> D[Handle Overflow/Underflow]
D --> E[Combine Components]
end
subgraph Storage
FP8[uint8 Tensor]
end
subgraph "FP8 to Float"
F[Extract Components] --> G[Reconstruct Value]
G --> H[Apply Sign & Scale]
end
subgraph Output
BF16_OUT[bfloat16 Tensor]
end
BF16 --> A
E --> FP8
FP8 --> F
H --> BF16_OUT
Special Value Handling
flowchart TD
INPUT[Input Value] --> ZERO{abs < 1e-7?}
ZERO -->|Yes| RET_ZERO[Return 0x00]
ZERO -->|No| NAN{isNaN?}
NAN -->|Yes| RET_NAN[Return 0x7F]
NAN -->|No| INF{isInf?}
INF -->|Yes| FORMAT{Format?}
FORMAT -->|E4M3FN| CLAMP[Clamp to max<br/>+: 0x7E, -: 0xFE]
FORMAT -->|E5M2| INF_ENC[Encode infinity<br/>+: 0x7C, -: 0xFC]
INF -->|No| NORMAL[Normal conversion]
- OCP 8-bit Floating Point Specification
- FP8 Formats for Deep Learning (arXiv)
- Stochastic Rounding - Higham
MIT