Skip to content

Muhtasham/fp8-auto

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FP8 Auto-Converter

A pure PyTorch implementation for FP8 quantization with stochastic rounding. No CUDA extensions or custom kernels required.

Quick Start

import torch
from main import to_fp8, from_fp8

# Convert any tensor to FP8
tensor = torch.randn(1000, dtype=torch.float32)
fp8_data, scale = to_fp8(tensor, fmt="e4m3fn")

# Convert back
recovered = from_fp8(fp8_data, fmt="e4m3fn", scaling_factor=scale)

Installation

pip install torch rich

Which Format Should I Use?

Format Best For Range Precision
E4M3FN Weights & activations ±448 Higher (3-bit mantissa)
E5M2 Gradients ±57,344 Lower (2-bit mantissa)

Rule of thumb: Use E4M3FN for forward pass, E5M2 for backward pass.

API Reference

High-Level API (Recommended)

from main import to_fp8, from_fp8, compute_scaling_factor

# Auto-scaling (recommended for best precision)
fp8_tensor, scale = to_fp8(tensor, fmt="e4m3fn")

# Manual scaling
scale = compute_scaling_factor(tensor, fmt="e4m3fn")
fp8_tensor, _ = to_fp8(tensor, fmt="e4m3fn", scaling_factor=scale)

# Deterministic rounding (faster, but biased)
fp8_tensor, scale = to_fp8(tensor, fmt="e4m3fn", stochastic=False)

# Convert back
recovered = from_fp8(fp8_tensor, fmt="e4m3fn", scaling_factor=scale)

Supported Formats

# Standard OCP formats
to_fp8(tensor, fmt="e4m3fn")   # 4-bit exp, 3-bit mantissa, no inf
to_fp8(tensor, fmt="e5m2")     # 5-bit exp, 2-bit mantissa, has inf

# FNUZ variants (no negative zero, larger bias)
to_fp8(tensor, fmt="e4m3fnuz") # bias=8 instead of 7
to_fp8(tensor, fmt="e5m2fnuz") # bias=16 instead of 15

Low-Level API

For fine-grained control:

from main import round_to_fp8_represented_as_int8, undo_int8_fp8

# Convert to FP8 (stored as uint8)
fp8 = round_to_fp8_represented_as_int8(tensor, 3, scaling_factor=1.0)

# Convert back
recovered = undo_int8_fp8(fp8, 3, torch.float32, scaling_factor=1.0)

Stochastic vs Deterministic Rounding

Stochastic rounding (default) randomly rounds up or down based on proximity to representable values. This preserves the mean across many samples, making it ideal for training.

# Stochastic (default) - unbiased, better for training
fp8, scale = to_fp8(tensor, stochastic=True)

# Deterministic - faster, better for inference
fp8, scale = to_fp8(tensor, stochastic=False)

Performance

Benchmarked on 100,000 elements (CPU):

Operation Time Throughput
Float → FP8 ~1.5ms ~70M elem/s
FP8 → Float ~0.8ms ~130M elem/s

Compared to PyTorch native FP8 (which uses C++/CUDA):

Metric Our Implementation PyTorch Native
E4M3FN MSE 0.000763 0.000763
E5M2 MSE 0.002706 0.002706

Our implementation produces identical results to PyTorch's native FP8 types.

Running Tests

uv run python main.py

The test suite validates:

  • Sign preservation
  • Special values (zero, inf, NaN) per OCP spec
  • Round-trip accuracy for all formats
  • Stochastic rounding unbiasedness
  • Comparison with PyTorch native FP8

Format Specifications

Property E4M3FN E5M2
Exponent bits 4 5
Mantissa bits 3 2
Bias 7 15
Max value ±448 ±57,344
Has infinity No Yes
NaN encoding 0x7F, 0xFF exp=31, mantissa≠0

Architecture

Conversion Flow Diagram
flowchart LR
    subgraph Input
        BF16[bfloat16 Tensor]
    end

    subgraph "Float to FP8"
        A[Scale & Extract Sign] --> B[Calculate Exponent & Mantissa]
        B --> C[Stochastic Rounding]
        C --> D[Handle Overflow/Underflow]
        D --> E[Combine Components]
    end

    subgraph Storage
        FP8[uint8 Tensor]
    end

    subgraph "FP8 to Float"
        F[Extract Components] --> G[Reconstruct Value]
        G --> H[Apply Sign & Scale]
    end

    subgraph Output
        BF16_OUT[bfloat16 Tensor]
    end

    BF16 --> A
    E --> FP8
    FP8 --> F
    H --> BF16_OUT
Loading
Special Value Handling
flowchart TD
    INPUT[Input Value] --> ZERO{abs < 1e-7?}
    ZERO -->|Yes| RET_ZERO[Return 0x00]
    ZERO -->|No| NAN{isNaN?}
    NAN -->|Yes| RET_NAN[Return 0x7F]
    NAN -->|No| INF{isInf?}

    INF -->|Yes| FORMAT{Format?}
    FORMAT -->|E4M3FN| CLAMP[Clamp to max<br/>+: 0x7E, -: 0xFE]
    FORMAT -->|E5M2| INF_ENC[Encode infinity<br/>+: 0x7C, -: 0xFC]

    INF -->|No| NORMAL[Normal conversion]
Loading

References

License

MIT

About

FP8 stochastic rounding

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages